Skip to content

Commit

Permalink
feat(llms)!: extend driver response type
Browse files Browse the repository at this point in the history
Return wrapper object instead of the result directly.

Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D committed Dec 9, 2024
1 parent 2778a9a commit 16d4bfd
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 23 deletions.
8 changes: 4 additions & 4 deletions examples/agents/custom_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ interface RunInput {

interface RunOutput {
message: BaseMessage;
raw: {
state: {
thought: string;
final_answer: string;
};
Expand Down Expand Up @@ -83,13 +83,13 @@ IMPORTANT: Every message must be a parsable JSON string without additional outpu

const result = BaseMessage.of({
role: Role.ASSISTANT,
text: response.final_answer,
text: response.parsed.final_answer,
});
await this.memory.add(result);

return {
message: result,
raw: response,
state: response.parsed,
};
}

Expand Down Expand Up @@ -122,4 +122,4 @@ const agent = new CustomAgent({
const response = await agent.run({
message: BaseMessage.of({ role: Role.USER, text: "Why is the sky blue?" }),
});
console.info(response.raw);
console.info(response.state);
9 changes: 6 additions & 3 deletions src/internals/helpers/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import { TypeOf, ZodType, ZodEffects, ZodTypeAny, AnyZodObject, input } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { zodToJsonSchema, Options } from "zod-to-json-schema";
import { Ajv, SchemaObject, ValidateFunction, Options as AjvOptions } from "ajv";
import addFormats from "ajv-formats";
import { findFirstPair } from "@/internals/helpers/string.js";
Expand All @@ -39,10 +39,13 @@ export function validateSchema<T extends AnySchemaLike>(
}
}

export function toJsonSchema<T extends AnySchemaLike>(schema: T): SchemaObject {
export function toJsonSchema<T extends AnySchemaLike>(
schema: T,
options?: Partial<Options>,
): SchemaObject {
validateSchema(schema);
if (schema instanceof ZodType) {
return zodToJsonSchema(schema);
return zodToJsonSchema(schema, options);
}
return schema;
}
Expand Down
37 changes: 21 additions & 16 deletions src/llms/drivers/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,27 @@
* limitations under the License.
*/

import {
AnySchemaLike,
FromSchemaLike,
createSchemaValidator,
toJsonSchema,
} from "@/internals/helpers/schema.js";
import { AnySchemaLike, createSchemaValidator, toJsonSchema } from "@/internals/helpers/schema.js";
import { GenerateOptions, LLMError } from "@/llms/base.js";
import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { Retryable } from "@/internals/helpers/retryable.js";
import { PromptTemplate } from "@/template.js";
import { SchemaObject } from "ajv";
import { z } from "zod";
import { TypeOf, z, ZodTypeAny } from "zod";
import { Serializable } from "@/internals/serializable.js";

export interface GenerateSchemaInput<T> {
maxRetries?: number;
options?: T;
}

export interface DriverResponse<T> {
raw: ChatLLMOutput;
parsed: T extends ZodTypeAny ? TypeOf<T> : T;
messages: BaseMessage[];
}

export abstract class BaseDriver<
TGenerateOptions extends GenerateOptions = GenerateOptions,
> extends Serializable<any> {
Expand All @@ -60,11 +61,11 @@ Validation Errors: "{{errors}}"`,
return undefined;
}

async generate<T extends AnySchemaLike>(
schema: T,
async generate<T = any>(
schema: T extends AnySchemaLike ? T : SchemaObject,
input: BaseMessage[],
{ maxRetries = 3, options }: GenerateSchemaInput<TGenerateOptions> = {},
): Promise<FromSchemaLike<T>> {
): Promise<DriverResponse<T>> {
const jsonSchema = toJsonSchema(schema);
const validator = createSchemaValidator(jsonSchema);
const schemaString = await this.schemaToString(jsonSchema);
Expand All @@ -79,15 +80,15 @@ Validation Errors: "{{errors}}"`,

return new Retryable({
executor: async () => {
const rawResponse = await this.llm.generate(messages, {
const raw = await this.llm.generate(messages, {
guided: this.guided(jsonSchema),
...options,
} as TGenerateOptions);
const textResponse = rawResponse.getTextContent();
let parsedResponse: any;
const textResponse = raw.getTextContent();
let parsed: any;

try {
parsedResponse = this.parseResponse(textResponse);
parsed = this.parseResponse(textResponse);
} catch (error) {
throw new LLMError(`Failed to parse the generated response.`, [], {
isFatal: false,
Expand All @@ -96,7 +97,7 @@ Validation Errors: "{{errors}}"`,
});
}

const success = validator(parsedResponse);
const success = validator(parsed);
if (!success) {
const context = {
expected: schemaString,
Expand All @@ -120,7 +121,11 @@ Validation Errors: "{{errors}}"`,
},
);
}
return parsedResponse as FromSchemaLike<T>;
return {
raw: raw,
parsed: parsed,
messages,
};
},
config: {
signal: options?.signal,
Expand Down

0 comments on commit 16d4bfd

Please sign in to comment.