Skip to content

Commit

Permalink
feat(tools)!: make run options always partial with fallback to an emp…
Browse files Browse the repository at this point in the history
…ty object

The "options" parameter in Tool's run method now fallbacks to an empty object to allow overrides.

Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D committed Dec 3, 2024
1 parent 6681c78 commit ff65e0c
Show file tree
Hide file tree
Showing 12 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ _Source: [examples/tools/custom/openLibrary.ts](/examples/tools/custom/openLibra
<!-- eslint-skip -->

```ts
protected async _run(input: ToolInput<this>, options: BaseToolRunOptions | undefined, run: RunContext<this>) {
protected async _run(input: ToolInput<this>, options: Partial<BaseToolRunOptions>, run: RunContext<this>) {
// insert custom code here
// MUST: return an instance of the output type specified in the tool class definition
// MAY: throw an instance of ToolError upon unrecoverable error conditions encountered by the tool
Expand Down
11 changes: 5 additions & 6 deletions examples/tools/custom/extending.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import {
DuckDuckGoSearchTool,
DuckDuckGoSearchToolSearchType as SafeSearchType,
} from "bee-agent-framework/tools/search/duckDuckGoSearch";
import { setProp } from "bee-agent-framework/internals/helpers/object";

const searchTool = new DuckDuckGoSearchTool();

Expand All @@ -13,11 +12,11 @@ const customSearchTool = searchTool.extend(
safeSearch: z.boolean().default(true),
}),
(input, options) => {
setProp(
options,
["search", "safeSearch"],
input.safeSearch ? SafeSearchType.STRICT : SafeSearchType.OFF,
);
if (!options.search) {
options.search = {};
}
options.search.safeSearch = input.safeSearch ? SafeSearchType.STRICT : SafeSearchType.OFF;

return { query: input.query };
},
);
Expand Down
2 changes: 1 addition & 1 deletion src/tools/arxiv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ export class ArXivTool extends Tool<ArXivToolOutput, ToolOptions, ToolRunOptions

protected async _run(
input: ToolInput<this>,
_options: BaseToolRunOptions | undefined,
_options: Partial<BaseToolRunOptions>,
run: RunContext<this>,
) {
const params = this._prepareParams(input);
Expand Down
8 changes: 4 additions & 4 deletions src/tools/base.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ describe("Base Tool", () => {

protected async _run(
{ query }: ToolInput<this>,
options?: BaseToolRunOptions,
options: Partial<BaseToolRunOptions>,
): Promise<StringToolOutput> {
const result = await fn(query, options);
return new StringToolOutput(result);
Expand Down Expand Up @@ -89,7 +89,7 @@ describe("Base Tool", () => {
handler.mockResolvedValue(output);
await expect(tool.run({ query }, runOptions)).resolves.toBeTruthy();
expect(handler).toBeCalledTimes(1);
expect(handler).toBeCalledWith(query, runOptions);
expect(handler).toBeCalledWith(query, runOptions ?? {});
},
);

Expand Down Expand Up @@ -548,11 +548,11 @@ describe("Base Tool", () => {
"event": "tool.dummy.run.start",
},
{
"data": "{"input":{"query":"Hello!"}}",
"data": "{"input":{"query":"Hello!"},"options":{}}",
"event": "tool.dummy.start",
},
{
"data": "{"output":{"result":"Hey!"},"input":{"query":"Hello!"}}",
"data": "{"output":{"result":"Hey!"},"input":{"query":"Hello!"},"options":{}}",
"event": "tool.dummy.success",
},
{
Expand Down
15 changes: 8 additions & 7 deletions src/tools/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,9 @@ export abstract class Tool<
}
}

run(input: ToolInputRaw<this>, options?: TRunOptions) {
run(input: ToolInputRaw<this>, options: Partial<TRunOptions> = {}) {
input = shallowCopy(input);
options = shallowCopy(options);

return RunContext.enter(
this,
Expand Down Expand Up @@ -273,7 +274,7 @@ export abstract class Tool<

protected async _runCached(
input: ToolInput<this>,
options: TRunOptions | undefined,
options: Partial<TRunOptions>,
run: GetRunContext<this>,
): Promise<TOutput> {
const key = ObjectHashKeyFn({
Expand Down Expand Up @@ -303,7 +304,7 @@ export abstract class Tool<

protected abstract _run(
arg: ToolInput<this>,
options: TRunOptions | undefined,
options: Partial<TRunOptions>,
run: GetRunContext<typeof this>,
): Promise<TOutput>;

Expand Down Expand Up @@ -397,7 +398,7 @@ export abstract class Tool<
mapper: (
input: ToolInputRaw<S>,
output: TOutput,
options: TRunOptions | undefined,
options: Partial<TRunOptions>,
run: RunContext<
DynamicTool<TOutput, ZodSchema<ToolInput<S>>, TOptions, TRunOptions, ToolInput<S>>
>,
Expand All @@ -421,7 +422,7 @@ export abstract class Tool<
schema: TS,
mapper: (
input: z.output<TS>,
options: TRunOptions | undefined,
options: Partial<TRunOptions>,
run: RunContext<DynamicTool<TOutput, TS, TOptions, TRunOptions, z.output<TS>>>,
) => ToolInputRaw<S>,
overrides: {
Expand Down Expand Up @@ -471,7 +472,7 @@ export class DynamicTool<
inputSchema: TInputSchema;
handler: (
input: TInput,
options: TRunOptions | undefined,
options: Partial<TRunOptions>,
run: GetRunContext<DynamicTool<TOutput, TInputSchema, TOptions, TRunOptions, TInput>>,
) => Promise<TOutput>;
options?: TOptions;
Expand Down Expand Up @@ -507,7 +508,7 @@ export class DynamicTool<

protected _run(
arg: TInput,
options: TRunOptions | undefined,
options: Partial<TRunOptions>,
run: GetRunContext<DynamicTool<TOutput, TInputSchema, TOptions, TRunOptions, TInput>>,
): Promise<TOutput> {
return this.handler(arg, options, run);
Expand Down
2 changes: 1 addition & 1 deletion src/tools/custom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ export class CustomTool extends Tool<StringToolOutput, CustomToolOptions> {

protected async _run(
input: any,
_options: BaseToolRunOptions | undefined,
_options: Partial<BaseToolRunOptions>,
run: RunContext<typeof this>,
) {
const { response } = await this.client.executeCustomTool(
Expand Down
2 changes: 1 addition & 1 deletion src/tools/database/elasticsearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ export class ElasticSearchTool extends Tool<

protected async _run(
input: ToolInput<this>,
_options: BaseToolRunOptions | undefined,
_options: Partial<BaseToolRunOptions>,
run: RunContext<this>,
): Promise<JSONToolOutput<any>> {
if (input.action === ElasticSearchAction.ListIndices) {
Expand Down
2 changes: 1 addition & 1 deletion src/tools/database/milvus.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ export class MilvusDatabaseTool extends Tool<

protected async _run(
input: ToolInput<this>,
_options: BaseToolRunOptions | undefined,
_options: Partial<BaseToolRunOptions>,
): Promise<JSONToolOutput<any>> {
switch (input.action) {
case MilvusAction.ListCollections: {
Expand Down
2 changes: 1 addition & 1 deletion src/tools/python/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ export class PythonTool extends Tool<PythonToolOutput, PythonToolOptions> {

protected async _run(
input: ToolInput<this>,
_options: BaseToolRunOptions | undefined,
_options: Partial<BaseToolRunOptions>,
run: RunContext<this>,
) {
const inputFiles = await pipe(
Expand Down
2 changes: 1 addition & 1 deletion src/tools/similarity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export class SimilarityTool<TProviderOptions> extends Tool<

protected async _run(
{ query, documents }: ToolInput<this>,
options: SimilarityToolRunOptions<TProviderOptions> | undefined,
options: Partial<SimilarityToolRunOptions<TProviderOptions>>,
run: RunContext<this>,
) {
return pipe(
Expand Down
2 changes: 1 addition & 1 deletion src/tools/weather/openMeteo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ export class OpenMeteoTool extends Tool<OpenMeteoToolOutput, ToolOptions, ToolRu

protected async _run(
{ location, start_date: startDate, end_date: endDate, ...input }: ToolInput<this>,
_options: BaseToolRunOptions | undefined,
_options: Partial<BaseToolRunOptions>,
run: RunContext<this>,
) {
const { apiKey } = this.options;
Expand Down
2 changes: 1 addition & 1 deletion src/tools/web/webCrawler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ export class WebCrawlerTool extends Tool<WebCrawlerToolOutput, WebsiteCrawlerToo

protected async _run(
{ url }: ToolInput<this>,
_options: BaseToolRunOptions | undefined,
_options: Partial<BaseToolRunOptions>,
run: RunContext<this>,
) {
const response = await this.client(url, {
Expand Down

0 comments on commit ff65e0c

Please sign in to comment.