Skip to content

Commit

Permalink
feat(agents): use prompt templates from a concrete runner (i-am-bee#223)
Browse files Browse the repository at this point in the history
Ref: i-am-bee#219

Signed-off-by: Tomas Dvorak <[email protected]>
Signed-off-by: Matias Molinas <[email protected]>
  • Loading branch information
Tomas2D authored and matiasmolinas committed Dec 7, 2024
1 parent 6ccb3f1 commit 9980b5f
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 27 deletions.
3 changes: 1 addition & 2 deletions src/agents/bee/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import {
} from "@/agents/bee/types.js";
import { GetRunContext } from "@/context.js";
import { assign } from "@/internals/helpers/object.js";
import { BeeAssistantPrompt } from "@/agents/bee/prompts.js";
import * as R from "remeda";
import { BaseRunner } from "@/agents/bee/runners/base.js";
import { GraniteRunner } from "@/agents/bee/runners/granite/runner.js";
Expand Down Expand Up @@ -130,7 +129,7 @@ export class BeeAgent extends BaseAgent<BeeRunInput, BeeRunOutput, BeeRunOptions
await runner.memory.add(
BaseMessage.of({
role: Role.ASSISTANT,
text: (this.input.templates?.assistant ?? BeeAssistantPrompt).render({
text: runner.templates.assistant.render({
thought: [state.thought].filter(R.isTruthy),
toolName: [state.tool_name].filter(R.isTruthy),
toolInput: [state.tool_input].filter(R.isTruthy).map((call) => JSON.stringify(call)),
Expand Down
3 changes: 3 additions & 0 deletions src/agents/bee/runners/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import { Serializable } from "@/internals/serializable.js";
import {
BeeAgentRunIteration,
BeeAgentTemplates,
BeeCallbacks,
BeeIterationToolResult,
BeeMeta,
Expand Down Expand Up @@ -90,6 +91,8 @@ export abstract class BaseRunner extends Serializable {

abstract tool(input: BeeRunnerToolInput): Promise<{ output: string; success: boolean }>;

abstract get templates(): BeeAgentTemplates;

protected abstract initMemory(input: BeeRunInput): Promise<BaseMemory>;

createSnapshot() {
Expand Down
41 changes: 30 additions & 11 deletions src/agents/bee/runners/default/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@
* limitations under the License.
*/
import { BaseRunner, BeeRunnerLLMInput, BeeRunnerToolInput } from "@/agents/bee/runners/base.js";
import { BeeAgentRunIteration, BeeParserInput, BeeRunInput } from "@/agents/bee/types.js";
import type {
BeeAgentRunIteration,
BeeAgentTemplates,
BeeParserInput,
BeeRunInput,
} from "@/agents/bee/types.js";
import { Retryable } from "@/internals/helpers/retryable.js";
import { AgentError } from "@/agents/base.js";
import {
BeeAssistantPrompt,
BeeSchemaErrorPrompt,
BeeSystemPrompt,
BeeToolErrorPrompt,
Expand Down Expand Up @@ -69,7 +75,7 @@ export class DefaultRunner extends BaseRunner {
await this.memory.add(
BaseMessage.of({
role: Role.ASSISTANT,
text: (this.input.templates?.schemaError ?? BeeSchemaErrorPrompt).render({}),
text: this.templates.schemaError.render({}),
meta: {
[tempMessageKey]: true,
},
Expand Down Expand Up @@ -151,10 +157,9 @@ export class DefaultRunner extends BaseRunner {
}),
);

const template = this.input.templates?.toolNotFoundError ?? BeeToolNotFoundPrompt;
return {
success: false,
output: template.render({
output: this.templates.toolNotFoundError.render({
tools: this.input.tools,
}),
};
Expand Down Expand Up @@ -202,8 +207,7 @@ export class DefaultRunner extends BaseRunner {
});

if (toolOutput.isEmpty()) {
const template = this.input.templates?.toolNoResultError ?? BeeToolNoResultsPrompt;
return { output: template.render({}), success: true };
return { output: this.templates.toolNoResultError.render({}), success: true };
}

return {
Expand All @@ -225,10 +229,9 @@ export class DefaultRunner extends BaseRunner {
if (error instanceof ToolInputValidationError) {
this.failedAttemptsCounter.use(error);

const template = this.input.templates?.toolInputError ?? BeeToolInputErrorPrompt;
return {
success: false,
output: template.render({
output: this.templates.toolInputError.render({
reason: error.toString(),
}),
};
Expand All @@ -237,10 +240,9 @@ export class DefaultRunner extends BaseRunner {
if (error instanceof ToolError) {
this.failedAttemptsCounter.use(error);

const template = this.input.templates?.toolError ?? BeeToolErrorPrompt;
return {
success: false,
output: template.render({
output: this.templates.toolError.render({
reason: error.explain(),
}),
};
Expand Down Expand Up @@ -288,7 +290,7 @@ export class DefaultRunner extends BaseRunner {
message: async () =>
BaseMessage.of({
role: Role.SYSTEM,
text: (this.input.templates?.system ?? BeeSystemPrompt).render({
text: this.templates.system.render({
tools: await self.system.variables.tools(),
instructions: undefined,
}),
Expand Down Expand Up @@ -359,6 +361,23 @@ export class DefaultRunner extends BaseRunner {
return memory;
}

@Cache({ enumerable: false })
get templates(): BeeAgentTemplates {
const customTemplates = this.input.templates ?? {};

return {
system: customTemplates.system ?? BeeSystemPrompt,
assistant: customTemplates.assistant ?? BeeAssistantPrompt,
user: customTemplates.user ?? BeeUserPrompt,
userEmpty: customTemplates.userEmpty ?? BeeUserEmptyPrompt,
toolError: customTemplates.toolError ?? BeeToolErrorPrompt,
toolInputError: customTemplates.toolInputError ?? BeeToolInputErrorPrompt,
toolNoResultError: customTemplates.toolNoResultError ?? BeeToolNoResultsPrompt,
toolNotFoundError: customTemplates.toolNotFoundError ?? BeeToolNotFoundPrompt,
schemaError: customTemplates.schemaError ?? BeeSchemaErrorPrompt,
};
}

protected createParser(tools: AnyTool[]) {
const parserRegex = isEmpty(tools)
? new RegExp(`Thought: .+\\nFinal Answer: [\\s\\S]+`)
Expand Down
34 changes: 20 additions & 14 deletions src/agents/bee/runners/granite/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,28 @@ import type { AnyTool } from "@/tools/base.js";
import { isEmpty } from "remeda";
import { DefaultRunner } from "@/agents/bee/runners/default/runner.js";
import { BaseMemory } from "@/memory/base.js";
import type { BeeParserInput, BeeRunInput, BeeRunOptions } from "@/agents/bee/types.js";
import type {
BeeAgentTemplates,
BeeParserInput,
BeeRunInput,
BeeRunOptions,
} from "@/agents/bee/types.js";
import { BeeAgent, BeeInput } from "@/agents/bee/agent.js";
import type { GetRunContext } from "@/context.js";
import {
GraniteBeeAssistantPrompt,
GraniteBeeSchemaErrorPrompt,
GraniteBeeSystemPrompt,
} from "@/agents/bee/runners/granite/prompts.js";
import { Cache } from "@/cache/decoratorCache.js";

export class GraniteRunner extends DefaultRunner {
static {
this.register();
}

constructor(input: BeeInput, options: BeeRunOptions, run: GetRunContext<BeeAgent>) {
super(
{
...input,
templates: {
...input.templates,
system: input.templates?.system ?? GraniteBeeSystemPrompt,
assistant: input.templates?.assistant ?? GraniteBeeAssistantPrompt,
schemaError: input.templates?.schemaError ?? GraniteBeeSchemaErrorPrompt,
},
},
options,
run,
);
super(input, options, run);

run.emitter.on(
"update",
Expand Down Expand Up @@ -83,6 +77,18 @@ export class GraniteRunner extends DefaultRunner {
return memory;
}

@Cache({ enumerable: false })
get templates(): BeeAgentTemplates {
const customTemplates = this.input.templates ?? {};

return {
...super.templates,
system: customTemplates.system ?? GraniteBeeSystemPrompt,
assistant: customTemplates.assistant ?? GraniteBeeAssistantPrompt,
schemaError: customTemplates.schemaError ?? GraniteBeeSchemaErrorPrompt,
};
}

protected createParser(tools: AnyTool[]) {
const { parser } = super.createParser(tools);

Expand Down

0 comments on commit 9980b5f

Please sign in to comment.