diff --git a/packages/interface-mocks/src/connection.ts b/packages/interface-mocks/src/connection.ts index ad9ab3a9f..b5db61645 100644 --- a/packages/interface-mocks/src/connection.ts +++ b/packages/interface-mocks/src/connection.ts @@ -82,19 +82,13 @@ class MockConnection implements Connection { const stream = await this.muxer.newStream(id) const result = await mss.select(stream, protocols, options) - const streamWithProtocol: Stream = { - ...stream, - ...result.stream, - stat: { - ...stream.stat, - direction: 'outbound', - protocol: result.protocol - } - } + stream.sink = result.stream.sink + stream.source = result.stream.source + stream.stat.protocol = result.protocol - this.streams.push(streamWithProtocol) + this.streams.push(stream) - return streamWithProtocol + return stream } addStream (stream: Stream): void { @@ -136,7 +130,9 @@ export function mockConnection (maConn: MultiaddrConnection, opts: MockConnectio mss.handle(muxedStream, registrar.getProtocols()) .then(({ stream, protocol }) => { log('%s: incoming stream opened on %s', direction, protocol) - muxedStream = { ...muxedStream, ...stream } + + muxedStream.sink = stream.sink + muxedStream.source = stream.source muxedStream.stat.protocol = protocol connection.addStream(muxedStream) diff --git a/packages/interface-mocks/src/muxer.ts b/packages/interface-mocks/src/muxer.ts index 00f4370d7..d1299dc1c 100644 --- a/packages/interface-mocks/src/muxer.ts +++ b/packages/interface-mocks/src/muxer.ts @@ -1,7 +1,5 @@ -import { CodeError } from '@libp2p/interfaces/errors' import { type Logger, logger } from '@libp2p/logger' import { abortableSource } from 'abortable-iterator' -import { anySignal } from 'any-signal' import map from 'it-map' import * as ndjson from 'it-ndjson' import { pipe } from 'it-pipe' @@ -9,6 +7,7 @@ import { type Pushable, pushable } from 'it-pushable' import { Uint8ArrayList } from 'uint8arraylist' import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' import { toString as uint8ArrayToString } from 'uint8arrays/to-string' +import { AbstractStream } from '@libp2p/interface-stream-muxer/stream' import type { Stream } from '@libp2p/interface-connection' import type { StreamMuxer, StreamMuxerFactory, StreamMuxerInit } from '@libp2p/interface-stream-muxer' import type { Source } from 'it-stream-types' @@ -44,221 +43,91 @@ interface CreateMessage { type StreamMessage = DataMessage | ResetMessage | CloseMessage | CreateMessage -class MuxedStream { - public id: string - public input: Pushable - public stream: Stream - public type: 'initiator' | 'recipient' - - private sinkEnded: boolean - private sourceEnded: boolean - private readonly abortController: AbortController - private readonly resetController: AbortController - private readonly closeController: AbortController - private readonly log: Logger +class MuxedStream extends AbstractStream { + public readonly type: 'initiator' | 'recipient' + public readonly pushable: Pushable constructor (init: { id: string, type: 'initiator' | 'recipient', push: Pushable, onEnd: (err?: Error) => void }) { const { id, type, push, onEnd } = init - this.log = logger(`libp2p:mock-muxer:stream:${id}:${type}`) + super({ + id, + direction: type === 'initiator' ? 'outbound' : 'inbound', + maxDataSize: MAX_MESSAGE_SIZE, + onEnd + }) - this.id = id this.type = type - this.abortController = new AbortController() - this.resetController = new AbortController() - this.closeController = new AbortController() - - this.sourceEnded = false - this.sinkEnded = false - - let endErr: Error | undefined - - const onSourceEnd = (err?: Error): void => { - if (this.sourceEnded) { - return - } - - this.log('onSourceEnd sink ended? %s', this.sinkEnded) - - this.sourceEnded = true - - if (err != null && endErr == null) { - endErr = err - } - - if (this.sinkEnded) { - this.stream.stat.timeline.close = Date.now() + this.pushable = push + } - if (onEnd != null) { - onEnd(endErr) - } - } + /** + * Send a message to the remote muxer informing them a new stream is being + * opened + */ + sendNewStream (): void | Promise { + console.info('initiator send create stream') + const createMsg: CreateMessage = { + id: this.id, + type: 'create', + direction: 'initiator' } + this.pushable.push(createMsg) + } - const onSinkEnd = (err?: Error): void => { - if (this.sinkEnded) { - return - } - - this.log('onSinkEnd source ended? %s', this.sourceEnded) - - this.sinkEnded = true - - if (err != null && endErr == null) { - endErr = err - } - - if (this.sourceEnded) { - this.stream.stat.timeline.close = Date.now() - - if (onEnd != null) { - onEnd(endErr) - } - } + /** + * Send a data message to the remote muxer + */ + sendData (buf: Uint8ArrayList): void | Promise { + console.info(this.type, 'send data') + const dataMsg: DataMessage = { + id: this.id, + type: 'data', + chunk: uint8ArrayToString(buf.subarray(), 'base64pad'), + direction: this.type } + this.pushable.push(dataMsg) + } - this.input = pushable({ - onEnd: onSourceEnd - }) - - this.stream = { - id, - sink: async (source) => { - if (this.sinkEnded) { - throw new CodeError('stream closed for writing', 'ERR_SINK_ENDED') - } - - const signal = anySignal([ - this.abortController.signal, - this.resetController.signal, - this.closeController.signal - ]) - - source = abortableSource(source, signal) - - try { - if (this.type === 'initiator') { - // If initiator, open a new stream - const createMsg: CreateMessage = { - id: this.id, - type: 'create', - direction: this.type - } - push.push(createMsg) - } - - const list = new Uint8ArrayList() - - for await (const chunk of source) { - list.append(chunk) - - while (list.length > 0) { - const available = Math.min(list.length, MAX_MESSAGE_SIZE) - const dataMsg: DataMessage = { - id, - type: 'data', - chunk: uint8ArrayToString(list.subarray(0, available), 'base64pad'), - direction: this.type - } - - push.push(dataMsg) - list.consume(available) - } - } - } catch (err: any) { - if (err.type === 'aborted' && err.message === 'The operation was aborted') { - if (this.closeController.signal.aborted) { - return - } - - if (this.resetController.signal.aborted) { - err.message = 'stream reset' - err.code = 'ERR_STREAM_RESET' - } - - if (this.abortController.signal.aborted) { - err.message = 'stream aborted' - err.code = 'ERR_STREAM_ABORT' - } - } - - // Send no more data if this stream was remotely reset - if (err.code !== 'ERR_STREAM_RESET') { - const resetMsg: ResetMessage = { - id, - type: 'reset', - direction: this.type - } - push.push(resetMsg) - } - - this.log('sink erred', err) - - this.input.end(err) - onSinkEnd(err) - return - } finally { - signal.clear() - } - - this.log('sink ended') + /** + * Send a reset message to the remote muxer + */ + sendReset (): void | Promise { + console.info(this.type, 'send reset') + const resetMsg: ResetMessage = { + id: this.id, + type: 'reset', + direction: this.type + } + this.pushable.push(resetMsg) + } - onSinkEnd() + /** + * Send a message to the remote muxer, informing them no more data messages + * will be sent by this end of the stream + */ + sendCloseWrite (): void | Promise { + console.info(this.type, 'send close write') + const closeMsg: CloseMessage = { + id: this.id, + type: 'close', + direction: this.type + } + this.pushable.push(closeMsg) + } - const closeMsg: CloseMessage = { - id, - type: 'close', - direction: this.type - } - push.push(closeMsg) - }, - source: this.input, - - // Close for reading - close: () => { - this.stream.closeRead() - this.stream.closeWrite() - }, - - closeRead: () => { - this.input.end() - }, - - closeWrite: () => { - this.closeController.abort() - - const closeMsg: CloseMessage = { - id, - type: 'close', - direction: this.type - } - push.push(closeMsg) - onSinkEnd() - }, - - // Close for reading and writing (local error) - abort: (err: Error) => { - // End the source with the passed error - this.input.end(err) - this.abortController.abort() - onSinkEnd(err) - }, - - // Close immediately for reading and writing (remote error) - reset: () => { - const err = new CodeError('stream reset', 'ERR_STREAM_RESET') - this.resetController.abort() - this.input.end(err) - onSinkEnd(err) - }, - stat: { - direction: type === 'initiator' ? 'outbound' : 'inbound', - timeline: { - open: Date.now() - } - }, - metadata: {} + /** + * Send a message to the remote muxer, informing them no more data messages + * will be read by this end of the stream + */ + sendCloseRead (): void | Promise { + console.info(this.type, 'send close read') + const closeMsg: CloseMessage = { + id: this.id, + type: 'close', + direction: this.type } + this.pushable.push(closeMsg) } } @@ -270,8 +139,8 @@ class MockMuxer implements StreamMuxer { public protocol: string = '/mock-muxer/1.0.0' private readonly closeController: AbortController - private readonly registryInitiatorStreams: Map - private readonly registryRecipientStreams: Map + private readonly registryInitiatorStreams: Map + private readonly registryRecipientStreams: Map private readonly options: StreamMuxerInit private readonly log: Logger @@ -321,7 +190,7 @@ class MockMuxer implements StreamMuxer { } handleMessage (message: StreamMessage): void { - let muxedStream: MuxedStream | undefined + let muxedStream: AbstractStream | undefined const registry = message.direction === 'initiator' ? this.registryRecipientStreams : this.registryInitiatorStreams @@ -331,10 +200,10 @@ class MockMuxer implements StreamMuxer { } muxedStream = this.createStream(message.id, 'recipient') - registry.set(muxedStream.stream.id, muxedStream) + registry.set(muxedStream.id, muxedStream) if (this.options.onIncomingStream != null) { - this.options.onIncomingStream(muxedStream.stream) + this.options.onIncomingStream(muxedStream) } } @@ -347,20 +216,19 @@ class MockMuxer implements StreamMuxer { } if (message.type === 'data') { - muxedStream.input.push(new Uint8ArrayList(uint8ArrayFromString(message.chunk, 'base64pad'))) + muxedStream.sourcePush(new Uint8ArrayList(uint8ArrayFromString(message.chunk, 'base64pad'))) } else if (message.type === 'reset') { - this.log('-> reset stream %s %s', muxedStream.type, muxedStream.stream.id) - muxedStream.stream.reset() + this.log('-> reset stream %s %s', muxedStream.stat.direction, muxedStream.id) + muxedStream.reset() } else if (message.type === 'close') { - this.log('-> closing stream %s %s', muxedStream.type, muxedStream.stream.id) - muxedStream.stream.closeRead() + this.log('-> closing stream %s %s', muxedStream.stat.direction, muxedStream.id) + muxedStream.closeRead() } } get streams (): Stream[] { return Array.from(this.registryRecipientStreams.values()) .concat(Array.from(this.registryInitiatorStreams.values())) - .map(({ stream }) => stream) } newStream (name?: string): Stream { @@ -369,9 +237,9 @@ class MockMuxer implements StreamMuxer { } this.log('newStream %s', name) const storedStream = this.createStream(name, 'initiator') - this.registryInitiatorStreams.set(storedStream.stream.id, storedStream) + this.registryInitiatorStreams.set(storedStream.id, storedStream) - return storedStream.stream + return storedStream } createStream (name?: string, type: 'initiator' | 'recipient' = 'initiator'): MuxedStream { @@ -393,7 +261,7 @@ class MockMuxer implements StreamMuxer { } if (this.options.onStreamEnd != null) { - this.options.onStreamEnd(muxedStream.stream) + this.options.onStreamEnd(muxedStream) } } }) diff --git a/packages/interface-stream-muxer-compliance-tests/package.json b/packages/interface-stream-muxer-compliance-tests/package.json index 20a0eb78b..0eae50903 100644 --- a/packages/interface-stream-muxer-compliance-tests/package.json +++ b/packages/interface-stream-muxer-compliance-tests/package.json @@ -128,6 +128,7 @@ "clean": "aegir clean", "lint": "aegir lint", "dep-check": "aegir dep-check", + "generate": "protons src/fixtures/pb/*.proto", "build": "aegir build", "release": "aegir release" }, @@ -142,10 +143,13 @@ "it-drain": "^3.0.1", "it-map": "^3.0.2", "it-pair": "^2.0.2", + "it-pb-stream": "^4.0.1", "it-pipe": "^3.0.1", "it-stream-types": "^2.0.1", "p-defer": "^4.0.0", "p-limit": "^4.0.0", + "protons": "^7.0.2", + "protons-runtime": "^5.0.0", "uint8arraylist": "^2.4.3", "uint8arrays": "^4.0.2" }, diff --git a/packages/interface-stream-muxer-compliance-tests/src/close-test.ts b/packages/interface-stream-muxer-compliance-tests/src/close-test.ts index e75d047e9..28a0e3b4d 100644 --- a/packages/interface-stream-muxer-compliance-tests/src/close-test.ts +++ b/packages/interface-stream-muxer-compliance-tests/src/close-test.ts @@ -9,8 +9,10 @@ import { pipe } from 'it-pipe' import pDefer from 'p-defer' import { Uint8ArrayList } from 'uint8arraylist' import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' +import { pbStream } from 'it-pb-stream' import type { TestSetup } from '@libp2p/interface-compliance-tests' import type { StreamMuxerFactory } from '@libp2p/interface-stream-muxer' +import { Message } from './fixtures/pb/message.js' function randomBuffer (): Uint8Array { return uint8ArrayFromString(Math.random().toString()) @@ -342,5 +344,51 @@ export default (common: TestSetup): void => { stream.closeRead() await deferred.promise }) + + it('can close a stream gracefully', async () => { + const deferred = pDefer() + + const p = duplexPair() + const dialerFactory = await common.setup() + const dialer = dialerFactory.createStreamMuxer({ direction: 'outbound' }) + + const listenerFactory = await common.setup() + const listener = listenerFactory.createStreamMuxer({ + direction: 'inbound', + onIncomingStream: (stream) => { + const pb = pbStream(stream) + console.info('--> pb.read') + void pb.readPB(Message) + .then(message => { + deferred.resolve(message) + console.info('--> read end close stream') + pb.unwrap().close() + }) + .catch(err => { + deferred.reject(err) + }) + } + }) + + void pipe(p[0], dialer, p[0]) + void pipe(p[1], listener, p[1]) + + const message = { + message: 'hello world', + value: 5, + flag: true + } + + const stream = await dialer.newStream() + const pb = pbStream(stream) + + console.info('--> pb.write') + pb.writePB(message, Message) + + console.info('--> write end close stream') + pb.unwrap().close() + + await expect(deferred.promise).to.eventually.deep.equal(message) + }) }) } diff --git a/packages/interface-stream-muxer-compliance-tests/src/fixtures/pb/message.proto b/packages/interface-stream-muxer-compliance-tests/src/fixtures/pb/message.proto new file mode 100644 index 000000000..f734b891e --- /dev/null +++ b/packages/interface-stream-muxer-compliance-tests/src/fixtures/pb/message.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +message Message { + string message = 1; + uint32 value = 2; + bool flag = 3; +} diff --git a/packages/interface-stream-muxer-compliance-tests/src/fixtures/pb/message.ts b/packages/interface-stream-muxer-compliance-tests/src/fixtures/pb/message.ts new file mode 100644 index 000000000..74bdd8bb6 --- /dev/null +++ b/packages/interface-stream-muxer-compliance-tests/src/fixtures/pb/message.ts @@ -0,0 +1,87 @@ +/* eslint-disable import/export */ +/* eslint-disable complexity */ +/* eslint-disable @typescript-eslint/no-namespace */ +/* eslint-disable @typescript-eslint/no-unnecessary-boolean-literal-compare */ +/* eslint-disable @typescript-eslint/no-empty-interface */ + +import { encodeMessage, decodeMessage, message } from 'protons-runtime' +import type { Codec } from 'protons-runtime' +import type { Uint8ArrayList } from 'uint8arraylist' + +export interface Message { + message: string + value: number + flag: boolean +} + +export namespace Message { + let _codec: Codec + + export const codec = (): Codec => { + if (_codec == null) { + _codec = message((obj, w, opts = {}) => { + if (opts.lengthDelimited !== false) { + w.fork() + } + + if ((obj.message != null && obj.message !== '')) { + w.uint32(10) + w.string(obj.message) + } + + if ((obj.value != null && obj.value !== 0)) { + w.uint32(16) + w.uint32(obj.value) + } + + if ((obj.flag != null && obj.flag !== false)) { + w.uint32(24) + w.bool(obj.flag) + } + + if (opts.lengthDelimited !== false) { + w.ldelim() + } + }, (reader, length) => { + const obj: any = { + message: '', + value: 0, + flag: false + } + + const end = length == null ? reader.len : reader.pos + length + + while (reader.pos < end) { + const tag = reader.uint32() + + switch (tag >>> 3) { + case 1: + obj.message = reader.string() + break + case 2: + obj.value = reader.uint32() + break + case 3: + obj.flag = reader.bool() + break + default: + reader.skipType(tag & 7) + break + } + } + + return obj + }) + } + + return _codec + } + + export const encode = (obj: Partial): Uint8Array => { + return encodeMessage(obj, Message.codec()) + } + + export const decode = (buf: Uint8Array | Uint8ArrayList): Message => { + return decodeMessage(buf, Message.codec()) + } +} diff --git a/packages/interface-stream-muxer/src/stream.ts b/packages/interface-stream-muxer/src/stream.ts index 08d08ed51..6f17a1970 100644 --- a/packages/interface-stream-muxer/src/stream.ts +++ b/packages/interface-stream-muxer/src/stream.ts @@ -163,7 +163,10 @@ export abstract class AbstractStream implements Stream { return } - this.streamSource.end() + // this has to be done after the current macrotask has finished https://github.com/libp2p/js-libp2p/issues/1793 + Promise.resolve().then(() => { + this.streamSource.end() + }) } // Close for writing @@ -174,23 +177,26 @@ export abstract class AbstractStream implements Stream { return } - this.closeController.abort() + // this has to be done after the current macrotask has finished https://github.com/libp2p/js-libp2p/issues/1793 + Promise.resolve().then(() => { + this.closeController.abort() - try { - // need to call this here as the sink method returns in the catch block - // when the close controller is aborted - const res = this.sendCloseWrite() + try { + // need to call this here as the sink method returns in the catch block + // when the close controller is aborted + const res = this.sendCloseWrite() - if (isPromise(res)) { - res.catch(err => { - log.error('error while sending close write', err) - }) + if (isPromise(res)) { + res.catch(err => { + log.error('error while sending close write', err) + }) + } + } catch (err) { + log.trace('%s stream %s error sending close', this.stat.direction, this.id, err) } - } catch (err) { - log.trace('%s stream %s error sending close', this.stat.direction, this.id, err) - } - this.onSinkEnd() + this.onSinkEnd() + }) } // Close for reading and writing (local error) @@ -239,6 +245,7 @@ export abstract class AbstractStream implements Stream { } for await (let data of source) { + console.info('stream sink got data from source') while (data.length > 0) { if (data.length <= this.maxDataSize) { const res = this.sendData(data instanceof Uint8Array ? new Uint8ArrayList(data) : data) @@ -260,6 +267,8 @@ export abstract class AbstractStream implements Stream { } } } catch (err: any) { + console.info('stream sink ended') + if (err.type === 'aborted' && err.message === 'The operation was aborted') { if (this.closeController.signal.aborted) { return