-
Notifications
You must be signed in to change notification settings - Fork 145
/
runner.ts
121 lines (112 loc) · 3.73 KB
/
runner.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import type { AnyTool } from "@/tools/base.js";
import { isEmpty } from "remeda";
import { DefaultRunner } from "@/agents/bee/runners/default/runner.js";
import { BaseMemory } from "@/memory/base.js";
import type {
BeeAgentTemplates,
BeeParserInput,
BeeRunInput,
BeeRunOptions,
} from "@/agents/bee/types.js";
import { BeeAgent, BeeInput } from "@/agents/bee/agent.js";
import type { GetRunContext } from "@/context.js";
import {
GraniteBeeAssistantPrompt,
GraniteBeeSchemaErrorPrompt,
GraniteBeeSystemPrompt,
} from "@/agents/bee/runners/granite/prompts.js";
export class GraniteRunner extends DefaultRunner {
static {
this.register();
}
constructor(input: BeeInput, options: BeeRunOptions, run: GetRunContext<BeeAgent>) {
super(
{
...input,
templates: {
...input.templates,
system: input.templates?.system ?? GraniteBeeSystemPrompt,
assistant: input.templates?.assistant ?? GraniteBeeAssistantPrompt,
schemaError: input.templates?.schemaError ?? GraniteBeeSchemaErrorPrompt,
},
},
options,
run,
);
run.emitter.on(
"update",
async ({ update, meta, memory }) => {
if (update.key === "tool_output") {
await memory.add(
BaseMessage.of({
role: "tool_response",
text: update.value,
meta: { success: meta.success },
}),
);
}
},
{
isBlocking: true,
},
);
}
protected async initMemory(input: BeeRunInput): Promise<BaseMemory> {
const memory = await super.initMemory(input);
if (!isEmpty(this.input.tools)) {
const index = memory.messages.findIndex((msg) => msg.role === Role.SYSTEM) + 1;
await memory.add(
BaseMessage.of({
role: "available_tools",
text: JSON.stringify(await this.renderers.system.variables.tools(), null, 4),
}),
index,
);
}
return memory;
}
get templates(): BeeAgentTemplates {
return {
...super.templates,
system: this.input.templates?.system ?? GraniteBeeSystemPrompt,
assistant: this.input.templates?.assistant ?? GraniteBeeAssistantPrompt,
schemaError: this.input.templates?.schemaError ?? GraniteBeeSchemaErrorPrompt,
};
}
protected createParser(tools: AnyTool[]) {
const { parser } = super.createParser(tools);
return {
parserRegex: isEmpty(tools)
? new RegExp(`Thought: .+\\nFinal Answer: [\\s\\S]+`)
: new RegExp(
`Thought: .+\\n(?:Final Answer: [\\s\\S]+|Tool Name: (${tools.map((tool) => tool.name).join("|")})\\nTool Input: \\{.*\\})`,
),
parser: parser.fork<BeeParserInput>((nodes, options) => ({
options,
nodes: {
...nodes,
thought: { ...nodes.thought, prefix: "Thought:" },
tool_name: { ...nodes.tool_name, prefix: "Tool Name:" },
tool_input: { ...nodes.tool_input, prefix: "Tool Input:", isEnd: true, next: [] },
final_answer: { ...nodes.final_answer, prefix: "Final Answer:" },
},
})),
};
}
}