Skip to content

Commit

Permalink
feat: extend workflows behaviour
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D committed Jan 8, 2025
1 parent a3811d0 commit 7bd09f1
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 46 deletions.
15 changes: 7 additions & 8 deletions examples/flows/multiAgents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,12 @@ workflow.addAgent({
tools: [new WikipediaTool()],
llm: BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"),
});

const llm = BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct");
llm.emitter.on("start", (x) => {
console.info(llm.messagesToPrompt(x.input));
});
workflow.addAgent({
name: "Solver",
instructions:
"Your task is to provide the most useful final answer based on the assistants' responses which all are relevant. Ignore those where assistant do not know.",
tools: [],
llm,
llm: BAMChatLLM.fromPreset("meta-llama/llama-3-1-70b-instruct"),
});

const reader = createConsoleReader();
Expand All @@ -61,8 +56,12 @@ for await (const { prompt } of reader) {
);

const { result } = await workflow.run(memory.messages).observe((emitter) => {
emitter.on("success", (data) => {
reader.write(`-> ${data.step}`, data.response?.update?.finalAnswer ?? "-");
//emitter.on("success", (data) => {
// reader.write(`-> ${data.step}`, data.response?.update?.finalAnswer ?? "-");
//});
emitter.match("*.*", (_, event) => {
console.info(event.path);
// TODO
});
});
await memory.addMany(result.newMessages);
Expand Down
1 change: 1 addition & 0 deletions src/agents/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ export abstract class BaseAgent<
this.emitter.destroy();
}

public abstract set memory(memory: BaseMemory);
public abstract get memory(): BaseMemory;

public get meta(): AgentMeta {
Expand Down
17 changes: 12 additions & 5 deletions src/agents/bee/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { AgentMeta } from "@/agents/types.js";
import { Emitter } from "@/emitter/emitter.js";
import {
BeeAgentExecutionConfig,
BeeAgentTemplates,
BeeCallbacks,
BeeRunInput,
Expand All @@ -42,6 +43,7 @@ export interface BeeInput {
memory: BaseMemory;
meta?: Omit<AgentMeta, "tools">;
templates?: Partial<BeeAgentTemplates>;
execution?: BeeAgentExecutionConfig;
}

export class BeeAgent extends BaseAgent<BeeRunInput, BeeRunOutput, BeeRunOptions> {
Expand Down Expand Up @@ -71,6 +73,10 @@ export class BeeAgent extends BaseAgent<BeeRunInput, BeeRunOutput, BeeRunOptions
this.register();
}

set memory(memory: BaseMemory) {
this.input.memory = memory;
}

get memory() {
return this.input.memory;
}
Expand Down Expand Up @@ -105,11 +111,12 @@ export class BeeAgent extends BaseAgent<BeeRunInput, BeeRunOutput, BeeRunOptions
this.input,
{
...options,
execution: options?.execution ?? {
maxRetriesPerStep: 3,
totalMaxRetries: 20,
maxIterations: 10,
},
execution: this.input.execution ??
options?.execution ?? {
maxRetriesPerStep: 3,
totalMaxRetries: 20,
maxIterations: 10,
},
},
run,
);
Expand Down
4 changes: 4 additions & 0 deletions src/agents/experimental/replan/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,8 @@ export class RePlanAgent extends BaseAgent<RePlanRunInput, RePlanRunOutput> {
get memory() {
return this.input.memory;
}

set memory(memory: BaseMemory) {
this.input.memory = memory;
}
}
4 changes: 4 additions & 0 deletions src/agents/experimental/streamlit/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ export class StreamlitAgent extends BaseAgent<StreamlitRunInput, StreamlitRunOut
};
}

set memory(memory: BaseMemory) {
this.input.memory = memory;
}

public get memory() {
return this.input.memory;
}
Expand Down
8 changes: 7 additions & 1 deletion src/emitter/emitter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export class Emitter<T = Record<keyof any, Callback<unknown>>> extends Serializa
public readonly creator?: object;
public readonly context: object;
public readonly trace?: EventTrace;
protected readonly cleanups: CleanupFn[] = [];

constructor(input: EmitterInput = {}) {
super();
Expand Down Expand Up @@ -93,7 +94,9 @@ export class Emitter<T = Record<keyof any, Callback<unknown>>> extends Serializa
: this.namespace.slice(),
});

child.pipe(this);
const cleanup = child.pipe(this);
this.cleanups.push(cleanup);

return child;
}

Expand All @@ -116,6 +119,8 @@ export class Emitter<T = Record<keyof any, Callback<unknown>>> extends Serializa

destroy() {
this.listeners.clear();
this.cleanups.forEach((child) => child());
this.cleanups.length = 0;
}

reset() {
Expand Down Expand Up @@ -228,6 +233,7 @@ export class Emitter<T = Record<keyof any, Callback<unknown>>> extends Serializa
context: this.context,
trace: this.trace,
listeners: Array.from(this.listeners).map(pick(["raw", "options", "callback"])),
cleanups: this.cleanups,
};
}

Expand Down
100 changes: 69 additions & 31 deletions src/experimental/workflows/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,32 @@
*/

import { BeeAgent } from "@/agents/bee/agent.js";
import { Workflow } from "@/experimental/workflows/workflow.js";
import { Workflow, WorkflowRunOptions } from "@/experimental/workflows/workflow.js";
import { BaseMessage } from "@/llms/primitives/message.js";
import { AnyTool } from "@/tools/base.js";
import { AnyChatLLM } from "@/llms/chat.js";
import { BeeSystemPrompt } from "@/agents/bee/prompts.js";
import { ReadOnlyMemory } from "@/memory/base.js";
import { BaseMemory, ReadOnlyMemory } from "@/memory/base.js";
import { z } from "zod";
import { UnconstrainedMemory } from "@/memory/unconstrainedMemory.js";
import { BaseAgent } from "@/agents/base.js";
import {
BeeAgentExecutionConfig,
BeeRunInput,
BeeRunOptions,
BeeRunOutput,
} from "@/agents/bee/types.js";
import { isFunction, randomString } from "remeda";

type AgentInstance = BaseAgent<BeeRunInput, BeeRunOutput, BeeRunOptions>;
type AgentFactory = (memory: ReadOnlyMemory) => AgentInstance | Promise<AgentInstance>;
interface AgentFactoryInput {
name: string;
llm: AnyChatLLM;
instructions?: string;
tools?: AnyTool[];
execution?: BeeAgentExecutionConfig;
}

export class AgentWorkflow {
protected readonly workflow;
Expand All @@ -42,39 +60,64 @@ export class AgentWorkflow {
});
}

addAgent(agent: { name: string; instructions?: string; tools: AnyTool[]; llm: AnyChatLLM }) {
return this.addRawAgent(agent.name, (memory) => {
return new BeeAgent({
llm: agent.llm,
tools: agent.tools,
memory,
meta: {
name: agent.name,
description: agent.instructions ?? "",
},
templates: {
system: BeeSystemPrompt.fork((config) => ({
...config,
defaults: {
...config.defaults,
instructions: agent.instructions || config.defaults.instructions,
},
})),
},
});
});
run(messages: BaseMessage[], options: WorkflowRunOptions<string> = {}) {
return this.workflow.run(
{
messages,
},
options,
);
}

addAgent(agent: AgentInstance | AgentFactory | AgentFactoryInput) {
if (agent instanceof BaseAgent) {
const clone = agent.clone();
const factory: AgentFactory = (memory) => {
clone.memory = memory;
return clone;
};
return this._add(clone.meta.name, factory);
}

const name = agent.name || `Agent${randomString(4)}`;
return this._add(name, isFunction(agent) ? agent : this._createFactory(agent));
}

delAgent(name: string) {
return this.workflow.delStep(name);
}

addRawAgent(name: string, factory: (memory: ReadOnlyMemory) => BeeAgent) {
protected _createFactory(input: AgentFactoryInput): AgentFactory {
return (memory: BaseMemory) =>
new BeeAgent({
llm: input.llm,
tools: input.tools ?? [],
memory,
meta: {
name: input.name,
description: input.instructions ?? "",
},
execution: input.execution,
...(input.instructions && {
templates: {
system: BeeSystemPrompt.fork((config) => ({
...config,
defaults: {
...config.defaults,
instructions: input.instructions || config.defaults.instructions,
},
})),
},
}),
});
}

protected _add(name: string, factory: AgentFactory) {
this.workflow.addStep(name, async (state, ctx) => {
const memory = new UnconstrainedMemory();
await memory.addMany([...state.messages, ...state.newMessages]);

const agent = factory(memory.asReadOnly());
const agent = await factory(memory.asReadOnly());
const { result } = await agent.run({ prompt: null }, { signal: ctx.signal });

return {
Expand All @@ -92,11 +135,6 @@ export class AgentWorkflow {
},
};
});
}

run(messages: BaseMessage[]) {
return this.workflow.run({
messages,
});
return this;
}
}
2 changes: 1 addition & 1 deletion src/experimental/workflows/workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ export class Workflow<
.parseAsync(run.state)
.catch((err) => {
throw new WorkflowError(
`Flow has ended but it's state does not adhere to the flow's output schema.`,
`Workflow has ended but it's state does not adhere to the workflow's output schema.`,
{ run: shallowCopy(run), errors: [err] },
);
});
Expand Down

0 comments on commit 7bd09f1

Please sign in to comment.