From 33f1db5971d269935878a4a83a484a44371c3cdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Dvo=C5=99=C3=A1k?= Date: Mon, 18 Nov 2024 04:10:52 -0800 Subject: [PATCH] feat(tools): make tools composable (#169) Contributes to #166 --- examples/tools/custom/extending.ts | 32 +++++++++++++ examples/tools/custom/piping.ts | 42 +++++++++++++++++ src/tools/base.ts | 51 +++++++++++++++++++++ src/tools/similarity.ts | 72 ++++++++++++++++-------------- 4 files changed, 163 insertions(+), 34 deletions(-) create mode 100644 examples/tools/custom/extending.ts create mode 100644 examples/tools/custom/piping.ts diff --git a/examples/tools/custom/extending.ts b/examples/tools/custom/extending.ts new file mode 100644 index 00000000..e7e339f3 --- /dev/null +++ b/examples/tools/custom/extending.ts @@ -0,0 +1,32 @@ +import { z } from "zod"; +import { DuckDuckGoSearchTool } from "bee-agent-framework/tools/search/duckDuckGoSearch"; +import { setProp } from "bee-agent-framework/internals/helpers/object"; +import { SafeSearchType } from "duck-duck-scrape/src/util.js"; + +const searchTool = new DuckDuckGoSearchTool(); + +const customSearchTool = searchTool.extend( + z.object({ + query: z.string(), + safeSearch: z.boolean().default(true), + }), + (input, options) => { + setProp( + options, + ["search", "safeSearch"], + input.safeSearch ? SafeSearchType.STRICT : SafeSearchType.OFF, + ); + return { query: input.query }; + }, +); + +const response = await customSearchTool.run( + { + query: "News in the world!", + safeSearch: true, + }, + { + signal: AbortSignal.timeout(10_000), + }, +); +console.info(response); diff --git a/examples/tools/custom/piping.ts b/examples/tools/custom/piping.ts new file mode 100644 index 00000000..f4b7dbac --- /dev/null +++ b/examples/tools/custom/piping.ts @@ -0,0 +1,42 @@ +import { WikipediaTool } from "bee-agent-framework/tools/search/wikipedia"; +import { SimilarityTool } from "bee-agent-framework/tools/similarity"; +import { splitString } from "bee-agent-framework/internals/helpers/string"; +import { z } from "zod"; + +const wikipedia = new WikipediaTool(); +const similarity = new SimilarityTool({ + maxResults: 5, + provider: async (input) => + input.documents.map((document) => ({ + score: document.text + .toLowerCase() + .split(" ") + .reduce((acc, word) => acc + (input.query.toLowerCase().includes(word) ? 1 : 0), 0), + })), +}); + +const wikipediaWithSimilarity = wikipedia + .extend( + z.object({ + page: z.string().describe("Wikipedia page"), + query: z.string().describe("Search query"), + }), + (newInput) => ({ query: newInput.page }), + ) + .pipe(similarity, (input, output) => ({ + query: input.query, + documents: output.results.flatMap((document) => + Array.from(splitString(document.fields.markdown as string, { size: 1000, overlap: 50 })).map( + (chunk) => ({ + text: chunk, + source: document, + }), + ), + ), + })); + +const response = await wikipediaWithSimilarity.run({ + page: "JavaScript", + query: "engine", +}); +console.info(response); diff --git a/src/tools/base.ts b/src/tools/base.ts index ee6b1722..158a33fc 100644 --- a/src/tools/base.ts +++ b/src/tools/base.ts @@ -152,6 +152,7 @@ export interface ToolSnapshot; } +export type InferToolOutput = T extends Tool ? A : never; export type ToolInput = FromSchemaLike>>; export type ToolInputRaw = FromSchemaLikeRaw< Awaited> @@ -383,6 +384,56 @@ export abstract class Tool< loadSnapshot(snapshot: ToolSnapshot): void { Object.assign(this, snapshot); } + + pipe( + this: S, + tool: T, + mapper: ( + input: ToolInputRaw, + output: TOutput, + options: TRunOptions | undefined, + run: RunContext< + DynamicTool>, TOptions, TRunOptions, ToolInput> + >, + ) => ToolInputRaw, + ) { + return new DynamicTool>, TOptions, TRunOptions, ToolInput>({ + name: this.name, + description: this.description, + options: this.options, + inputSchema: this.inputSchema() as ZodSchema>, + handler: async (input: ToolInputRaw, options, run): Promise => { + const selfOutput = await this.run(input, options); + const wrappedInput = mapper(input, selfOutput, options, run); + return await tool.run(wrappedInput); + }, + } as const); + } + + extend( + this: S, + schema: TS, + mapper: ( + input: z.output, + options: TRunOptions | undefined, + run: RunContext>>, + ) => ToolInputRaw, + overrides: { + name?: string; + description?: string; + } = {}, + ) { + return new DynamicTool>({ + name: overrides?.name || this.name, + description: overrides?.name || this.description, + options: shallowCopy(this.options), + inputSchema: schema, + handler: async (input: ToolInputRaw, options, run): Promise => { + const wrappedInput = mapper(input, options, run); + return await this.run(wrappedInput, options); + }, + } as const); + } } export type AnyTool = Tool; diff --git a/src/tools/similarity.ts b/src/tools/similarity.ts index 7966c403..e6f50621 100644 --- a/src/tools/similarity.ts +++ b/src/tools/similarity.ts @@ -16,17 +16,26 @@ import { BaseToolOptions, BaseToolRunOptions, JSONToolOutput, Tool, ToolInput } from "./base.js"; import { string, z } from "zod"; -import * as R from "remeda"; +import { RunContext } from "@/context.js"; +import { map, pipe, prop, sortBy, take } from "remeda"; const documentSchema = z.object({ text: string() }).passthrough(); type Document = z.infer; +interface ProviderInput { + query: string; + documents: Document[]; +} + +type Provider = ( + input: ProviderInput, + options: TProviderOptions | undefined, + run: RunContext>, +) => Promise<{ score: number }[]>; + export interface SimilarityToolOptions extends BaseToolOptions { - provider: ( - input: { query: string; documents: Document[] }, - options?: TProviderOptions, - ) => Promise<{ score: number }[]>; + provider: Provider; maxResults?: number; } @@ -60,38 +69,33 @@ export class SimilarityTool extends Tool< } protected async _run( - input: ToolInput, - options?: SimilarityToolRunOptions, + { query, documents }: ToolInput, + options: SimilarityToolRunOptions | undefined, + run: RunContext, ) { - const { query, documents } = input; - - const results = await this.options.provider( - { - query, - documents, - }, - options?.provider, - ); - - const resultsWithDocumentIndices = results.map(({ score }, idx) => ({ - documentIndex: idx, - score, - })); - const sortedResultsWithDocumentIndices = R.sortBy(resultsWithDocumentIndices, [ - ({ score }) => score, - "desc", - ]); - const filteredResultsWithDocumentIndices = sortedResultsWithDocumentIndices.slice( - 0, - options?.maxResults ?? this.options.maxResults, - ); - - return new SimilarityToolOutput( - filteredResultsWithDocumentIndices.map(({ documentIndex, score }) => ({ - document: documents[documentIndex], - index: documentIndex, + return pipe( + await this.options.provider( + { + query, + documents, + }, + options?.provider, + run, + ), + map(({ score }, idx) => ({ + documentIndex: idx, score, })), + sortBy([prop("score"), "desc"]), + take(options?.maxResults ?? this.options.maxResults ?? Infinity), + (data) => + new SimilarityToolOutput( + data.map(({ documentIndex, score }) => ({ + document: documents[documentIndex], + index: documentIndex, + score, + })), + ), ); } }