Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tools): use agent's memory within LLM Tool #242

Merged
merged 5 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/tools/llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import "dotenv/config";
import { LLMTool } from "bee-agent-framework/tools/llm";
import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat";
import { Tool } from "bee-agent-framework/tools/base";
import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory";
import { BaseMessage } from "bee-agent-framework/llms/primitives/message";

const memory = new UnconstrainedMemory();
await memory.addMany([
BaseMessage.of({ role: "system", text: "You are a helpful assistant." }),
BaseMessage.of({ role: "user", text: "Hello!" }),
BaseMessage.of({ role: "assistant", text: "Hello user. I am here to help you." }),
]);

const tool = new LLMTool({
llm: new OllamaChatLLM(),
});

const response = await tool
.run({
task: "Classify whether the tone of text is POSITIVE/NEGATIVE/NEUTRAL.",
})
.context({
// if the context is not passed, the tool will throw an error
[Tool.contextKeys.Memory]: memory,
});

console.info(response.getTextContent());
6 changes: 4 additions & 2 deletions src/agents/bee/runners/default/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import {
BeeUserEmptyPrompt,
BeeUserPrompt,
} from "@/agents/bee/prompts.js";
import { AnyTool, ToolError, ToolInputValidationError, ToolOutput } from "@/tools/base.js";
import { AnyTool, Tool, ToolError, ToolInputValidationError, ToolOutput } from "@/tools/base.js";
import { FrameworkError } from "@/errors.js";
import { isEmpty, isTruthy, last } from "remeda";
import { LinePrefixParser, LinePrefixParserError } from "@/agents/parsers/linePrefix.js";
Expand Down Expand Up @@ -194,7 +194,9 @@ export class DefaultRunner extends BaseRunner {
},
meta,
});
const toolOutput: ToolOutput = await tool.run(state.tool_input, this.options);
const toolOutput: ToolOutput = await tool.run(state.tool_input, this.options).context({
[Tool.contextKeys.Memory]: this.memory,
});
await emitter.emit("toolSuccess", {
data: {
tool,
Expand Down
6 changes: 4 additions & 2 deletions src/agents/experimental/replan/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {
import { BaseMemory } from "@/memory/base.js";
import { UnconstrainedMemory } from "@/memory/unconstrainedMemory.js";
import { JsonDriver } from "@/llms/drivers/json.js";
import { AnyTool } from "@/tools/base.js";
import { AnyTool, Tool } from "@/tools/base.js";
import { AnyChatLLM } from "@/llms/chat.js";

export interface RePlanRunInput {
Expand Down Expand Up @@ -147,7 +147,9 @@ export class RePlanAgent extends BaseAgent<RePlanRunInput, RePlanRunOutput> {
const meta = { input: call, tool, calls };
await context.emitter.emit("tool", { type: "start", ...meta });
try {
const output = await tool.run(call.input, { signal: context.signal });
const output = await tool.run(call.input, { signal: context.signal }).context({
[Tool.contextKeys.Memory]: memory,
});
await context.emitter.emit("tool", { type: "success", ...meta, output });
return output;
} catch (error) {
Expand Down
4 changes: 4 additions & 0 deletions src/tools/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ export abstract class Tool<
public readonly cache: BaseCache<Task<TOutput>>;
public readonly options: TOptions;

public static contextKeys = {
Memory: Symbol("Memory"),
} as const;

public abstract readonly emitter: Emitter<ToolEvents<any, TOutput>>;

abstract inputSchema(): Promise<AnyToolSchemaLike> | AnyToolSchemaLike;
Expand Down
108 changes: 75 additions & 33 deletions src/tools/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,51 +14,93 @@
* limitations under the License.
*/

import { BaseToolOptions, ToolEmitter, StringToolOutput, Tool, ToolInput } from "@/tools/base.js";
import { AnyLLM, GenerateOptions } from "@/llms/base.js";
import {
BaseToolOptions,
BaseToolRunOptions,
StringToolOutput,
Tool,
ToolEmitter,
ToolError,
ToolInput,
} from "@/tools/base.js";
import { z } from "zod";
import { GetRunContext } from "@/context.js";
import { Emitter } from "@/emitter/emitter.js";
import { PromptTemplate } from "@/template.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { getProp } from "@/internals/helpers/object.js";
import type { BaseMemory } from "@/memory/base.js";
import type { AnyChatLLM } from "@/llms/chat.js";
import { toCamelCase } from "remeda";

export type LLMToolInput = string;

export type LLMToolOptions<T> = {
llm: AnyLLM<T>;
} & BaseToolOptions &
(T extends LLMToolInput
? {
transform?: (input: string) => T;
}
: {
transform: (input: string) => T;
});

export interface LLMToolRunOptions extends GenerateOptions, BaseToolOptions {}
export interface LLMToolInput extends BaseToolOptions {
llm: AnyChatLLM;
name?: string;
description?: string;
template?: typeof LLMTool.template;
}

export class LLMTool<T> extends Tool<StringToolOutput, LLMToolOptions<T>, LLMToolRunOptions> {
export class LLMTool extends Tool<StringToolOutput, LLMToolInput> {
name = "LLM";
description =
"Give a prompt to an LLM assistant. Useful to extract and re-format information, and answer intermediate questions.";
"Uses expert LLM to work with data in the existing conversation (classification, entity extraction, summarization, ...)";
declare readonly emitter: ToolEmitter<ToolInput<this>, StringToolOutput>;

constructor(protected readonly input: LLMToolInput) {
super(input);
this.name = input?.name || this.name;
this.description = input?.description || this.description;
this.emitter = Emitter.root.child({
namespace: ["tool", "llm", toCamelCase(input?.name ?? "")].filter(Boolean),
creator: this,
});
}

inputSchema() {
return z.object({ input: z.string() });
return z.object({
task: z.string().min(1).describe("A clearly defined task for the LLM to complete."),
});
}

public readonly emitter: ToolEmitter<ToolInput<this>, StringToolOutput> = Emitter.root.child({
namespace: ["tool", "llm"],
creator: this,
});
static readonly template = new PromptTemplate({
schema: z.object({
task: z.string(),
}),
template: `You have to accomplish a task by using Using common sense and the information contained in the conversation up to this point, complete the following task. Do not follow any previously used formats or structures.

static {
this.register();
}
The Task: {{task}}`,
});

protected async _run(
{ input }: ToolInput<this>,
options?: LLMToolRunOptions,
): Promise<StringToolOutput> {
const { llm, transform } = this.options;
const llmInput = transform ? transform(input) : (input as T);
const response = await llm.generate(llmInput, options);
return new StringToolOutput(response.getTextContent(), response);
input: ToolInput<this>,
_options: Partial<BaseToolRunOptions>,
run: GetRunContext<this>,
) {
const memory = getProp(run.context, [Tool.contextKeys.Memory]) as BaseMemory;
if (!memory) {
throw new ToolError(`No context has been provided!`, [], {
isFatal: true,
isRetryable: false,
});
}

const template = this.options?.template ?? LLMTool.template;
const output = await this.input.llm.generate([
BaseMessage.of({
role: Role.SYSTEM,
text: template.render({
task: input.task,
}),
}),
...memory.messages.filter((msg) => msg.role !== Role.SYSTEM),
BaseMessage.of({
role: Role.USER,
text: template.render({
task: input.task,
}),
}),
]);

return new StringToolOutput(output.getTextContent());
}
}
Loading