Skip to content

Commit

Permalink
Return SCM collection abstractions from Assistant streaming convenien…
Browse files Browse the repository at this point in the history
…ce methods (#104)

* Initial revisions; WIP

* add test that sketches out protocol SSE usage.

* updates

* initial rework

* tidy

* more reshuffle

* add SSE extension method stub to AOAI

* tidy

* WIP - intial impl of Async chat update coll

* extract response content if not buffered

* implement sync collection and don't dispose response if not buffered

* tidy

* fix

* Add AOAI streaming chat test

* add streaming chat test to AOAI

* add streaming chat test to AOAI

* add streaming chat test to AOAI

* rm AOAI bits

* updates from SCM

* remove temp workaround

* add mock test for tool call chat updates

* comment out strong-naming and internals test

* update to use alpha SCM package with collection types

* nits

* post merge conflict resolution

* WIP

* make it build

* remove temp methods

* revert chat bits from this PR

* revert more unneeded support files

* nits

* nits
  • Loading branch information
annelo-msft authored May 17, 2024
1 parent de80b8d commit aa7dee8
Show file tree
Hide file tree
Showing 14 changed files with 797 additions and 87 deletions.
12 changes: 6 additions & 6 deletions .dotnet/src/Custom/Assistants/AssistantClient.Convenience.cs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ public virtual ClientResult<ThreadRun> CreateRun(AssistantThread thread, Assista
/// <param name="thread"> The thread that the run should evaluate. </param>
/// <param name="assistant"> The assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> CreateRunStreamingAsync(
AssistantThread thread,
Assistant assistant,
RunCreationOptions options = null)
Expand All @@ -249,7 +249,7 @@ public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateRunSt
/// <param name="thread"> The thread that the run should evaluate. </param>
/// <param name="assistant"> The assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> CreateRunStreaming(
public virtual ResultCollection<StreamingUpdate> CreateRunStreaming(
AssistantThread thread,
Assistant assistant,
RunCreationOptions options = null)
Expand Down Expand Up @@ -287,7 +287,7 @@ public virtual ClientResult<ThreadRun> CreateThreadAndRun(
/// <param name="assistant"> The assistant that the new run should use. </param>
/// <param name="threadOptions"> Options for the new thread that will be created. </param>
/// <param name="runOptions"> Additional options to apply to the run that will begin. </param>
public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateThreadAndRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> CreateThreadAndRunStreamingAsync(
Assistant assistant,
ThreadCreationOptions threadOptions = null,
RunCreationOptions runOptions = null)
Expand All @@ -299,7 +299,7 @@ public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateThrea
/// <param name="assistant"> The assistant that the new run should use. </param>
/// <param name="threadOptions"> Options for the new thread that will be created. </param>
/// <param name="runOptions"> Additional options to apply to the run that will begin. </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> CreateThreadAndRunStreaming(
public virtual ResultCollection<StreamingUpdate> CreateThreadAndRunStreaming(
Assistant assistant,
ThreadCreationOptions threadOptions = null,
RunCreationOptions runOptions = null)
Expand Down Expand Up @@ -390,7 +390,7 @@ public virtual ClientResult<ThreadRun> SubmitToolOutputsToRun(
/// <param name="toolOutputs">
/// The tool outputs, corresponding to <see cref="InternalRequiredToolCall"/> instances from the run.
/// </param>
public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> SubmitToolOutputsToRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> SubmitToolOutputsToRunStreamingAsync(
ThreadRun run,
IEnumerable<ToolOutput> toolOutputs)
=> SubmitToolOutputsToRunStreamingAsync(run?.ThreadId, run?.Id, toolOutputs);
Expand All @@ -402,7 +402,7 @@ public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> SubmitToolO
/// <param name="toolOutputs">
/// The tool outputs, corresponding to <see cref="InternalRequiredToolCall"/> instances from the run.
/// </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> SubmitToolOutputsToRunStreaming(
public virtual ResultCollection<StreamingUpdate> SubmitToolOutputsToRunStreaming(
ThreadRun run,
IEnumerable<ToolOutput> toolOutputs)
=> SubmitToolOutputsToRunStreaming(run?.ThreadId, run?.Id, toolOutputs);
Expand Down
53 changes: 33 additions & 20 deletions .dotnet/src/Custom/Assistants/AssistantClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -491,21 +491,23 @@ public virtual ClientResult<ThreadRun> CreateRun(string threadId, string assista
/// <param name="threadId"> The ID of the thread that the run should evaluate. </param>
/// <param name="assistantId"> The ID of the assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> CreateRunStreamingAsync(
string threadId,
string assistantId,
RunCreationOptions options = null)
{
Argument.AssertNotNullOrEmpty(threadId, nameof(threadId));
Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId));

options ??= new();
options.AssistantId = assistantId;
options.Stream = true;

ClientResult protocolResult = await CreateRunAsync(threadId, options.ToBinaryContent(), StreamRequestOptions)
async Task<ClientResult> getResultAsync() =>
await CreateRunAsync(threadId, options.ToBinaryContent(), StreamRequestOptions)
.ConfigureAwait(false);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
return new AsyncStreamingUpdateCollection(getResultAsync);
}

/// <summary>
Expand All @@ -515,20 +517,21 @@ public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> Creat
/// <param name="threadId"> The ID of the thread that the run should evaluate. </param>
/// <param name="assistantId"> The ID of the assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> CreateRunStreaming(
public virtual ResultCollection<StreamingUpdate> CreateRunStreaming(
string threadId,
string assistantId,
RunCreationOptions options = null)
{
Argument.AssertNotNullOrEmpty(threadId, nameof(threadId));
Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId));

options ??= new();
options.AssistantId = assistantId;
options.Stream = true;

ClientResult protocolResult = CreateRun(threadId, options.ToBinaryContent(), StreamRequestOptions);
ClientResult getResult() => CreateRun(threadId, options.ToBinaryContent(), StreamRequestOptions);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
return new StreamingUpdateCollection(getResult);
}

/// <summary>
Expand Down Expand Up @@ -575,18 +578,22 @@ public virtual ClientResult<ThreadRun> CreateThreadAndRun(
/// <param name="assistantId"> The ID of the assistant that the new run should use. </param>
/// <param name="threadOptions"> Options for the new thread that will be created. </param>
/// <param name="runOptions"> Additional options to apply to the run that will begin. </param>
public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateThreadAndRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> CreateThreadAndRunStreamingAsync(
string assistantId,
ThreadCreationOptions threadOptions = null,
RunCreationOptions runOptions = null)
{
Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId));

runOptions ??= new();
runOptions.Stream = true;
BinaryContent protocolContent = CreateThreadAndRunProtocolContent(assistantId, threadOptions, runOptions);
ClientResult protocolResult = await CreateThreadAndRunAsync(protocolContent, StreamRequestOptions)

async Task<ClientResult> getResultAsync() =>
await CreateThreadAndRunAsync(protocolContent, StreamRequestOptions)
.ConfigureAwait(false);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
return new AsyncStreamingUpdateCollection(getResultAsync);
}

/// <summary>
Expand All @@ -595,17 +602,20 @@ public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> Creat
/// <param name="assistantId"> The ID of the assistant that the new run should use. </param>
/// <param name="threadOptions"> Options for the new thread that will be created. </param>
/// <param name="runOptions"> Additional options to apply to the run that will begin. </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> CreateThreadAndRunStreaming(
public virtual ResultCollection<StreamingUpdate> CreateThreadAndRunStreaming(
string assistantId,
ThreadCreationOptions threadOptions = null,
RunCreationOptions runOptions = null)
{
Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId));

runOptions ??= new();
runOptions.Stream = true;
BinaryContent protocolContent = CreateThreadAndRunProtocolContent(assistantId, threadOptions, runOptions);
ClientResult protocolResult = CreateThreadAndRun(protocolContent, StreamRequestOptions);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
ClientResult getResult() => CreateThreadAndRun(protocolContent, StreamRequestOptions);

return new StreamingUpdateCollection(getResult);
}

/// <summary>
Expand Down Expand Up @@ -729,7 +739,7 @@ public virtual ClientResult<ThreadRun> SubmitToolOutputsToRun(
/// <param name="toolOutputs">
/// The tool outputs, corresponding to <see cref="InternalRequiredToolCall"/> instances from the run.
/// </param>
public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> SubmitToolOutputsToRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> SubmitToolOutputsToRunStreamingAsync(
string threadId,
string runId,
IEnumerable<ToolOutput> toolOutputs)
Expand All @@ -739,10 +749,12 @@ public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> Submi

BinaryContent content = new InternalSubmitToolOutputsRunRequest(toolOutputs.ToList(), stream: true, null)
.ToBinaryContent();
ClientResult protocolResult = await SubmitToolOutputsToRunAsync(threadId, runId, content, StreamRequestOptions)

async Task<ClientResult> getResultAsync() =>
await SubmitToolOutputsToRunAsync(threadId, runId, content, StreamRequestOptions)
.ConfigureAwait(false);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
return new AsyncStreamingUpdateCollection(getResultAsync);
}

/// <summary>
Expand All @@ -753,7 +765,7 @@ public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> Submi
/// <param name="toolOutputs">
/// The tool outputs, corresponding to <see cref="InternalRequiredToolCall"/> instances from the run.
/// </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> SubmitToolOutputsToRunStreaming(
public virtual ResultCollection<StreamingUpdate> SubmitToolOutputsToRunStreaming(
string threadId,
string runId,
IEnumerable<ToolOutput> toolOutputs)
Expand All @@ -763,9 +775,10 @@ public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> SubmitToolOutputs

BinaryContent content = new InternalSubmitToolOutputsRunRequest(toolOutputs.ToList(), stream: true, null)
.ToBinaryContent();
ClientResult protocolResult = SubmitToolOutputsToRun(threadId, runId, content, StreamRequestOptions);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
ClientResult getResult() => SubmitToolOutputsToRun(threadId, runId, content, StreamRequestOptions);

return new StreamingUpdateCollection(getResult);
}

/// <summary>
Expand Down Expand Up @@ -903,6 +916,6 @@ private static ClientResult<T> CreateResultFromProtocol<T>(ClientResult protocol
return ClientResult.FromValue(deserializedResultValue, pipelineResponse);
}

private RequestOptions StreamRequestOptions => _streamRequestOptions ??= new() { BufferResponse = false };
private RequestOptions _streamRequestOptions;
private static RequestOptions StreamRequestOptions => _streamRequestOptions ??= new() { BufferResponse = false };
private static RequestOptions _streamRequestOptions;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;

#nullable enable

namespace OpenAI.Assistants;

/// <summary>
/// Implementation of collection abstraction over streaming assistant updates.
/// </summary>
internal class AsyncStreamingUpdateCollection : AsyncResultCollection<StreamingUpdate>
{
private readonly Func<Task<ClientResult>> _getResultAsync;

public AsyncStreamingUpdateCollection(Func<Task<ClientResult>> getResultAsync) : base()
{
Argument.AssertNotNull(getResultAsync, nameof(getResultAsync));

_getResultAsync = getResultAsync;
}

public override IAsyncEnumerator<StreamingUpdate> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new AsyncStreamingUpdateEnumerator(_getResultAsync, this, cancellationToken);
}

private sealed class AsyncStreamingUpdateEnumerator : IAsyncEnumerator<StreamingUpdate>
{
private const string _terminalData = "[DONE]";

private readonly Func<Task<ClientResult>> _getResultAsync;
private readonly AsyncStreamingUpdateCollection _enumerable;
private readonly CancellationToken _cancellationToken;

// These enumerators represent what is effectively a doubly-nested
// loop over the outer event collection and the inner update collection,
// i.e.:
// foreach (var sse in _events) {
// // get _updates from sse event
// foreach (var update in _updates) { ... }
// }
private IAsyncEnumerator<ServerSentEvent>? _events;
private IEnumerator<StreamingUpdate>? _updates;

private StreamingUpdate? _current;
private bool _started;

public AsyncStreamingUpdateEnumerator(Func<Task<ClientResult>> getResultAsync,
AsyncStreamingUpdateCollection enumerable,
CancellationToken cancellationToken)
{
Debug.Assert(getResultAsync is not null);
Debug.Assert(enumerable is not null);

_getResultAsync = getResultAsync!;
_enumerable = enumerable!;
_cancellationToken = cancellationToken;
}

StreamingUpdate IAsyncEnumerator<StreamingUpdate>.Current
=> _current!;

async ValueTask<bool> IAsyncEnumerator<StreamingUpdate>.MoveNextAsync()
{
if (_events is null && _started)
{
throw new ObjectDisposedException(nameof(AsyncStreamingUpdateEnumerator));
}

_cancellationToken.ThrowIfCancellationRequested();
_events ??= await CreateEventEnumeratorAsync().ConfigureAwait(false);
_started = true;

if (_updates is not null && _updates.MoveNext())
{
_current = _updates.Current;
return true;
}

if (await _events.MoveNextAsync().ConfigureAwait(false))
{
if (_events.Current.Data == _terminalData)
{
_current = default;
return false;
}

var updates = StreamingUpdate.FromEvent(_events.Current);
_updates = updates.GetEnumerator();

if (_updates.MoveNext())
{
_current = _updates.Current;
return true;
}
}

_current = default;
return false;
}

private async Task<IAsyncEnumerator<ServerSentEvent>> CreateEventEnumeratorAsync()
{
ClientResult result = await _getResultAsync().ConfigureAwait(false);
PipelineResponse response = result.GetRawResponse();
_enumerable.SetRawResponse(response);

if (response.ContentStream is null)
{
throw new InvalidOperationException("Unable to create result from response with null ContentStream");
}

AsyncServerSentEventEnumerable enumerable = new(response.ContentStream);
return enumerable.GetAsyncEnumerator(_cancellationToken);
}

public async ValueTask DisposeAsync()
{
await DisposeAsyncCore().ConfigureAwait(false);

GC.SuppressFinalize(this);
}

private async ValueTask DisposeAsyncCore()
{
if (_events is not null)
{
await _events.DisposeAsync().ConfigureAwait(false);
_events = null;

// Dispose the response so we don't leave the unbuffered
// network stream open.
PipelineResponse response = _enumerable.GetRawResponse();
response.Dispose();
}
}
}
}
Loading

0 comments on commit aa7dee8

Please sign in to comment.