Skip to content

Commit

Permalink
[JS] Remove cursor from chunk parsing, do it based on existing chunks.
Browse files Browse the repository at this point in the history
…#708 continued (#1143)
  • Loading branch information
mbleigh authored Oct 30, 2024
1 parent 92f908e commit d7b0753
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 156 deletions.
19 changes: 9 additions & 10 deletions js/ai/src/formats/array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { GenkitError } from '@genkit-ai/core';
import { extractItems } from '../extract';
import type { Formatter } from './types';

export const arrayFormatter: Formatter<unknown[], unknown[], number> = {
export const arrayFormatter: Formatter<unknown[], unknown[]> = {
name: 'array',
config: {
contentType: 'application/json',
Expand All @@ -43,16 +43,15 @@ export const arrayFormatter: Formatter<unknown[], unknown[], number> = {
}

return {
parseChunk: (chunk, cursor = 0) => {
const { items, cursor: newCursor } = extractItems(
chunk.accumulatedText,
cursor
);
parseChunk: (chunk) => {
// first, determine the cursor position from the previous chunks
const cursor = chunk.previousChunks?.length
? extractItems(chunk.previousText).cursor
: 0;
// then, extract the items starting at that cursor
const { items } = extractItems(chunk.accumulatedText, cursor);

return {
output: items,
cursor: newCursor,
};
return items;
},

parseResponse: (response) => {
Expand Down
2 changes: 1 addition & 1 deletion js/ai/src/formats/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export async function resolveFormat(
return arg as Formatter;
}

export const DEFAULT_FORMATS: Formatter<any, any, any>[] = [
export const DEFAULT_FORMATS: Formatter<any, any>[] = [
jsonFormatter,
arrayFormatter,
textFormatter,
Expand Down
9 changes: 3 additions & 6 deletions js/ai/src/formats/json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import { extractJson } from '../extract';
import type { Formatter } from './types';

export const jsonFormatter: Formatter<unknown, unknown, string> = {
export const jsonFormatter: Formatter<unknown, unknown> = {
name: 'json',
config: {
contentType: 'application/json',
Expand All @@ -36,11 +36,8 @@ ${JSON.stringify(request.output!.schema!)}
}

return {
parseChunk: (chunk, cursor = '') => {
return {
output: extractJson(chunk.accumulatedText),
cursor: chunk.accumulatedText,
};
parseChunk: (chunk) => {
return extractJson(chunk.accumulatedText);
},

parseResponse: (response) => {
Expand Down
41 changes: 25 additions & 16 deletions js/ai/src/formats/jsonl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function objectLines(text: string): string[] {
.filter((line) => line.startsWith('{'));
}

export const jsonlFormatter: Formatter<unknown[], unknown[], number> = {
export const jsonlFormatter: Formatter<unknown[], unknown[]> = {
name: 'jsonl',
config: {
contentType: 'application/jsonl',
Expand Down Expand Up @@ -54,27 +54,36 @@ ${JSON.stringify(request.output.schema.items)}
}

return {
parseChunk: (chunk, cursor = 0) => {
const jsonLines = objectLines(chunk.accumulatedText);
parseChunk: (chunk) => {
const results: unknown[] = [];
let newCursor = cursor;

for (let i = cursor; i < jsonLines.length; i++) {
try {
const result = JSON5.parse(jsonLines[i]);
if (result) {
results.push(result);
const text = chunk.accumulatedText;

let startIndex = 0;
if (chunk.previousChunks?.length) {
const lastNewline = chunk.previousText.lastIndexOf('\n');
if (lastNewline !== -1) {
startIndex = lastNewline + 1;
}
}

const lines = text.slice(startIndex).split('\n');

for (const line of lines) {
const trimmed = line.trim();
if (trimmed.startsWith('{')) {
try {
const result = JSON5.parse(trimmed);
if (result) {
results.push(result);
}
} catch (e) {
break;
}
newCursor = i + 1;
} catch (e) {
break;
}
}

return {
output: results,
cursor: newCursor,
};
return results;
},

parseResponse: (response) => {
Expand Down
4 changes: 1 addition & 3 deletions js/ai/src/formats/text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ export const textFormatter: Formatter<string, string> = {
handler: () => {
return {
parseChunk: (chunk) => {
return {
output: chunk.text,
};
return chunk.text;
},

parseResponse: (response) => {
Expand Down
17 changes: 2 additions & 15 deletions js/ai/src/formats/types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,17 @@
import { GenerateResponse, GenerateResponseChunk } from '../generate.js';
import { ModelRequest, Part } from '../model.js';

export interface ParsedChunk<CO = unknown, CC = unknown> {
output: CO;
/**
* The cursor of a parsed chunk response holds context that is relevant to continue parsing.
* The returned cursor will be passed into the next iteration of the chunk parser. Cursors
* are not exposed to external consumers of the formatter.
*/
cursor?: CC;
}

type OutputContentTypes =
| 'application/json'
| 'text/plain'
| 'application/jsonl';

export interface Formatter<O = unknown, CO = unknown, CC = unknown> {
export interface Formatter<O = unknown, CO = unknown> {
name: string;
config: ModelRequest['output'];
handler: (req: ModelRequest) => {
parseResponse(response: GenerateResponse): O;
parseChunk?: (
chunk: GenerateResponseChunk,
cursor?: CC
) => ParsedChunk<CO, CC>;
parseChunk?: (chunk: GenerateResponseChunk, cursor?: CC) => CO;
instructions?: string | Part[];
};
}
6 changes: 4 additions & 2 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,14 @@ async function generate(
streamingCallback
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
accumulatedChunks.push(chunk);
if (streamingCallback) {
streamingCallback!(
new GenerateResponseChunk(chunk, accumulatedChunks)
new GenerateResponseChunk(chunk, {
previousChunks: accumulatedChunks,
})
);
}
accumulatedChunks.push(chunk);
}
: undefined,
async () => {
Expand Down
60 changes: 43 additions & 17 deletions js/ai/src/generate/chunk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,50 @@

import { GenkitError } from '@genkit-ai/core';
import { extractJson } from '../extract.js';
import { GenerateResponseChunkData, Part, ToolRequestPart } from '../model.js';
import {
GenerateResponseChunkData,
Part,
Role,
ToolRequestPart,
} from '../model.js';

export interface ChunkParser<T = unknown> {
(chunk: GenerateResponseChunk<T>): T;
}

export class GenerateResponseChunk<T = unknown>
implements GenerateResponseChunkData
{
/** The index of the candidate this chunk corresponds to. */
/** The index of the message this chunk corresponds to, starting with `0` for the first model response of the generation. */
index?: number;
/** The role of the message this chunk corresponds to. Will always be `model` or `tool`. */
role?: Role;
/** The content generated in this chunk. */
content: Part[];
/** Custom model-specific data for this chunk. */
custom?: unknown;
/** Accumulated chunks for partial output extraction. */
accumulatedChunks?: GenerateResponseChunkData[];
previousChunks?: GenerateResponseChunkData[];
/** The parser to be used to parse `output` from this chunk. */
parser?: ChunkParser<T>;

constructor(
data: GenerateResponseChunkData,
accumulatedChunks?: GenerateResponseChunkData[]
options?: {
previousChunks?: GenerateResponseChunkData[];
role?: Role;
index?: number;
parser?: ChunkParser<T>;
}
) {
this.index = data.index;
this.content = data.content || [];
this.custom = data.custom;
this.accumulatedChunks = accumulatedChunks;
this.previousChunks = options?.previousChunks
? [...options.previousChunks]
: undefined;
this.index = options?.index;
this.role = options?.role;
this.parser = options?.parser;
}

/**
Expand All @@ -53,13 +75,20 @@ export class GenerateResponseChunk<T = unknown>
* @returns A string of all concatenated chunk text content.
*/
get accumulatedText(): string {
if (!this.accumulatedChunks)
return this.previousText + this.text;
}

/**
* Concatenates all `text` parts of all preceding chunks.
*/
get previousText(): string {
if (!this.previousChunks)
throw new GenkitError({
status: 'FAILED_PRECONDITION',
message: 'Cannot compose accumulated text without accumulated chunks.',
message: 'Cannot compose accumulated text without previous chunks.',
});

return this.accumulatedChunks
return this.previousChunks
?.map((c) => c.content.map((p) => p.text || '').join(''))
.join('');
}
Expand Down Expand Up @@ -92,18 +121,15 @@ export class GenerateResponseChunk<T = unknown>
}

/**
* Attempts to extract the longest valid JSON substring from the accumulated chunks.
* @returns The longest valid JSON substring found in the accumulated chunks.
* Parses the chunk into the desired output format using the parser associated
* with the generate request, or falls back to naive JSON parsing otherwise.
*/
get output(): T | null {
if (!this.accumulatedChunks) return null;
const accumulatedText = this.accumulatedChunks
.map((chunk) => chunk.content.map((part) => part.text || '').join(''))
.join('');
return extractJson<T>(accumulatedText, false);
if (this.parser) return this.parser(this);
return this.data || extractJson(this.accumulatedText);
}

toJSON(): GenerateResponseChunkData {
return { index: this.index, content: this.content, custom: this.custom };
return { content: this.content, custom: this.custom };
}
}
8 changes: 3 additions & 5 deletions js/ai/tests/formats/array_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,13 @@ describe('arrayFormat', () => {
const newChunk: GenerateResponseChunkData = {
content: [{ text: chunk.text }],
};
chunks.push(newChunk);

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, chunks),
lastCursor
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
);
chunks.push(newChunk);

assert.deepStrictEqual(result.output, chunk.want);
lastCursor = result.cursor!;
assert.deepStrictEqual(result, chunk.want);
}
});
}
Expand Down
7 changes: 3 additions & 4 deletions js/ai/tests/formats/json_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,14 @@ describe('jsonFormat', () => {
const newChunk: GenerateResponseChunkData = {
content: [{ text: chunk.text }],
};
chunks.push(newChunk);

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, chunks),
new GenerateResponseChunk(newChunk, { previousChunks: [...chunks] }),
lastCursor
);
chunks.push(newChunk);

assert.deepStrictEqual(result.output, chunk.want);
lastCursor = result.cursor!;
assert.deepStrictEqual(result, chunk.want);
}
});
}
Expand Down
9 changes: 3 additions & 6 deletions js/ai/tests/formats/jsonl_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,18 @@ describe('jsonlFormat', () => {
it(st.desc, () => {
const parser = jsonlFormatter.handler({ messages: [] });
const chunks: GenerateResponseChunkData[] = [];
let lastCursor = 0;

for (const chunk of st.chunks) {
const newChunk: GenerateResponseChunkData = {
content: [{ text: chunk.text }],
};
chunks.push(newChunk);

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, chunks),
lastCursor
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
);
chunks.push(newChunk);

assert.deepStrictEqual(result.output, chunk.want);
lastCursor = result.cursor!;
assert.deepStrictEqual(result, chunk.want);
}
});
}
Expand Down
6 changes: 3 additions & 3 deletions js/ai/tests/formats/text_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ describe('textFormat', () => {
const newChunk: GenerateResponseChunkData = {
content: [{ text: chunk.text }],
};
chunks.push(newChunk);

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, chunks)
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
);
chunks.push(newChunk);

assert.strictEqual(result.output, chunk.want);
assert.strictEqual(result, chunk.want);
}
});
}
Expand Down
Loading

0 comments on commit d7b0753

Please sign in to comment.