From da7c7e69f126495649a85b146ef41aa129a7f76a Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Fri, 27 Sep 2024 00:36:33 +0200 Subject: [PATCH] feat(agent): improve invalid json parsing --- src/agents/parsers/field.test.ts | 21 +++++++++++ src/agents/parsers/field.ts | 60 ++++++++++++++++++++++++++++---- src/agents/parsers/linePrefix.ts | 26 ++++++++++---- 3 files changed, 94 insertions(+), 13 deletions(-) diff --git a/src/agents/parsers/field.test.ts b/src/agents/parsers/field.test.ts index 4502159b..acd9dd33 100644 --- a/src/agents/parsers/field.test.ts +++ b/src/agents/parsers/field.test.ts @@ -61,6 +61,27 @@ describe("Parser Fields", () => { }); }); + describe("Invalid JSON", () => { + it("Object", async () => { + const field = new JSONParserField({ + schema: z.object({}).passthrough(), + base: {}, + matchPair: ["{", "}"], + }); + + const validPart = `{"a":{"b":{"c":{"d":1}}},"b":2}`; + const invalidPart = `{"a":{"b":{"c":{"d":1}}},"b":2,}`; + + const content = `Here is the object that you were asking for: ${invalidPart} Thank you!`; + for (const chunk of splitString(content, { size: 4, overlap: 0 })) { + field.write(chunk); + } + await field.end(); + expect(field.raw).toBe(invalidPart); + expect(JSON.stringify(field.get())).toBe(validPart); + }); + }); + it("String", async () => { const field = new ZodParserField(z.string()); const content = "Hello world!"; diff --git a/src/agents/parsers/field.ts b/src/agents/parsers/field.ts index 79583f4f..ee6982e4 100644 --- a/src/agents/parsers/field.ts +++ b/src/agents/parsers/field.ts @@ -22,6 +22,8 @@ import { JSONParser } from "@streamparser/json"; import { jsonrepairTransform } from "jsonrepair/stream"; import { Cache, SingletonCacheKeyFn } from "@/cache/decoratorCache.js"; import { shallowCopy } from "@/serializer/utils.js"; +import { parseBrokenJson } from "@/internals/helpers/schema.js"; +import { findFirstPair } from "@/internals/helpers/string.js"; export abstract class ParserField extends Serializable { public raw = ""; @@ -69,9 +71,16 @@ export class ZodParserField extends ParserField { export class JSONParserField extends ParserField> { protected stream!: ReturnType; protected jsonParser!: JSONParser; + protected errored = false; protected ref!: { value: Partial }; - constructor(protected readonly input: { schema: ZodSchema; base: Partial }) { + constructor( + protected readonly input: { + schema: ZodSchema; + base: Partial; + matchPair?: [string, string]; + }, + ) { super(); if (input.base === undefined) { throw new ValueError(`Base must be defined!`); @@ -85,7 +94,15 @@ export class JSONParserField extends ParserField> { this.jsonParser = new JSONParser({ emitPartialTokens: false, emitPartialValues: true }); this.stream = jsonrepairTransform(); this.stream.on("data", (chunk) => { - this.jsonParser.write(chunk.toString()); + if (this.errored) { + return; + } + + try { + this.jsonParser.write(chunk.toString()); + } catch { + this.errored = true; + } }); this.jsonParser.onValue = ({ value, key, stack }) => { const keys = stack @@ -103,12 +120,43 @@ export class JSONParserField extends ParserField> { } write(chunk: string) { + if (this.input.matchPair) { + if (!this.raw) { + const startChar = this.input.matchPair[0]; + const index = chunk.indexOf(startChar); + if (index === -1) { + return; + } + chunk = chunk.substring(index); + } else { + const merged = this.raw.concat(chunk); + const match = findFirstPair(merged, this.input.matchPair); + if (match) { + const end = match[1]; + if (end < this.raw.length) { + return; + } + chunk = merged.substring(this.raw.length, end + 1); + } + } + } + super.write(chunk); - this.stream.push(chunk); + try { + this.stream.push(chunk); + } catch { + this.errored = true; + } } get() { - return this.input.schema.parse(this.ref.value); + const inputToParse = this.errored + ? parseBrokenJson(this.raw, { + pair: this.input.matchPair, + }) + : this.ref.value; + + return this.input.schema.parse(inputToParse); } getPartial() { @@ -116,7 +164,7 @@ export class JSONParserField extends ParserField> { } async end() { - if (this.stream.closed || this.jsonParser.isEnded) { + if (this.stream.closed || this.jsonParser.isEnded || this.errored) { return; } @@ -130,7 +178,7 @@ export class JSONParserField extends ParserField> { } createSnapshot() { - return { ...super.createSnapshot(), input: this.input }; + return { ...super.createSnapshot(), input: this.input, errored: this.errored }; } loadSnapshot({ raw, ...snapshot }: ReturnType) { diff --git a/src/agents/parsers/linePrefix.ts b/src/agents/parsers/linePrefix.ts index ef6f8c1d..aa87f310 100644 --- a/src/agents/parsers/linePrefix.ts +++ b/src/agents/parsers/linePrefix.ts @@ -21,6 +21,7 @@ import { shallowCopy } from "@/serializer/utils.js"; import { Cache } from "@/cache/decoratorCache.js"; import { ParserField } from "@/agents/parsers/field.js"; import { Callback, InferCallbackValue } from "@/emitter/types.js"; +import { ZodError } from "zod"; export interface ParserNode> { prefix: string; @@ -251,13 +252,24 @@ export class LinePrefixParser< if (key in this.finalState) { throw new LinePrefixParserError(`Duplicated key '${key}'`); } - const value = field.get(); - this.finalState[key] = value; - await this.emitter.emit("update", { - key, - field, - value, - }); + + try { + const value = field.get(); + this.finalState[key] = value; + await this.emitter.emit("update", { + key, + field, + value, + }); + } catch (e) { + if (e instanceof ZodError) { + throw new LinePrefixParserError( + `Value for ${key} cannot be retrieved because it's value does not adhere to the appropriate schema.`, + [e], + ); + } + throw e; + } } @Cache()