diff --git a/src/agents/bee/agent.ts b/src/agents/bee/agent.ts index 5f7d673d..3cbe0c1c 100644 --- a/src/agents/bee/agent.ts +++ b/src/agents/bee/agent.ts @@ -30,7 +30,6 @@ import { } from "@/agents/bee/types.js"; import { GetRunContext } from "@/context.js"; import { assign } from "@/internals/helpers/object.js"; -import { BeeAssistantPrompt } from "@/agents/bee/prompts.js"; import * as R from "remeda"; import { BaseRunner } from "@/agents/bee/runners/base.js"; import { GraniteRunner } from "@/agents/bee/runners/granite/runner.js"; @@ -130,7 +129,7 @@ export class BeeAgent extends BaseAgent JSON.stringify(call)), diff --git a/src/agents/bee/runners/base.ts b/src/agents/bee/runners/base.ts index 6aa2e827..53c61193 100644 --- a/src/agents/bee/runners/base.ts +++ b/src/agents/bee/runners/base.ts @@ -17,6 +17,7 @@ import { Serializable } from "@/internals/serializable.js"; import { BeeAgentRunIteration, + BeeAgentTemplates, BeeCallbacks, BeeIterationToolResult, BeeMeta, @@ -90,6 +91,8 @@ export abstract class BaseRunner extends Serializable { abstract tool(input: BeeRunnerToolInput): Promise<{ output: string; success: boolean }>; + abstract get templates(): BeeAgentTemplates; + protected abstract initMemory(input: BeeRunInput): Promise; createSnapshot() { diff --git a/src/agents/bee/runners/default/runner.ts b/src/agents/bee/runners/default/runner.ts index 64fb9bd1..979c99a2 100644 --- a/src/agents/bee/runners/default/runner.ts +++ b/src/agents/bee/runners/default/runner.ts @@ -14,10 +14,16 @@ * limitations under the License. */ import { BaseRunner, BeeRunnerLLMInput, BeeRunnerToolInput } from "@/agents/bee/runners/base.js"; -import { BeeAgentRunIteration, BeeParserInput, BeeRunInput } from "@/agents/bee/types.js"; +import type { + BeeAgentRunIteration, + BeeAgentTemplates, + BeeParserInput, + BeeRunInput, +} from "@/agents/bee/types.js"; import { Retryable } from "@/internals/helpers/retryable.js"; import { AgentError } from "@/agents/base.js"; import { + BeeAssistantPrompt, BeeSchemaErrorPrompt, BeeSystemPrompt, BeeToolErrorPrompt, @@ -69,7 +75,7 @@ export class DefaultRunner extends BaseRunner { await this.memory.add( BaseMessage.of({ role: Role.ASSISTANT, - text: (this.input.templates?.schemaError ?? BeeSchemaErrorPrompt).render({}), + text: this.templates.schemaError.render({}), meta: { [tempMessageKey]: true, }, @@ -151,10 +157,9 @@ export class DefaultRunner extends BaseRunner { }), ); - const template = this.input.templates?.toolNotFoundError ?? BeeToolNotFoundPrompt; return { success: false, - output: template.render({ + output: this.templates.toolNotFoundError.render({ tools: this.input.tools, }), }; @@ -202,8 +207,7 @@ export class DefaultRunner extends BaseRunner { }); if (toolOutput.isEmpty()) { - const template = this.input.templates?.toolNoResultError ?? BeeToolNoResultsPrompt; - return { output: template.render({}), success: true }; + return { output: this.templates.toolNoResultError.render({}), success: true }; } return { @@ -225,10 +229,9 @@ export class DefaultRunner extends BaseRunner { if (error instanceof ToolInputValidationError) { this.failedAttemptsCounter.use(error); - const template = this.input.templates?.toolInputError ?? BeeToolInputErrorPrompt; return { success: false, - output: template.render({ + output: this.templates.toolInputError.render({ reason: error.toString(), }), }; @@ -237,10 +240,9 @@ export class DefaultRunner extends BaseRunner { if (error instanceof ToolError) { this.failedAttemptsCounter.use(error); - const template = this.input.templates?.toolError ?? BeeToolErrorPrompt; return { success: false, - output: template.render({ + output: this.templates.toolError.render({ reason: error.explain(), }), }; @@ -288,7 +290,7 @@ export class DefaultRunner extends BaseRunner { message: async () => BaseMessage.of({ role: Role.SYSTEM, - text: (this.input.templates?.system ?? BeeSystemPrompt).render({ + text: this.templates.system.render({ tools: await self.system.variables.tools(), instructions: undefined, }), @@ -359,6 +361,23 @@ export class DefaultRunner extends BaseRunner { return memory; } + @Cache({ enumerable: false }) + get templates(): BeeAgentTemplates { + const customTemplates = this.input.templates ?? {}; + + return { + system: customTemplates.system ?? BeeSystemPrompt, + assistant: customTemplates.assistant ?? BeeAssistantPrompt, + user: customTemplates.user ?? BeeUserPrompt, + userEmpty: customTemplates.userEmpty ?? BeeUserEmptyPrompt, + toolError: customTemplates.toolError ?? BeeToolErrorPrompt, + toolInputError: customTemplates.toolInputError ?? BeeToolInputErrorPrompt, + toolNoResultError: customTemplates.toolNoResultError ?? BeeToolNoResultsPrompt, + toolNotFoundError: customTemplates.toolNotFoundError ?? BeeToolNotFoundPrompt, + schemaError: customTemplates.schemaError ?? BeeSchemaErrorPrompt, + }; + } + protected createParser(tools: AnyTool[]) { const parserRegex = isEmpty(tools) ? new RegExp(`Thought: .+\\nFinal Answer: [\\s\\S]+`) diff --git a/src/agents/bee/runners/granite/runner.ts b/src/agents/bee/runners/granite/runner.ts index 52df06ff..37be4f38 100644 --- a/src/agents/bee/runners/granite/runner.ts +++ b/src/agents/bee/runners/granite/runner.ts @@ -18,7 +18,12 @@ import type { AnyTool } from "@/tools/base.js"; import { isEmpty } from "remeda"; import { DefaultRunner } from "@/agents/bee/runners/default/runner.js"; import { BaseMemory } from "@/memory/base.js"; -import type { BeeParserInput, BeeRunInput, BeeRunOptions } from "@/agents/bee/types.js"; +import type { + BeeAgentTemplates, + BeeParserInput, + BeeRunInput, + BeeRunOptions, +} from "@/agents/bee/types.js"; import { BeeAgent, BeeInput } from "@/agents/bee/agent.js"; import type { GetRunContext } from "@/context.js"; import { @@ -26,6 +31,7 @@ import { GraniteBeeSchemaErrorPrompt, GraniteBeeSystemPrompt, } from "@/agents/bee/runners/granite/prompts.js"; +import { Cache } from "@/cache/decoratorCache.js"; export class GraniteRunner extends DefaultRunner { static { @@ -33,19 +39,7 @@ export class GraniteRunner extends DefaultRunner { } constructor(input: BeeInput, options: BeeRunOptions, run: GetRunContext) { - super( - { - ...input, - templates: { - ...input.templates, - system: input.templates?.system ?? GraniteBeeSystemPrompt, - assistant: input.templates?.assistant ?? GraniteBeeAssistantPrompt, - schemaError: input.templates?.schemaError ?? GraniteBeeSchemaErrorPrompt, - }, - }, - options, - run, - ); + super(input, options, run); run.emitter.on( "update", @@ -83,6 +77,18 @@ export class GraniteRunner extends DefaultRunner { return memory; } + @Cache({ enumerable: false }) + get templates(): BeeAgentTemplates { + const customTemplates = this.input.templates ?? {}; + + return { + ...super.templates, + system: customTemplates.system ?? GraniteBeeSystemPrompt, + assistant: customTemplates.assistant ?? GraniteBeeAssistantPrompt, + schemaError: customTemplates.schemaError ?? GraniteBeeSchemaErrorPrompt, + }; + } + protected createParser(tools: AnyTool[]) { const { parser } = super.createParser(tools);