Skip to content

Commit

Permalink
feat(agents): improve template overriding
Browse files Browse the repository at this point in the history
Signed-off-by: Radek Ježek <[email protected]>
  • Loading branch information
jezekra1 committed Jan 9, 2025
1 parent 693e899 commit 88461c0
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 123 deletions.
34 changes: 9 additions & 25 deletions examples/agents/bee_advanced.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@ import {
DuckDuckGoSearchToolSearchType,
} from "bee-agent-framework/tools/search/duckDuckGoSearch";
import { OpenMeteoTool } from "bee-agent-framework/tools/weather/openMeteo";
import {
BeeAssistantPrompt,
BeeSchemaErrorPrompt,
BeeSystemPrompt,
BeeToolErrorPrompt,
BeeToolInputErrorPrompt,
BeeToolNoResultsPrompt,
BeeUserEmptyPrompt,
} from "bee-agent-framework/agents/bee/prompts";
import { PromptTemplate } from "bee-agent-framework/template";
import { BAMChatLLM } from "bee-agent-framework/adapters/bam/chat";
import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory";
import { z } from "zod";
Expand All @@ -31,28 +21,25 @@ const agent = new BeeAgent({
llm,
memory: new UnconstrainedMemory(),
// You can override internal templates
templates: {
user: new PromptTemplate({
templatesUpdate: {
user: {
schema: z
.object({
input: z.string(),
})
.passthrough(),
template: `User: {{input}}`,
}),
system: BeeSystemPrompt.fork((old) => ({
...old,
},
system: {
defaults: {
instructions: "You are a helpful assistant that uses tools to answer questions.",
},
})),
toolError: BeeToolErrorPrompt,
toolInputError: BeeToolInputErrorPrompt,
toolNoResultError: BeeToolNoResultsPrompt.fork((old) => ({
},
toolNoResultError: (old) => ({
...old,
template: `${old.template}\nPlease reformat your input.`,
})),
toolNotFoundError: new PromptTemplate({
}),
toolNotFoundError: {
schema: z
.object({
tools: z.array(z.object({ name: z.string() }).passthrough()),
Expand All @@ -62,10 +49,7 @@ const agent = new BeeAgent({
{{#tools.length}}
Use one of the following tools: {{#trim}}{{#tools}}{{name}},{{/tools}}{{/trim}}
{{/tools.length}}`,
}),
schemaError: BeeSchemaErrorPrompt,
assistant: BeeAssistantPrompt,
userEmpty: BeeUserEmptyPrompt,
},
},
tools: [
new DuckDuckGoSearchTool({
Expand Down
7 changes: 6 additions & 1 deletion src/agents/bee/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,18 @@ import { BaseRunner } from "@/agents/bee/runners/base.js";
import { GraniteRunner } from "@/agents/bee/runners/granite/runner.js";
import { DefaultRunner } from "@/agents/bee/runners/default/runner.js";
import { ValueError } from "@/errors.js";
import { PromptTemplate, PromptTemplateSimpleForkInput } from "@/template.js";

export interface BeeInput {
llm: ChatLLM<ChatLLMOutput>;
tools: AnyTool[];
memory: BaseMemory;
meta?: Omit<AgentMeta, "tools">;
templates?: Partial<BeeAgentTemplates>;
templatesUpdate?: Partial<{
[K in keyof BeeAgentTemplates]: PromptTemplateSimpleForkInput<
BeeAgentTemplates[K] extends PromptTemplate<infer T> ? T : never
>;
}>;
execution?: BeeAgentExecutionConfig;
}

Expand Down
33 changes: 30 additions & 3 deletions src/agents/bee/runners/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import { shallowCopy } from "@/serializer/utils.js";
import { BaseMemory } from "@/memory/base.js";
import { GetRunContext } from "@/context.js";
import { Emitter } from "@/emitter/emitter.js";
import * as R from "remeda";

export interface BeeRunnerLLMInput {
meta: BeeMeta;
Expand All @@ -48,15 +49,21 @@ export interface BeeRunnerToolInput {
export abstract class BaseRunner extends Serializable {
public memory!: BaseMemory;
public readonly iterations: BeeAgentRunIteration[] = [];
protected readonly failedAttemptsCounter: RetryCounter;
protected failedAttemptsCounter: RetryCounter;

public templates: BeeAgentTemplates;

constructor(
protected readonly input: BeeInput,
protected readonly options: BeeRunOptions,
protected readonly run: GetRunContext<BeeAgent>,
) {
super();
this.failedAttemptsCounter = new RetryCounter(options?.execution?.totalMaxRetries, AgentError);
this.failedAttemptsCounter = new RetryCounter(
this.options?.execution?.totalMaxRetries,
AgentError,
);
this.templates = this._resolveTemplates();
}

async createIteration() {
Expand Down Expand Up @@ -89,13 +96,33 @@ export abstract class BaseRunner extends Serializable {

async init(input: BeeRunInput) {
this.memory = await this.initMemory(input);
this.failedAttemptsCounter = new RetryCounter(
this.options?.execution?.totalMaxRetries,
AgentError,
);
}

abstract llm(input: BeeRunnerLLMInput): Promise<BeeAgentRunIteration>;

abstract tool(input: BeeRunnerToolInput): Promise<{ output: string; success: boolean }>;

abstract get templates(): BeeAgentTemplates;
abstract get _defaultTemplates(): BeeAgentTemplates;

_resolveTemplates(): BeeAgentTemplates {
const templatesUpdate = this.input.templatesUpdate;

if (!templatesUpdate) {
return this._defaultTemplates;
}

return R.mapValues(this._defaultTemplates, (template, key) => {
if (!templatesUpdate?.[key]) {
return template;
}
// @ts-expect-error not sure how avoid "incompatible union signatures"
return this._defaultTemplates[key].fork(templatesUpdate[key]);
});
}

protected abstract initMemory(input: BeeRunInput): Promise<BaseMemory>;

Expand Down
31 changes: 16 additions & 15 deletions src/agents/bee/runners/default/runner.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import { DefaultRunner } from "@/agents/bee/runners/default/runner.js";
import { UnconstrainedMemory } from "@/memory/unconstrainedMemory.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { BaseMemory } from "@/memory/base.js";
import { BeeUserPrompt } from "@/agents/bee/prompts.js";
import { zip } from "remeda";
import { RunContext } from "@/context.js";
import { BeeAgent } from "@/agents/bee/agent.js";
import { BeeAgent, BeeInput } from "@/agents/bee/agent.js";
import { PromptTemplateInput } from "@/template.js";

vi.mock("@/memory/tokenMemory.js", async () => {
const { UnconstrainedMemory } = await import("@/memory/unconstrainedMemory.js");
Expand Down Expand Up @@ -61,7 +61,6 @@ describe("Bee Agent Runner", () => {
llm: expect.any(Function),
memory,
tools: [],
templates: {},
},
{},
new RunContext<BeeAgent, any>({} as any, {} as any),
Expand All @@ -83,14 +82,16 @@ describe("Bee Agent Runner", () => {
});

it.each([
BeeUserPrompt.fork((old) => ({
...old,
functions: { ...old.functions, formatMeta: () => "" },
})),
BeeUserPrompt.fork((old) => ({ ...old, template: `{{input}}` })),
BeeUserPrompt.fork((old) => ({ ...old, template: `User: {{input}}` })),
BeeUserPrompt.fork((old) => ({ ...old, template: `` })),
])("Correctly formats user input", async (template: typeof BeeUserPrompt) => {
{
user: (old: PromptTemplateInput<any>) => ({
...old,
functions: { ...old.functions, formatMeta: () => "" },
}),
},
{ user: { template: `{{input}}` } },
{ user: { template: `User: {{input}}` } },
{ user: { template: `` } },
])("Correctly formats user input", async (templatesUpdate: BeeInput["templatesUpdate"]) => {
const memory = new UnconstrainedMemory();
await memory.addMany([
BaseMessage.of({
Expand All @@ -117,9 +118,7 @@ describe("Bee Agent Runner", () => {
llm: expect.any(Function),
memory,
tools: [],
templates: {
user: template,
},
templatesUpdate,
},
{},
new RunContext<BeeAgent, any>({} as any, {} as any),
Expand All @@ -133,7 +132,9 @@ describe("Bee Agent Runner", () => {
],
instance.memory.messages.filter((msg) => msg.role === Role.USER),
)) {
expect(template.render({ input: a.text, meta: undefined })).toStrictEqual(b.text);
expect(instance.templates.user.render({ input: a.text, meta: undefined })).toStrictEqual(
b.text,
);
}
});
});
38 changes: 15 additions & 23 deletions src/agents/bee/runners/default/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
*/

import { BaseRunner, BeeRunnerLLMInput, BeeRunnerToolInput } from "@/agents/bee/runners/base.js";
import type {
BeeAgentRunIteration,
BeeAgentTemplates,
BeeParserInput,
BeeRunInput,
} from "@/agents/bee/types.js";
import type { BeeAgentRunIteration, BeeParserInput, BeeRunInput } from "@/agents/bee/types.js";
import { Retryable } from "@/internals/helpers/retryable.js";
import { AgentError } from "@/agents/base.js";
import {
Expand Down Expand Up @@ -48,6 +43,20 @@ import { Cache } from "@/cache/decoratorCache.js";
import { shallowCopy } from "@/serializer/utils.js";

export class DefaultRunner extends BaseRunner {
get _defaultTemplates() {
return {
system: BeeSystemPrompt,
assistant: BeeAssistantPrompt,
user: BeeUserPrompt,
schemaError: BeeSchemaErrorPrompt,
toolNotFoundError: BeeToolNotFoundPrompt,
toolError: BeeToolErrorPrompt,
toolInputError: BeeToolInputErrorPrompt,
userEmpty: BeeUserEmptyPrompt,
toolNoResultError: BeeToolNoResultsPrompt,
};
}

static {
this.register();
}
Expand Down Expand Up @@ -369,23 +378,6 @@ export class DefaultRunner extends BaseRunner {
return memory;
}

@Cache({ enumerable: false })
get templates(): BeeAgentTemplates {
const customTemplates = this.input.templates ?? {};

return {
system: customTemplates.system ?? BeeSystemPrompt,
assistant: customTemplates.assistant ?? BeeAssistantPrompt,
user: customTemplates.user ?? BeeUserPrompt,
userEmpty: customTemplates.userEmpty ?? BeeUserEmptyPrompt,
toolError: customTemplates.toolError ?? BeeToolErrorPrompt,
toolInputError: customTemplates.toolInputError ?? BeeToolInputErrorPrompt,
toolNoResultError: customTemplates.toolNoResultError ?? BeeToolNoResultsPrompt,
toolNotFoundError: customTemplates.toolNotFoundError ?? BeeToolNotFoundPrompt,
schemaError: customTemplates.schemaError ?? BeeSchemaErrorPrompt,
};
}

protected createParser(tools: AnyTool[]) {
const parserRegex = isEmpty(tools)
? new RegExp(`Thought: .+\\nFinal Answer: [\\s\\S]+`)
Expand Down
37 changes: 14 additions & 23 deletions src/agents/bee/runners/granite/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,15 @@ import {
BeeUserPrompt,
} from "@/agents/bee/prompts.js";

export const GraniteBeeAssistantPrompt = BeeAssistantPrompt.fork((config) => ({
...config,
export const GraniteBeeAssistantPrompt = BeeAssistantPrompt.fork({
template: `{{#thought}}Thought: {{.}}\n{{/thought}}{{#toolName}}Tool Name: {{.}}\n{{/toolName}}{{#toolInput}}Tool Input: {{.}}\n{{/toolInput}}{{#finalAnswer}}Final Answer: {{.}}{{/finalAnswer}}`,
}));
});

export const GraniteBeeSystemPrompt = BeeSystemPrompt.fork((config) => ({
...config,
export const GraniteBeeSystemPrompt = BeeSystemPrompt.fork({
defaults: {
...config.defaults,
instructions: "",
},
functions: {
...config.functions,
formatDate: function () {
const date = this.createdAt ? new Date(this.createdAt) : new Date();
return new Intl.DateTimeFormat("en-US", {
Expand Down Expand Up @@ -86,37 +82,32 @@ You do not need a tool to get the current Date and Time. Use the information ava
{{.}}
{{/instructions}}
`,
}));
});

export const GraniteBeeSchemaErrorPrompt = BeeSchemaErrorPrompt.fork((config) => ({
...config,
export const GraniteBeeSchemaErrorPrompt = BeeSchemaErrorPrompt.fork({
template: `Error: The generated response does not adhere to the communication structure mentioned in the system prompt.
You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by 'Tool Name' and then 'Tool Input' or 'Thought' followed by 'Final Answer'.`,
}));
});

export const GraniteBeeUserPrompt = BeeUserPrompt.fork((config) => ({
...config,
export const GraniteBeeUserPrompt = BeeUserPrompt.fork({
template: `{{input}}`,
}));
});

export const GraniteBeeToolNotFoundPrompt = BeeToolNotFoundPrompt.fork((config) => ({
...config,
export const GraniteBeeToolNotFoundPrompt = BeeToolNotFoundPrompt.fork({
template: `Tool does not exist!
{{#tools.length}}
Use one of the following tools: {{#trim}}{{#tools}}{{name}},{{/tools}}{{/trim}}
{{/tools.length}}`,
}));
});

export const GraniteBeeToolErrorPrompt = BeeToolErrorPrompt.fork((config) => ({
...config,
export const GraniteBeeToolErrorPrompt = BeeToolErrorPrompt.fork({
template: `The tool has failed; the error log is shown below. If the tool cannot accomplish what you want, use a different tool or explain why you can't use it.
{{reason}}`,
}));
});

export const GraniteBeeToolInputErrorPrompt = BeeToolInputErrorPrompt.fork((config) => ({
...config,
export const GraniteBeeToolInputErrorPrompt = BeeToolInputErrorPrompt.fork({
template: `{{reason}}
HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.`,
}));
});
Loading

0 comments on commit 88461c0

Please sign in to comment.