Skip to content

Commit

Permalink
feat(tool)!: remove pagination functionality from GoogleSearch and Du…
Browse files Browse the repository at this point in the history
…ckDuckGo (#152)

Ref: #151

Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D authored Nov 7, 2024
1 parent 34156b7 commit 59424de
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 49 deletions.
2 changes: 1 addition & 1 deletion examples/agents/bee_advanced.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Use one of the following tools: {{#trim}}{{#tools}}{{name}},{{/tools}}{{/trim}}
},
tools: [
new DuckDuckGoSearchTool({
maxResultsPerPage: 10,
maxResults: 10,
search: {
safeSearch: DuckDuckGoSearchToolSearchType.STRICT,
},
Expand Down
60 changes: 60 additions & 0 deletions src/internals/helpers/paginate.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { paginate, PaginateInput } from "@/internals/helpers/paginate.js";

describe("paginate", () => {
it.each([
{
size: 1,
chunkSize: 1,
items: Array(100).fill(1),
},
{
size: 10,
chunkSize: 1,
items: [],
},
{
size: 11,
chunkSize: 10,
items: Array(100).fill(1),
},
{
size: 25,
chunkSize: 1,
items: Array(20).fill(1),
},
])("Works %#", async ({ size, items, chunkSize }) => {
const fn: PaginateInput<number>["handler"] = vi.fn().mockImplementation(async ({ offset }) => {
const chunk = items.slice(offset, offset + chunkSize);
return { done: offset + chunk.length >= items.length, data: chunk };
});

const results = await paginate({
size,
handler: fn,
});

const maxItemsToBeRetrieved = Math.min(size, items.length);
let expectedCalls = Math.ceil(maxItemsToBeRetrieved / chunkSize);
if (expectedCalls === 0 && size > 0) {
expectedCalls = 1;
}
expect(fn).toBeCalledTimes(expectedCalls);
expect(results).toHaveLength(maxItemsToBeRetrieved);
});
});
42 changes: 42 additions & 0 deletions src/internals/helpers/paginate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

export interface PaginateInput<T> {
size: number;
handler: (data: { offset: number; limit: number }) => Promise<{ data: T[]; done: boolean }>;
}

export async function paginate<T>(input: PaginateInput<T>): Promise<T[]> {
const acc: T[] = [];

while (acc.length < input.size) {
const { data, done } = await input.handler({
offset: acc.length,
limit: input.size - acc.length,
});
acc.push(...data);

if (done || data.length === 0) {
break;
}
}

if (acc.length > input.size) {
acc.length = input.size;
}

return acc;
}
8 changes: 4 additions & 4 deletions src/tools/search/duckDuckGoSearch.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ describe("DuckDuckGoSearch Tool", () => {
options: DuckDuckGoSearchToolOptions;
}
it.each([
{ query: "LLM", options: { maxResultsPerPage: 1 } },
{ query: "LLM", options: { maxResults: 1 } },
{ query: "IBM Research" },
{ query: "NLP", options: { maxResultsPerPage: 3 } },
{ query: "NLP", options: { maxResults: 3 } },
] as RetrieveDataInput[])("Retrieves data (%o)", async (input) => {
const globalMaxResults = 10;
const maxResultsPerPage = (input as any).options?.maxResultsPerPage ?? globalMaxResults;

const tool = new DuckDuckGoSearchTool({
maxResultsPerPage: globalMaxResults,
maxResults: globalMaxResults,
cache: false,
throttle: false,
});
Expand All @@ -97,7 +97,7 @@ describe("DuckDuckGoSearch Tool", () => {
size: 10,
ttl: 1000,
}),
maxResultsPerPage: 1,
maxResults: 1,
});

await tool.cache!.set(
Expand Down
51 changes: 31 additions & 20 deletions src/tools/search/duckDuckGoSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ import { HeaderGenerator } from "header-generator";
import type { NeedleOptions } from "needle";
import { z } from "zod";
import { Cache } from "@/cache/decoratorCache.js";
import { RunContext } from "@/context.js";
import { paginate } from "@/internals/helpers/paginate.js";

export { SafeSearchType as DuckDuckGoSearchToolSearchType };

export interface DuckDuckGoSearchToolOptions extends SearchToolOptions {
search?: SearchOptions;
throttle?: ThrottleOptions | false;
httpClientOptions?: NeedleOptions;
maxResultsPerPage: number;
maxResults: number;
}

export interface DuckDuckGoSearchToolRunOptions extends SearchToolRunOptions {
Expand Down Expand Up @@ -84,7 +86,7 @@ export class DuckDuckGoSearchTool extends Tool<
}

public constructor(options: Partial<DuckDuckGoSearchToolOptions> = {}) {
super({ ...options, maxResultsPerPage: options?.maxResultsPerPage ?? 15 });
super({ ...options, maxResults: options?.maxResults ?? 15 });

this.client = this._createClient();
}
Expand All @@ -107,28 +109,37 @@ export class DuckDuckGoSearchTool extends Tool<

protected async _run(
{ query: input }: ToolInput<this>,
options?: DuckDuckGoSearchToolRunOptions,
options: DuckDuckGoSearchToolRunOptions | undefined,
run: RunContext<this>,
) {
const headers = new HeaderGenerator().getHeaders();

const { results } = await this.client(
input,
{
safeSearch: SafeSearchType.MODERATE,
...this.options.search,
...options?.search,
const results = await paginate({
size: this.options.maxResults,
handler: async ({ offset }) => {
const { results: data, noResults: done } = await this.client(
input,
{
safeSearch: SafeSearchType.MODERATE,
...this.options.search,
...options?.search,
offset,
},
{
headers,
user_agent: headers["user-agent"],
...this.options?.httpClientOptions,
...options?.httpClientOptions,
signal: run.signal,
},
);

return {
data,
done,
};
},
{
headers,
user_agent: headers["user-agent"],
...this.options?.httpClientOptions,
...options?.httpClientOptions,
},
);

if (results.length > this.options.maxResultsPerPage) {
results.length = this.options.maxResultsPerPage;
}
});

return new DuckDuckGoSearchToolOutput(
results.map((result) => ({
Expand Down
11 changes: 6 additions & 5 deletions src/tools/search/googleSearch.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ describe("GoogleCustomSearch Tool", () => {
googleSearchTool = new GoogleSearchTool({
apiKey: "test-api-key",
cseId: "test-cse-id",
maxResultsPerPage: 10,
maxResults: 10,
});

Object.defineProperty(googleSearchTool, "client", {
Expand Down Expand Up @@ -77,6 +77,7 @@ describe("GoogleCustomSearch Tool", () => {
{
cx: "test-cse-id",
q: query,
start: 0,
num: 10,
safe: "active",
},
Expand All @@ -86,21 +87,21 @@ describe("GoogleCustomSearch Tool", () => {
);
});

it("validates maxResultsPerPage range", () => {
it("validates maxResults range", () => {
expect(
() =>
new GoogleSearchTool({
apiKey: "test-api-key",
cseId: "test-cse-id",
maxResultsPerPage: 0,
maxResults: 0,
}),
).toThrowError("validation failed");
expect(
() =>
new GoogleSearchTool({
apiKey: "test-api-key",
cseId: "test-cse-id",
maxResultsPerPage: 11,
maxResults: 111,
}),
).toThrowError("validation failed");
});
Expand All @@ -109,7 +110,7 @@ describe("GoogleCustomSearch Tool", () => {
const tool = new GoogleSearchTool({
apiKey: "test-api-key",
cseId: "test-cse-id",
maxResultsPerPage: 1,
maxResults: 1,
cache: new SlidingCache({
size: 10,
ttl: 1000,
Expand Down
50 changes: 32 additions & 18 deletions src/tools/search/googleSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ import { Tool, ToolInput } from "@/tools/base.js";
import { z } from "zod";
import { Cache } from "@/cache/decoratorCache.js";
import { ValueError } from "@/errors.js";
import { ValidationError } from "ajv";
import { parseEnv } from "@/internals/env.js";
import { RunContext } from "@/context.js";
import { paginate } from "@/internals/helpers/paginate.js";
import { ValidationError } from "ajv";

export interface GoogleSearchToolOptions extends SearchToolOptions {
apiKey?: string;
cseId?: string;
maxResultsPerPage: number;
maxResults: number;
}

type GoogleSearchToolRunOptions = SearchToolRunOptions;
Expand Down Expand Up @@ -77,7 +78,7 @@ export class GoogleSearchTool extends Tool<
protected apiKey: string;
protected cseId: string;

public constructor(options: GoogleSearchToolOptions = { maxResultsPerPage: 10 }) {
public constructor(options: GoogleSearchToolOptions = { maxResults: 10 }) {
super(options);

this.apiKey = options.apiKey || parseEnv("GOOGLE_API_KEY", z.string());
Expand All @@ -92,11 +93,11 @@ export class GoogleSearchTool extends Tool<
);
}

if (options.maxResultsPerPage < 1 || options.maxResultsPerPage > 10) {
if (options.maxResults < 1 || options.maxResults > 100) {
throw new ValidationError([
{
message: "Property range must be between 1 and 10",
propertyName: "options.maxResultsPerPage",
message: "Property 'maxResults' must be between 1 and 100",
propertyName: "options.maxResults",
},
]);
}
Expand All @@ -118,19 +119,32 @@ export class GoogleSearchTool extends Tool<
_options: GoogleSearchToolRunOptions | undefined,
run: RunContext<this>,
) {
const response = await this.client.cse.list(
{
cx: this.cseId,
q: input,
num: this.options.maxResultsPerPage,
safe: "active",
},
{
signal: run.signal,
const results = await paginate({
size: this.options.maxResults,
handler: async ({ offset, limit }) => {
const maxChunkSize = 10;

const {
data: { items = [] },
} = await this.client.cse.list(
{
cx: this.cseId,
q: input,
start: offset,
num: Math.min(limit, maxChunkSize),
safe: "active",
},
{
signal: run.signal,
},
);

return {
data: items,
done: items.length < maxChunkSize,
};
},
);

const results = response.data.items || [];
});

return new GoogleSearchToolOutput(
results.map((result) => ({
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/agents/bee.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ describe("Bee Agent", () => {
tools: [
new DuckDuckGoSearchTool({
cache: new UnconstrainedCache(),
maxResultsPerPage: 10,
maxResults: 10,
throttle: {
interval: 5000,
limit: 1,
Expand Down
26 changes: 26 additions & 0 deletions tests/e2e/tools/duckDuckGoSearch.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { expect } from "vitest";
import { DuckDuckGoSearchTool } from "@/tools/search/duckDuckGoSearch.js";

describe("DuckDuckGo", () => {
it("Retrieves data", async () => {
const instance = new DuckDuckGoSearchTool();
const response = await instance.run({ query: "Bee Agent Framework" });
expect(response.results.length).toBeGreaterThan(0);
});
});

0 comments on commit 59424de

Please sign in to comment.