From 7260c7955d2c6c5a64770581d4fdb2494a3c4c2c Mon Sep 17 00:00:00 2001 From: Jose Date: Mon, 27 Nov 2023 23:29:37 +0100 Subject: [PATCH] Add tests for streaming messages on background (#3797) --- src/IceRpc.Protobuf/InvokerExtensions.cs | 9 +- tests/IceRpc.Protobuf.Tests/OperationTests.cs | 4 +- tests/IceRpc.Protobuf.Tests/StreamTests.cs | 419 ++++++++++++++++++ 3 files changed, 424 insertions(+), 8 deletions(-) diff --git a/src/IceRpc.Protobuf/InvokerExtensions.cs b/src/IceRpc.Protobuf/InvokerExtensions.cs index 87d6a0d8d..19771d7e3 100644 --- a/src/IceRpc.Protobuf/InvokerExtensions.cs +++ b/src/IceRpc.Protobuf/InvokerExtensions.cs @@ -162,7 +162,7 @@ public static Task> InvokeServerStreamingAsyncSends a request to a service and decodes the response. This method is for Protobuf bidi-streaming @@ -213,7 +213,7 @@ public static Task> InvokeBidiStreamingAsync ReceiveResponseAsync( @@ -251,8 +251,7 @@ private static async Task ReceiveResponseAsync( private static async Task> ReceiveStreamingResponseAsync( MessageParser messageParser, Task responseTask, - OutgoingRequest request, - CancellationToken cancellationToken) where TOutput : IMessage + OutgoingRequest request) where TOutput : IMessage { try { @@ -264,7 +263,7 @@ private static async Task> ReceiveStreamingResponseAsy return payload.ToAsyncEnumerable( messageParser, protobufFeature.MaxMessageLength, - cancellationToken); + CancellationToken.None); } else { diff --git a/tests/IceRpc.Protobuf.Tests/OperationTests.cs b/tests/IceRpc.Protobuf.Tests/OperationTests.cs index 4e54c9a17..4ad3e0181 100644 --- a/tests/IceRpc.Protobuf.Tests/OperationTests.cs +++ b/tests/IceRpc.Protobuf.Tests/OperationTests.cs @@ -4,8 +4,6 @@ using IceRpc.Features; using IceRpc.Tests.Common; using NUnit.Framework; -using System.IO.Pipelines; -using static System.Net.Mime.MediaTypeNames; namespace IceRpc.Protobuf.Tests; @@ -436,7 +434,7 @@ public void Server_streaming_rpc_completes_request_and_response_payloads_upon_fa var client = new MyOperationsClient(pipeline); - // Act + // Act/Assert Assert.ThrowsAsync(async () => await client.ServerStreamingOpAsync(new Empty())); Assert.That(() => requestPayload!.Completed, Throws.Nothing); Assert.That(() => responsePayload!.Completed, Throws.Nothing); diff --git a/tests/IceRpc.Protobuf.Tests/StreamTests.cs b/tests/IceRpc.Protobuf.Tests/StreamTests.cs index 22ef11fdb..814aac1d7 100644 --- a/tests/IceRpc.Protobuf.Tests/StreamTests.cs +++ b/tests/IceRpc.Protobuf.Tests/StreamTests.cs @@ -1,5 +1,6 @@ // Copyright (c) ZeroC, Inc. +using Google.Protobuf.WellKnownTypes; using IceRpc.Protobuf.Internal; using IceRpc.Tests.Common; using NUnit.Framework; @@ -389,6 +390,424 @@ static async Task EncodeDataAsync(PipeWriter writer) } } + /// Ensures that the async enumerable stream provided to DispatchClientStreamingAsync can be consumed + /// after the dispatch returns and the cancellation token provided to dispatch has been canceled. + [Test] + public async Task Dispatch_client_streaming_rpc_continues_in_the_background() + { + // Arrange + using var request = new IncomingRequest(Protocol.IceRpc, FakeConnectionContext.Instance) + { + Payload = GetDataAsync().ToPipeReader() + }; + + using var cancellationTokenSource = new CancellationTokenSource(); + var completionSource = new TaskCompletionSource(); + var messages = new List(); + Task? streamTask = null; + + // Act + await request.DispatchClientStreamingAsync( + InputMessage.Parser, + this, + (self, stream, features, cancellationToken) => + { + streamTask = Task.Run( + async () => + { + await completionSource.Task; + await foreach (var message in stream) + { + messages.Add(message); + } + }, + CancellationToken.None); + return new ValueTask(new Empty()); + }, + cancellationTokenSource.Token); + cancellationTokenSource.Cancel(); + + // Assert + Assert.That(messages, Is.Empty); + Assert.That(streamTask!.IsCompleted, Is.False); + completionSource.SetResult(); + await streamTask; + Assert.That(messages, Has.Count.EqualTo(10)); + + static async IAsyncEnumerable GetDataAsync() + { + await Task.Yield(); + for (int i = 0; i < 10; i++) + { + await Task.Yield(); + yield return new InputMessage() + { + P1 = $"P{i}", + P2 = i, + }; + } + } + } + + /// Ensures that the async enumerable stream returned by DispatchServerStreamingAsync can be consumed + /// after the dispatch returns and the cancellation token provided to dispatch has been canceled. + [Test] + public async Task Dispatch_server_streaming_continues_in_the_background() + { + // Arrange + using var request = new IncomingRequest(Protocol.IceRpc, FakeConnectionContext.Instance) + { + Payload = new Empty().EncodeAsLengthPrefixedMessage(new PipeOptions()) + }; + + using var cancellationTokenSource = new CancellationTokenSource(); + var completionSource = new TaskCompletionSource(); + var messages = new List(); + Task? streamTask = null; + + // Act + var response = await request.DispatchServerStreamingAsync( + Empty.Parser, + this, + (self, empty, features, cancellationToken) => + new ValueTask>(GetDataAsync()), + cancellationTokenSource.Token); + cancellationTokenSource.Cancel(); + + // Assert + Assert.That(response.PayloadContinuation, Is.Not.Null); + streamTask = ConsumeDataAsync(response.PayloadContinuation); + Assert.That(messages, Is.Empty); + Assert.That(streamTask!.IsCompleted, Is.False); + completionSource.SetResult(); + await streamTask; + Assert.That(messages, Has.Count.EqualTo(10)); + + async IAsyncEnumerable GetDataAsync() + { + await completionSource.Task; + for (int i = 0; i < 10; i++) + { + yield return new OutputMessage() + { + P1 = $"P{i}", + P2 = i, + }; + } + } + + async Task ConsumeDataAsync(PipeReader payload) + { + IAsyncEnumerable stream = payload.ToAsyncEnumerable( + OutputMessage.Parser, + ProtobufFeature.Default.MaxMessageLength, + default); + + await foreach (var message in stream) + { + messages.Add(message); + } + } + } + + /// Ensures that the async enumerable streams provided to and returned by DispatchBidiStreamingAsync can be + /// consumed after the dispatch returns and the cancellation token provided to dispatch has been canceled. + [Test] + public async Task Dispatch_bidi_streaming_continues_in_the_background() + { + // Arrange + var completionSource = new TaskCompletionSource(); + using var request = new IncomingRequest(Protocol.IceRpc, FakeConnectionContext.Instance) + { + Payload = GetInputDataAsync().ToPipeReader() + }; + + using var cancellationTokenSource = new CancellationTokenSource(); + var inputMessages = new List(); + var outputMessages = new List(); + Task? clientStreamTask = null; + Task? serverStreamTask = null; + + // Act + var response = await request.DispatchBidiStreamingAsync( + InputMessage.Parser, + this, + (self, stream, features, cancellationToken) => + { + clientStreamTask = Task.Run( + async () => + { + await completionSource.Task; + await foreach (var message in stream) + { + inputMessages.Add(message); + } + }, + CancellationToken.None); + return new ValueTask>(GetOutputDataAsync()); + }, + cancellationTokenSource.Token); + cancellationTokenSource.Cancel(); + + // Assert + Assert.That(response.PayloadContinuation, Is.Not.Null); + serverStreamTask = ConsumeDataAsync(response.PayloadContinuation); + Assert.That(inputMessages, Is.Empty); + Assert.That(outputMessages, Is.Empty); + Assert.That(clientStreamTask!.IsCompleted, Is.False); + Assert.That(serverStreamTask!.IsCompleted, Is.False); + completionSource.SetResult(); + await clientStreamTask; + Assert.That(inputMessages, Has.Count.EqualTo(10)); + await serverStreamTask; + Assert.That(outputMessages, Has.Count.EqualTo(10)); + + async IAsyncEnumerable GetInputDataAsync() + { + await completionSource.Task; + for (int i = 0; i < 10; i++) + { + yield return new OutputMessage() + { + P1 = $"P{i}", + P2 = i, + }; + } + } + + async IAsyncEnumerable GetOutputDataAsync() + { + await completionSource.Task; + for (int i = 0; i < 10; i++) + { + yield return new OutputMessage() + { + P1 = $"P{i}", + P2 = i, + }; + } + } + + async Task ConsumeDataAsync(PipeReader payload) + { + IAsyncEnumerable stream = payload.ToAsyncEnumerable( + OutputMessage.Parser, + ProtobufFeature.Default.MaxMessageLength, + default); + + await foreach (var message in stream) + { + outputMessages.Add(message); + } + } + } + + /// Ensures that the async enumerable stream provided to InvokeClientStreamingAsync can be consumed after + /// the invocation returns and the cancellation token provided to the invocation has been canceled. + [Test] + public async Task Invoke_client_streaming_rpc_continues_in_the_background() + { + // Arrange + PipeReader? payloadContinuation = null; + using var cancellationTokenSource = new CancellationTokenSource(); + var completionSource = new TaskCompletionSource(); + + // Act + _ = await InvokerExtensions.InvokeClientStreamingAsync( + new InlineInvoker((request, cancellationToken) => + { + var response = new IncomingResponse( + new OutgoingRequest(request.ServiceAddress), + FakeConnectionContext.Instance) + { + Payload = new Empty().EncodeAsLengthPrefixedMessage(new PipeOptions()) + }; + payloadContinuation = request.PayloadContinuation; + request.PayloadContinuation = null; + return Task.FromResult(response); + }), + new ServiceAddress(Protocol.IceRpc), + "Op", + GetDataAsync(), + Empty.Parser, + cancellationToken: cancellationTokenSource.Token); + cancellationTokenSource.Cancel(); + + // Assert + completionSource.SetResult(); // Let streaming start + Assert.That(payloadContinuation, Is.Not.Null); + var messages = await ConsumeDataAsync(payloadContinuation); + Assert.That(messages, Has.Count.EqualTo(10)); + + static async Task> ConsumeDataAsync(PipeReader payload) + { + IAsyncEnumerable stream = payload.ToAsyncEnumerable( + InputMessage.Parser, + ProtobufFeature.Default.MaxMessageLength, + CancellationToken.None); + + var messages = new List(); + await foreach (var message in stream) + { + messages.Add(message); + } + return messages; + } + + async IAsyncEnumerable GetDataAsync() + { + await completionSource.Task; + for (int i = 0; i < 10; i++) + { + yield return new InputMessage() + { + P1 = $"P{i}", + P2 = i, + }; + } + } + } + + /// Ensures that the async enumerable stream returned by InvokeServerStreamingAsync can be consumed after + /// the invocation returns and the cancellation token provided to the invocation has been canceled. + [Test] + public async Task Invoke_server_streaming_rpc_continues_in_the_background() + { + // Arrange + PipeReader? payloadContinuation = null; + using var cancellationTokenSource = new CancellationTokenSource(); + var completionSource = new TaskCompletionSource(); + + // Act + IAsyncEnumerable stream = await InvokerExtensions.InvokeServerStreamingAsync( + new InlineInvoker((request, cancellationToken) => + { + var response = new IncomingResponse( + new OutgoingRequest(request.ServiceAddress), + FakeConnectionContext.Instance) + { + Payload = GetDataAsync().ToPipeReader() + }; + payloadContinuation = request.PayloadContinuation; + request.PayloadContinuation = null; + return Task.FromResult(response); + }), + new ServiceAddress(Protocol.IceRpc), + "Op", + new Empty(), + OutputMessage.Parser, + cancellationToken: cancellationTokenSource.Token); + cancellationTokenSource.Cancel(); + + // Assert + completionSource.SetResult(); // Let streaming start + var messages = await ConsumeDataAsync(stream); + Assert.That(messages, Has.Count.EqualTo(10)); + + static async Task> ConsumeDataAsync(IAsyncEnumerable stream) + { + var messages = new List(); + await foreach (var message in stream) + { + messages.Add(message); + } + return messages; + } + + async IAsyncEnumerable GetDataAsync() + { + await completionSource.Task; + for (int i = 0; i < 10; i++) + { + yield return new OutputMessage() + { + P1 = $"P{i}", + P2 = i, + }; + } + } + } + + /// Ensures that the async enumerable streams provided to and returned by InvokeBidiStreamingAsync can be + /// consumed after the invocation returns and the cancellation token provided to the invocation has been canceled. + /// + [Test] + public async Task Invoke_bidi_streaming_rpc_continues_in_the_background() + { + // Arrange + PipeReader? payloadContinuation = null; + using var cancellationTokenSource = new CancellationTokenSource(); + var completionSource = new TaskCompletionSource(); + + // Act + IAsyncEnumerable stream = await InvokerExtensions.InvokeBidiStreamingAsync( + new InlineInvoker((request, cancellationToken) => + { + var response = new IncomingResponse( + new OutgoingRequest(request.ServiceAddress), + FakeConnectionContext.Instance) + { + Payload = GetOutputDataAsync().ToPipeReader() + }; + payloadContinuation = request.PayloadContinuation; + request.PayloadContinuation = null; + return Task.FromResult(response); + }), + new ServiceAddress(Protocol.IceRpc), + "Op", + GetInputDataAsync(), + OutputMessage.Parser, + cancellationToken: cancellationTokenSource.Token); + cancellationTokenSource.Cancel(); + + // Assert + completionSource.SetResult(); // Let streaming start + var outputMessages = await ConsumeDataAsync(stream); + Assert.That(outputMessages, Has.Count.EqualTo(10)); + Assert.That(payloadContinuation, Is.Not.Null); + var inputMessages = await ConsumeDataAsync( + payloadContinuation.ToAsyncEnumerable( + InputMessage.Parser, + ProtobufFeature.Default.MaxMessageLength)); + Assert.That(inputMessages, Has.Count.EqualTo(10)); + + static async Task> ConsumeDataAsync(IAsyncEnumerable stream) + { + var messages = new List(); + await foreach (var message in stream) + { + messages.Add(message); + } + return messages; + } + + async IAsyncEnumerable GetOutputDataAsync() + { + await completionSource.Task; + for (int i = 0; i < 10; i++) + { + yield return new OutputMessage() + { + P1 = $"P{i}", + P2 = i, + }; + } + } + + async IAsyncEnumerable GetInputDataAsync() + { + await completionSource.Task; + for (int i = 0; i < 10; i++) + { + yield return new InputMessage() + { + P1 = $"P{i}", + P2 = i, + }; + } + } + } + #pragma warning disable CA1001 // _listener is disposed by Listen caller. private sealed class TestAsyncEnumerable : IAsyncEnumerable #pragma warning restore CA1001