From 16d4bfd362eb10cdbd1aec0d822b141076d6c53e Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Mon, 9 Dec 2024 15:31:40 +0100 Subject: [PATCH] feat(llms)!: extend driver response type Return wrapper object instead of the result directly. Signed-off-by: Tomas Dvorak --- examples/agents/custom_agent.ts | 8 +++---- src/internals/helpers/schema.ts | 9 +++++--- src/llms/drivers/base.ts | 37 +++++++++++++++++++-------------- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/examples/agents/custom_agent.ts b/examples/agents/custom_agent.ts index dc4469c0..b569aa0f 100644 --- a/examples/agents/custom_agent.ts +++ b/examples/agents/custom_agent.ts @@ -17,7 +17,7 @@ interface RunInput { interface RunOutput { message: BaseMessage; - raw: { + state: { thought: string; final_answer: string; }; @@ -83,13 +83,13 @@ IMPORTANT: Every message must be a parsable JSON string without additional outpu const result = BaseMessage.of({ role: Role.ASSISTANT, - text: response.final_answer, + text: response.parsed.final_answer, }); await this.memory.add(result); return { message: result, - raw: response, + state: response.parsed, }; } @@ -122,4 +122,4 @@ const agent = new CustomAgent({ const response = await agent.run({ message: BaseMessage.of({ role: Role.USER, text: "Why is the sky blue?" }), }); -console.info(response.raw); +console.info(response.state); diff --git a/src/internals/helpers/schema.ts b/src/internals/helpers/schema.ts index 2810998b..094432d3 100644 --- a/src/internals/helpers/schema.ts +++ b/src/internals/helpers/schema.ts @@ -15,7 +15,7 @@ */ import { TypeOf, ZodType, ZodEffects, ZodTypeAny, AnyZodObject, input } from "zod"; -import { zodToJsonSchema } from "zod-to-json-schema"; +import { zodToJsonSchema, Options } from "zod-to-json-schema"; import { Ajv, SchemaObject, ValidateFunction, Options as AjvOptions } from "ajv"; import addFormats from "ajv-formats"; import { findFirstPair } from "@/internals/helpers/string.js"; @@ -39,10 +39,13 @@ export function validateSchema( } } -export function toJsonSchema(schema: T): SchemaObject { +export function toJsonSchema( + schema: T, + options?: Partial, +): SchemaObject { validateSchema(schema); if (schema instanceof ZodType) { - return zodToJsonSchema(schema); + return zodToJsonSchema(schema, options); } return schema; } diff --git a/src/llms/drivers/base.ts b/src/llms/drivers/base.ts index 5c7f4ca2..5209dbbc 100644 --- a/src/llms/drivers/base.ts +++ b/src/llms/drivers/base.ts @@ -14,19 +14,14 @@ * limitations under the License. */ -import { - AnySchemaLike, - FromSchemaLike, - createSchemaValidator, - toJsonSchema, -} from "@/internals/helpers/schema.js"; +import { AnySchemaLike, createSchemaValidator, toJsonSchema } from "@/internals/helpers/schema.js"; import { GenerateOptions, LLMError } from "@/llms/base.js"; import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js"; import { BaseMessage, Role } from "@/llms/primitives/message.js"; import { Retryable } from "@/internals/helpers/retryable.js"; import { PromptTemplate } from "@/template.js"; import { SchemaObject } from "ajv"; -import { z } from "zod"; +import { TypeOf, z, ZodTypeAny } from "zod"; import { Serializable } from "@/internals/serializable.js"; export interface GenerateSchemaInput { @@ -34,6 +29,12 @@ export interface GenerateSchemaInput { options?: T; } +export interface DriverResponse { + raw: ChatLLMOutput; + parsed: T extends ZodTypeAny ? TypeOf : T; + messages: BaseMessage[]; +} + export abstract class BaseDriver< TGenerateOptions extends GenerateOptions = GenerateOptions, > extends Serializable { @@ -60,11 +61,11 @@ Validation Errors: "{{errors}}"`, return undefined; } - async generate( - schema: T, + async generate( + schema: T extends AnySchemaLike ? T : SchemaObject, input: BaseMessage[], { maxRetries = 3, options }: GenerateSchemaInput = {}, - ): Promise> { + ): Promise> { const jsonSchema = toJsonSchema(schema); const validator = createSchemaValidator(jsonSchema); const schemaString = await this.schemaToString(jsonSchema); @@ -79,15 +80,15 @@ Validation Errors: "{{errors}}"`, return new Retryable({ executor: async () => { - const rawResponse = await this.llm.generate(messages, { + const raw = await this.llm.generate(messages, { guided: this.guided(jsonSchema), ...options, } as TGenerateOptions); - const textResponse = rawResponse.getTextContent(); - let parsedResponse: any; + const textResponse = raw.getTextContent(); + let parsed: any; try { - parsedResponse = this.parseResponse(textResponse); + parsed = this.parseResponse(textResponse); } catch (error) { throw new LLMError(`Failed to parse the generated response.`, [], { isFatal: false, @@ -96,7 +97,7 @@ Validation Errors: "{{errors}}"`, }); } - const success = validator(parsedResponse); + const success = validator(parsed); if (!success) { const context = { expected: schemaString, @@ -120,7 +121,11 @@ Validation Errors: "{{errors}}"`, }, ); } - return parsedResponse as FromSchemaLike; + return { + raw: raw, + parsed: parsed, + messages, + }; }, config: { signal: options?.signal,