diff --git a/sdk/core/Azure.Core.TestFramework/src/MockJsonModel.cs b/sdk/core/Azure.Core.TestFramework/src/MockJsonModel.cs index 2b2774b4b8efc..d683d195f0337 100644 --- a/sdk/core/Azure.Core.TestFramework/src/MockJsonModel.cs +++ b/sdk/core/Azure.Core.TestFramework/src/MockJsonModel.cs @@ -11,7 +11,7 @@ namespace Azure.Core.TestFramework { public class MockJsonModel : IJsonModel { - internal MockJsonModel() + public MockJsonModel() { } diff --git a/sdk/core/System.ClientModel/CHANGELOG.md b/sdk/core/System.ClientModel/CHANGELOG.md index 5e464475e950b..622aa9d4d41b9 100644 --- a/sdk/core/System.ClientModel/CHANGELOG.md +++ b/sdk/core/System.ClientModel/CHANGELOG.md @@ -5,9 +5,13 @@ ### Features Added - Added `BufferResponse` property to `RequestOptions` so protocol method callers can turn off response buffering if desired. +- Added `AsyncResultCollection` and `ResultCollection` for clients to return from service methods where the service response contains a collection of values. +- Added `SetRawResponse` method to `ClientResult` to allow the response held by the result to be changed, for example by derived types that obtain multiple responses from polling the service. ### Breaking Changes +- `ClientResult.GetRawResponse` will now throw `InvalidOperationException` if called before the result's raw response is set, for example by collection result types that delay sending a request to the service until the collection is enumerated. + ### Bugs Fixed ### Other Changes diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index 28d11b73ce27f..767ef0f2dfae5 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -7,6 +7,12 @@ public ApiKeyCredential(string key) { } public static implicit operator System.ClientModel.ApiKeyCredential (string key) { throw null; } public void Update(string key) { } } + public abstract partial class AsyncResultCollection : System.ClientModel.ClientResult, System.Collections.Generic.IAsyncEnumerable + { + protected internal AsyncResultCollection() { } + protected internal AsyncResultCollection(System.ClientModel.Primitives.PipelineResponse response) { } + public abstract System.Collections.Generic.IAsyncEnumerator GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); + } public abstract partial class BinaryContent : System.IDisposable { protected BinaryContent() { } @@ -20,11 +26,13 @@ protected BinaryContent() { } } public partial class ClientResult { + protected ClientResult() { } protected ClientResult(System.ClientModel.Primitives.PipelineResponse response) { } public static System.ClientModel.ClientResult FromOptionalValue(T? value, System.ClientModel.Primitives.PipelineResponse response) { throw null; } public static System.ClientModel.ClientResult FromResponse(System.ClientModel.Primitives.PipelineResponse response) { throw null; } public static System.ClientModel.ClientResult FromValue(T value, System.ClientModel.Primitives.PipelineResponse response) { throw null; } public System.ClientModel.Primitives.PipelineResponse GetRawResponse() { throw null; } + protected void SetRawResponse(System.ClientModel.Primitives.PipelineResponse response) { } } public partial class ClientResultException : System.Exception { @@ -36,10 +44,17 @@ public ClientResultException(string message, System.ClientModel.Primitives.Pipel } public partial class ClientResult : System.ClientModel.ClientResult { - protected internal ClientResult(T value, System.ClientModel.Primitives.PipelineResponse response) : base (default(System.ClientModel.Primitives.PipelineResponse)) { } + protected internal ClientResult(T value, System.ClientModel.Primitives.PipelineResponse response) { } public virtual T Value { get { throw null; } } public static implicit operator T (System.ClientModel.ClientResult result) { throw null; } } + public abstract partial class ResultCollection : System.ClientModel.ClientResult, System.Collections.Generic.IEnumerable, System.Collections.IEnumerable + { + protected internal ResultCollection() { } + protected internal ResultCollection(System.ClientModel.Primitives.PipelineResponse response) { } + public abstract System.Collections.Generic.IEnumerator GetEnumerator(); + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; } + } } namespace System.ClientModel.Primitives { diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index 4270fc544902b..c3303e8be4d7e 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -7,6 +7,12 @@ public ApiKeyCredential(string key) { } public static implicit operator System.ClientModel.ApiKeyCredential (string key) { throw null; } public void Update(string key) { } } + public abstract partial class AsyncResultCollection : System.ClientModel.ClientResult, System.Collections.Generic.IAsyncEnumerable + { + protected internal AsyncResultCollection() { } + protected internal AsyncResultCollection(System.ClientModel.Primitives.PipelineResponse response) { } + public abstract System.Collections.Generic.IAsyncEnumerator GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)); + } public abstract partial class BinaryContent : System.IDisposable { protected BinaryContent() { } @@ -20,11 +26,13 @@ protected BinaryContent() { } } public partial class ClientResult { + protected ClientResult() { } protected ClientResult(System.ClientModel.Primitives.PipelineResponse response) { } public static System.ClientModel.ClientResult FromOptionalValue(T? value, System.ClientModel.Primitives.PipelineResponse response) { throw null; } public static System.ClientModel.ClientResult FromResponse(System.ClientModel.Primitives.PipelineResponse response) { throw null; } public static System.ClientModel.ClientResult FromValue(T value, System.ClientModel.Primitives.PipelineResponse response) { throw null; } public System.ClientModel.Primitives.PipelineResponse GetRawResponse() { throw null; } + protected void SetRawResponse(System.ClientModel.Primitives.PipelineResponse response) { } } public partial class ClientResultException : System.Exception { @@ -36,10 +44,17 @@ public ClientResultException(string message, System.ClientModel.Primitives.Pipel } public partial class ClientResult : System.ClientModel.ClientResult { - protected internal ClientResult(T value, System.ClientModel.Primitives.PipelineResponse response) : base (default(System.ClientModel.Primitives.PipelineResponse)) { } + protected internal ClientResult(T value, System.ClientModel.Primitives.PipelineResponse response) { } public virtual T Value { get { throw null; } } public static implicit operator T (System.ClientModel.ClientResult result) { throw null; } } + public abstract partial class ResultCollection : System.ClientModel.ClientResult, System.Collections.Generic.IEnumerable, System.Collections.IEnumerable + { + protected internal ResultCollection() { } + protected internal ResultCollection(System.ClientModel.Primitives.PipelineResponse response) { } + public abstract System.Collections.Generic.IEnumerator GetEnumerator(); + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; } + } } namespace System.ClientModel.Primitives { diff --git a/sdk/core/System.ClientModel/src/Convenience/AsyncResultCollectionOfT.cs b/sdk/core/System.ClientModel/src/Convenience/AsyncResultCollectionOfT.cs new file mode 100644 index 0000000000000..da0f411799f0e --- /dev/null +++ b/sdk/core/System.ClientModel/src/Convenience/AsyncResultCollectionOfT.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Threading; + +namespace System.ClientModel; + +/// +/// Represents a collection of results returned from a cloud service operation. +/// +public abstract class AsyncResultCollection : ClientResult, IAsyncEnumerable +{ + /// + /// Create a new instance of . + /// + /// If no is provided when the + /// instance is created, it is expected that + /// a derived type will call + /// prior to a user calling . + /// This constructor is indended for use by collection implementations that + /// postpone sending a request until + /// is called. Such implementations will typically be returned from client + /// convenience methods so that callers of the methods don't need to + /// dispose the return value. + protected internal AsyncResultCollection() : base() + { + } + + /// + /// Create a new instance of . + /// + /// The holding the + /// items in the collection, or the first set of the items in the collection. + /// + protected internal AsyncResultCollection(PipelineResponse response) : base(response) + { + } + + /// + public abstract IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default); +} diff --git a/sdk/core/System.ClientModel/src/Convenience/ClientResult.cs b/sdk/core/System.ClientModel/src/Convenience/ClientResult.cs index 5f60f040fa3f6..7205c9165fdbe 100644 --- a/sdk/core/System.ClientModel/src/Convenience/ClientResult.cs +++ b/sdk/core/System.ClientModel/src/Convenience/ClientResult.cs @@ -11,7 +11,18 @@ namespace System.ClientModel; /// public class ClientResult { - private readonly PipelineResponse _response; + private PipelineResponse? _response; + + /// + /// Create a new instance of . + /// + /// If no is provided when the + /// instance is created, it is expected that + /// a derived type will call + /// prior to a user calling . + protected ClientResult() + { + } /// /// Create a new instance of from a service @@ -31,7 +42,39 @@ protected ClientResult(PipelineResponse response) /// /// the received from the service. /// - public PipelineResponse GetRawResponse() => _response; + /// No + /// value is currently available for this + /// instance. This can happen when the instance + /// is a collection type like + /// that has not yet been enumerated. + public PipelineResponse GetRawResponse() + { + if (_response is null) + { + throw new InvalidOperationException("No response is associated " + + "with this result. If the result is a collection result " + + "type, this may be because no request has been sent to the " + + "server yet."); + } + + return _response; + } + + /// + /// Update the value returned from . + /// + /// This method may be called from types derived from + /// that poll the service for status updates + /// or to retrieve additional collection values to update the raw response + /// to the response most recently returned from the service. + /// The to return + /// from . + protected void SetRawResponse(PipelineResponse response) + { + Argument.AssertNotNull(response, nameof(response)); + + _response = response; + } #region Factory methods for ClientResult and subtypes @@ -44,7 +87,11 @@ protected ClientResult(PipelineResponse response) /// provided . /// public static ClientResult FromResponse(PipelineResponse response) - => new ClientResult(response); + { + Argument.AssertNotNull(response, nameof(response)); + + return new ClientResult(response); + } /// /// Creates a new instance of that holds the @@ -60,6 +107,8 @@ public static ClientResult FromResponse(PipelineResponse response) /// public static ClientResult FromValue(T value, PipelineResponse response) { + Argument.AssertNotNull(response, nameof(response)); + if (value is null) { string message = "ClientResult contract guarantees that ClientResult.Value is non-null. " + @@ -90,7 +139,11 @@ public static ClientResult FromValue(T value, PipelineResponse response) /// provided and . /// public static ClientResult FromOptionalValue(T? value, PipelineResponse response) - => new ClientResult(value, response); + { + Argument.AssertNotNull(response, nameof(response)); + + return new ClientResult(value, response); + } #endregion } diff --git a/sdk/core/System.ClientModel/src/Convenience/ResultCollectionOfT.cs b/sdk/core/System.ClientModel/src/Convenience/ResultCollectionOfT.cs new file mode 100644 index 0000000000000..5943cc8438f95 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Convenience/ResultCollectionOfT.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Collections; +using System.Collections.Generic; + +namespace System.ClientModel; + +/// +/// Represents a collection of results returned from a cloud service operation. +/// +public abstract class ResultCollection : ClientResult, IEnumerable +{ + /// + /// Create a new instance of . + /// + /// If no is provided when the + /// instance is created, it is expected that + /// a derived type will call + /// prior to a user calling . + /// This constructor is indended for use by collection implementations that + /// postpone sending a request until + /// is called. Such implementations will typically be returned from client + /// convenience methods so that callers of the methods don't need to + /// dispose the return value. + protected internal ResultCollection() : base() + { + } + + /// + /// Create a new instance of . + /// + /// The holding the + /// items in the collection, or the first set of the items in the collection. + /// + protected internal ResultCollection(PipelineResponse response) : base(response) + { + } + + /// + public abstract IEnumerator GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); +} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/AsyncServerSentEventEnumerable.cs b/sdk/core/System.ClientModel/src/Internal/SSE/AsyncServerSentEventEnumerable.cs new file mode 100644 index 0000000000000..e7efca1805b1d --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/SSE/AsyncServerSentEventEnumerable.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.ClientModel.Internal; + +/// +/// Represents a collection of SSE events that can be enumerated as a C# async stream. +/// +internal class AsyncServerSentEventEnumerable : IAsyncEnumerable +{ + private readonly Stream _contentStream; + + public AsyncServerSentEventEnumerable(Stream contentStream) + { + Argument.AssertNotNull(contentStream, nameof(contentStream)); + + _contentStream = contentStream; + + LastEventId = string.Empty; + ReconnectionInterval = Timeout.InfiniteTimeSpan; + } + + public string LastEventId { get; private set; } + + public TimeSpan ReconnectionInterval { get; private set; } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new AsyncServerSentEventEnumerator(_contentStream, this, cancellationToken); + } + + private sealed class AsyncServerSentEventEnumerator : IAsyncEnumerator + { + private readonly ServerSentEventReader _reader; + private readonly AsyncServerSentEventEnumerable _enumerable; + private readonly CancellationToken _cancellationToken; + + public ServerSentEvent Current { get; private set; } + + public AsyncServerSentEventEnumerator(Stream contentStream, + AsyncServerSentEventEnumerable enumerable, + CancellationToken cancellationToken = default) + { + _reader = new(contentStream); + _enumerable = enumerable; + _cancellationToken = cancellationToken; + } + + public async ValueTask MoveNextAsync() + { + ServerSentEvent? nextEvent = await _reader.TryGetNextEventAsync(_cancellationToken).ConfigureAwait(false); + _enumerable.LastEventId = _reader.LastEventId; + _enumerable.ReconnectionInterval = _reader.ReconnectionInterval; + + if (nextEvent.HasValue) + { + Current = nextEvent.Value; + return true; + } + + Current = default; + return false; + } + + public ValueTask DisposeAsync() + { + // The creator of the enumerable has responsibility for disposing + // the content stream passed to the enumerable constructor. + +#if NET6_0_OR_GREATER + return ValueTask.CompletedTask; +#else + return new ValueTask(); +#endif + } + } +} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEvent.cs b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEvent.cs new file mode 100644 index 0000000000000..f962dd2bac4b8 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEvent.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace System.ClientModel.Internal; + +/// +/// Represents an SSE event. +/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html +/// +internal readonly struct ServerSentEvent +{ + // Gets the value of the SSE "event type" buffer, used to distinguish + // between event kinds. + public string EventType { get; } + + // Gets the value of the SSE "data" buffer, which holds the payload of the + // server-sent event. + public string Data { get; } + + public ServerSentEvent(string type, string data) + { + EventType = type; + Data = data; + } +} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventEnumerable.cs b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventEnumerable.cs new file mode 100644 index 0000000000000..8c0ebca656813 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventEnumerable.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Threading; + +namespace System.ClientModel.Internal; + +/// +/// Represents a collection of SSE events that can be enumerated as a C# collection. +/// +internal class ServerSentEventEnumerable : IEnumerable +{ + private readonly Stream _contentStream; + + public ServerSentEventEnumerable(Stream contentStream) + { + Argument.AssertNotNull(contentStream, nameof(contentStream)); + + _contentStream = contentStream; + + LastEventId = string.Empty; + ReconnectionInterval = Timeout.InfiniteTimeSpan; + } + + public string LastEventId { get; private set; } + + public TimeSpan ReconnectionInterval { get; private set; } + + public IEnumerator GetEnumerator() + { + return new ServerSentEventEnumerator(_contentStream, this); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + private sealed class ServerSentEventEnumerator : IEnumerator + { + private readonly ServerSentEventReader _reader; + private readonly ServerSentEventEnumerable _enumerable; + + public ServerSentEventEnumerator(Stream contentStream, ServerSentEventEnumerable enumerable) + { + _reader = new(contentStream); + _enumerable = enumerable; + } + + public ServerSentEvent Current { get; private set; } + + object IEnumerator.Current => Current; + + public bool MoveNext() + { + ServerSentEvent? nextEvent = _reader.TryGetNextEvent(); + _enumerable.LastEventId = _reader.LastEventId; + _enumerable.ReconnectionInterval= _reader.ReconnectionInterval; + + if (nextEvent.HasValue) + { + Current = nextEvent.Value; + return true; + } + + Current = default; + return false; + } + + public void Reset() + { + throw new NotSupportedException("Cannot seek back in an SSE stream."); + } + + public void Dispose() + { + // The creator of the enumerable has responsibility for disposing + // the content stream passed to the enumerable constructor. + } + } +} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventField.cs b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventField.cs new file mode 100644 index 0000000000000..eaf72fb5121a4 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventField.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace System.ClientModel.Internal; + +/// +/// Represents a field that can be composed into an SSE event. +/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html +/// +internal readonly struct ServerSentEventField +{ + private static readonly ReadOnlyMemory s_eventFieldName = "event".AsMemory(); + private static readonly ReadOnlyMemory s_dataFieldName = "data".AsMemory(); + private static readonly ReadOnlyMemory s_lastEventIdFieldName = "id".AsMemory(); + private static readonly ReadOnlyMemory s_retryFieldName = "retry".AsMemory(); + + public ServerSentEventFieldKind FieldType { get; } + + // Note: we don't plan to expose UTF16 publicly + public ReadOnlyMemory Value { get; } + + internal ServerSentEventField(string line) + { + int colonIndex = line.AsSpan().IndexOf(':'); + + ReadOnlyMemory fieldName = colonIndex < 0 ? + line.AsMemory() : + line.AsMemory(0, colonIndex); + + FieldType = fieldName.Span switch + { + var x when x.SequenceEqual(s_eventFieldName.Span) => ServerSentEventFieldKind.Event, + var x when x.SequenceEqual(s_dataFieldName.Span) => ServerSentEventFieldKind.Data, + var x when x.SequenceEqual(s_lastEventIdFieldName.Span) => ServerSentEventFieldKind.Id, + var x when x.SequenceEqual(s_retryFieldName.Span) => ServerSentEventFieldKind.Retry, + _ => ServerSentEventFieldKind.Ignore, + }; + + if (colonIndex < 0) + { + Value = ReadOnlyMemory.Empty; + } + else + { + Value = line.AsMemory(colonIndex + 1); + + // Per spec, remove a leading space if present. + if (Value.Length > 0 && Value.Span[0] == ' ') + { + Value = Value.Slice(1); + } + } + } +} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventFieldKind.cs b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventFieldKind.cs new file mode 100644 index 0000000000000..3ddc00aff270c --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventFieldKind.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace System.ClientModel.Internal; + +/// +/// The kind of line or field received over an SSE stream. +/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html +/// +internal enum ServerSentEventFieldKind +{ + Ignore, + Event, + Data, + Id, + Retry, +} diff --git a/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventReader.cs b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventReader.cs new file mode 100644 index 0000000000000..e1e881743533c --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/SSE/ServerSentEventReader.cs @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.ClientModel.Internal; + +/// +/// An SSE event reader that reads lines from an SSE stream and composes them +/// into SSE events. +/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html +/// +internal sealed class ServerSentEventReader +{ + private readonly StreamReader _reader; + + public ServerSentEventReader(Stream stream) + { + Argument.AssertNotNull(stream, nameof(stream)); + + // The creator of the reader has responsibility for disposing the + // stream passed to the reader's constructor. + _reader = new StreamReader(stream); + + LastEventId = string.Empty; + ReconnectionInterval = Timeout.InfiniteTimeSpan; + } + + public string LastEventId { get; private set; } + + public TimeSpan ReconnectionInterval { get; private set; } + + /// + /// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// + public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default) + { + PendingEvent pending = default; + while (true) + { + cancellationToken.ThrowIfCancellationRequested(); + + // Note: would be nice to have polyfill that takes cancellation token, + // but may become moot if we shift to all UTF-8. + string? line = _reader.ReadLine(); + + if (line is null) + { + // A null line indicates end of input + return null; + } + + ProcessLine(line, ref pending, out bool dispatch); + + if (dispatch) + { + return pending.ToEvent(); + } + } + } + + /// + /// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// + public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) + { + PendingEvent pending = default; + while (true) + { + cancellationToken.ThrowIfCancellationRequested(); + + // Note: would be nice to have polyfill that takes cancellation token, + // but may become moot if we shift to all UTF-8. + string? line = await _reader.ReadLineAsync().ConfigureAwait(false); + + if (line is null) + { + // A null line indicates end of input + return null; + } + + ProcessLine(line, ref pending, out bool dispatch); + + if (dispatch) + { + return pending.ToEvent(); + } + } + } + + private void ProcessLine(string line, ref PendingEvent pending, out bool dispatch) + { + dispatch = false; + + if (line.Length == 0) + { + if (pending.DataLength == 0) + { + // Per spec, if there's no data, don't dispatch an event. + pending = default; + } + else + { + dispatch = true; + } + } + else if (line[0] != ':') + { + // Per spec, ignore comment lines (i.e. that begin with ':'). + // If we got this far, process the field + value and accumulate + // it for the next dispatched event. + ServerSentEventField field = new(line); + switch (field.FieldType) + { + case ServerSentEventFieldKind.Event: + pending.EventTypeField = field; + break; + case ServerSentEventFieldKind.Data: + // Per spec, we'll append \n when we concatenate the data lines. + pending.DataLength += field.Value.Length + 1; + pending.DataFields.Add(field); + break; + case ServerSentEventFieldKind.Id: + LastEventId = field.Value.ToString(); + break; + case ServerSentEventFieldKind.Retry: + if (field.Value.Length > 0 && int.TryParse(field.Value.ToString(), out int retry)) + { + ReconnectionInterval = TimeSpan.FromMilliseconds(retry); + } + break; + default: + // Ignore + break; + } + } + } + + private struct PendingEvent + { + private const char LF = '\n'; + + private List? _dataFields; + + public int DataLength { get; set; } + public List DataFields => _dataFields ??= new(); + public ServerSentEventField? EventTypeField { get; set; } + + public ServerSentEvent ToEvent() + { + Debug.Assert(DataLength > 0); + + // Per spec, if event type buffer is empty, set event.type to "message". + string type = EventTypeField.HasValue ? + EventTypeField.Value.Value.ToString() : + "message"; + + Memory buffer = new(new char[DataLength]); + + int curr = 0; + foreach (ServerSentEventField field in DataFields) + { + Debug.Assert(field.FieldType == ServerSentEventFieldKind.Data); + + field.Value.Span.CopyTo(buffer.Span.Slice(curr)); + + // Per spec, append trailing LF to each data field value. + buffer.Span[curr + field.Value.Length] = LF; + curr += field.Value.Length + 1; + } + + // Per spec, remove trailing LF from concatenated data fields. + string data = buffer.Slice(0, buffer.Length - 1).ToString(); + + return new ServerSentEvent(type, data); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/Convenience/ClientResultTests.cs b/sdk/core/System.ClientModel/tests/Convenience/ClientResultTests.cs index ccad66170e579..629f1ec27f2b1 100644 --- a/sdk/core/System.ClientModel/tests/Convenience/ClientResultTests.cs +++ b/sdk/core/System.ClientModel/tests/Convenience/ClientResultTests.cs @@ -15,7 +15,6 @@ public class PipelineResponseTests [Test] public void CannotCreateClientResultFromNullResponse() { - Assert.Throws(() => new MockClientResult(null!)); Assert.Throws(() => { ClientResult result = ClientResult.FromResponse(null!); @@ -98,7 +97,6 @@ public void CannotCreateClientResultOfTFromNullResponse() { object value = new(); - Assert.Throws(() => new MockClientResult(value, null!)); Assert.Throws(() => { ClientResult result = ClientResult.FromValue(value, null!); diff --git a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/AsyncServerSentEventEnumerableTests.cs b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/AsyncServerSentEventEnumerableTests.cs new file mode 100644 index 0000000000000..fe511057338c4 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/AsyncServerSentEventEnumerableTests.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using ClientModel.Tests.Internal.Mocks; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Convenience; + +public class AsyncServerSentEventEnumerableTests +{ + [Test] + public async Task EnumeratesEvents() + { + using Stream contentStream = BinaryData.FromString(MockSseClient.DefaultMockContent).ToStream(); + AsyncServerSentEventEnumerable enumerable = new(contentStream); + + List events = new(); + + await foreach (ServerSentEvent sse in enumerable) + { + events.Add(sse); + } + + Assert.AreEqual(4, events.Count); + + for (int i = 0; i < 3; i++) + { + Assert.AreEqual($"event.{i}", events[i].EventType); + Assert.AreEqual($"{{ \"IntValue\": {i}, \"StringValue\": \"{i}\" }}", events[i].Data); + } + } + + [Test] + public void ThrowsIfCancelled() + { + CancellationToken token = new(true); + + using Stream contentStream = BinaryData.FromString(MockSseClient.DefaultMockContent).ToStream(); + AsyncServerSentEventEnumerable enumerable = new(contentStream); + IAsyncEnumerator enumerator = enumerable.GetAsyncEnumerator(token); + + Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ClientResultCollectionTests.cs b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ClientResultCollectionTests.cs new file mode 100644 index 0000000000000..8e2fb8d3095e3 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ClientResultCollectionTests.cs @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests.Internal.Mocks; +using NUnit.Framework; +using SyncAsyncTestBase = ClientModel.Tests.SyncAsyncTestBase; + +namespace System.ClientModel.Tests.Convenience; + +public class ClientResultCollectionTests : SyncAsyncTestBase +{ + public ClientResultCollectionTests(bool isAsync) : base(isAsync) + { + } + + [Test] + public async Task EnumeratesModelValues() + { + MockSseClient client = new(); + AsyncResultCollection models = client.GetModelsStreamingAsync(); + + int i = 0; + await foreach (MockJsonModel model in models) + { + Assert.AreEqual(i, model.IntValue); + Assert.AreEqual(i.ToString(), model.StringValue); + + i++; + } + + Assert.AreEqual(i, 3); + } + + [Test] + public async Task ModelCollectionDelaysSendingRequest() + { + MockSseClient client = new(); + AsyncResultCollection models = client.GetModelsStreamingAsync(); + + Assert.IsFalse(client.ProtocolMethodCalled); + + int i = 0; + await foreach (MockJsonModel model in models) + { + Assert.AreEqual(i, model.IntValue); + Assert.AreEqual(i.ToString(), model.StringValue); + + i++; + } + + Assert.AreEqual(3, i); + Assert.IsTrue(client.ProtocolMethodCalled); + } + + [Test] + public void ModelCollectionThrowsIfCancelled() + { + MockSseClient client = new(); + AsyncResultCollection models = client.GetModelsStreamingAsync(); + + // Set it to `cancelled: true` to validate functionality. + CancellationToken token = new(true); + + Assert.ThrowsAsync(async () => + { + await foreach (MockJsonModel model in models.WithCancellation(token)) + { + } + }); + } + + [Test] + public async Task ModelCollectionDisposesStream() + { + MockSseClient client = new(); + AsyncResultCollection models = client.GetModelsStreamingAsync(); + + await foreach (MockJsonModel model in models) + { + } + + PipelineResponse response = models.GetRawResponse(); + Assert.Throws(() => { var p = response.ContentStream!.Position; }); + } + + [Test] + public void ModelCollectionGetRawResponseThrowsBeforeEnumerated() + { + MockSseClient client = new(); + AsyncResultCollection models = client.GetModelsStreamingAsync(); + Assert.Throws(() => { PipelineResponse response = models.GetRawResponse(); }); + } + + [Test] + public async Task StopsOnStringBasedTerminalEvent() + { + MockSseClient client = new(); + AsyncResultCollection models = client.GetModelsStreamingAsync("[DONE]"); + + bool empty = true; + await foreach (MockJsonModel model in models) + { + empty = false; + } + + Assert.IsNotNull(models); + Assert.AreEqual("[DONE]", models.GetRawResponse().Content.ToString()); + Assert.IsTrue(empty); + } + + [Test] + public async Task EnumeratesDataValues() + { + MockSseClient client = new(); + ClientResult result = client.GetModelsStreamingAsync(MockSseClient.DefaultMockContent, new RequestOptions()); + + int i = 0; + await foreach (BinaryData data in result.GetRawResponse().EnumerateDataEvents()) + { + MockJsonModel model = data.ToObjectFromJson(); + + Assert.AreEqual(i, model.IntValue); + Assert.AreEqual(i.ToString(), model.StringValue); + + i++; + } + + Assert.AreEqual(3, i); + } + + [Test] + public void DataCollectionThrowsIfCancelled() + { + MockSseClient client = new(); + ClientResult result = client.GetModelsStreamingAsync(MockSseClient.DefaultMockContent, new RequestOptions()); + + // Set it to `cancelled: true` to validate functionality. + CancellationToken token = new(true); + + Assert.ThrowsAsync(async () => + { + await foreach (BinaryData data in result.GetRawResponse().EnumerateDataEvents().WithCancellation(token)) + { + } + }); + } + + [Test] + public async Task DataCollectionDoesNotDisposeStream() + { + MockSseClient client = new(); + ClientResult result = client.GetModelsStreamingAsync(MockSseClient.DefaultMockContent, new RequestOptions()); + + await foreach (BinaryData data in result.GetRawResponse().EnumerateDataEvents()) + { + } + + Assert.DoesNotThrow(() => { var p = result.GetRawResponse().ContentStream!.Position; }); + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventEnumerableTests.cs b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventEnumerableTests.cs new file mode 100644 index 0000000000000..d32835fa7486e --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventEnumerableTests.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.Collections.Generic; +using System.IO; +using ClientModel.Tests.Internal.Mocks; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Convenience; + +public class ServerSentEventEnumerableTests +{ + [Test] + public void EnumeratesEvents() + { + using Stream contentStream = BinaryData.FromString(MockSseClient.DefaultMockContent).ToStream(); + ServerSentEventEnumerable enumerable = new(contentStream); + + List events = new(); + + foreach (ServerSentEvent sse in enumerable) + { + events.Add(sse); + } + + Assert.AreEqual(4, events.Count); + + for (int i = 0; i < 3; i++) + { + Assert.AreEqual($"event.{i}", events[i].EventType); + Assert.AreEqual($"{{ \"IntValue\": {i}, \"StringValue\": \"{i}\" }}", events[i].Data); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventFieldTests.cs b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventFieldTests.cs new file mode 100644 index 0000000000000..8807aaa805b31 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventFieldTests.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Convenience; + +public class ServerSentEventFieldTests +{ + [Test] + public void ParsesEventField() + { + string line = "event: event.name"; + ServerSentEventField field = new(line); + + Assert.AreEqual(ServerSentEventFieldKind.Event, field.FieldType); + Assert.IsTrue("event.name".AsSpan().SequenceEqual(field.Value.Span)); + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventReaderTests.cs b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventReaderTests.cs new file mode 100644 index 0000000000000..67365117e576e --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Convenience/SSE/ServerSentEventReaderTests.cs @@ -0,0 +1,290 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using ClientModel.Tests.Internal.Mocks; +using NUnit.Framework; +using SyncAsyncTestBase = ClientModel.Tests.SyncAsyncTestBase; + +namespace System.ClientModel.Tests.Convenience; + +public class ServerSentEventReaderTests : SyncAsyncTestBase +{ + public ServerSentEventReaderTests(bool isAsync) : base(isAsync) + { + } + + [Test] + public async Task GetsEventsFromStream() + { + Stream contentStream = BinaryData.FromString(MockSseClient.DefaultMockContent).ToStream(); + ServerSentEventReader reader = new(contentStream); + + List events = new(); + ServerSentEvent? ssEvent = await reader.TryGetNextEventSyncOrAsync(IsAsync); + while (ssEvent is not null) + { + events.Add(ssEvent.Value); + ssEvent = await reader.TryGetNextEventSyncOrAsync(IsAsync); + } + + Assert.AreEqual(events.Count, 4); + + for (int i = 0; i < 3; i++) + { + ServerSentEvent sse = events[i]; + Assert.AreEqual($"event.{i}", sse.EventType); + Assert.AreEqual($"{{ \"IntValue\": {i}, \"StringValue\": \"{i}\" }}", sse.Data); + } + + Assert.AreEqual("done", events[3].EventType); + Assert.AreEqual("[DONE]", events[3].Data); + } + + [Test] + public async Task HandlesNullLine() + { + Stream contentStream = BinaryData.FromString(string.Empty).ToStream(); + ServerSentEventReader reader = new(contentStream); + + ServerSentEvent? ssEvent = await reader.TryGetNextEventSyncOrAsync(IsAsync); + Assert.IsNull(ssEvent); + } + + [Test] + public async Task DiscardsCommentLine() + { + Stream contentStream = BinaryData.FromString(": comment").ToStream(); + ServerSentEventReader reader = new(contentStream); + + ServerSentEvent? ssEvent = await reader.TryGetNextEventSyncOrAsync(IsAsync); + Assert.IsNull(ssEvent); + } + + [Test] + public async Task HandlesIgnoreLine() + { + Stream contentStream = BinaryData.FromString(""" + ignore: noop + + + """).ToStream(); + ServerSentEventReader reader = new(contentStream); + + ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + + Assert.IsNull(sse); + } + + [Test] + public async Task HandlesDoneEvent() + { + Stream contentStream = BinaryData.FromString("event: stop\ndata: ~stop~\n\n").ToStream(); + ServerSentEventReader reader = new(contentStream); + + ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + + Assert.IsNotNull(sse); + + Assert.AreEqual("stop", sse.Value.EventType); + Assert.AreEqual("~stop~", sse.Value.Data); + + Assert.AreEqual(string.Empty, reader.LastEventId); + Assert.AreEqual(Timeout.InfiniteTimeSpan, reader.ReconnectionInterval); + } + + [Test] + public async Task ConcatenatesDataLines() + { + Stream contentStream = BinaryData.FromString(""" + data: YHOO + data: +2 + data: 10 + + + """).ToStream(); + ServerSentEventReader reader = new(contentStream); + + ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + + Assert.IsNotNull(sse); + + Assert.AreEqual("YHOO\n+2\n10", sse.Value.Data); + + Assert.AreEqual(string.Empty, reader.LastEventId); + Assert.AreEqual(Timeout.InfiniteTimeSpan, reader.ReconnectionInterval); + } + + [Test] + public async Task DefaultsEventTypeToMessage() + { + Stream contentStream = BinaryData.FromString(""" + data: data + + + """).ToStream(); + ServerSentEventReader reader = new(contentStream); + + ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + + Assert.IsNotNull(sse); + + Assert.AreEqual("message", sse.Value.EventType); + Assert.AreEqual("data", sse.Value.Data); + } + + [Test] + public async Task SecondTestCaseFromSpec() + { + // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + Stream contentStream = BinaryData.FromString(""" + : test stream + + data: first event + id: 1 + + data:second event + id + + data: third event + + + """).ToStream(); + ServerSentEventReader reader = new(contentStream); + + List events = new(); + List ids = new(); + + ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + while (sse is not null) + { + events.Add(sse.Value); + ids.Add(reader.LastEventId.ToString()); + + sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + } + + Assert.AreEqual(3, events.Count); + + Assert.AreEqual("first event", events[0].Data); + Assert.AreEqual("1", ids[0]); + + Assert.AreEqual("second event", events[1].Data); + Assert.AreEqual(string.Empty, ids[1]); + + Assert.AreEqual(" third event", events[2].Data); + Assert.AreEqual(string.Empty, ids[2]); + } + + [Test] + public async Task ThirdSpecTestCase() + { + // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + Stream contentStream = BinaryData.FromString(""" + data + + data + data + + data: + """).ToStream(); + ServerSentEventReader reader = new(contentStream); + + List events = new(); + + ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + while (sse is not null) + { + events.Add(sse.Value); + sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + } + + Assert.AreEqual(2, events.Count); + Assert.AreEqual(0, events[0].Data.Length); + Assert.AreEqual("\n", events[1].Data); + } + + [Test] + public async Task FourthSpecTestCase() + { + // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + Stream contentStream = BinaryData.FromString(""" + data:test + + data: test + + + """).ToStream(); + ServerSentEventReader reader = new(contentStream); + + List events = new(); + + ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + while (sse is not null) + { + events.Add(sse.Value); + sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + } + + Assert.AreEqual(2, events.Count); + Assert.AreEqual(events[0].Data, events[1].Data); + } + + [Test] + public async Task SetsReconnectionInterval() + { + Stream contentStream = BinaryData.FromString(""" + data: test + + data: test + retry: 2500 + + data: test + retry: + + + """).ToStream(); + ServerSentEventReader reader = new(contentStream); + + List events = new(); + List retryValues = new(); + + ServerSentEvent? sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + while (sse is not null) + { + events.Add(sse.Value); + retryValues.Add(reader.ReconnectionInterval); + + sse = await reader.TryGetNextEventSyncOrAsync(IsAsync); + } + + Assert.AreEqual(3, events.Count); + + // Defaults to infinite timespan + Assert.AreEqual("test", events[0].Data); + Assert.AreEqual(Timeout.InfiniteTimeSpan, retryValues[0]); + + Assert.AreEqual("test", events[1].Data); + Assert.AreEqual(new TimeSpan(0, 0, 0, 2, 500), retryValues[1]); + + // Ignores invalid values + Assert.AreEqual("test", events[2].Data); + Assert.AreEqual(new TimeSpan(0, 0, 0, 2, 500), retryValues[2]); + } + + [Test] + public void ThrowsIfCancelled() + { + CancellationToken token = new(true); + + using Stream contentStream = BinaryData.FromString(MockSseClient.DefaultMockContent).ToStream(); + ServerSentEventReader reader = new(contentStream); + + Assert.ThrowsAsync(async () + => await reader.TryGetNextEventAsync(token)); + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClient.cs b/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClient.cs new file mode 100644 index 0000000000000..2113d2acb0474 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClient.cs @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.ClientModel.Internal; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests.Mocks; + +namespace ClientModel.Tests.Internal.Mocks; + +// Note: keeping this mock client used to illustrate SSE usage patterns in +// Tests.Internal for now as it needs access to internal types. Once we are +// able to port to a solution that uses the public BCL SseParser type, this +// will no longer be needed. +public class MockSseClient +{ + // Note: raw string literal removes \n from final line. + internal const string DefaultMockContent = """ + event: event.0 + data: { "IntValue": 0, "StringValue": "0" } + + event: event.1 + data: { "IntValue": 1, "StringValue": "1" } + + event: event.2 + data: { "IntValue": 2, "StringValue": "2" } + + event: done + data: [DONE] + + + """; + + public bool ProtocolMethodCalled { get; private set; } + + // mock convenience method + public virtual AsyncResultCollection GetModelsStreamingAsync(string content = DefaultMockContent) + { + return new AsyncMockJsonModelCollection(content, GetModelsStreamingAsync); + } + + // mock protocol method + public virtual ClientResult GetModelsStreamingAsync(string content, RequestOptions? options = default) + { + // This mocks sending a request and returns a respose containing + // the passed-in content in the content stream. + + MockPipelineResponse response = new(); + response.SetContent(content); + + ProtocolMethodCalled = true; + + return ClientResult.FromResponse(response); + } + + // Internal client implementation of convenience-layer AsyncResultCollection. + // This currently layers over an internal AsyncResultCollection + // representing the event.data values, but does not strictly have to. + private class AsyncMockJsonModelCollection : AsyncResultCollection + { + private readonly string _content; + private readonly Func _protocolMethod; + + public AsyncMockJsonModelCollection(string content, Func protocolMethod) + { + _content = content; + _protocolMethod = protocolMethod; + } + + public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + async Task getResultAsync() + { + await Task.Delay(0, cancellationToken); + return _protocolMethod(_content, /*options:*/ default); + } + + return new AsyncMockJsonModelEnumerator(getResultAsync, this, cancellationToken); + } + + private sealed class AsyncMockJsonModelEnumerator : IAsyncEnumerator + { + private const string _terminalData = "[DONE]"; + + private readonly Func> _getResultAsync; + private readonly AsyncMockJsonModelCollection _enumerable; + private readonly CancellationToken _cancellationToken; + + private IAsyncEnumerator? _events; + private MockJsonModel? _current; + + private bool _started; + + public AsyncMockJsonModelEnumerator(Func> getResultAsync, AsyncMockJsonModelCollection enumerable, CancellationToken cancellationToken) + { + Debug.Assert(getResultAsync is not null); + Debug.Assert(enumerable is not null); + + _getResultAsync = getResultAsync!; + _enumerable = enumerable!; + _cancellationToken = cancellationToken; + } + + MockJsonModel IAsyncEnumerator.Current + => _current!; + + async ValueTask IAsyncEnumerator.MoveNextAsync() + { + if (_events is null && _started) + { + throw new ObjectDisposedException(nameof(AsyncMockJsonModelEnumerator)); + } + + _cancellationToken.ThrowIfCancellationRequested(); + _events ??= await CreateEventEnumeratorAsync().ConfigureAwait(false); + _started = true; + + if (await _events.MoveNextAsync().ConfigureAwait(false)) + { + if (_events.Current.Data == _terminalData) + { + _current = default; + return false; + } + + BinaryData data = BinaryData.FromString(_events.Current.Data); + MockJsonModel model = ModelReaderWriter.Read(data) ?? + throw new JsonException($"Failed to deserialize expected type MockJsonModel from sse data payload '{_events.Current.Data}'."); + + _current = model; + return true; + } + + _current = default; + return false; + } + + private async Task> CreateEventEnumeratorAsync() + { + ClientResult result = await _getResultAsync().ConfigureAwait(false); + PipelineResponse response = result.GetRawResponse(); + _enumerable.SetRawResponse(response); + + if (response.ContentStream is null) + { + throw new ArgumentException("Unable to create result from response with null ContentStream", nameof(response)); + } + + AsyncServerSentEventEnumerable enumerable = new(response.ContentStream); + return enumerable.GetAsyncEnumerator(_cancellationToken); + } + + public async ValueTask DisposeAsync() + { + await DisposeAsyncCore().ConfigureAwait(false); + + GC.SuppressFinalize(this); + } + + private async ValueTask DisposeAsyncCore() + { + if (_events is not null) + { + // Disposing the sse enumerator should be a no-op. + await _events.DisposeAsync().ConfigureAwait(false); + _events = null; + + // But we also need to dispose the response content stream + // so we don't leave the unbuffered network stream open. + PipelineResponse response = _enumerable.GetRawResponse(); + + if (response.ContentStream is IAsyncDisposable asyncDisposable) + { + await asyncDisposable.DisposeAsync().ConfigureAwait(false); + } + else if (response.ContentStream is IDisposable disposable) + { + disposable.Dispose(); + } + } + } + } + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClientExtensions.cs b/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClientExtensions.cs new file mode 100644 index 0000000000000..00ebadbe6e60f --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSseClientExtensions.cs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel; +using System.ClientModel.Internal; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace ClientModel.Tests.Internal.Mocks; + +public static class MockSseClientExtensions +{ + public static AsyncResultCollection EnumerateDataEvents(this PipelineResponse response) + { + if (response.ContentStream is null) + { + throw new ArgumentException("Unable to create result collection from PipelineResponse with null ContentStream", nameof(response)); + } + + return new AsyncSseDataEventCollection(response, "[DONE]"); + } + + private class AsyncSseDataEventCollection : AsyncResultCollection + { + private readonly string _terminalData; + + public AsyncSseDataEventCollection(PipelineResponse response, string terminalData) : base(response) + { + Argument.AssertNotNull(response, nameof(response)); + + _terminalData = terminalData; + } + + public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + PipelineResponse response = GetRawResponse(); + + // We validate that response.ContentStream is non-null in outer extension method. + Debug.Assert(response.ContentStream is not null); + + return new AsyncSseDataEventEnumerator(response.ContentStream!, _terminalData, cancellationToken); + } + + private sealed class AsyncSseDataEventEnumerator : IAsyncEnumerator + { + private readonly string _terminalData; + + private IAsyncEnumerator? _events; + private BinaryData? _current; + + public BinaryData Current { get => _current!; } + + public AsyncSseDataEventEnumerator(Stream contentStream, string terminalData, CancellationToken cancellationToken) + { + Debug.Assert(contentStream is not null); + + AsyncServerSentEventEnumerable enumerable = new(contentStream!); + _events = enumerable.GetAsyncEnumerator(cancellationToken); + + _terminalData = terminalData; + } + + public async ValueTask MoveNextAsync() + { + if (_events is null) + { + throw new ObjectDisposedException(nameof(AsyncSseDataEventEnumerator)); + } + + if (await _events.MoveNextAsync().ConfigureAwait(false)) + { + if (_events.Current.Data == _terminalData) + { + _current = default; + return false; + } + + _current = BinaryData.FromString(_events.Current.Data); + return true; + } + + _current = default; + return false; + } + + public async ValueTask DisposeAsync() + { + await DisposeAsyncCore().ConfigureAwait(false); + + GC.SuppressFinalize(this); + } + + private async ValueTask DisposeAsyncCore() + { + if (_events is not null) + { + await _events.DisposeAsync().ConfigureAwait(false); + _events = null; + } + } + } + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSyncAsyncInternalExtensions.cs b/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSyncAsyncInternalExtensions.cs new file mode 100644 index 0000000000000..8ed090f78a226 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/TestFramework/Mocks/MockSyncAsyncInternalExtensions.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.Threading.Tasks; + +namespace ClientModel.Tests.Internal.Mocks; + +internal static class MockSyncAsyncInternalExtensions +{ + public static async Task TryGetNextEventSyncOrAsync(this ServerSentEventReader reader, bool isAsync) + { + if (isAsync) + { + return await reader.TryGetNextEventAsync().ConfigureAwait(false); + } + else + { + return reader.TryGetNextEvent(); + } + } +}