diff --git a/js/ai/src/formats/array.ts b/js/ai/src/formats/array.ts index 68981d960..2df3a4777 100644 --- a/js/ai/src/formats/array.ts +++ b/js/ai/src/formats/array.ts @@ -18,7 +18,7 @@ import { GenkitError } from '@genkit-ai/core'; import { extractItems } from '../extract'; import type { Formatter } from './types'; -export const arrayFormatter: Formatter = { +export const arrayFormatter: Formatter = { name: 'array', config: { contentType: 'application/json', @@ -43,16 +43,15 @@ export const arrayFormatter: Formatter = { } return { - parseChunk: (chunk, cursor = 0) => { - const { items, cursor: newCursor } = extractItems( - chunk.accumulatedText, - cursor - ); + parseChunk: (chunk) => { + // first, determine the cursor position from the previous chunks + const cursor = chunk.previousChunks?.length + ? extractItems(chunk.previousText).cursor + : 0; + // then, extract the items starting at that cursor + const { items } = extractItems(chunk.accumulatedText, cursor); - return { - output: items, - cursor: newCursor, - }; + return items; }, parseResponse: (response) => { diff --git a/js/ai/src/formats/index.ts b/js/ai/src/formats/index.ts index 57c971c4b..40189e4ba 100644 --- a/js/ai/src/formats/index.ts +++ b/js/ai/src/formats/index.ts @@ -48,7 +48,7 @@ export async function resolveFormat( return arg as Formatter; } -export const DEFAULT_FORMATS: Formatter[] = [ +export const DEFAULT_FORMATS: Formatter[] = [ jsonFormatter, arrayFormatter, textFormatter, diff --git a/js/ai/src/formats/json.ts b/js/ai/src/formats/json.ts index 37dc78e88..61fe6e8c1 100644 --- a/js/ai/src/formats/json.ts +++ b/js/ai/src/formats/json.ts @@ -17,7 +17,7 @@ import { extractJson } from '../extract'; import type { Formatter } from './types'; -export const jsonFormatter: Formatter = { +export const jsonFormatter: Formatter = { name: 'json', config: { contentType: 'application/json', @@ -36,11 +36,8 @@ ${JSON.stringify(request.output!.schema!)} } return { - parseChunk: (chunk, cursor = '') => { - return { - output: extractJson(chunk.accumulatedText), - cursor: chunk.accumulatedText, - }; + parseChunk: (chunk) => { + return extractJson(chunk.accumulatedText); }, parseResponse: (response) => { diff --git a/js/ai/src/formats/jsonl.ts b/js/ai/src/formats/jsonl.ts index a3740a48f..d20898714 100644 --- a/js/ai/src/formats/jsonl.ts +++ b/js/ai/src/formats/jsonl.ts @@ -26,7 +26,7 @@ function objectLines(text: string): string[] { .filter((line) => line.startsWith('{')); } -export const jsonlFormatter: Formatter = { +export const jsonlFormatter: Formatter = { name: 'jsonl', config: { contentType: 'application/jsonl', @@ -54,27 +54,36 @@ ${JSON.stringify(request.output.schema.items)} } return { - parseChunk: (chunk, cursor = 0) => { - const jsonLines = objectLines(chunk.accumulatedText); + parseChunk: (chunk) => { const results: unknown[] = []; - let newCursor = cursor; - for (let i = cursor; i < jsonLines.length; i++) { - try { - const result = JSON5.parse(jsonLines[i]); - if (result) { - results.push(result); + const text = chunk.accumulatedText; + + let startIndex = 0; + if (chunk.previousChunks?.length) { + const lastNewline = chunk.previousText.lastIndexOf('\n'); + if (lastNewline !== -1) { + startIndex = lastNewline + 1; + } + } + + const lines = text.slice(startIndex).split('\n'); + + for (const line of lines) { + const trimmed = line.trim(); + if (trimmed.startsWith('{')) { + try { + const result = JSON5.parse(trimmed); + if (result) { + results.push(result); + } + } catch (e) { + break; } - newCursor = i + 1; - } catch (e) { - break; } } - return { - output: results, - cursor: newCursor, - }; + return results; }, parseResponse: (response) => { diff --git a/js/ai/src/formats/text.ts b/js/ai/src/formats/text.ts index acdf994cc..985b64a96 100644 --- a/js/ai/src/formats/text.ts +++ b/js/ai/src/formats/text.ts @@ -24,9 +24,7 @@ export const textFormatter: Formatter = { handler: () => { return { parseChunk: (chunk) => { - return { - output: chunk.text, - }; + return chunk.text; }, parseResponse: (response) => { diff --git a/js/ai/src/formats/types.d.ts b/js/ai/src/formats/types.d.ts index a4beb3e91..e8847e572 100644 --- a/js/ai/src/formats/types.d.ts +++ b/js/ai/src/formats/types.d.ts @@ -17,30 +17,17 @@ import { GenerateResponse, GenerateResponseChunk } from '../generate.js'; import { ModelRequest, Part } from '../model.js'; -export interface ParsedChunk { - output: CO; - /** - * The cursor of a parsed chunk response holds context that is relevant to continue parsing. - * The returned cursor will be passed into the next iteration of the chunk parser. Cursors - * are not exposed to external consumers of the formatter. - */ - cursor?: CC; -} - type OutputContentTypes = | 'application/json' | 'text/plain' | 'application/jsonl'; -export interface Formatter { +export interface Formatter { name: string; config: ModelRequest['output']; handler: (req: ModelRequest) => { parseResponse(response: GenerateResponse): O; - parseChunk?: ( - chunk: GenerateResponseChunk, - cursor?: CC - ) => ParsedChunk; + parseChunk?: (chunk: GenerateResponseChunk, cursor?: CC) => CO; instructions?: string | Part[]; }; } diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 0710aa30c..7a563c701 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -142,12 +142,14 @@ async function generate( streamingCallback ? (chunk: GenerateResponseChunkData) => { // Store accumulated chunk data - accumulatedChunks.push(chunk); if (streamingCallback) { streamingCallback!( - new GenerateResponseChunk(chunk, accumulatedChunks) + new GenerateResponseChunk(chunk, { + previousChunks: accumulatedChunks, + }) ); } + accumulatedChunks.push(chunk); } : undefined, async () => { diff --git a/js/ai/src/generate/chunk.ts b/js/ai/src/generate/chunk.ts index 21f0ea000..4cf20d67a 100644 --- a/js/ai/src/generate/chunk.ts +++ b/js/ai/src/generate/chunk.ts @@ -16,28 +16,50 @@ import { GenkitError } from '@genkit-ai/core'; import { extractJson } from '../extract.js'; -import { GenerateResponseChunkData, Part, ToolRequestPart } from '../model.js'; +import { + GenerateResponseChunkData, + Part, + Role, + ToolRequestPart, +} from '../model.js'; + +export interface ChunkParser { + (chunk: GenerateResponseChunk): T; +} export class GenerateResponseChunk implements GenerateResponseChunkData { - /** The index of the candidate this chunk corresponds to. */ + /** The index of the message this chunk corresponds to, starting with `0` for the first model response of the generation. */ index?: number; + /** The role of the message this chunk corresponds to. Will always be `model` or `tool`. */ + role?: Role; /** The content generated in this chunk. */ content: Part[]; /** Custom model-specific data for this chunk. */ custom?: unknown; /** Accumulated chunks for partial output extraction. */ - accumulatedChunks?: GenerateResponseChunkData[]; + previousChunks?: GenerateResponseChunkData[]; + /** The parser to be used to parse `output` from this chunk. */ + parser?: ChunkParser; constructor( data: GenerateResponseChunkData, - accumulatedChunks?: GenerateResponseChunkData[] + options?: { + previousChunks?: GenerateResponseChunkData[]; + role?: Role; + index?: number; + parser?: ChunkParser; + } ) { - this.index = data.index; this.content = data.content || []; this.custom = data.custom; - this.accumulatedChunks = accumulatedChunks; + this.previousChunks = options?.previousChunks + ? [...options.previousChunks] + : undefined; + this.index = options?.index; + this.role = options?.role; + this.parser = options?.parser; } /** @@ -53,13 +75,20 @@ export class GenerateResponseChunk * @returns A string of all concatenated chunk text content. */ get accumulatedText(): string { - if (!this.accumulatedChunks) + return this.previousText + this.text; + } + + /** + * Concatenates all `text` parts of all preceding chunks. + */ + get previousText(): string { + if (!this.previousChunks) throw new GenkitError({ status: 'FAILED_PRECONDITION', - message: 'Cannot compose accumulated text without accumulated chunks.', + message: 'Cannot compose accumulated text without previous chunks.', }); - return this.accumulatedChunks + return this.previousChunks ?.map((c) => c.content.map((p) => p.text || '').join('')) .join(''); } @@ -92,18 +121,15 @@ export class GenerateResponseChunk } /** - * Attempts to extract the longest valid JSON substring from the accumulated chunks. - * @returns The longest valid JSON substring found in the accumulated chunks. + * Parses the chunk into the desired output format using the parser associated + * with the generate request, or falls back to naive JSON parsing otherwise. */ get output(): T | null { - if (!this.accumulatedChunks) return null; - const accumulatedText = this.accumulatedChunks - .map((chunk) => chunk.content.map((part) => part.text || '').join('')) - .join(''); - return extractJson(accumulatedText, false); + if (this.parser) return this.parser(this); + return this.data || extractJson(this.accumulatedText); } toJSON(): GenerateResponseChunkData { - return { index: this.index, content: this.content, custom: this.custom }; + return { content: this.content, custom: this.custom }; } } diff --git a/js/ai/tests/formats/array_test.ts b/js/ai/tests/formats/array_test.ts index 3bc8bff14..f174d1cd0 100644 --- a/js/ai/tests/formats/array_test.ts +++ b/js/ai/tests/formats/array_test.ts @@ -73,15 +73,13 @@ describe('arrayFormat', () => { const newChunk: GenerateResponseChunkData = { content: [{ text: chunk.text }], }; - chunks.push(newChunk); const result = parser.parseChunk!( - new GenerateResponseChunk(newChunk, chunks), - lastCursor + new GenerateResponseChunk(newChunk, { previousChunks: chunks }) ); + chunks.push(newChunk); - assert.deepStrictEqual(result.output, chunk.want); - lastCursor = result.cursor!; + assert.deepStrictEqual(result, chunk.want); } }); } diff --git a/js/ai/tests/formats/json_test.ts b/js/ai/tests/formats/json_test.ts index edb5a8e79..833ca790d 100644 --- a/js/ai/tests/formats/json_test.ts +++ b/js/ai/tests/formats/json_test.ts @@ -69,15 +69,14 @@ describe('jsonFormat', () => { const newChunk: GenerateResponseChunkData = { content: [{ text: chunk.text }], }; - chunks.push(newChunk); const result = parser.parseChunk!( - new GenerateResponseChunk(newChunk, chunks), + new GenerateResponseChunk(newChunk, { previousChunks: [...chunks] }), lastCursor ); + chunks.push(newChunk); - assert.deepStrictEqual(result.output, chunk.want); - lastCursor = result.cursor!; + assert.deepStrictEqual(result, chunk.want); } }); } diff --git a/js/ai/tests/formats/jsonl_test.ts b/js/ai/tests/formats/jsonl_test.ts index f89c6d9a2..6d7fe138d 100644 --- a/js/ai/tests/formats/jsonl_test.ts +++ b/js/ai/tests/formats/jsonl_test.ts @@ -76,21 +76,18 @@ describe('jsonlFormat', () => { it(st.desc, () => { const parser = jsonlFormatter.handler({ messages: [] }); const chunks: GenerateResponseChunkData[] = []; - let lastCursor = 0; for (const chunk of st.chunks) { const newChunk: GenerateResponseChunkData = { content: [{ text: chunk.text }], }; - chunks.push(newChunk); const result = parser.parseChunk!( - new GenerateResponseChunk(newChunk, chunks), - lastCursor + new GenerateResponseChunk(newChunk, { previousChunks: chunks }) ); + chunks.push(newChunk); - assert.deepStrictEqual(result.output, chunk.want); - lastCursor = result.cursor!; + assert.deepStrictEqual(result, chunk.want); } }); } diff --git a/js/ai/tests/formats/text_test.ts b/js/ai/tests/formats/text_test.ts index 607ccd220..cee7a719c 100644 --- a/js/ai/tests/formats/text_test.ts +++ b/js/ai/tests/formats/text_test.ts @@ -55,13 +55,13 @@ describe('textFormat', () => { const newChunk: GenerateResponseChunkData = { content: [{ text: chunk.text }], }; - chunks.push(newChunk); const result = parser.parseChunk!( - new GenerateResponseChunk(newChunk, chunks) + new GenerateResponseChunk(newChunk, { previousChunks: chunks }) ); + chunks.push(newChunk); - assert.strictEqual(result.output, chunk.want); + assert.strictEqual(result, chunk.want); } }); } diff --git a/js/ai/tests/generate/chunk_test.ts b/js/ai/tests/generate/chunk_test.ts index b37c3588f..febc670a1 100644 --- a/js/ai/tests/generate/chunk_test.ts +++ b/js/ai/tests/generate/chunk_test.ts @@ -17,77 +17,25 @@ import assert from 'node:assert'; import { describe, it } from 'node:test'; import { GenerateResponseChunk } from '../../src/generate.js'; -import { GenerateResponseChunkData } from '../../src/model.js'; describe('GenerateResponseChunk', () => { - describe('#output()', () => { - const testCases = [ - { - should: 'parse ``` correctly', - accumulatedChunksTexts: ['```'], - correctJson: null, - }, - { - should: 'parse valid json correctly', - accumulatedChunksTexts: [`{"foo":"bar"}`], - correctJson: { foo: 'bar' }, - }, - { - should: 'if json invalid, return null', - accumulatedChunksTexts: [`invalid json`], - correctJson: null, - }, - { - should: 'handle missing closing brace', - accumulatedChunksTexts: [`{"foo":"bar"`], - correctJson: { foo: 'bar' }, - }, - { - should: 'handle missing closing bracket in nested object', - accumulatedChunksTexts: [`{"foo": {"bar": "baz"`], - correctJson: { foo: { bar: 'baz' } }, - }, - { - should: 'handle multiple chunks', - accumulatedChunksTexts: [`{"foo": {"bar"`, `: "baz`], - correctJson: { foo: { bar: 'baz' } }, - }, - { - should: 'handle multiple chunks with nested objects', - accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: {"baz": "qux`], - correctJson: { foo: { bar: { baz: 'qux' } } }, - }, - { - should: 'handle array nested in object', - accumulatedChunksTexts: [`{"foo": ["bar`], - correctJson: { foo: ['bar'] }, - }, - { - should: 'handle array nested in object with multiple chunks', - accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: ["baz`], - correctJson: { foo: { bar: ['baz'] } }, - }, - ]; - - for (const test of testCases) { - if (test.should) { - it(test.should, () => { - const accumulatedChunks: GenerateResponseChunkData[] = - test.accumulatedChunksTexts.map((text, index) => ({ - index, - content: [{ text }], - })); - - const chunkData = accumulatedChunks[accumulatedChunks.length - 1]; - - const responseChunk: GenerateResponseChunk = - new GenerateResponseChunk(chunkData, accumulatedChunks); + describe('text accumulation', () => { + const testChunk = new GenerateResponseChunk( + { content: [{ text: 'new' }] }, + { + previousChunks: [ + { content: [{ text: 'old1' }] }, + { content: [{ text: 'old2' }] }, + ], + } + ); - const output = responseChunk.output; + it('#previousText should concatenate the text of previous parts', () => { + assert.strictEqual(testChunk.previousText, 'old1old2'); + }); - assert.deepStrictEqual(output, test.correctJson); - }); - } - } + it('#accumulatedText should concatenate previous with current text', () => { + assert.strictEqual(testChunk.accumulatedText, 'old1old2new'); + }); }); });