From 2ac59d966e3677f23d8536747497726eb9b0c4a9 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Thu, 11 Apr 2024 16:08:56 +0200 Subject: [PATCH] Add test for StreamTextResult.toAIStream. (#1328) --- .../core/generate-text/stream-text.test.ts | 28 +++++++++++++++++++ .../core/core/generate-text/stream-text.ts | 4 +-- .../test/convert-readable-stream-to-array.ts | 14 ++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 packages/core/core/test/convert-readable-stream-to-array.ts diff --git a/packages/core/core/generate-text/stream-text.test.ts b/packages/core/core/generate-text/stream-text.test.ts index b269482757a1..2ca7b2cb00cc 100644 --- a/packages/core/core/generate-text/stream-text.test.ts +++ b/packages/core/core/generate-text/stream-text.test.ts @@ -2,6 +2,7 @@ import assert from 'node:assert'; import { z } from 'zod'; import { convertArrayToReadableStream } from '../test/convert-array-to-readable-stream'; import { convertAsyncIterableToArray } from '../test/convert-async-iterable-to-array'; +import { convertReadableStreamToArray } from '../test/convert-readable-stream-to-array'; import { MockLanguageModelV1 } from '../test/mock-language-model-v1'; import { experimental_streamText } from './stream-text'; @@ -232,6 +233,33 @@ describe('result.fullStream', () => { }); }); +describe('result.toAIStream', () => { + it('should transform textStream through callbacks and data transformers', async () => { + const result = await experimental_streamText({ + model: new MockLanguageModelV1({ + doStream: async ({ prompt, mode }) => { + return { + stream: convertArrayToReadableStream([ + { type: 'text-delta', textDelta: 'Hello' }, + { type: 'text-delta', textDelta: ', ' }, + { type: 'text-delta', textDelta: 'world!' }, + ]), + rawCall: { rawPrompt: 'prompt', rawSettings: {} }, + }; + }, + }), + prompt: 'test-input', + }); + + assert.deepStrictEqual( + await convertReadableStreamToArray( + result.toAIStream().pipeThrough(new TextDecoderStream()), + ), + ['0:"Hello"\n', '0:", "\n', '0:"world!"\n'], + ); + }); +}); + describe('result.toTextStreamResponse', () => { it('should create a Response with a text stream', async () => { const result = await experimental_streamText({ diff --git a/packages/core/core/generate-text/stream-text.ts b/packages/core/core/generate-text/stream-text.ts index 2cd70e0ac0e0..bc5c596c86ee 100644 --- a/packages/core/core/generate-text/stream-text.ts +++ b/packages/core/core/generate-text/stream-text.ts @@ -7,7 +7,6 @@ import { AIStreamCallbacksAndOptions, createCallbacksTransformer, createStreamDataTransformer, - readableFromAsyncIterable, } from '../../streams'; import { CallSettings } from '../prompt/call-settings'; import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt'; @@ -214,8 +213,7 @@ Stream callbacks that will be called when the stream emits events. @returns an `AIStream` object. */ toAIStream(callbacks?: AIStreamCallbacksAndOptions) { - // TODO add support for tool calls - return readableFromAsyncIterable(this.textStream) + return this.textStream .pipeThrough(createCallbacksTransformer(callbacks)) .pipeThrough(createStreamDataTransformer()); } diff --git a/packages/core/core/test/convert-readable-stream-to-array.ts b/packages/core/core/test/convert-readable-stream-to-array.ts new file mode 100644 index 000000000000..307e33179660 --- /dev/null +++ b/packages/core/core/test/convert-readable-stream-to-array.ts @@ -0,0 +1,14 @@ +export async function convertReadableStreamToArray( + stream: ReadableStream, +): Promise { + const reader = stream.getReader(); + const result: T[] = []; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + result.push(value); + } + + return result; +}