Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Narrower update to StreamingClientResult/SSE #68

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 8 additions & 22 deletions .dotnet/src/Custom/Chat/ChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -200,19 +200,12 @@ public virtual StreamingClientResult<StreamingChatUpdate> CompleteChatStreaming(
PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options);
requestMessage.BufferResponse = false;
Shim.Pipeline.Send(requestMessage);
PipelineResponse response = requestMessage.ExtractResponse();

if (response.IsError)
if (requestMessage.Response.IsError)
{
throw new ClientResultException(response);
throw new ClientResultException(requestMessage.Response);
}

ClientResult genericResult = ClientResult.FromResponse(response);
return StreamingClientResult<StreamingChatUpdate>.CreateFromResponse(
genericResult,
(responseForEnumeration) => SseAsyncEnumerator<StreamingChatUpdate>.EnumerateFromSseStream(
responseForEnumeration.GetRawResponse().ContentStream,
e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e)));
return StreamingClientResult<StreamingChatUpdate>
.Create<StreamingChatUpdate, StreamingChatUpdateCollection>(requestMessage.Response);
}

/// <summary>
Expand All @@ -238,19 +231,12 @@ public virtual async Task<StreamingClientResult<StreamingChatUpdate>> CompleteCh
PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options);
requestMessage.BufferResponse = false;
await Shim.Pipeline.SendAsync(requestMessage).ConfigureAwait(false);
PipelineResponse response = requestMessage.ExtractResponse();

if (response.IsError)
if (requestMessage.Response.IsError)
{
throw new ClientResultException(response);
throw new ClientResultException(requestMessage.Response);
}

ClientResult genericResult = ClientResult.FromResponse(response);
return StreamingClientResult<StreamingChatUpdate>.CreateFromResponse(
genericResult,
(responseForEnumeration) => SseAsyncEnumerator<StreamingChatUpdate>.EnumerateFromSseStream(
responseForEnumeration.GetRawResponse().ContentStream,
e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e)));
return StreamingClientResult<StreamingChatUpdate>
.Create<StreamingChatUpdate, StreamingChatUpdateCollection>(requestMessage.Response);
}

private Internal.Models.CreateChatCompletionRequest CreateInternalRequest(
Expand Down
152 changes: 0 additions & 152 deletions .dotnet/src/Custom/Chat/StreamingChatUpdate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -181,156 +181,4 @@ internal StreamingChatUpdate(
ToolCallUpdate = toolCallUpdate;
LogProbabilities = logProbabilities;
}

internal static List<StreamingChatUpdate> DeserializeStreamingChatUpdates(JsonElement element)
{
List<StreamingChatUpdate> results = [];
if (element.ValueKind == JsonValueKind.Null)
{
return results;
}
string id = default;
DateTimeOffset created = default;
string systemFingerprint = null;
foreach (JsonProperty property in element.EnumerateObject())
{
if (property.NameEquals("id"u8))
{
id = property.Value.GetString();
continue;
}
if (property.NameEquals("created"u8))
{
created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());
continue;
}
if (property.NameEquals("system_fingerprint"))
{
systemFingerprint = property.Value.GetString();
continue;
}
if (property.NameEquals("choices"u8))
{
foreach (JsonElement choiceElement in property.Value.EnumerateArray())
{
ChatRole? role = null;
string contentUpdate = null;
string functionName = null;
string functionArgumentsUpdate = null;
int choiceIndex = 0;
ChatFinishReason? finishReason = null;
List<StreamingToolCallUpdate> toolCallUpdates = [];
ChatLogProbabilityCollection logProbabilities = new([]);

foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject())
{
if (choiceProperty.NameEquals("index"u8))
{
choiceIndex = choiceProperty.Value.GetInt32();
continue;
}
if (choiceProperty.NameEquals("finish_reason"u8))
{
if (choiceProperty.Value.ValueKind == JsonValueKind.Null)
{
finishReason = null;
continue;
}
finishReason = choiceProperty.Value.GetString() switch
{
"stop" => ChatFinishReason.Stopped,
"length" => ChatFinishReason.Length,
"tool_calls" => ChatFinishReason.ToolCalls,
"function_call" => ChatFinishReason.FunctionCall,
"content_filter" => ChatFinishReason.ContentFilter,
_ => throw new ArgumentException(nameof(finishReason)),
};
continue;
}
if (choiceProperty.NameEquals("delta"u8))
{
foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject())
{
if (deltaProperty.NameEquals("role"u8))
{
role = deltaProperty.Value.GetString() switch
{
"system" => ChatRole.System,
"user" => ChatRole.User,
"assistant" => ChatRole.Assistant,
"tool" => ChatRole.Tool,
"function" => ChatRole.Function,
_ => throw new ArgumentException(nameof(role)),
};
continue;
}
if (deltaProperty.NameEquals("content"u8))
{
contentUpdate = deltaProperty.Value.GetString();
continue;
}
if (deltaProperty.NameEquals("function_call"u8))
{
foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject())
{
if (functionProperty.NameEquals("name"u8))
{
functionName = functionProperty.Value.GetString();
continue;
}
if (functionProperty.NameEquals("arguments"u8))
{
functionArgumentsUpdate = functionProperty.Value.GetString();
}
}
}
if (deltaProperty.NameEquals("tool_calls"))
{
foreach (JsonElement toolCallElement in deltaProperty.Value.EnumerateArray())
{
toolCallUpdates.Add(
StreamingToolCallUpdate.DeserializeStreamingToolCallUpdate(toolCallElement));
}
}
}
}
if (choiceProperty.NameEquals("logprobs"u8))
{
Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs
= Internal.Models.CreateChatCompletionResponseChoiceLogprobs.DeserializeCreateChatCompletionResponseChoiceLogprobs(
choiceProperty.Value);
logProbabilities = ChatLogProbabilityCollection.FromInternalData(internalLogprobs);
}
}
// In the unlikely event that more than one tool call arrives on a single chunk, we'll generate
// separate updates just like for choices. Adding a "null" if empty lets us avoid a separate loop.
if (toolCallUpdates.Count == 0)
{
toolCallUpdates.Add(null);
}
foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates)
{
results.Add(new StreamingChatUpdate(
id,
created,
systemFingerprint,
choiceIndex,
role,
contentUpdate,
finishReason,
functionName,
functionArgumentsUpdate,
toolCallUpdate,
logProbabilities));
}
}
continue;
}
}
if (results.Count == 0)
{
results.Add(new StreamingChatUpdate(id, created, systemFingerprint));
}
return results;
}
}
181 changes: 181 additions & 0 deletions .dotnet/src/Custom/Chat/StreamingChatUpdateCollection.Serialization.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
using System;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Text.Json;

namespace OpenAI.Chat;

internal partial class StreamingChatUpdateCollection : IJsonModel<StreamingChatUpdateCollection>
{
StreamingChatUpdateCollection IJsonModel<StreamingChatUpdateCollection>.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options)
=> ModelSerializationHelpers.DeserializeNewInstance(this, DeserializeStreamingChatUpdateCollection, ref reader, options);

StreamingChatUpdateCollection IPersistableModel<StreamingChatUpdateCollection>.Create(BinaryData data, ModelReaderWriterOptions options)
=> ModelSerializationHelpers.DeserializeNewInstance(this, DeserializeStreamingChatUpdateCollection, data, options);

void IJsonModel<StreamingChatUpdateCollection>.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options)
=> ModelSerializationHelpers.SerializeInstance<StreamingChatUpdateCollection, StreamingChatUpdateCollection>(this, SerializeStreamingChatUpdateCollections, writer, options);

BinaryData IPersistableModel<StreamingChatUpdateCollection>.Write(ModelReaderWriterOptions options)
=> ModelSerializationHelpers.SerializeInstance<StreamingChatUpdateCollection, StreamingChatUpdateCollection>(this, options);

string IPersistableModel<StreamingChatUpdateCollection>.GetFormatFromOptions(ModelReaderWriterOptions options) => "J";

internal static StreamingChatUpdateCollection DeserializeStreamingChatUpdateCollection(
JsonElement sseDataJson,
ModelReaderWriterOptions options = default)
{
List<StreamingChatUpdate> results = [];
if (sseDataJson.ValueKind == JsonValueKind.Null)
{
return new(results);
}
string id = default;
DateTimeOffset created = default;
string systemFingerprint = null;
foreach (JsonProperty property in sseDataJson.EnumerateObject())
{
if (property.NameEquals("id"u8))
{
id = property.Value.GetString();
continue;
}
if (property.NameEquals("created"u8))
{
created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());
continue;
}
if (property.NameEquals("system_fingerprint"))
{
systemFingerprint = property.Value.GetString();
continue;
}
if (property.NameEquals("choices"u8))
{
foreach (JsonElement choiceElement in property.Value.EnumerateArray())
{
ChatRole? role = null;
string contentUpdate = null;
string functionName = null;
string functionArgumentsUpdate = null;
int choiceIndex = 0;
ChatFinishReason? finishReason = null;
List<StreamingToolCallUpdate> toolCallUpdates = [];
ChatLogProbabilityCollection logProbabilities = new([]);

foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject())
{
if (choiceProperty.NameEquals("index"u8))
{
choiceIndex = choiceProperty.Value.GetInt32();
continue;
}
if (choiceProperty.NameEquals("finish_reason"u8))
{
if (choiceProperty.Value.ValueKind == JsonValueKind.Null)
{
finishReason = null;
continue;
}
finishReason = choiceProperty.Value.GetString() switch
{
"stop" => ChatFinishReason.Stopped,
"length" => ChatFinishReason.Length,
"tool_calls" => ChatFinishReason.ToolCalls,
"function_call" => ChatFinishReason.FunctionCall,
"content_filter" => ChatFinishReason.ContentFilter,
_ => throw new ArgumentException(nameof(finishReason)),
};
continue;
}
if (choiceProperty.NameEquals("delta"u8))
{
foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject())
{
if (deltaProperty.NameEquals("role"u8))
{
role = deltaProperty.Value.GetString() switch
{
"system" => ChatRole.System,
"user" => ChatRole.User,
"assistant" => ChatRole.Assistant,
"tool" => ChatRole.Tool,
"function" => ChatRole.Function,
_ => throw new ArgumentException(nameof(role)),
};
continue;
}
if (deltaProperty.NameEquals("content"u8))
{
contentUpdate = deltaProperty.Value.GetString();
continue;
}
if (deltaProperty.NameEquals("function_call"u8))
{
foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject())
{
if (functionProperty.NameEquals("name"u8))
{
functionName = functionProperty.Value.GetString();
continue;
}
if (functionProperty.NameEquals("arguments"u8))
{
functionArgumentsUpdate = functionProperty.Value.GetString();
}
}
}
if (deltaProperty.NameEquals("tool_calls"))
{
foreach (JsonElement toolCallElement in deltaProperty.Value.EnumerateArray())
{
toolCallUpdates.Add(
StreamingToolCallUpdate.DeserializeStreamingToolCallUpdate(toolCallElement));
}
}
}
}
if (choiceProperty.NameEquals("logprobs"u8))
{
Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs
= Internal.Models.CreateChatCompletionResponseChoiceLogprobs.DeserializeCreateChatCompletionResponseChoiceLogprobs(
choiceProperty.Value);
logProbabilities = ChatLogProbabilityCollection.FromInternalData(internalLogprobs);
}
}
// In the unlikely event that more than one tool call arrives on a single chunk, we'll generate
// separate updates just like for choices. Adding a "null" if empty lets us avoid a separate loop.
if (toolCallUpdates.Count == 0)
{
toolCallUpdates.Add(null);
}
foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates)
{
results.Add(new(
id,
created,
systemFingerprint,
choiceIndex,
role,
contentUpdate,
finishReason,
functionName,
functionArgumentsUpdate,
toolCallUpdate));
}
}
continue;
}
}
if (results.Count == 0)
{
results.Add(new(id, created, systemFingerprint));
}
return new(results);
}

internal static void SerializeStreamingChatUpdateCollections(StreamingChatUpdateCollection StreamingChatUpdateCollection, Utf8JsonWriter writer, ModelReaderWriterOptions options)
{
throw new NotSupportedException(nameof(StreamingChatUpdateCollection));
}
}
12 changes: 12 additions & 0 deletions .dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace OpenAI.Chat;

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Text.Json;

internal partial class StreamingChatUpdateCollection : ReadOnlyCollection<StreamingChatUpdate>
{
internal StreamingChatUpdateCollection() : this([]) { }
internal StreamingChatUpdateCollection(IList<StreamingChatUpdate> list) : base(list) { }
}
Loading
Loading