From 4ebb4dbb66a3762521eedb435bd7e649dbd38c29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Je=C5=BEek?= Date: Thu, 9 Jan 2025 17:52:10 +0100 Subject: [PATCH] feat(agents): improve template overriding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Radek Ježek --- examples/agents/bee_advanced.ts | 50 +++++++++---------- src/agents/bee/agent.ts | 7 ++- src/agents/bee/runners/base.ts | 33 ++++++++++-- src/agents/bee/runners/default/runner.spec.ts | 31 ++++++------ src/agents/bee/runners/default/runner.ts | 38 ++++++-------- src/agents/bee/runners/granite/prompts.ts | 37 ++++++-------- src/agents/bee/runners/granite/runner.ts | 42 +++++++--------- src/experimental/workflows/agent.ts | 7 ++- src/internals/types.ts | 4 ++ src/template.ts | 22 ++++++-- 10 files changed, 148 insertions(+), 123 deletions(-) diff --git a/examples/agents/bee_advanced.ts b/examples/agents/bee_advanced.ts index 562b090f..e51b5550 100644 --- a/examples/agents/bee_advanced.ts +++ b/examples/agents/bee_advanced.ts @@ -1,3 +1,19 @@ +/** + * Copyright 2025 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + import "dotenv/config.js"; import { BeeAgent } from "bee-agent-framework/agents/bee/agent"; import { createConsoleReader } from "../helpers/io.js"; @@ -8,16 +24,6 @@ import { DuckDuckGoSearchToolSearchType, } from "bee-agent-framework/tools/search/duckDuckGoSearch"; import { OpenMeteoTool } from "bee-agent-framework/tools/weather/openMeteo"; -import { - BeeAssistantPrompt, - BeeSchemaErrorPrompt, - BeeSystemPrompt, - BeeToolErrorPrompt, - BeeToolInputErrorPrompt, - BeeToolNoResultsPrompt, - BeeUserEmptyPrompt, -} from "bee-agent-framework/agents/bee/prompts"; -import { PromptTemplate } from "bee-agent-framework/template"; import { BAMChatLLM } from "bee-agent-framework/adapters/bam/chat"; import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory"; import { z } from "zod"; @@ -31,28 +37,25 @@ const agent = new BeeAgent({ llm, memory: new UnconstrainedMemory(), // You can override internal templates - templates: { - user: new PromptTemplate({ + templatesUpdate: { + user: { schema: z .object({ input: z.string(), }) .passthrough(), template: `User: {{input}}`, - }), - system: BeeSystemPrompt.fork((old) => ({ - ...old, + }, + system: { defaults: { instructions: "You are a helpful assistant that uses tools to answer questions.", }, - })), - toolError: BeeToolErrorPrompt, - toolInputError: BeeToolInputErrorPrompt, - toolNoResultError: BeeToolNoResultsPrompt.fork((old) => ({ + }, + toolNoResultError: (old) => ({ ...old, template: `${old.template}\nPlease reformat your input.`, - })), - toolNotFoundError: new PromptTemplate({ + }), + toolNotFoundError: { schema: z .object({ tools: z.array(z.object({ name: z.string() }).passthrough()), @@ -62,10 +65,7 @@ const agent = new BeeAgent({ {{#tools.length}} Use one of the following tools: {{#trim}}{{#tools}}{{name}},{{/tools}}{{/trim}} {{/tools.length}}`, - }), - schemaError: BeeSchemaErrorPrompt, - assistant: BeeAssistantPrompt, - userEmpty: BeeUserEmptyPrompt, + }, }, tools: [ new DuckDuckGoSearchTool({ diff --git a/src/agents/bee/agent.ts b/src/agents/bee/agent.ts index 5ed76105..d24556b0 100644 --- a/src/agents/bee/agent.ts +++ b/src/agents/bee/agent.ts @@ -36,13 +36,18 @@ import { BaseRunner } from "@/agents/bee/runners/base.js"; import { GraniteRunner } from "@/agents/bee/runners/granite/runner.js"; import { DefaultRunner } from "@/agents/bee/runners/default/runner.js"; import { ValueError } from "@/errors.js"; +import { PromptTemplate, PromptTemplateSimpleForkInput } from "@/template.js"; export interface BeeInput { llm: ChatLLM; tools: AnyTool[]; memory: BaseMemory; meta?: Omit; - templates?: Partial; + templatesUpdate?: Partial<{ + [K in keyof BeeAgentTemplates]: PromptTemplateSimpleForkInput< + BeeAgentTemplates[K] extends PromptTemplate ? T : never + >; + }>; execution?: BeeAgentExecutionConfig; } diff --git a/src/agents/bee/runners/base.ts b/src/agents/bee/runners/base.ts index dc4ec42b..1b695bc6 100644 --- a/src/agents/bee/runners/base.ts +++ b/src/agents/bee/runners/base.ts @@ -31,6 +31,7 @@ import { shallowCopy } from "@/serializer/utils.js"; import { BaseMemory } from "@/memory/base.js"; import { GetRunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; +import * as R from "remeda"; export interface BeeRunnerLLMInput { meta: BeeMeta; @@ -48,7 +49,9 @@ export interface BeeRunnerToolInput { export abstract class BaseRunner extends Serializable { public memory!: BaseMemory; public readonly iterations: BeeAgentRunIteration[] = []; - protected readonly failedAttemptsCounter: RetryCounter; + protected failedAttemptsCounter: RetryCounter; + + public templates: BeeAgentTemplates; constructor( protected readonly input: BeeInput, @@ -56,7 +59,11 @@ export abstract class BaseRunner extends Serializable { protected readonly run: GetRunContext, ) { super(); - this.failedAttemptsCounter = new RetryCounter(options?.execution?.totalMaxRetries, AgentError); + this.failedAttemptsCounter = new RetryCounter( + this.options?.execution?.totalMaxRetries, + AgentError, + ); + this.templates = this._resolveTemplates(); } async createIteration() { @@ -89,13 +96,33 @@ export abstract class BaseRunner extends Serializable { async init(input: BeeRunInput) { this.memory = await this.initMemory(input); + this.failedAttemptsCounter = new RetryCounter( + this.options?.execution?.totalMaxRetries, + AgentError, + ); } abstract llm(input: BeeRunnerLLMInput): Promise; abstract tool(input: BeeRunnerToolInput): Promise<{ output: string; success: boolean }>; - abstract get templates(): BeeAgentTemplates; + abstract get _defaultTemplates(): BeeAgentTemplates; + + _resolveTemplates(): BeeAgentTemplates { + const templatesUpdate = this.input.templatesUpdate; + + if (!templatesUpdate) { + return this._defaultTemplates; + } + + return R.mapValues(this._defaultTemplates, (template, key) => { + if (!templatesUpdate?.[key]) { + return template; + } + // @ts-expect-error not sure how avoid "incompatible union signatures" + return this._defaultTemplates[key].fork(templatesUpdate[key]); + }); + } protected abstract initMemory(input: BeeRunInput): Promise; diff --git a/src/agents/bee/runners/default/runner.spec.ts b/src/agents/bee/runners/default/runner.spec.ts index cdea62d0..cf20de3b 100644 --- a/src/agents/bee/runners/default/runner.spec.ts +++ b/src/agents/bee/runners/default/runner.spec.ts @@ -18,10 +18,10 @@ import { DefaultRunner } from "@/agents/bee/runners/default/runner.js"; import { UnconstrainedMemory } from "@/memory/unconstrainedMemory.js"; import { BaseMessage, Role } from "@/llms/primitives/message.js"; import { BaseMemory } from "@/memory/base.js"; -import { BeeUserPrompt } from "@/agents/bee/prompts.js"; import { zip } from "remeda"; import { RunContext } from "@/context.js"; -import { BeeAgent } from "@/agents/bee/agent.js"; +import { BeeAgent, BeeInput } from "@/agents/bee/agent.js"; +import { PromptTemplateInput } from "@/template.js"; vi.mock("@/memory/tokenMemory.js", async () => { const { UnconstrainedMemory } = await import("@/memory/unconstrainedMemory.js"); @@ -61,7 +61,6 @@ describe("Bee Agent Runner", () => { llm: expect.any(Function), memory, tools: [], - templates: {}, }, {}, new RunContext({} as any, {} as any), @@ -83,14 +82,16 @@ describe("Bee Agent Runner", () => { }); it.each([ - BeeUserPrompt.fork((old) => ({ - ...old, - functions: { ...old.functions, formatMeta: () => "" }, - })), - BeeUserPrompt.fork((old) => ({ ...old, template: `{{input}}` })), - BeeUserPrompt.fork((old) => ({ ...old, template: `User: {{input}}` })), - BeeUserPrompt.fork((old) => ({ ...old, template: `` })), - ])("Correctly formats user input", async (template: typeof BeeUserPrompt) => { + { + user: (old: PromptTemplateInput) => ({ + ...old, + functions: { ...old.functions, formatMeta: () => "" }, + }), + }, + { user: { template: `{{input}}` } }, + { user: { template: `User: {{input}}` } }, + { user: { template: `` } }, + ])("Correctly formats user input", async (templatesUpdate: BeeInput["templatesUpdate"]) => { const memory = new UnconstrainedMemory(); await memory.addMany([ BaseMessage.of({ @@ -117,9 +118,7 @@ describe("Bee Agent Runner", () => { llm: expect.any(Function), memory, tools: [], - templates: { - user: template, - }, + templatesUpdate, }, {}, new RunContext({} as any, {} as any), @@ -133,7 +132,9 @@ describe("Bee Agent Runner", () => { ], instance.memory.messages.filter((msg) => msg.role === Role.USER), )) { - expect(template.render({ input: a.text, meta: undefined })).toStrictEqual(b.text); + expect(instance.templates.user.render({ input: a.text, meta: undefined })).toStrictEqual( + b.text, + ); } }); }); diff --git a/src/agents/bee/runners/default/runner.ts b/src/agents/bee/runners/default/runner.ts index 4b786959..18d7d510 100644 --- a/src/agents/bee/runners/default/runner.ts +++ b/src/agents/bee/runners/default/runner.ts @@ -15,12 +15,7 @@ */ import { BaseRunner, BeeRunnerLLMInput, BeeRunnerToolInput } from "@/agents/bee/runners/base.js"; -import type { - BeeAgentRunIteration, - BeeAgentTemplates, - BeeParserInput, - BeeRunInput, -} from "@/agents/bee/types.js"; +import type { BeeAgentRunIteration, BeeParserInput, BeeRunInput } from "@/agents/bee/types.js"; import { Retryable } from "@/internals/helpers/retryable.js"; import { AgentError } from "@/agents/base.js"; import { @@ -48,6 +43,20 @@ import { Cache } from "@/cache/decoratorCache.js"; import { shallowCopy } from "@/serializer/utils.js"; export class DefaultRunner extends BaseRunner { + get _defaultTemplates() { + return { + system: BeeSystemPrompt, + assistant: BeeAssistantPrompt, + user: BeeUserPrompt, + schemaError: BeeSchemaErrorPrompt, + toolNotFoundError: BeeToolNotFoundPrompt, + toolError: BeeToolErrorPrompt, + toolInputError: BeeToolInputErrorPrompt, + userEmpty: BeeUserEmptyPrompt, + toolNoResultError: BeeToolNoResultsPrompt, + }; + } + static { this.register(); } @@ -369,23 +378,6 @@ export class DefaultRunner extends BaseRunner { return memory; } - @Cache({ enumerable: false }) - get templates(): BeeAgentTemplates { - const customTemplates = this.input.templates ?? {}; - - return { - system: customTemplates.system ?? BeeSystemPrompt, - assistant: customTemplates.assistant ?? BeeAssistantPrompt, - user: customTemplates.user ?? BeeUserPrompt, - userEmpty: customTemplates.userEmpty ?? BeeUserEmptyPrompt, - toolError: customTemplates.toolError ?? BeeToolErrorPrompt, - toolInputError: customTemplates.toolInputError ?? BeeToolInputErrorPrompt, - toolNoResultError: customTemplates.toolNoResultError ?? BeeToolNoResultsPrompt, - toolNotFoundError: customTemplates.toolNotFoundError ?? BeeToolNotFoundPrompt, - schemaError: customTemplates.schemaError ?? BeeSchemaErrorPrompt, - }; - } - protected createParser(tools: AnyTool[]) { const parserRegex = isEmpty(tools) ? new RegExp(`Thought: .+\\nFinal Answer: [\\s\\S]+`) diff --git a/src/agents/bee/runners/granite/prompts.ts b/src/agents/bee/runners/granite/prompts.ts index 97fa1cf5..b6f47915 100644 --- a/src/agents/bee/runners/granite/prompts.ts +++ b/src/agents/bee/runners/granite/prompts.ts @@ -24,19 +24,15 @@ import { BeeUserPrompt, } from "@/agents/bee/prompts.js"; -export const GraniteBeeAssistantPrompt = BeeAssistantPrompt.fork((config) => ({ - ...config, +export const GraniteBeeAssistantPrompt = BeeAssistantPrompt.fork({ template: `{{#thought}}Thought: {{.}}\n{{/thought}}{{#toolName}}Tool Name: {{.}}\n{{/toolName}}{{#toolInput}}Tool Input: {{.}}\n{{/toolInput}}{{#finalAnswer}}Final Answer: {{.}}{{/finalAnswer}}`, -})); +}); -export const GraniteBeeSystemPrompt = BeeSystemPrompt.fork((config) => ({ - ...config, +export const GraniteBeeSystemPrompt = BeeSystemPrompt.fork({ defaults: { - ...config.defaults, instructions: "", }, functions: { - ...config.functions, formatDate: function () { const date = this.createdAt ? new Date(this.createdAt) : new Date(); return new Intl.DateTimeFormat("en-US", { @@ -86,37 +82,32 @@ You do not need a tool to get the current Date and Time. Use the information ava {{.}} {{/instructions}} `, -})); +}); -export const GraniteBeeSchemaErrorPrompt = BeeSchemaErrorPrompt.fork((config) => ({ - ...config, +export const GraniteBeeSchemaErrorPrompt = BeeSchemaErrorPrompt.fork({ template: `Error: The generated response does not adhere to the communication structure mentioned in the system prompt. You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by 'Tool Name' and then 'Tool Input' or 'Thought' followed by 'Final Answer'.`, -})); +}); -export const GraniteBeeUserPrompt = BeeUserPrompt.fork((config) => ({ - ...config, +export const GraniteBeeUserPrompt = BeeUserPrompt.fork({ template: `{{input}}`, -})); +}); -export const GraniteBeeToolNotFoundPrompt = BeeToolNotFoundPrompt.fork((config) => ({ - ...config, +export const GraniteBeeToolNotFoundPrompt = BeeToolNotFoundPrompt.fork({ template: `Tool does not exist! {{#tools.length}} Use one of the following tools: {{#trim}}{{#tools}}{{name}},{{/tools}}{{/trim}} {{/tools.length}}`, -})); +}); -export const GraniteBeeToolErrorPrompt = BeeToolErrorPrompt.fork((config) => ({ - ...config, +export const GraniteBeeToolErrorPrompt = BeeToolErrorPrompt.fork({ template: `The tool has failed; the error log is shown below. If the tool cannot accomplish what you want, use a different tool or explain why you can't use it. {{reason}}`, -})); +}); -export const GraniteBeeToolInputErrorPrompt = BeeToolInputErrorPrompt.fork((config) => ({ - ...config, +export const GraniteBeeToolInputErrorPrompt = BeeToolInputErrorPrompt.fork({ template: `{{reason}} HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.`, -})); +}); diff --git a/src/agents/bee/runners/granite/runner.ts b/src/agents/bee/runners/granite/runner.ts index e928ff34..158c8cf6 100644 --- a/src/agents/bee/runners/granite/runner.ts +++ b/src/agents/bee/runners/granite/runner.ts @@ -15,16 +15,11 @@ */ import { BaseMessage, Role } from "@/llms/primitives/message.js"; -import type { AnyTool } from "@/tools/base.js"; import { isEmpty } from "remeda"; +import type { AnyTool } from "@/tools/base.js"; import { DefaultRunner } from "@/agents/bee/runners/default/runner.js"; import { BaseMemory } from "@/memory/base.js"; -import type { - BeeAgentTemplates, - BeeParserInput, - BeeRunInput, - BeeRunOptions, -} from "@/agents/bee/types.js"; +import type { BeeParserInput, BeeRunInput, BeeRunOptions } from "@/agents/bee/types.js"; import { BeeAgent, BeeInput } from "@/agents/bee/agent.js"; import type { GetRunContext } from "@/context.js"; import { @@ -36,9 +31,24 @@ import { GraniteBeeToolNotFoundPrompt, GraniteBeeUserPrompt, } from "@/agents/bee/runners/granite/prompts.js"; -import { Cache } from "@/cache/decoratorCache.js"; +import { BeeToolNoResultsPrompt, BeeUserEmptyPrompt } from "@/agents/bee/prompts.js"; export class GraniteRunner extends DefaultRunner { + get _defaultTemplates() { + return { + system: GraniteBeeSystemPrompt, + assistant: GraniteBeeAssistantPrompt, + user: GraniteBeeUserPrompt, + schemaError: GraniteBeeSchemaErrorPrompt, + toolNotFoundError: GraniteBeeToolNotFoundPrompt, + toolError: GraniteBeeToolErrorPrompt, + toolInputError: GraniteBeeToolInputErrorPrompt, + // Note: These are from bee + userEmpty: BeeUserEmptyPrompt, + toolNoResultError: BeeToolNoResultsPrompt, + }; + } + static { this.register(); } @@ -89,22 +99,6 @@ export class GraniteRunner extends DefaultRunner { return memory; } - @Cache({ enumerable: false }) - get templates(): BeeAgentTemplates { - const customTemplates = this.input.templates ?? {}; - - return { - ...super.templates, - user: customTemplates.user ?? GraniteBeeUserPrompt, - system: customTemplates.system ?? GraniteBeeSystemPrompt, - assistant: customTemplates.assistant ?? GraniteBeeAssistantPrompt, - schemaError: customTemplates.schemaError ?? GraniteBeeSchemaErrorPrompt, - toolNotFoundError: customTemplates.toolNotFoundError ?? GraniteBeeToolNotFoundPrompt, - toolError: customTemplates.toolError ?? GraniteBeeToolErrorPrompt, - toolInputError: customTemplates.toolInputError ?? GraniteBeeToolInputErrorPrompt, - }; - } - protected createParser(tools: AnyTool[]) { const { parser } = super.createParser(tools); diff --git a/src/experimental/workflows/agent.ts b/src/experimental/workflows/agent.ts index 7d7acd40..fc6c37bf 100644 --- a/src/experimental/workflows/agent.ts +++ b/src/experimental/workflows/agent.ts @@ -19,7 +19,6 @@ import { Workflow, WorkflowRunOptions } from "@/experimental/workflows/workflow. import { BaseMessage } from "@/llms/primitives/message.js"; import { AnyTool } from "@/tools/base.js"; import { AnyChatLLM } from "@/llms/chat.js"; -import { BeeSystemPrompt } from "@/agents/bee/prompts.js"; import { BaseMemory, ReadOnlyMemory } from "@/memory/base.js"; import { z } from "zod"; import { UnconstrainedMemory } from "@/memory/unconstrainedMemory.js"; @@ -99,14 +98,14 @@ export class AgentWorkflow { }, execution: input.execution, ...(input.instructions && { - templates: { - system: BeeSystemPrompt.fork((config) => ({ + templatesUpdate: { + system: (config) => ({ ...config, defaults: { ...config.defaults, instructions: input.instructions || config.defaults.instructions, }, - })), + }), }, }), }); diff --git a/src/internals/types.ts b/src/internals/types.ts index 251afe14..e9aad87c 100644 --- a/src/internals/types.ts +++ b/src/internals/types.ts @@ -97,6 +97,10 @@ export type NoPromise = T extends Promise ? never : T; export type TypedFn

= (...args: P) => R; +export type DeepPartial = { + [P in keyof T]?: T[P] extends object ? DeepPartial : T[P]; +}; + export function narrowTo(value: unknown, fn: boolean | ((value: T) => boolean)): value is T { if (typeof fn === "function") { return fn(value as T); diff --git a/src/template.ts b/src/template.ts index 05f3dbcc..5b0227ed 100644 --- a/src/template.ts +++ b/src/template.ts @@ -15,7 +15,7 @@ */ import { FrameworkError } from "@/errors.js"; -import { ObjectLike, PlainObject } from "@/internals/types.js"; +import { DeepPartial, ObjectLike, PlainObject } from "@/internals/types.js"; import * as R from "remeda"; import Mustache from "mustache"; import { Serializable } from "@/internals/serializable.js"; @@ -59,6 +59,10 @@ type Customizer = ( config: Required>, ) => PromptTemplateConstructor; +export type PromptTemplateSimpleForkInput = + | Customizer + | DeepPartial>; + export class PromptTemplateError extends FrameworkError { template: PromptTemplate; @@ -73,7 +77,7 @@ export class PromptTemplateError extends FrameworkError { export class ValidationPromptTemplateError extends PromptTemplateError {} export class PromptTemplate extends Serializable { - protected config: Required>; + public config: Required>; public static functions = { trim: () => (text: string, render: (value: string) => string) => { @@ -117,13 +121,21 @@ export class PromptTemplate extends Serializable { } } - fork(customizer: Customizer): PromptTemplate; + fork(customizer: PromptTemplateSimpleForkInput): PromptTemplate; fork(customizer: Customizer): PromptTemplate; fork( - customizer: Customizer | Customizer, + customizerOrUpdate: + | Customizer + | Customizer + | DeepPartial>, ): PromptTemplate { const config = shallowCopy(this.config); - const newConfig = customizer?.(config) ?? config; + + if (R.isPlainObject(customizerOrUpdate)) { + return new PromptTemplate(R.mergeDeep(config, customizerOrUpdate) as typeof config); + } + + const newConfig = customizerOrUpdate?.(config) ?? config; return new PromptTemplate(newConfig); }