diff --git a/src/llms/drivers/base.ts b/src/llms/drivers/base.ts index fb11438a..5c7f4ca2 100644 --- a/src/llms/drivers/base.ts +++ b/src/llms/drivers/base.ts @@ -27,13 +27,16 @@ import { Retryable } from "@/internals/helpers/retryable.js"; import { PromptTemplate } from "@/template.js"; import { SchemaObject } from "ajv"; import { z } from "zod"; +import { Serializable } from "@/internals/serializable.js"; export interface GenerateSchemaInput { maxRetries?: number; options?: T; } -export abstract class BaseDriver { +export abstract class BaseDriver< + TGenerateOptions extends GenerateOptions = GenerateOptions, +> extends Serializable { protected abstract template: PromptTemplate.infer<{ schema: string }>; protected errorTemplate = new PromptTemplate({ schema: z.object({ @@ -45,7 +48,9 @@ export abstract class BaseDriver) {} + constructor(protected readonly llm: ChatLLM) { + super(); + } protected abstract parseResponse(textResponse: string): unknown; protected abstract schemaToString(schema: SchemaObject): Promise | string; @@ -123,4 +128,15 @@ Validation Errors: "{{errors}}"`, }, }).get(); } + + createSnapshot() { + return { + template: this.template, + errorTemplate: this.errorTemplate, + }; + } + + loadSnapshot(snapshot: ReturnType) { + Object.assign(this, snapshot); + } } diff --git a/src/llms/drivers/json.ts b/src/llms/drivers/json.ts index 8505c166..34f8d723 100644 --- a/src/llms/drivers/json.ts +++ b/src/llms/drivers/json.ts @@ -38,6 +38,10 @@ IMPORTANT: Every message must be a parsable JSON string without additional outpu `, }); + static { + this.register(); + } + protected parseResponse(textResponse: string): unknown { return parseBrokenJson(textResponse); } diff --git a/src/llms/drivers/typescript.ts b/src/llms/drivers/typescript.ts index 7ceedcba..7fd853d8 100644 --- a/src/llms/drivers/typescript.ts +++ b/src/llms/drivers/typescript.ts @@ -39,6 +39,10 @@ IMPORTANT: Every message must be a parsable JSON string without additional outpu `, }); + static { + this.register(); + } + protected parseResponse(textResponse: string): unknown { return parseBrokenJson(textResponse); } diff --git a/src/llms/drivers/yaml.ts b/src/llms/drivers/yaml.ts index 9576452e..67c54515 100644 --- a/src/llms/drivers/yaml.ts +++ b/src/llms/drivers/yaml.ts @@ -38,6 +38,10 @@ IMPORTANT: Every message must be a parsable YAML string without additional outpu `, }); + static { + this.register(); + } + protected parseResponse(textResponse: string): unknown { return yaml.load(textResponse); }