Skip to content

Commit

Permalink
feat(agent): inject metadata to messages
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D committed Sep 11, 2024
1 parent 5726691 commit a10d432
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 8 deletions.
19 changes: 18 additions & 1 deletion src/agents/bee/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import { PromptTemplate } from "@/template.js";
import { BaseMessageMeta } from "@/llms/primitives/message.js";
import { z } from "zod";

export const BeeSystemPrompt = new PromptTemplate({
Expand Down Expand Up @@ -114,8 +115,24 @@ export const BeeAssistantPrompt = new PromptTemplate({
export const BeeUserPrompt = new PromptTemplate({
schema: z.object({
input: z.string(),
meta: z
.object({
createdAt: z.string().datetime().optional(),
})
.optional(),
}),
template: `Question: {{input}}`,
functions: {
formatMeta: function () {
const meta = this.meta as BaseMessageMeta;
if (!meta) {
return "";
}

const parts = [meta.createdAt && `Created At: ${meta.createdAt}`].filter(Boolean).join("\n");
return parts ? `\n\n${parts}` : parts;
},
},
template: `Question: {{input}}{{formatMeta}}`,
});

export const BeeUserEmptyPrompt = new PromptTemplate({
Expand Down
18 changes: 15 additions & 3 deletions src/agents/bee/runner.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ vi.mock("@/memory/tokenMemory.js", async () => {
});

describe("Bee Agent Runner", () => {
beforeEach(() => {
vi.useRealTimers();
});

it("Handles different prompt input source", async () => {
vi.useFakeTimers();
vi.setSystemTime(new Date("2024-09-10T19:51:46.954Z"));

const createMemory = async () => {
const memory = new UnconstrainedMemory();
await memory.addMany([
Expand Down Expand Up @@ -62,13 +69,18 @@ describe("Bee Agent Runner", () => {
const instance = await createInstance(memory, prompt);

const memory2 = await createMemory();
await memory2.add(BaseMessage.of({ role: Role.USER, text: prompt }));
await memory2.add(
BaseMessage.of({ role: Role.USER, text: prompt, meta: { createdAt: new Date() } }),
);
const instance2 = await createInstance(memory2, null);
expect(instance.memory.messages).toEqual(instance2.memory.messages);
});

it.each([
BeeUserPrompt,
BeeUserPrompt.fork((old) => ({
...old,
functions: { ...old.functions, formatMeta: () => "" },
})),
BeeUserPrompt.fork((old) => ({ ...old, template: `{{input}}` })),
BeeUserPrompt.fork((old) => ({ ...old, template: `User: {{input}}` })),
BeeUserPrompt.fork((old) => ({ ...old, template: `` })),
Expand Down Expand Up @@ -114,7 +126,7 @@ describe("Bee Agent Runner", () => {
],
instance.memory.messages.filter((msg) => msg.role === Role.USER),
)) {
expect(template.render({ input: a.text })).toStrictEqual(b.text);
expect(template.render({ input: a.text, meta: undefined })).toStrictEqual(b.text);
}
});
});
24 changes: 22 additions & 2 deletions src/agents/bee/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ export class BeeAgentRunner {
const isEmpty = !message.text.trim();
const text = isEmpty
? (input.templates?.userEmpty ?? BeeUserEmptyPrompt).render({})
: (input.templates?.user ?? BeeUserPrompt).render({ input: message.text });
: (input.templates?.user ?? BeeUserPrompt).render({
input: message.text,
meta: {
...message?.meta,
createdAt: message?.meta?.createdAt?.toISOString?.(),
},
});

return BaseMessage.of({
role: Role.USER,
Expand Down Expand Up @@ -118,12 +124,26 @@ export class BeeAgentRunner {
tool_names: input.tools.map((tool) => tool.name).join(","),
instructions: undefined,
}),
meta: {
createdAt: new Date(),
},
}),
...input.memory.messages.map(transformMessage),
]);

if (prompt !== null || input.memory.isEmpty()) {
await memory.add(transformMessage(BaseMessage.of({ role: Role.USER, text: prompt ?? "" })));
await memory.add(
transformMessage(
BaseMessage.of({
role: Role.USER,
text: prompt ?? "",
meta: {
// TODO: createdAt
createdAt: new Date(),
},
}),
),
);
}

return new BeeAgentRunner(input, options, memory);
Expand Down
9 changes: 7 additions & 2 deletions src/llms/primitives/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,22 @@ export const Role = {

export type RoleType = EnumLowerCaseValue<typeof Role> | string;

export interface BaseMessageMeta {
[key: string]: any;
createdAt?: Date;
}

export interface BaseMessageInput {
role: RoleType;
text: string;
meta?: Record<string, any>;
meta?: BaseMessageMeta;
}

export class BaseMessage extends Serializable {
constructor(
public readonly role: RoleType,
public readonly text: string,
public readonly meta?: Record<string, any>,
public readonly meta?: BaseMessageMeta,
) {
super();
}
Expand Down

0 comments on commit a10d432

Please sign in to comment.