Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return SCM collection abstractions from Assistant streaming convenience methods #104

Merged
merged 35 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a42df2f
Initial revisions; WIP
annelo-msft May 3, 2024
700aaab
add test that sketches out protocol SSE usage.
annelo-msft May 3, 2024
acaa719
updates
annelo-msft May 3, 2024
b56683e
Merge remote-tracking branch 'joseharriaga/main' into oai-sse-protoco…
annelo-msft May 10, 2024
d4788d8
initial rework
annelo-msft May 10, 2024
608abbc
tidy
annelo-msft May 10, 2024
0dd7b44
more reshuffle
annelo-msft May 10, 2024
01ebf69
add SSE extension method stub to AOAI
annelo-msft May 10, 2024
6b52cb7
tidy
annelo-msft May 10, 2024
abd7002
WIP - intial impl of Async chat update coll
annelo-msft May 10, 2024
99a8db2
extract response content if not buffered
annelo-msft May 10, 2024
bbcf6d6
implement sync collection and don't dispose response if not buffered
annelo-msft May 10, 2024
cb3660e
tidy
annelo-msft May 10, 2024
86cf24c
fix
annelo-msft May 10, 2024
d1cf5d5
Add AOAI streaming chat test
annelo-msft May 11, 2024
f464e61
add streaming chat test to AOAI
annelo-msft May 11, 2024
98b8b01
add streaming chat test to AOAI
annelo-msft May 11, 2024
4fa995f
add streaming chat test to AOAI
annelo-msft May 11, 2024
f57aede
Merge remote-tracking branch 'joseharriaga/main' into oai-sse-protoco…
annelo-msft May 14, 2024
3c988fc
rm AOAI bits
annelo-msft May 14, 2024
6e7c795
updates from SCM
annelo-msft May 14, 2024
9f3516d
remove temp workaround
annelo-msft May 14, 2024
5c1ce17
add mock test for tool call chat updates
annelo-msft May 14, 2024
ef89eb3
comment out strong-naming and internals test
annelo-msft May 14, 2024
9144349
update to use alpha SCM package with collection types
annelo-msft May 14, 2024
febc14c
nits
annelo-msft May 14, 2024
5f40dea
Merge remote-tracking branch 'joseharriaga/main' into oai-sse-protoco…
annelo-msft May 17, 2024
887492d
post merge conflict resolution
annelo-msft May 17, 2024
a82fd43
WIP
annelo-msft May 17, 2024
cab2332
make it build
annelo-msft May 17, 2024
82a0fab
remove temp methods
annelo-msft May 17, 2024
6a406f8
revert chat bits from this PR
annelo-msft May 17, 2024
4f19434
revert more unneeded support files
annelo-msft May 17, 2024
72d70d1
nits
annelo-msft May 17, 2024
7439b70
nits
annelo-msft May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .dotnet/src/Custom/Assistants/AssistantClient.Convenience.cs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ public virtual ClientResult<ThreadRun> CreateRun(AssistantThread thread, Assista
/// <param name="thread"> The thread that the run should evaluate. </param>
/// <param name="assistant"> The assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> CreateRunStreamingAsync(
AssistantThread thread,
Assistant assistant,
RunCreationOptions options = null)
Expand All @@ -249,7 +249,7 @@ public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateRunSt
/// <param name="thread"> The thread that the run should evaluate. </param>
/// <param name="assistant"> The assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> CreateRunStreaming(
public virtual ResultCollection<StreamingUpdate> CreateRunStreaming(
AssistantThread thread,
Assistant assistant,
RunCreationOptions options = null)
Expand Down Expand Up @@ -287,7 +287,7 @@ public virtual ClientResult<ThreadRun> CreateThreadAndRun(
/// <param name="assistant"> The assistant that the new run should use. </param>
/// <param name="threadOptions"> Options for the new thread that will be created. </param>
/// <param name="runOptions"> Additional options to apply to the run that will begin. </param>
public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateThreadAndRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> CreateThreadAndRunStreamingAsync(
Assistant assistant,
ThreadCreationOptions threadOptions = null,
RunCreationOptions runOptions = null)
Expand All @@ -299,7 +299,7 @@ public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateThrea
/// <param name="assistant"> The assistant that the new run should use. </param>
/// <param name="threadOptions"> Options for the new thread that will be created. </param>
/// <param name="runOptions"> Additional options to apply to the run that will begin. </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> CreateThreadAndRunStreaming(
public virtual ResultCollection<StreamingUpdate> CreateThreadAndRunStreaming(
Assistant assistant,
ThreadCreationOptions threadOptions = null,
RunCreationOptions runOptions = null)
Expand Down Expand Up @@ -390,7 +390,7 @@ public virtual ClientResult<ThreadRun> SubmitToolOutputsToRun(
/// <param name="toolOutputs">
/// The tool outputs, corresponding to <see cref="InternalRequiredToolCall"/> instances from the run.
/// </param>
public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> SubmitToolOutputsToRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> SubmitToolOutputsToRunStreamingAsync(
ThreadRun run,
IEnumerable<ToolOutput> toolOutputs)
=> SubmitToolOutputsToRunStreamingAsync(run?.ThreadId, run?.Id, toolOutputs);
Expand All @@ -402,7 +402,7 @@ public virtual Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> SubmitToolO
/// <param name="toolOutputs">
/// The tool outputs, corresponding to <see cref="InternalRequiredToolCall"/> instances from the run.
/// </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> SubmitToolOutputsToRunStreaming(
public virtual ResultCollection<StreamingUpdate> SubmitToolOutputsToRunStreaming(
ThreadRun run,
IEnumerable<ToolOutput> toolOutputs)
=> SubmitToolOutputsToRunStreaming(run?.ThreadId, run?.Id, toolOutputs);
Expand Down
53 changes: 33 additions & 20 deletions .dotnet/src/Custom/Assistants/AssistantClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -491,21 +491,23 @@ public virtual ClientResult<ThreadRun> CreateRun(string threadId, string assista
/// <param name="threadId"> The ID of the thread that the run should evaluate. </param>
/// <param name="assistantId"> The ID of the assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> CreateRunStreamingAsync(
string threadId,
string assistantId,
RunCreationOptions options = null)
{
Argument.AssertNotNullOrEmpty(threadId, nameof(threadId));
Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId));

options ??= new();
options.AssistantId = assistantId;
options.Stream = true;

ClientResult protocolResult = await CreateRunAsync(threadId, options.ToBinaryContent(), StreamRequestOptions)
async Task<ClientResult> getResultAsync() =>
await CreateRunAsync(threadId, options.ToBinaryContent(), StreamRequestOptions)
.ConfigureAwait(false);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
return new AsyncStreamingUpdateCollection(getResultAsync);
}

/// <summary>
Expand All @@ -515,20 +517,21 @@ public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> Creat
/// <param name="threadId"> The ID of the thread that the run should evaluate. </param>
/// <param name="assistantId"> The ID of the assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> CreateRunStreaming(
public virtual ResultCollection<StreamingUpdate> CreateRunStreaming(
string threadId,
string assistantId,
RunCreationOptions options = null)
{
Argument.AssertNotNullOrEmpty(threadId, nameof(threadId));
Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId));

options ??= new();
options.AssistantId = assistantId;
options.Stream = true;

ClientResult protocolResult = CreateRun(threadId, options.ToBinaryContent(), StreamRequestOptions);
ClientResult getResult() => CreateRun(threadId, options.ToBinaryContent(), StreamRequestOptions);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
return new StreamingUpdateCollection(getResult);
}

/// <summary>
Expand Down Expand Up @@ -575,18 +578,22 @@ public virtual ClientResult<ThreadRun> CreateThreadAndRun(
/// <param name="assistantId"> The ID of the assistant that the new run should use. </param>
/// <param name="threadOptions"> Options for the new thread that will be created. </param>
/// <param name="runOptions"> Additional options to apply to the run that will begin. </param>
public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> CreateThreadAndRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> CreateThreadAndRunStreamingAsync(
string assistantId,
ThreadCreationOptions threadOptions = null,
RunCreationOptions runOptions = null)
{
Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId));

runOptions ??= new();
runOptions.Stream = true;
BinaryContent protocolContent = CreateThreadAndRunProtocolContent(assistantId, threadOptions, runOptions);
ClientResult protocolResult = await CreateThreadAndRunAsync(protocolContent, StreamRequestOptions)

async Task<ClientResult> getResultAsync() =>
await CreateThreadAndRunAsync(protocolContent, StreamRequestOptions)
.ConfigureAwait(false);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
return new AsyncStreamingUpdateCollection(getResultAsync);
}

/// <summary>
Expand All @@ -595,17 +602,20 @@ public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> Creat
/// <param name="assistantId"> The ID of the assistant that the new run should use. </param>
/// <param name="threadOptions"> Options for the new thread that will be created. </param>
/// <param name="runOptions"> Additional options to apply to the run that will begin. </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> CreateThreadAndRunStreaming(
public virtual ResultCollection<StreamingUpdate> CreateThreadAndRunStreaming(
string assistantId,
ThreadCreationOptions threadOptions = null,
RunCreationOptions runOptions = null)
{
Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId));

runOptions ??= new();
runOptions.Stream = true;
BinaryContent protocolContent = CreateThreadAndRunProtocolContent(assistantId, threadOptions, runOptions);
ClientResult protocolResult = CreateThreadAndRun(protocolContent, StreamRequestOptions);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
ClientResult getResult() => CreateThreadAndRun(protocolContent, StreamRequestOptions);

return new StreamingUpdateCollection(getResult);
}

/// <summary>
Expand Down Expand Up @@ -729,7 +739,7 @@ public virtual ClientResult<ThreadRun> SubmitToolOutputsToRun(
/// <param name="toolOutputs">
/// The tool outputs, corresponding to <see cref="InternalRequiredToolCall"/> instances from the run.
/// </param>
public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> SubmitToolOutputsToRunStreamingAsync(
public virtual AsyncResultCollection<StreamingUpdate> SubmitToolOutputsToRunStreamingAsync(
string threadId,
string runId,
IEnumerable<ToolOutput> toolOutputs)
Expand All @@ -739,10 +749,12 @@ public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> Submi

BinaryContent content = new InternalSubmitToolOutputsRunRequest(toolOutputs.ToList(), stream: true, null)
.ToBinaryContent();
ClientResult protocolResult = await SubmitToolOutputsToRunAsync(threadId, runId, content, StreamRequestOptions)

async Task<ClientResult> getResultAsync() =>
await SubmitToolOutputsToRunAsync(threadId, runId, content, StreamRequestOptions)
.ConfigureAwait(false);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
return new AsyncStreamingUpdateCollection(getResultAsync);
}

/// <summary>
Expand All @@ -753,7 +765,7 @@ public virtual async Task<ClientResult<IAsyncEnumerable<StreamingUpdate>>> Submi
/// <param name="toolOutputs">
/// The tool outputs, corresponding to <see cref="InternalRequiredToolCall"/> instances from the run.
/// </param>
public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> SubmitToolOutputsToRunStreaming(
public virtual ResultCollection<StreamingUpdate> SubmitToolOutputsToRunStreaming(
string threadId,
string runId,
IEnumerable<ToolOutput> toolOutputs)
Expand All @@ -763,9 +775,10 @@ public virtual ClientResult<IAsyncEnumerable<StreamingUpdate>> SubmitToolOutputs

BinaryContent content = new InternalSubmitToolOutputsRunRequest(toolOutputs.ToList(), stream: true, null)
.ToBinaryContent();
ClientResult protocolResult = SubmitToolOutputsToRun(threadId, runId, content, StreamRequestOptions);

return StreamingUpdate.CreateTemporaryResult(protocolResult);
ClientResult getResult() => SubmitToolOutputsToRun(threadId, runId, content, StreamRequestOptions);

return new StreamingUpdateCollection(getResult);
}

/// <summary>
Expand Down Expand Up @@ -903,6 +916,6 @@ private static ClientResult<T> CreateResultFromProtocol<T>(ClientResult protocol
return ClientResult.FromValue(deserializedResultValue, pipelineResponse);
}

private RequestOptions StreamRequestOptions => _streamRequestOptions ??= new() { BufferResponse = false };
private RequestOptions _streamRequestOptions;
private static RequestOptions StreamRequestOptions => _streamRequestOptions ??= new() { BufferResponse = false };
private static RequestOptions _streamRequestOptions;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;

#nullable enable

namespace OpenAI.Assistants;

/// <summary>
/// Implementation of collection abstraction over streaming assistant updates.
/// </summary>
internal class AsyncStreamingUpdateCollection : AsyncResultCollection<StreamingUpdate>
{
private readonly Func<Task<ClientResult>> _getResultAsync;

public AsyncStreamingUpdateCollection(Func<Task<ClientResult>> getResultAsync) : base()
{
Argument.AssertNotNull(getResultAsync, nameof(getResultAsync));

_getResultAsync = getResultAsync;
}

public override IAsyncEnumerator<StreamingUpdate> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new AsyncStreamingUpdateEnumerator(_getResultAsync, this, cancellationToken);
}

private sealed class AsyncStreamingUpdateEnumerator : IAsyncEnumerator<StreamingUpdate>
{
private const string _terminalData = "[DONE]";

private readonly Func<Task<ClientResult>> _getResultAsync;
private readonly AsyncStreamingUpdateCollection _enumerable;
private readonly CancellationToken _cancellationToken;

// These enumerators represent what is effectively a doubly-nested
// loop over the outer event collection and the inner update collection,
// i.e.:
// foreach (var sse in _events) {
// // get _updates from sse event
// foreach (var update in _updates) { ... }
// }
private IAsyncEnumerator<ServerSentEvent>? _events;
private IEnumerator<StreamingUpdate>? _updates;

private StreamingUpdate? _current;
private bool _started;

public AsyncStreamingUpdateEnumerator(Func<Task<ClientResult>> getResultAsync,
AsyncStreamingUpdateCollection enumerable,
CancellationToken cancellationToken)
{
Debug.Assert(getResultAsync is not null);
Debug.Assert(enumerable is not null);

_getResultAsync = getResultAsync!;
_enumerable = enumerable!;
_cancellationToken = cancellationToken;
}

StreamingUpdate IAsyncEnumerator<StreamingUpdate>.Current
=> _current!;

async ValueTask<bool> IAsyncEnumerator<StreamingUpdate>.MoveNextAsync()
{
if (_events is null && _started)
{
throw new ObjectDisposedException(nameof(AsyncStreamingUpdateEnumerator));
}

_cancellationToken.ThrowIfCancellationRequested();
_events ??= await CreateEventEnumeratorAsync().ConfigureAwait(false);
_started = true;

if (_updates is not null && _updates.MoveNext())
{
_current = _updates.Current;
return true;
}

if (await _events.MoveNextAsync().ConfigureAwait(false))
{
if (_events.Current.Data == _terminalData)
{
_current = default;
return false;
}

var updates = StreamingUpdate.FromEvent(_events.Current);
_updates = updates.GetEnumerator();

if (_updates.MoveNext())
{
_current = _updates.Current;
return true;
}
}

_current = default;
return false;
}

private async Task<IAsyncEnumerator<ServerSentEvent>> CreateEventEnumeratorAsync()
{
ClientResult result = await _getResultAsync().ConfigureAwait(false);
PipelineResponse response = result.GetRawResponse();
_enumerable.SetRawResponse(response);

if (response.ContentStream is null)
{
throw new InvalidOperationException("Unable to create result from response with null ContentStream");
}

AsyncServerSentEventEnumerable enumerable = new(response.ContentStream);
return enumerable.GetAsyncEnumerator(_cancellationToken);
}

public async ValueTask DisposeAsync()
{
await DisposeAsyncCore().ConfigureAwait(false);

GC.SuppressFinalize(this);
}

private async ValueTask DisposeAsyncCore()
{
if (_events is not null)
{
await _events.DisposeAsync().ConfigureAwait(false);
_events = null;

// Dispose the response so we don't leave the unbuffered
// network stream open.
PipelineResponse response = _enumerable.GetRawResponse();
response.Dispose();
}
}
}
}
Loading
Loading