Skip to content

Commit

Permalink
feat(tools): make tools composable (#169)
Browse files Browse the repository at this point in the history
Contributes to #166
  • Loading branch information
Tomas2D authored Nov 18, 2024
1 parent b333594 commit 33f1db5
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 34 deletions.
32 changes: 32 additions & 0 deletions examples/tools/custom/extending.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { z } from "zod";
import { DuckDuckGoSearchTool } from "bee-agent-framework/tools/search/duckDuckGoSearch";
import { setProp } from "bee-agent-framework/internals/helpers/object";
import { SafeSearchType } from "duck-duck-scrape/src/util.js";

const searchTool = new DuckDuckGoSearchTool();

const customSearchTool = searchTool.extend(
z.object({
query: z.string(),
safeSearch: z.boolean().default(true),
}),
(input, options) => {
setProp(
options,
["search", "safeSearch"],
input.safeSearch ? SafeSearchType.STRICT : SafeSearchType.OFF,
);
return { query: input.query };
},
);

const response = await customSearchTool.run(
{
query: "News in the world!",
safeSearch: true,
},
{
signal: AbortSignal.timeout(10_000),
},
);
console.info(response);
42 changes: 42 additions & 0 deletions examples/tools/custom/piping.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { WikipediaTool } from "bee-agent-framework/tools/search/wikipedia";
import { SimilarityTool } from "bee-agent-framework/tools/similarity";
import { splitString } from "bee-agent-framework/internals/helpers/string";
import { z } from "zod";

const wikipedia = new WikipediaTool();
const similarity = new SimilarityTool({
maxResults: 5,
provider: async (input) =>
input.documents.map((document) => ({
score: document.text
.toLowerCase()
.split(" ")
.reduce((acc, word) => acc + (input.query.toLowerCase().includes(word) ? 1 : 0), 0),
})),
});

const wikipediaWithSimilarity = wikipedia
.extend(
z.object({
page: z.string().describe("Wikipedia page"),
query: z.string().describe("Search query"),
}),
(newInput) => ({ query: newInput.page }),
)
.pipe(similarity, (input, output) => ({
query: input.query,
documents: output.results.flatMap((document) =>
Array.from(splitString(document.fields.markdown as string, { size: 1000, overlap: 50 })).map(
(chunk) => ({
text: chunk,
source: document,
}),
),
),
}));

const response = await wikipediaWithSimilarity.run({
page: "JavaScript",
query: "engine",
});
console.info(response);
51 changes: 51 additions & 0 deletions src/tools/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ export interface ToolSnapshot<TOutput extends ToolOutput, TOptions extends BaseT
emitter: Emitter<any>;
}

export type InferToolOutput<T extends AnyTool> = T extends Tool<infer A, any, any> ? A : never;
export type ToolInput<T extends AnyTool> = FromSchemaLike<Awaited<ReturnType<T["inputSchema"]>>>;
export type ToolInputRaw<T extends AnyTool> = FromSchemaLikeRaw<
Awaited<ReturnType<T["inputSchema"]>>
Expand Down Expand Up @@ -383,6 +384,56 @@ export abstract class Tool<
loadSnapshot(snapshot: ToolSnapshot<TOutput, TOptions>): void {
Object.assign(this, snapshot);
}

pipe<S extends AnyTool, T extends AnyTool>(
this: S,
tool: T,
mapper: (
input: ToolInputRaw<S>,
output: TOutput,
options: TRunOptions | undefined,
run: RunContext<
DynamicTool<TOutput, ZodSchema<ToolInput<S>>, TOptions, TRunOptions, ToolInput<S>>
>,
) => ToolInputRaw<typeof tool>,
) {
return new DynamicTool<TOutput, ZodSchema<ToolInput<S>>, TOptions, TRunOptions, ToolInput<S>>({
name: this.name,
description: this.description,
options: this.options,
inputSchema: this.inputSchema() as ZodSchema<ToolInput<S>>,
handler: async (input: ToolInputRaw<S>, options, run): Promise<TOutput> => {
const selfOutput = await this.run(input, options);
const wrappedInput = mapper(input, selfOutput, options, run);
return await tool.run(wrappedInput);
},
} as const);
}

extend<S extends AnyTool, TS extends ZodSchema>(
this: S,
schema: TS,
mapper: (
input: z.output<TS>,
options: TRunOptions | undefined,
run: RunContext<DynamicTool<TOutput, TS, TOptions, TRunOptions, z.output<TS>>>,
) => ToolInputRaw<S>,
overrides: {
name?: string;
description?: string;
} = {},
) {
return new DynamicTool<TOutput, TS, TOptions, TRunOptions, z.output<TS>>({
name: overrides?.name || this.name,
description: overrides?.name || this.description,
options: shallowCopy(this.options),
inputSchema: schema,
handler: async (input: ToolInputRaw<S>, options, run): Promise<TOutput> => {
const wrappedInput = mapper(input, options, run);
return await this.run(wrappedInput, options);
},
} as const);
}
}

export type AnyTool = Tool<any, any, any>;
Expand Down
72 changes: 38 additions & 34 deletions src/tools/similarity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,26 @@

import { BaseToolOptions, BaseToolRunOptions, JSONToolOutput, Tool, ToolInput } from "./base.js";
import { string, z } from "zod";
import * as R from "remeda";
import { RunContext } from "@/context.js";
import { map, pipe, prop, sortBy, take } from "remeda";

const documentSchema = z.object({ text: string() }).passthrough();

type Document = z.infer<typeof documentSchema>;

interface ProviderInput {
query: string;
documents: Document[];
}

type Provider<TProviderOptions> = (
input: ProviderInput,
options: TProviderOptions | undefined,
run: RunContext<SimilarityTool<TProviderOptions>>,
) => Promise<{ score: number }[]>;

export interface SimilarityToolOptions<TProviderOptions = unknown> extends BaseToolOptions {
provider: (
input: { query: string; documents: Document[] },
options?: TProviderOptions,
) => Promise<{ score: number }[]>;
provider: Provider<TProviderOptions>;
maxResults?: number;
}

Expand Down Expand Up @@ -60,38 +69,33 @@ export class SimilarityTool<TProviderOptions> extends Tool<
}

protected async _run(
input: ToolInput<this>,
options?: SimilarityToolRunOptions<TProviderOptions>,
{ query, documents }: ToolInput<this>,
options: SimilarityToolRunOptions<TProviderOptions> | undefined,
run: RunContext<this>,
) {
const { query, documents } = input;

const results = await this.options.provider(
{
query,
documents,
},
options?.provider,
);

const resultsWithDocumentIndices = results.map(({ score }, idx) => ({
documentIndex: idx,
score,
}));
const sortedResultsWithDocumentIndices = R.sortBy(resultsWithDocumentIndices, [
({ score }) => score,
"desc",
]);
const filteredResultsWithDocumentIndices = sortedResultsWithDocumentIndices.slice(
0,
options?.maxResults ?? this.options.maxResults,
);

return new SimilarityToolOutput(
filteredResultsWithDocumentIndices.map(({ documentIndex, score }) => ({
document: documents[documentIndex],
index: documentIndex,
return pipe(
await this.options.provider(
{
query,
documents,
},
options?.provider,
run,
),
map(({ score }, idx) => ({
documentIndex: idx,
score,
})),
sortBy([prop("score"), "desc"]),
take(options?.maxResults ?? this.options.maxResults ?? Infinity),
(data) =>
new SimilarityToolOutput(
data.map(({ documentIndex, score }) => ({
document: documents[documentIndex],
index: documentIndex,
score,
})),
),
);
}
}

0 comments on commit 33f1db5

Please sign in to comment.