diff --git a/src/Nerdbank.Streams/MultiplexingStream.Channel.cs b/src/Nerdbank.Streams/MultiplexingStream.Channel.cs index f27f4594..ce5c2acc 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.Channel.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.Channel.cs @@ -97,7 +97,7 @@ public class Channel : IDisposableObservable, IDuplexPipe private long? localWindowSize; /// - /// Indicates whether the method has been called. + /// Indicates whether the method has been called. /// private bool isDisposed; @@ -322,95 +322,122 @@ private long RemoteWindowRemaining private bool BackpressureSupportEnabled => this.MultiplexingStream.protocolMajorVersion > 1; /// - /// Closes this channel and releases all resources associated with it. + /// Immediately terminates the channel and shuts down any ongoing communication. /// /// /// Because this method may terminate the channel immediately and thus can cause previously queued content to not actually be received by the remote party, /// consider this method a "break glass" way of terminating a channel. The preferred method is that both sides "complete writing" and let the channel dispose itself. /// - public void Dispose() + public void Dispose() => this.Dispose(null); + + /// + /// Disposes the channel by releasing all resources associated with it. + /// + /// The exception to dispose this channel with. + internal void Dispose(Exception? disposeException) { - if (!this.IsDisposed) + // Ensure that we don't call dispose more than once. + lock (this.SyncObject) { - this.acceptanceSource.TrySetCanceled(); - this.optionsAppliedTaskSource?.TrySetCanceled(); - - PipeWriter? mxStreamIOWriter; - lock (this.SyncObject) + if (this.isDisposed) { - this.isDisposed = true; - mxStreamIOWriter = this.mxStreamIOWriter; + return; } - // Complete writing so that the mxstream cannot write to this channel any more. - // We must also cancel a pending flush since no one is guaranteed to be reading this any more - // and we don't want to deadlock on a full buffer in a disposed channel's pipe. - if (mxStreamIOWriter is not null) - { - mxStreamIOWriter.CancelPendingFlush(); - _ = this.mxStreamIOWriterSemaphore.EnterAsync().ContinueWith( - static (releaser, state) => - { - try - { - Channel self = (Channel)state; + // First call to dispose + this.isDisposed = true; + } - PipeWriter? mxStreamIOWriter; - lock (self.SyncObject) - { - mxStreamIOWriter = self.mxStreamIOWriter; - } + this.acceptanceSource.TrySetCanceled(); + this.optionsAppliedTaskSource?.TrySetCanceled(); - mxStreamIOWriter?.Complete(); - self.mxStreamIOWriterCompleted.Set(); - } - finally + PipeWriter? mxStreamIOWriter; + lock (this.SyncObject) + { + mxStreamIOWriter = this.mxStreamIOWriter; + } + + // Complete writing so that the mxstream cannot write to this channel any more. + // We must also cancel a pending flush since no one is guaranteed to be reading this any more + // and we don't want to deadlock on a full buffer in a disposed channel's pipe. + if (mxStreamIOWriter is not null) + { + mxStreamIOWriter.CancelPendingFlush(); + _ = this.mxStreamIOWriterSemaphore.EnterAsync().ContinueWith( + static (releaser, state) => + { + try + { + Channel self = (Channel)state; + + PipeWriter? mxStreamIOWriter; + lock (self.SyncObject) { - releaser.Result.Dispose(); + mxStreamIOWriter = self.mxStreamIOWriter; } - }, - this, - CancellationToken.None, - TaskContinuationOptions.OnlyOnRanToCompletion, - TaskScheduler.Default); - } - if (this.mxStreamIOReader is not null) - { - // We don't own the user's PipeWriter to complete it (so they can't write anything more to this channel). - // We can't know whether there is or will be more bytes written to the user's PipeWriter, - // but we need to terminate our reader for their writer as part of reclaiming resources. - // Cancel the pending or next read operation so the reader loop will immediately notice and shutdown. - this.mxStreamIOReader.CancelPendingRead(); - - // Only Complete the reader if our async reader doesn't own it to avoid thread-safety bugs. - PipeReader? mxStreamIOReader = null; - lock (this.SyncObject) - { - if (this.mxStreamIOReader is not UnownedPipeReader) + mxStreamIOWriter?.Complete(); + self.mxStreamIOWriterCompleted.Set(); + } + finally { - mxStreamIOReader = this.mxStreamIOReader; - this.mxStreamIOReader = null; + releaser.Result.Dispose(); } - } + }, + this, + CancellationToken.None, + TaskContinuationOptions.OnlyOnRanToCompletion, + TaskScheduler.Default); + } - mxStreamIOReader?.Complete(); + if (this.mxStreamIOReader is not null) + { + // We don't own the user's PipeWriter to complete it (so they can't write anything more to this channel). + // We can't know whether there is or will be more bytes written to the user's PipeWriter, + // but we need to terminate our reader for their writer as part of reclaiming resources. + // Cancel the pending or next read operation so the reader loop will immediately notice and shutdown. + this.mxStreamIOReader.CancelPendingRead(); + + // Only Complete the reader if our async reader doesn't own it to avoid thread-safety bugs. + PipeReader? mxStreamIOReader = null; + lock (this.SyncObject) + { + if (this.mxStreamIOReader is not UnownedPipeReader) + { + mxStreamIOReader = this.mxStreamIOReader; + this.mxStreamIOReader = null; + } } - // Unblock the reader that might be waiting on this. - this.remoteWindowHasCapacity.Set(); + mxStreamIOReader?.Complete(); + } - this.disposalTokenSource.Cancel(); + // Set the completion source based on whether we are disposing due to an error + if (disposeException != null) + { + this.completionSource.TrySetException(disposeException); + } + else + { this.completionSource.TrySetResult(null); - this.MultiplexingStream.OnChannelDisposed(this); } + + // Unblock the reader that might be waiting on this. + this.remoteWindowHasCapacity.Set(); + + this.disposalTokenSource.Cancel(); + this.MultiplexingStream.OnChannelDisposed(this, disposeException); } - internal async Task OnChannelTerminatedAsync() + internal async Task OnChannelTerminatedAsync(Exception? remoteError = null) { - if (this.IsDisposed) + // Don't process the frame if the channel has already been disposed. + lock (this.SyncObject) { - return; + if (this.isDisposed) + { + return; + } } try @@ -424,6 +451,16 @@ internal async Task OnChannelTerminatedAsync() { // We fell victim to a race condition. It's OK to just swallow it because the writer was never created, so it needn't be completed. } + + // Terminate the channel. + this.DisposeSelfOnFailure(Task.Run(async delegate + { + // Ensure that we processed the channel before terminating it. + await this.OptionsApplied.ConfigureAwait(false); + + this.IsRemotelyTerminated = true; + this.Dispose(remoteError); + })); } /// @@ -771,6 +808,7 @@ private async Task ProcessOutboundTransmissionsAsync() // We don't use a CancellationToken on this call because we prefer the exception-free cancellation path used by our Dispose method (CancelPendingRead). ReadResult result = await mxStreamIOReader.ReadAsync().ConfigureAwait(false); + if (result.IsCanceled) { // We've been asked to cancel. Presumably the channel has faulted or been disposed. @@ -832,13 +870,25 @@ private async Task ProcessOutboundTransmissionsAsync() } catch (Exception ex) { - if (ex is OperationCanceledException && this.DisposalToken.IsCancellationRequested) + await mxStreamIOReader!.CompleteAsync(ex).ConfigureAwait(false); + + // Record this as a faulting exception if the channel hasn't been disposed. + lock (this.SyncObject) { - await mxStreamIOReader!.CompleteAsync().ConfigureAwait(false); + if (!this.IsDisposed) + { + this.faultingException ??= ex; + } } - else + + // Add a trace indicating that we caught an exception. + if (this.TraceSource!.Switch.ShouldTrace(TraceEventType.Information)) { - await mxStreamIOReader!.CompleteAsync(ex).ConfigureAwait(false); + this.TraceSource.TraceEvent( + TraceEventType.Error, + 0, + "Rethrowing caught exception in " + nameof(this.ProcessOutboundTransmissionsAsync) + ": {0}", + ex.Message); } throw; @@ -911,23 +961,34 @@ private async Task AutoCloseOnPipesClosureAsync() this.TraceSource.TraceEvent(TraceEventType.Information, (int)TraceEventId.ChannelAutoClosing, "Channel {0} \"{1}\" self-closing because both reader and writer are complete.", this.QualifiedId, this.Name); } - this.Dispose(); + this.Dispose(null); } private void Fault(Exception exception) { - if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Critical) ?? false) - { - this.TraceSource!.TraceEvent(TraceEventType.Critical, (int)TraceEventId.FatalError, "Channel Closing self due to exception: {0}", exception); - } + this.mxStreamIOReader?.CancelPendingRead(); + // If the channel has already been disposed then only cancel the reader lock (this.SyncObject) { + if (this.isDisposed) + { + return; + } + this.faultingException ??= exception; } - this.mxStreamIOReader?.CancelPendingRead(); - this.Dispose(); + if (this.TraceSource?.Switch.ShouldTrace(TraceEventType.Error) ?? false) + { + this.TraceSource.TraceEvent(TraceEventType.Error, (int)TraceEventId.ChannelFatalError, "Channel faulted with exception: {0}", this.faultingException); + if (exception != this.faultingException) + { + this.TraceSource.TraceEvent(TraceEventType.Error, (int)TraceEventId.ChannelFatalError, "A subsequent fault exception was reported: {0}", exception); + } + } + + this.Dispose(this.faultingException); } private void DisposeSelfOnFailure(Task task) diff --git a/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs b/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs index c0550733..732ed4db 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.Formatters.cs @@ -410,6 +410,57 @@ internal override void WriteFrame(FrameHeader header, ReadOnlySequence pay } } + /// + /// Creates a payload for a frame. + /// + /// The exception to send to the remote side if there is one. + /// The payload to send when a channel gets terminated. + internal ReadOnlySequence SerializeException(Exception? exception) + { + if (exception == null) + { + return ReadOnlySequence.Empty; + } + + var sequence = new Sequence(); + var writer = new MessagePackWriter(sequence); + + writer.WriteArrayHeader(1); + + // Get the exception to send to the remote side + writer.Write($"{exception.GetType().Name}: {exception.Message}"); + writer.Flush(); + + return sequence; + } + + /// + /// Gets the error message in the payload if there is one. + /// + /// The payload that could contain an error message. + /// The error message in this payload if there is one, null otherwise. + internal Exception? DeserializeException(ReadOnlySequence payload) + { + // An empty payload means the remote side closed the channel without an exception. + if (payload.IsEmpty) + { + return null; + } + + var reader = new MessagePackReader(payload); + int numElements = reader.ReadArrayHeader(); + + // We received an empty payload. + if (numElements == 0) + { + return null; + } + + // Get the exception message and return it as an exception. + string remoteErrorMsg = reader.ReadString(); + return new MultiplexingProtocolException($"Received error from remote side: {remoteErrorMsg}"); + } + internal override ReadOnlySequence SerializeContentProcessed(long bytesProcessed) { var sequence = new Sequence(); diff --git a/src/Nerdbank.Streams/MultiplexingStream.cs b/src/Nerdbank.Streams/MultiplexingStream.cs index ba0fdbcc..37b9473b 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.cs @@ -173,6 +173,10 @@ private enum TraceEventId { HandshakeSuccessful = 1, HandshakeFailed, + + /// + /// A fatal error occurred at the overall multiplexing stream level, taking down the whole connection. + /// FatalError, UnexpectedChannelAccept, ChannelAutoClosing, @@ -209,6 +213,11 @@ private enum TraceEventId /// Raised when the protocol handshake is starting, to annouce the major version being used. /// HandshakeStarted, + + /// + /// A fatal exception occurred that took down one channel. + /// + ChannelFatalError, } /// @@ -680,7 +689,7 @@ public async ValueTask DisposeAsync() { foreach (KeyValuePair entry in this.openChannels) { - entry.Value.Dispose(); + entry.Value.Dispose(new ObjectDisposedException(nameof(MultiplexingStream))); } foreach (KeyValuePair>> entry in this.acceptingChannels) @@ -835,7 +844,7 @@ private async Task ReadStreamAsync() this.OnContentWritingCompleted(header.RequiredChannelId); break; case ControlCode.ChannelTerminated: - await this.OnChannelTerminatedAsync(header.RequiredChannelId).ConfigureAwait(false); + await this.OnChannelTerminatedAsync(header.RequiredChannelId, frame.Value.Payload).ConfigureAwait(false); break; default: break; @@ -892,7 +901,8 @@ private async Task ReadStreamAsync() /// Occurs when the remote party has terminated a channel (including canceling an offer). /// /// The ID of the terminated channel. - private async Task OnChannelTerminatedAsync(QualifiedChannelId channelId) + /// The payload sent from the remote side alongside the channel terminated frame. + private async Task OnChannelTerminatedAsync(QualifiedChannelId channelId, ReadOnlySequence payload) { Channel? channel; lock (this.syncObject) @@ -913,9 +923,21 @@ private async Task OnChannelTerminatedAsync(QualifiedChannelId channelId) if (channel is Channel) { - await channel.OnChannelTerminatedAsync().ConfigureAwait(false); - channel.IsRemotelyTerminated = true; - channel.Dispose(); + // Try to get the exception sent from the remote side if there was one sent. + Exception? remoteException = (this.formatter as V2Formatter)?.DeserializeException(payload); + + if (remoteException != null && this.TraceSource.Switch.ShouldTrace(TraceEventType.Error)) + { + this.TraceSource.TraceEvent( + TraceEventType.Error, + (int)TraceEventId.ChannelFatalError, + "Received {2} for channel {0} with exception: {1}", + channelId, + remoteException.Message, + ControlCode.ChannelTerminated); + } + + await channel.OnChannelTerminatedAsync(remoteException).ConfigureAwait(false); } } @@ -1095,21 +1117,33 @@ private void AcceptChannelOrThrow(Channel channel, ChannelOptions options) } /// - /// Raised when is called and any local transmission is completed. + /// Raised when is called and any local transmission is completed. /// /// The channel that is closing down. - private void OnChannelDisposed(Channel channel) + /// The exception to send to the remote side alongside the disposal. + private void OnChannelDisposed(Channel channel, Exception? exception = null) { Requires.NotNull(channel, nameof(channel)); if (!this.Completion.IsCompleted && !this.DisposalToken.IsCancellationRequested) { + // Determine the header to send alongside the error payload + var header = new FrameHeader + { + Code = ControlCode.ChannelTerminated, + ChannelId = channel.QualifiedId, + }; + + // If there is an error and we support sending errors then + // serialize the exception and store it in the payload + ReadOnlySequence payload = (this.formatter as V2Formatter)?.SerializeException(exception) ?? default; + if (this.TraceSource.Switch.ShouldTrace(TraceEventType.Information)) { this.TraceSource.TraceEvent(TraceEventType.Information, (int)TraceEventId.ChannelDisposed, "Local channel {0} \"{1}\" stream disposed.", channel.QualifiedId, channel.Name); } - this.SendFrame(ControlCode.ChannelTerminated, channel.QualifiedId); + this.SendFrame(header, payload, this.DisposalToken); } } diff --git a/src/nerdbank-streams/src/Channel.ts b/src/nerdbank-streams/src/Channel.ts index 79171273..1dad2ebe 100644 --- a/src/nerdbank-streams/src/Channel.ts +++ b/src/nerdbank-streams/src/Channel.ts @@ -55,8 +55,9 @@ export abstract class Channel implements IDisposableObservable { /** * Closes this channel. + * @param error An optional error to send to the remote side, if this multiplexing stream is using protocol versions >= 2. */ - public dispose(): void { + public dispose(error?: Error | null): void { // The interesting stuff is in the derived class. this._isDisposed = true; } @@ -247,7 +248,7 @@ export class ChannelClass extends Channel { } } - public dispose(): void { + public dispose(error?: Error | null): void { if (!this.isDisposed) { super.dispose(); @@ -261,10 +262,14 @@ export class ChannelClass extends Channel { this._duplex.end(); this._duplex.push(null); - this._completion.resolve(); + if (error) { + this._completion.reject(error); + } else { + this._completion.resolve(); + } // Send the notification, but we can't await the result of this. - caught(this._multiplexingStream.onChannelDisposed(this)); + caught(this._multiplexingStream.onChannelDisposed(this, error ?? null)); } } diff --git a/src/nerdbank-streams/src/MultiplexingStream.ts b/src/nerdbank-streams/src/MultiplexingStream.ts index 1f6eafcc..1a1ae7fb 100644 --- a/src/nerdbank-streams/src/MultiplexingStream.ts +++ b/src/nerdbank-streams/src/MultiplexingStream.ts @@ -106,15 +106,15 @@ export abstract class MultiplexingStream implements IDisposableObservable { * @param options Options to customize the behavior of the stream. * @returns The multiplexing stream. */ - public static Create( + public static Create( stream: NodeJS.ReadWriteStream, - options?: MultiplexingStreamOptions) : MultiplexingStream { + options?: MultiplexingStreamOptions): MultiplexingStream { options ??= { protocolMajorVersion: 3 }; options.protocolMajorVersion ??= 3; const formatter: MultiplexingStreamFormatter | undefined = options.protocolMajorVersion === 3 ? new MultiplexingStreamV3Formatter(stream) : - undefined; + undefined; if (!formatter) { throw new Error(`Protocol major version ${options.protocolMajorVersion} is not supported. Try CreateAsync instead.`); } @@ -378,6 +378,7 @@ export abstract class MultiplexingStream implements IDisposableObservable { public dispose() { this.disposalTokenSource.cancel(); this._completionSource.resolve(); + this.formatter.end(); [this.locallyOfferedOpenChannels, this.remotelyOfferedOpenChannels].forEach(cb => { for (const channelId in cb) { @@ -387,6 +388,7 @@ export abstract class MultiplexingStream implements IDisposableObservable { // Acceptance gets rejected when a channel is disposed. // Avoid a node.js crash or test failure for unobserved channels (e.g. offers for channels from the other party that no one cared to receive on this side). caught(channel.acceptance); + channel.dispose(); } } @@ -585,10 +587,15 @@ export class MultiplexingStreamClass extends MultiplexingStream { } } - public async onChannelDisposed(channel: ChannelClass) { + public async onChannelDisposed(channel: ChannelClass, error: Error | null) { if (!this._completionSource.isCompleted) { try { - await this.sendFrame(ControlCode.ChannelTerminated, channel.qualifiedId); + const payload = this.protocolMajorVersion > 1 && error + ? (this.formatter as MultiplexingStreamV2Formatter).serializeException(error) + : Buffer.alloc(0); + + const frameHeader = new FrameHeader(ControlCode.ChannelTerminated, channel.qualifiedId); + await this.sendFrameAsync(frameHeader, payload); } catch (err) { // Swallow exceptions thrown about channel disposal if the whole stream has been taken down. if (this.isDisposed) { @@ -630,7 +637,7 @@ export class MultiplexingStreamClass extends MultiplexingStream { this.onContentWritingCompleted(frame.header.requiredChannel); break; case ControlCode.ChannelTerminated: - this.onChannelTerminated(frame.header.requiredChannel); + this.onChannelTerminated(frame.header.requiredChannel, frame.payload); break; default: break; @@ -729,12 +736,18 @@ export class MultiplexingStreamClass extends MultiplexingStream { * Occurs when the remote party has terminated a channel (including canceling an offer). * @param channelId The ID of the terminated channel. */ - private onChannelTerminated(channelId: QualifiedChannelId) { + private onChannelTerminated(channelId: QualifiedChannelId, payload: Buffer) { const channel = this.getOpenChannel(channelId); if (channel) { this.deleteOpenChannel(channelId); this.removeChannelFromOfferedQueue(channel); - channel.dispose(); + + // Extract the exception that we received from the remote side. + const remoteException = this.protocolMajorVersion > 1 + ? (this.formatter as MultiplexingStreamV2Formatter).deserializeException(payload) + : null; + + channel.dispose(remoteException); } } } diff --git a/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts b/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts index 0e4a3438..e0d5a6bf 100644 --- a/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts +++ b/src/nerdbank-streams/src/MultiplexingStreamFormatters.ts @@ -293,6 +293,35 @@ export class MultiplexingStreamV2Formatter extends MultiplexingStreamFormatter { return msgpack.decode(payload)[0]; } + serializeException(error: Error | null): Buffer { + // If the error doesn't exist then return an empty buffer. + if (!error) { + return Buffer.alloc(0); + } + + const errorMsg: string = `${error.name}: ${error.message}`; + const payload: any[] = [errorMsg]; + return msgpack.encode(payload); + } + + deserializeException(payload: Buffer): Error | null { + // If the payload is empty then return null. + if (payload.length === 0) { + return null; + } + + // Make sure that the message pack object contains a message + const msgpackObject = msgpack.decode(payload); + if (!msgpackObject || msgpackObject.length === 0) { + return null; + } + + // Get error message and return the error to the remote side + let errorMsg: string = msgpack.decode(payload)[0]; + errorMsg = `Received error from remote side: ${errorMsg}`; + return new Error(errorMsg); + } + protected async readMessagePackAsync(cancellationToken: CancellationToken): Promise<{} | [] | null> { const streamEnded = new Deferred(); while (true) { diff --git a/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts b/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts index 09a3348b..fd0444bb 100644 --- a/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts +++ b/src/nerdbank-streams/src/tests/MultiplexingStream.Interop.spec.ts @@ -72,12 +72,51 @@ import { ChannelOptions } from "../ChannelOptions"; expect(recv).toEqual("recv: theclient\n"); }); + it("Can offer error completed channel", async () => { + const errorMsg: string = "Hello world"; + const error: Error = new Error(errorMsg); + + const channelToCompleteWithError = await mx.offerChannelAsync("clientErrorOffer"); + const communicationChannel = await mx.offerChannelAsync("clientErrorOfferComm"); + + channelToCompleteWithError.dispose(error); + channelToCompleteWithError.completion.then(_ => { + throw new Error("Channel disposed with error didn't complete with error"); + }).catch((channelCompleteErr) => { + expect(channelCompleteErr.message).toContain(errorMsg); + }); + + let expectedErrMessage = `Received error from remote side: Error: ${errorMsg}`; + if (protocolMajorVersion <= 1) { + expectedErrMessage = "Completed with no error"; + } + + const response = await readLineAsync(communicationChannel.stream); + expect(response).toContain(expectedErrMessage); + }) + it("Can accept channel", async () => { const channel = await mx.acceptChannelAsync("serverOffer"); const recv = await readLineAsync(channel.stream); await writeAsync(channel.stream, `recv: ${recv}`); }); + it("Can accept error completed channel", async () => { + const channelCompletedWithError = await mx.acceptChannelAsync("serverErrorOffer"); + const errorExpectedMessage: string = "Received error from remote side: Exception: Hello World"; + const channelCompleted = new Deferred(); + + channelCompletedWithError.completion.then(async _ => { + expect(protocolMajorVersion).toEqual(1); + channelCompleted.resolve(); + }).catch(async error => { + expect(error.message).toContain(errorExpectedMessage); + channelCompleted.resolve(); + }) + + await channelCompleted.promise; + }) + it("Exchange lots of data", async () => { const channel = await mx.offerChannelAsync("clientOffer", { channelReceivingWindowSize: 16 }); const bigdata = 'ABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEFABCDEF\n'; diff --git a/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts b/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts index 5f2cc5dd..1bfd251c 100644 --- a/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts +++ b/src/nerdbank-streams/src/tests/MultiplexingStream.spec.ts @@ -249,6 +249,49 @@ import { nextTick } from "process"; await channels[1].completion; }); + it("channel terminated with error", async () => { + // Determine the error to complete the local channel with. + const errorMsg: string = "Hello world"; + const error: Error = new Error(errorMsg); + + // Get the channels to send/receive data over + const channels = await Promise.all([ + mx1.offerChannelAsync("test"), + mx2.acceptChannelAsync("test"), + ]); + const localChannel = channels[0]; + const remoteChannel = channels[1]; + + const localChannelCompleted = new Deferred(); + const remoteChannelCompleted = new Deferred(); + + // Dispose the local channel + localChannel.dispose(error); + + // Ensure that the local channel is always completed with the expected error + localChannel.completion.then(response => { + localChannelCompleted.reject(); + throw new Error("Channel disposed with error didn't complete with error"); + }).catch(localChannelErr => { + localChannelCompleted.resolve(); + expect(localChannelErr.message).toContain(errorMsg); + }); + + // Ensure that the remote channel only throws an error for protocol version > 1 + remoteChannel.completion.then(response => { + remoteChannelCompleted.resolve(); + expect(protocolMajorVersion).toEqual(1); + }).catch(remoteChannelErr => { + remoteChannelCompleted.resolve(); + expect(protocolMajorVersion).toBeGreaterThan(1); + expect(remoteChannelErr.message).toContain(errorMsg); + }); + + // Ensure that we don't call multiplexing dispose too soon + await localChannelCompleted.promise; + await remoteChannelCompleted.promise; + }) + it("channels complete when mxstream is disposed", async () => { const channels = await Promise.all([ mx1.offerChannelAsync("test"), diff --git a/test/Nerdbank.Streams.Interop.Tests/Program.cs b/test/Nerdbank.Streams.Interop.Tests/Program.cs index d5b5a46e..657bc5c1 100644 --- a/test/Nerdbank.Streams.Interop.Tests/Program.cs +++ b/test/Nerdbank.Streams.Interop.Tests/Program.cs @@ -62,6 +62,8 @@ private async Task RunAsync(int protocolMajorVersion) { this.ClientOfferAsync().Forget(); this.ServerOfferAsync().Forget(); + this.ClientOffersErrorCompletedChannel().Forget(); + this.ServerOffersErrorCompletedChannel().Forget(); if (protocolMajorVersion >= 3) { @@ -71,6 +73,25 @@ private async Task RunAsync(int protocolMajorVersion) await this.mx.Completion; } + private async Task ClientOffersErrorCompletedChannel() + { + MultiplexingStream.Channel? expectedErrorChannel = await this.mx.AcceptChannelAsync("clientErrorOffer"); + MultiplexingStream.Channel? communicationChannel = await this.mx.AcceptChannelAsync("clientErrorOfferComm"); + (StreamReader _, StreamWriter writer) = CreateStreamIO(communicationChannel); + + string responseMessage = "Completed with no error"; + try + { + await expectedErrorChannel.Completion; + } + catch (Exception e) + { + responseMessage = e.Message; + } + + await writer.WriteLineAsync(responseMessage); + } + private async Task ClientOfferAsync() { MultiplexingStream.Channel? channel = await this.mx.AcceptChannelAsync("clientOffer"); @@ -79,6 +100,13 @@ private async Task ClientOfferAsync() await w.WriteLineAsync($"recv: {line}"); } + private async Task ServerOffersErrorCompletedChannel() + { + MultiplexingStream.Channel? expectedErrorChannel = await this.mx.OfferChannelAsync("serverErrorOffer"); + string errorMessage = "Hello World"; + await expectedErrorChannel.Output.CompleteAsync(new Exception(errorMessage)); + } + private async Task ServerOfferAsync() { MultiplexingStream.Channel? channel = await this.mx.OfferChannelAsync("serverOffer"); diff --git a/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs b/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs index 83e6717c..101fc06b 100644 --- a/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs +++ b/test/Nerdbank.Streams.Tests/MultiplexingStreamTests.cs @@ -196,6 +196,74 @@ public async Task OfferWriteOnlyPipe() await Task.WhenAll(ch1.Completion, ch2.Completion).WithCancellation(this.TimeoutToken); } + [Fact] + public async Task OfferErrorCompletedPipe() + { + // Prepare a readonly pipe that is completed with an error + string localErrMsg = "Hello World"; + string remoteErrMsg = $"Received error from remote side: {nameof(ApplicationException)}: {localErrMsg}"; + var pipe = new Pipe(); + await pipe.Writer.WriteAsync(new byte[] { 1, 2, 3 }, this.TimeoutToken); + pipe.Writer.Complete(new ApplicationException(localErrMsg)); + + // Offer the error pipe to the remote side and get the remote side to try to accept the channel + MultiplexingStream.Channel localChannel = this.mx1.CreateChannel(new MultiplexingStream.ChannelOptions { ExistingPipe = new DuplexPipe(pipe.Reader) }); + await this.WaitForEphemeralChannelOfferToPropagateAsync(); + MultiplexingStream.Channel remoteChannel = this.mx2.AcceptChannel(localChannel.QualifiedId.Id); + + // The local channel should always complete with the error + await VerifyChannelCompleted(localChannel, localErrMsg); + + // The remote side should only receive the remote exception for protocol versions > 1 + await VerifyChannelCompleted(remoteChannel, this.ProtocolMajorVersion > 1 ? remoteErrMsg : null); + } + + [Fact] + public async Task OfferEmptyErrorCompletedPipe() + { + string localErrMsg = string.Empty; + string remoteErrMsg = $"Received error from remote side: {nameof(IndexOutOfRangeException)}: {localErrMsg}"; + + // Prepare a readonly pipe that is completed with an error + var pipe = new Pipe(); + await pipe.Writer.WriteAsync(new byte[] { 1, 2, 3 }, this.TimeoutToken); + pipe.Writer.Complete(new IndexOutOfRangeException(string.Empty)); + + // Offer the error pipe to the remote side and get the remote side to try to accept the channel + MultiplexingStream.Channel localChannel = this.mx1.CreateChannel(new MultiplexingStream.ChannelOptions { ExistingPipe = new DuplexPipe(pipe.Reader) }); + await this.WaitForEphemeralChannelOfferToPropagateAsync(); + MultiplexingStream.Channel remoteChannel = this.mx2.AcceptChannel(localChannel.QualifiedId.Id); + + // The local channel should always complete with the error + await VerifyChannelCompleted(localChannel, localErrMsg); + + // The remote side should only receive the remote exception for protocol versions > 1 + await VerifyChannelCompleted(remoteChannel, this.ProtocolMajorVersion > 1 ? remoteErrMsg : null); + } + + [Fact] + public async Task OfferNullErrorCompletedPipe() + { + string localErrMsg = "Exception of type 'System.NullReferenceException' was thrown."; + string remoteErrMsg = $"Received error from remote side: {nameof(NullReferenceException)}: {localErrMsg}"; + + // Prepare a readonly pipe that is completed with an error + var pipe = new Pipe(); + await pipe.Writer.WriteAsync(new byte[] { 1, 2, 3 }, this.TimeoutToken); + pipe.Writer.Complete(new NullReferenceException(null)); + + // Offer the error pipe to the remote side and get the remote side to try to accept the channel + MultiplexingStream.Channel localChannel = this.mx1.CreateChannel(new MultiplexingStream.ChannelOptions { ExistingPipe = new DuplexPipe(pipe.Reader) }); + await this.WaitForEphemeralChannelOfferToPropagateAsync(); + MultiplexingStream.Channel remoteChannel = this.mx2.AcceptChannel(localChannel.QualifiedId.Id); + + // The local channel should always complete with the error + await VerifyChannelCompleted(localChannel, localErrMsg); + + // The remote side should only receive the remote exception for protocol versions > 1 + await VerifyChannelCompleted(remoteChannel, this.ProtocolMajorVersion > 1 ? remoteErrMsg : null); + } + [Fact] public async Task Dispose_CancelsOutstandingOperations() { @@ -207,6 +275,26 @@ public async Task Dispose_CancelsOutstandingOperations() Assert.True(accept.IsCanceled); } + [Fact] + public async Task Dispose_CompleteWithErrorAfterwards() + { + // Create the local and remote channels using channel names + Task? localChannelTask = this.mx1.OfferChannelAsync("completeAfterwards", this.TimeoutToken); + Task? remoteChannelTask = this.mx2.AcceptChannelAsync("completeAfterwards", this.TimeoutToken); + MultiplexingStream.Channel remoteChannel = await remoteChannelTask; + MultiplexingStream.Channel localChannel = await localChannelTask; + + // Dispose the local channel and then complete the writer that *we* own later with an error. + localChannel.Dispose(); + await localChannel.Output.CompleteAsync(new InvalidOperationException("Complete after dispose")); + + // Ensure that the local channel completed without error (because we disposed before faulting the PipeWriter). + await VerifyChannelCompleted(localChannel, null); + + // Ensure that the remote channel similarly did not receive notice of any fault. + await VerifyChannelCompleted(remoteChannel, null); + } + [Fact] public async Task Disposal_DisposesTransportStream() { @@ -220,7 +308,8 @@ public async Task Dispose_DisposesChannels() (MultiplexingStream.Channel channel1, MultiplexingStream.Channel channel2) = await this.EstablishChannelsAsync("A"); await this.mx1.DisposeAsync(); Assert.True(channel1.IsDisposed); - await channel1.Completion.WithCancellation(this.TimeoutToken); + await VerifyChannelCompleted(channel1, new ObjectDisposedException(nameof(MultiplexingStream)).Message); + #pragma warning disable CS0618 // Type or member is obsolete await channel1.Input.WaitForWriterCompletionAsync().WithCancellation(this.TimeoutToken); await channel1.Output.WaitForReaderCompletionAsync().WithCancellation(this.TimeoutToken); @@ -1179,6 +1268,19 @@ public async Task FaultingChannelReader() await this.ReadAtLeastAsync(mx2Baseline.Input, 3); } + protected static async Task VerifyChannelCompleted(MultiplexingStream.Channel channel, string? expectedErrMsg) + { + if (expectedErrMsg != null) + { + Exception completionException = await Assert.ThrowsAnyAsync(() => channel.Completion); + Assert.Equal(expectedErrMsg, completionException.Message); + } + else + { + await channel.Completion; + } + } + protected static Task CompleteChannelsAsync(params MultiplexingStream.Channel[] channels) { foreach (MultiplexingStream.Channel? channel in channels)