From 8dcdda2b0d1d86486eea5fd47d24a8d26fde4c19 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Thu, 4 Apr 2024 07:32:06 -0400 Subject: [PATCH] fix(streaming): handle special line characters and fix multi-byte character decoding (#757) --- src/streaming.ts | 120 +++++++++++++++----- tests/streaming.test.ts | 245 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 338 insertions(+), 27 deletions(-) diff --git a/src/streaming.ts b/src/streaming.ts index 6b0f2a345..722a8f69c 100644 --- a/src/streaming.ts +++ b/src/streaming.ts @@ -23,29 +23,6 @@ export class Stream implements AsyncIterable { static fromSSEResponse(response: Response, controller: AbortController) { let consumed = false; - const decoder = new SSEDecoder(); - - async function* iterMessages(): AsyncGenerator { - if (!response.body) { - controller.abort(); - throw new OpenAIError(`Attempted to iterate over a response with no body`); - } - - const lineDecoder = new LineDecoder(); - - const iter = readableStreamAsyncIterable(response.body); - for await (const chunk of iter) { - for (const line of lineDecoder.decode(chunk)) { - const sse = decoder.decode(line); - if (sse) yield sse; - } - } - - for (const line of lineDecoder.flush()) { - const sse = decoder.decode(line); - if (sse) yield sse; - } - } async function* iterator(): AsyncIterator { if (consumed) { @@ -54,7 +31,7 @@ export class Stream implements AsyncIterable { consumed = true; let done = false; try { - for await (const sse of iterMessages()) { + for await (const sse of _iterSSEMessages(response, controller)) { if (done) continue; if (sse.data.startsWith('[DONE]')) { @@ -220,6 +197,97 @@ export class Stream implements AsyncIterable { } } +export async function* _iterSSEMessages( + response: Response, + controller: AbortController, +): AsyncGenerator { + if (!response.body) { + controller.abort(); + throw new OpenAIError(`Attempted to iterate over a response with no body`); + } + + const sseDecoder = new SSEDecoder(); + const lineDecoder = new LineDecoder(); + + const iter = readableStreamAsyncIterable(response.body); + for await (const sseChunk of iterSSEChunks(iter)) { + for (const line of lineDecoder.decode(sseChunk)) { + const sse = sseDecoder.decode(line); + if (sse) yield sse; + } + } + + for (const line of lineDecoder.flush()) { + const sse = sseDecoder.decode(line); + if (sse) yield sse; + } +} + +/** + * Given an async iterable iterator, iterates over it and yields full + * SSE chunks, i.e. yields when a double new-line is encountered. + */ +async function* iterSSEChunks(iterator: AsyncIterableIterator): AsyncGenerator { + let data = new Uint8Array(); + + for await (const chunk of iterator) { + if (chunk == null) { + continue; + } + + const binaryChunk = + chunk instanceof ArrayBuffer ? new Uint8Array(chunk) + : typeof chunk === 'string' ? new TextEncoder().encode(chunk) + : chunk; + + let newData = new Uint8Array(data.length + binaryChunk.length); + newData.set(data); + newData.set(binaryChunk, data.length); + data = newData; + + let patternIndex; + while ((patternIndex = findDoubleNewlineIndex(data)) !== -1) { + yield data.slice(0, patternIndex); + data = data.slice(patternIndex); + } + } + + if (data.length > 0) { + yield data; + } +} + +function findDoubleNewlineIndex(buffer: Uint8Array): number { + // This function searches the buffer for the end patterns (\r\r, \n\n, \r\n\r\n) + // and returns the index right after the first occurrence of any pattern, + // or -1 if none of the patterns are found. + const newline = 0x0a; // \n + const carriage = 0x0d; // \r + + for (let i = 0; i < buffer.length - 2; i++) { + if (buffer[i] === newline && buffer[i + 1] === newline) { + // \n\n + return i + 2; + } + if (buffer[i] === carriage && buffer[i + 1] === carriage) { + // \r\r + return i + 2; + } + if ( + buffer[i] === carriage && + buffer[i + 1] === newline && + i + 3 < buffer.length && + buffer[i + 2] === carriage && + buffer[i + 3] === newline + ) { + // \r\n\r\n + return i + 4; + } + } + + return -1; +} + class SSEDecoder { private data: string[]; private event: string | null; @@ -283,8 +351,8 @@ class SSEDecoder { */ class LineDecoder { // prettier-ignore - static NEWLINE_CHARS = new Set(['\n', '\r', '\x0b', '\x0c', '\x1c', '\x1d', '\x1e', '\x85', '\u2028', '\u2029']); - static NEWLINE_REGEXP = /\r\n|[\n\r\x0b\x0c\x1c\x1d\x1e\x85\u2028\u2029]/g; + static NEWLINE_CHARS = new Set(['\n', '\r']); + static NEWLINE_REGEXP = /\r\n|[\n\r]/g; buffer: string[]; trailingCR: boolean; diff --git a/tests/streaming.test.ts b/tests/streaming.test.ts index 479b2a341..6fe9a5781 100644 --- a/tests/streaming.test.ts +++ b/tests/streaming.test.ts @@ -1,4 +1,7 @@ -import { _decodeChunks as decodeChunks } from 'openai/streaming'; +import { Response } from 'node-fetch'; +import { PassThrough } from 'stream'; +import assert from 'assert'; +import { _iterSSEMessages, _decodeChunks as decodeChunks } from 'openai/streaming'; describe('line decoder', () => { test('basic', () => { @@ -27,3 +30,243 @@ describe('line decoder', () => { expect(decodeChunks(['foo', ' bar\\r\\nbaz\n'])).toEqual(['foo bar\\r\\nbaz']); }); }); + +describe('streaming decoding', () => { + test('basic', async () => { + async function* body(): AsyncGenerator { + yield Buffer.from('event: completion\n'); + yield Buffer.from('data: {"foo":true}\n'); + yield Buffer.from('\n'); + } + + const stream = _iterSSEMessages(new Response(await iteratorToStream(body())), new AbortController())[ + Symbol.asyncIterator + ](); + + let event = await stream.next(); + assert(event.value); + expect(JSON.parse(event.value.data)).toEqual({ foo: true }); + + event = await stream.next(); + expect(event.done).toBeTruthy(); + }); + + test('data without event', async () => { + async function* body(): AsyncGenerator { + yield Buffer.from('data: {"foo":true}\n'); + yield Buffer.from('\n'); + } + + const stream = _iterSSEMessages(new Response(await iteratorToStream(body())), new AbortController())[ + Symbol.asyncIterator + ](); + + let event = await stream.next(); + assert(event.value); + expect(event.value.event).toBeNull(); + expect(JSON.parse(event.value.data)).toEqual({ foo: true }); + + event = await stream.next(); + expect(event.done).toBeTruthy(); + }); + + test('event without data', async () => { + async function* body(): AsyncGenerator { + yield Buffer.from('event: foo\n'); + yield Buffer.from('\n'); + } + + const stream = _iterSSEMessages(new Response(await iteratorToStream(body())), new AbortController())[ + Symbol.asyncIterator + ](); + + let event = await stream.next(); + assert(event.value); + expect(event.value.event).toEqual('foo'); + expect(event.value.data).toEqual(''); + + event = await stream.next(); + expect(event.done).toBeTruthy(); + }); + + test('multiple events', async () => { + async function* body(): AsyncGenerator { + yield Buffer.from('event: foo\n'); + yield Buffer.from('\n'); + yield Buffer.from('event: ping\n'); + yield Buffer.from('\n'); + } + + const stream = _iterSSEMessages(new Response(await iteratorToStream(body())), new AbortController())[ + Symbol.asyncIterator + ](); + + let event = await stream.next(); + assert(event.value); + expect(event.value.event).toEqual('foo'); + expect(event.value.data).toEqual(''); + + event = await stream.next(); + assert(event.value); + expect(event.value.event).toEqual('ping'); + expect(event.value.data).toEqual(''); + + event = await stream.next(); + expect(event.done).toBeTruthy(); + }); + + test('multiple events with data', async () => { + async function* body(): AsyncGenerator { + yield Buffer.from('event: foo\n'); + yield Buffer.from('data: {"foo":true}\n'); + yield Buffer.from('\n'); + yield Buffer.from('event: ping\n'); + yield Buffer.from('data: {"bar":false}\n'); + yield Buffer.from('\n'); + } + + const stream = _iterSSEMessages(new Response(await iteratorToStream(body())), new AbortController())[ + Symbol.asyncIterator + ](); + + let event = await stream.next(); + assert(event.value); + expect(event.value.event).toEqual('foo'); + expect(JSON.parse(event.value.data)).toEqual({ foo: true }); + + event = await stream.next(); + assert(event.value); + expect(event.value.event).toEqual('ping'); + expect(JSON.parse(event.value.data)).toEqual({ bar: false }); + + event = await stream.next(); + expect(event.done).toBeTruthy(); + }); + + test('multiple data lines with empty line', async () => { + async function* body(): AsyncGenerator { + yield Buffer.from('event: ping\n'); + yield Buffer.from('data: {\n'); + yield Buffer.from('data: "foo":\n'); + yield Buffer.from('data: \n'); + yield Buffer.from('data:\n'); + yield Buffer.from('data: true}\n'); + yield Buffer.from('\n\n'); + } + + const stream = _iterSSEMessages(new Response(await iteratorToStream(body())), new AbortController())[ + Symbol.asyncIterator + ](); + + let event = await stream.next(); + assert(event.value); + expect(event.value.event).toEqual('ping'); + expect(JSON.parse(event.value.data)).toEqual({ foo: true }); + expect(event.value.data).toEqual('{\n"foo":\n\n\ntrue}'); + + event = await stream.next(); + expect(event.done).toBeTruthy(); + }); + + test('data json escaped double new line', async () => { + async function* body(): AsyncGenerator { + yield Buffer.from('event: ping\n'); + yield Buffer.from('data: {"foo": "my long\\n\\ncontent"}'); + yield Buffer.from('\n\n'); + } + + const stream = _iterSSEMessages(new Response(await iteratorToStream(body())), new AbortController())[ + Symbol.asyncIterator + ](); + + let event = await stream.next(); + assert(event.value); + expect(event.value.event).toEqual('ping'); + expect(JSON.parse(event.value.data)).toEqual({ foo: 'my long\n\ncontent' }); + + event = await stream.next(); + expect(event.done).toBeTruthy(); + }); + + test('special new line characters', async () => { + async function* body(): AsyncGenerator { + yield Buffer.from('data: {"content": "culpa "}\n'); + yield Buffer.from('\n'); + yield Buffer.from('data: {"content": "'); + yield Buffer.from([0xe2, 0x80, 0xa8]); + yield Buffer.from('"}\n'); + yield Buffer.from('\n'); + yield Buffer.from('data: {"content": "foo"}\n'); + yield Buffer.from('\n'); + } + + const stream = _iterSSEMessages(new Response(await iteratorToStream(body())), new AbortController())[ + Symbol.asyncIterator + ](); + + let event = await stream.next(); + assert(event.value); + expect(JSON.parse(event.value.data)).toEqual({ content: 'culpa ' }); + + event = await stream.next(); + assert(event.value); + expect(JSON.parse(event.value.data)).toEqual({ content: Buffer.from([0xe2, 0x80, 0xa8]).toString() }); + + event = await stream.next(); + assert(event.value); + expect(JSON.parse(event.value.data)).toEqual({ content: 'foo' }); + + event = await stream.next(); + expect(event.done).toBeTruthy(); + }); + + test('multi-byte characters across chunks', async () => { + async function* body(): AsyncGenerator { + yield Buffer.from('event: completion\n'); + yield Buffer.from('data: {"content": "'); + // bytes taken from the string 'известни' and arbitrarily split + // so that some multi-byte characters span multiple chunks + yield Buffer.from([0xd0]); + yield Buffer.from([0xb8, 0xd0, 0xb7, 0xd0]); + yield Buffer.from([0xb2, 0xd0, 0xb5, 0xd1, 0x81, 0xd1, 0x82, 0xd0, 0xbd, 0xd0, 0xb8]); + yield Buffer.from('"}\n'); + yield Buffer.from('\n'); + } + + const stream = _iterSSEMessages(new Response(await iteratorToStream(body())), new AbortController())[ + Symbol.asyncIterator + ](); + + let event = await stream.next(); + assert(event.value); + expect(event.value.event).toEqual('completion'); + expect(JSON.parse(event.value.data)).toEqual({ content: 'известни' }); + + event = await stream.next(); + expect(event.done).toBeTruthy(); + }); +}); + +async function iteratorToStream(iterator: AsyncGenerator): Promise { + const parts: unknown[] = []; + + for await (const chunk of iterator) { + parts.push(chunk); + } + + let index = 0; + + const stream = new PassThrough({ + read() { + const value = parts[index]; + if (value === undefined) { + stream.end(); + } else { + index += 1; + stream.write(value); + } + }, + }); + + return stream; +}