Skip to content

Commit

Permalink
[OpenAI] Re-implement SSEs (Azure#26704)
Browse files Browse the repository at this point in the history
### Packages impacted by this PR
@azure/openai

### Issues associated with this PR
Azure#26376

### Describe the problem that is addressed by this PR
The old implementation of SSEs didn't do the right thing with regard to
parsing events spanning multiple stream chunks. It converted every chunk
to a string first which is not valid.

### What are the possible designs available to address the problem? If
there are more than one possible design, why was the one in this PR
chosen?
This implementation follows closely the spec in
https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
and is largely based on
[`@microsoft/fetch-event-source`](https://www.npmjs.com/package/@microsoft/fetch-event-source)'s.
I think we should consider moving the implementation to core to be used
to interpret responses with content type value of `text/event-stream`
and I can file an issue after merging this PR.

### Are there test cases added in this PR? _(If not, why?)_
Yes!

### Provide a list of related PRs _(if any)_
N/A

### Command used to generate this PR:**_(Applicable only to SDK release
request PRs)_

### Checklists
- [x] Added impacted package name to the issue description
- [ ] Does this PR needs any fixes in the SDK Generator?** _(If so,
create an Issue in the
[Autorest/typescript](https://github.com/Azure/autorest.typescript)
repository and link it here)_
- [x] Added a changelog (if necessary)

---------

Co-authored-by: Jeff Fisher <[email protected]>
  • Loading branch information
deyaaeldeen and xirzec authored Aug 8, 2023
1 parent adf8bd0 commit a84e331
Show file tree
Hide file tree
Showing 30 changed files with 1,008 additions and 183 deletions.
2 changes: 2 additions & 0 deletions sdk/openai/openai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

### Bugs Fixed

- Fix a bug where server-sent events were not being parsed correctly.

### Other Changes

## 1.0.0-beta.3 (2023-07-13)
Expand Down
8 changes: 3 additions & 5 deletions sdk/openai/openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"main": "dist/index.cjs",
"module": "dist-esm/src/index.js",
"browser": {
"./dist-esm/src/api/getStream.js": "./dist-esm/src/api/getStream.browser.js"
"./dist-esm/src/api/getSSEs.js": "./dist-esm/src/api/getSSEs.browser.js"
},
"type": "module",
"exports": {
Expand Down Expand Up @@ -46,7 +46,7 @@
"extract-api": "tsc -p . && api-extractor run --local",
"format": "prettier --write --config ../../../.prettierrc.json --ignore-path ../../../.prettierignore \"sources/customizations/**/*.ts\" \"src/**/*.ts\" \"test/**/*.ts\" \"samples-dev/**/*.ts\" \"*.{js,json}\"",
"integration-test:browser": "npm run unit-test:browser",
"integration-test:node": "dev-tool run test:node-js-input -- --timeout 5000000 \"dist-esm/test/**/*.spec.js\"",
"integration-test:node": "dev-tool run test:node-js-input -- --timeout 5000000 \"dist-esm/test/public/{,!(browser)/**/}/*.spec.js\"",
"integration-test": "npm run integration-test:node && npm run integration-test:browser",
"lint:fix": "eslint README.md package.json api-extractor.json src test --ext .ts,.javascript,.js --fix --fix-type [problem,suggestion]",
"lint": "eslint README.md package.json api-extractor.json src test --ext .ts,.javascript,.js",
Expand All @@ -55,7 +55,7 @@
"test:node": "npm run clean && tsc -p . && npm run integration-test:node",
"test": "npm run clean && tsc -p . && npm run unit-test:node && dev-tool run bundle && npm run unit-test:browser && npm run integration-test",
"unit-test:browser": "dev-tool run test:browser -- karma.conf.cjs",
"unit-test:node": "dev-tool run test:node-ts-input -- \"test/internal/unit/{,!(browser)/**/}*.spec.ts\" \"test/public/{,!(browser)/**/}*.spec.ts\"",
"unit-test:node": "dev-tool run test:node-ts-input -- \"test/internal/{,!(browser)/**/}*.spec.ts\" \"test/public/{,!(browser)/**/}*.spec.ts\"",
"unit-test": "npm run unit-test:node && npm run unit-test:browser"
},
"files": [
Expand Down Expand Up @@ -94,14 +94,12 @@
"@azure-tools/test-credential": "^1.0.0",
"@azure/test-utils": "^1.0.0",
"@microsoft/api-extractor": "^7.31.1",
"@types/fs-extra": "^9.0.0",
"@types/mocha": "^7.0.2",
"@types/node": "^14.0.0",
"cross-env": "^7.0.3",
"dotenv": "^16.0.0",
"eslint": "^8.16.0",
"esm": "^3.2.25",
"fs-extra": "^10.0.0",
"karma": "^6.4.0",
"karma-chrome-launcher": "^3.1.1",
"karma-coverage": "^2.2.0",
Expand Down
4 changes: 2 additions & 2 deletions sdk/openai/openai/review/openai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ export class OpenAIClient {
getCompletions(deploymentName: string, prompt: string[], options?: GetCompletionsOptions): Promise<Completions>;
getEmbeddings(deploymentName: string, input: string[], options?: GetEmbeddingsOptions): Promise<Embeddings>;
getImages(prompt: string, options?: ImageGenerationOptions): Promise<ImageGenerationResponse>;
listChatCompletions(deploymentName: string, messages: ChatMessage[], options?: GetChatCompletionsOptions): Promise<AsyncIterable<Omit<ChatCompletions, "usage">>>;
listCompletions(deploymentName: string, prompt: string[], options?: GetCompletionsOptions): Promise<AsyncIterable<Omit<Completions, "usage">>>;
listChatCompletions(deploymentName: string, messages: ChatMessage[], options?: GetChatCompletionsOptions): AsyncIterable<Omit<ChatCompletions, "usage">>;
listCompletions(deploymentName: string, prompt: string[], options?: GetCompletionsOptions): AsyncIterable<Omit<Completions, "usage">>;
}

// @public (undocumented)
Expand Down
50 changes: 25 additions & 25 deletions sdk/openai/openai/sources/customizations/OpenAIClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
OpenAIClientOptions,
} from "../generated/api/index.js";
import { getChatCompletionsResult, getCompletionsResult } from "./api/operations.js";
import { getSSEs } from "./api/sse.js";
import { getOaiSSEs } from "./api/oaiSse.js";
import { ChatCompletions, Completions, Embeddings } from "../generated/api/models.js";
import { _getChatCompletionsSend, _getCompletionsSend } from "../generated/api/operations.js";
import { ImageGenerationOptions } from "./api/operations.js";
Expand Down Expand Up @@ -154,90 +154,90 @@ export class OpenAIClient {

/**
* Returns textual completions as configured for a given prompt.
* @param deploymentOrModelName - Specifies either the model deployment name (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
* @param deploymentName - Specifies either the model deployment name (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
* @param prompt - The prompt to use for this request.
* @param options - The options for this completions request.
* @returns The completions for the given prompt.
*/
getCompletions(
deploymentOrModelName: string,
deploymentName: string,
prompt: string[],
options: GetCompletionsOptions = { requestOptions: {} }
): Promise<Completions> {
this.setModel(deploymentOrModelName, options);
return getCompletions(this._client, prompt, deploymentOrModelName, options);
this.setModel(deploymentName, options);
return getCompletions(this._client, prompt, deploymentName, options);
}

/**
* Lists the completions tokens as they become available for a given prompt.
* @param deploymentOrModelName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
* @param deploymentName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
* @param prompt - The prompt to use for this request.
* @param options - The completions options for this completions request.
* @returns An asynchronous iterable of completions tokens.
*/
listCompletions(
deploymentOrModelName: string,
deploymentName: string,
prompt: string[],
options: GetCompletionsOptions = {}
): Promise<AsyncIterable<Omit<Completions, "usage">>> {
this.setModel(deploymentOrModelName, options);
const response = _getCompletionsSend(this._client, prompt, deploymentOrModelName, {
): AsyncIterable<Omit<Completions, "usage">> {
this.setModel(deploymentName, options);
const response = _getCompletionsSend(this._client, prompt, deploymentName, {
...options,
stream: true,
});
return getSSEs(response, getCompletionsResult);
return getOaiSSEs(response, getCompletionsResult);
}

/**
* Return the computed embeddings for a given prompt.
* @param deploymentOrModelName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
* @param deploymentName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
* @param input - The prompt to use for this request.
* @param options - The embeddings options for this embeddings request.
* @returns The embeddings for the given prompt.
*/
getEmbeddings(
deploymentOrModelName: string,
deploymentName: string,
input: string[],
options: GetEmbeddingsOptions = { requestOptions: {} }
): Promise<Embeddings> {
this.setModel(deploymentOrModelName, options);
return getEmbeddings(this._client, input, deploymentOrModelName, options);
this.setModel(deploymentName, options);
return getEmbeddings(this._client, input, deploymentName, options);
}

/**
* Get chat completions for provided chat context messages.
* @param deploymentOrModelName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
* @param deploymentName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
* @param messages - The chat context messages to use for this request.
* @param options - The chat completions options for this completions request.
* @returns The chat completions for the given chat context messages.
*/
getChatCompletions(
deploymentOrModelName: string,
deploymentName: string,
messages: ChatMessage[],
options: GetChatCompletionsOptions = { requestOptions: {} }
): Promise<ChatCompletions> {
this.setModel(deploymentOrModelName, options);
return getChatCompletions(this._client, messages, deploymentOrModelName, options);
this.setModel(deploymentName, options);
return getChatCompletions(this._client, messages, deploymentName, options);
}

/**
* Lists the chat completions tokens as they become available for a chat context.
* @param deploymentOrModelName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
* @param deploymentName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
* @param messages - The chat context messages to use for this request.
* @param options - The chat completions options for this chat completions request.
* @returns An asynchronous iterable of chat completions tokens.
*/
listChatCompletions(
deploymentOrModelName: string,
deploymentName: string,
messages: ChatMessage[],
options: GetChatCompletionsOptions = { requestOptions: {} }
): Promise<AsyncIterable<Omit<ChatCompletions, "usage">>> {
this.setModel(deploymentOrModelName, options);
const response = _getChatCompletionsSend(this._client, messages, deploymentOrModelName, {
): AsyncIterable<Omit<ChatCompletions, "usage">> {
this.setModel(deploymentName, options);
const response = _getChatCompletionsSend(this._client, messages, deploymentName, {
...options,
stream: true,
});
return getSSEs(response, getChatCompletionsResult);
return getOaiSSEs(response, getChatCompletionsResult);
}

/**
Expand Down
35 changes: 35 additions & 0 deletions sdk/openai/openai/sources/customizations/api/getSSEs.browser.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

import { StreamableMethod } from "@azure-rest/core-client";
import { EventMessage, toSSE } from "./sse.js";

async function* toAsyncIterable<T>(stream: ReadableStream<T>): AsyncIterable<T> {
const reader = stream.getReader();
try {
while (true) {
const { value, done } = await reader.read();
if (done) {
return;
}
yield value;
}
} finally {
reader.releaseLock();
}
}

async function getStream<TResponse>(
response: StreamableMethod<TResponse>
): Promise<AsyncIterable<Uint8Array>> {
const stream = (await response.asBrowserStream()).body;
if (!stream) throw new Error("No stream found in response. Did you enable the stream option?");
return toAsyncIterable(stream);
}

export async function getSSEs(
response: StreamableMethod<unknown>
): Promise<AsyncIterable<EventMessage>> {
const iter = await getStream(response);
return toSSE(iter);
}
20 changes: 20 additions & 0 deletions sdk/openai/openai/sources/customizations/api/getSSEs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

import { StreamableMethod } from "@azure-rest/core-client";
import { EventMessage, toSSE } from "./sse.js";

async function getStream<TResponse>(
response: StreamableMethod<TResponse>
): Promise<AsyncIterable<Uint8Array>> {
const stream = (await response.asNodeStream()).body;
if (!stream) throw new Error("No stream found in response. Did you enable the stream option?");
return stream as AsyncIterable<Uint8Array>;
}

export async function getSSEs(
response: StreamableMethod<unknown>
): Promise<AsyncIterable<EventMessage>> {
const chunkIterator = await getStream(response);
return toSSE(chunkIterator);
}
22 changes: 0 additions & 22 deletions sdk/openai/openai/sources/customizations/api/getStream.browser.ts

This file was deleted.

14 changes: 0 additions & 14 deletions sdk/openai/openai/sources/customizations/api/getStream.ts

This file was deleted.

30 changes: 30 additions & 0 deletions sdk/openai/openai/sources/customizations/api/oaiSse.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

import { StreamableMethod } from "@azure-rest/core-client";
import { getSSEs } from "./getSSEs.js";
import { wrapError } from "./util.js";

export async function* getOaiSSEs<TEvent>(
response: StreamableMethod<unknown>,
toEvent: (obj: Record<string, any>) => TEvent
): AsyncIterable<TEvent> {
const stream = await getSSEs(response);
let isDone = false;
for await (const event of stream) {
if (isDone) {
// handle a case where the service sends excess stream
// data after the [DONE] event
continue;
} else if (event.data === "[DONE]") {
isDone = true;
} else {
yield toEvent(
wrapError(
() => JSON.parse(event.data),
"Error parsing an event. See 'cause' for more details"
)
);
}
}
}
Loading

0 comments on commit a84e331

Please sign in to comment.