Skip to content

Commit

Permalink
feat(tool): improve wikipedia results filtering (#143)
Browse files Browse the repository at this point in the history
Ref: #142

Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D authored Nov 5, 2024
1 parent ea84808 commit 2529a0c
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 26 deletions.
47 changes: 41 additions & 6 deletions src/tools/search/wikipedia.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,48 @@
* limitations under the License.
*/

import { wikiSearchResult } from "wikipedia";

vitest.mock("wikipedia", () => {
const pages = [
{
title: "Tomáš Dvořák (ice hockey)",
pageid: "1",
},
{
title: "Dvorak",
pageid: "2",
},
{
title: "Tomáš",
pageid: "3",
},
{
title: "List of Czech sportspeople (section Ice hockey)",
pageid: "4",
},
];

return {
default: {
default: {
setLang(lang: string) {
return lang;
},
async search(input: string) {
async search() {
return {
results: [{ title: input }],
suggestion: [],
};
results: pages,
suggestion: "",
} as wikiSearchResult;
},
async page(title: string) {
async page(titleOrId: string | number) {
const page = pages.find((page) => page.title === titleOrId || page.pageid === titleOrId);
if (!page) {
throw new Error("No page found.");
}

return {
title,
...page,
content: async () => "Content",
infobox: async () => ({ text: "Infobox" }),
};
Expand All @@ -43,6 +69,15 @@ import { verifyDeserialization } from "@tests/e2e/utils.js";
import { WikipediaTool } from "@/tools/search/wikipedia.js";

describe("Wikipedia", () => {
it("Retrieves a correct page", async () => {
const instance = new WikipediaTool();
const response = await instance.run({
query: "tomas dvorak ice hockey",
});
expect(response.results.length).toBe(1);
expect(response.results[0].title).toBe("Tomáš Dvořák (ice hockey)");
});

it("Serializes", async () => {
const instance = new WikipediaTool({
extraction: {
Expand Down
71 changes: 51 additions & 20 deletions src/tools/search/wikipedia.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import wiki from "wikipedia";
import { Cache } from "@/cache/decoratorCache.js";
import stringComparison from "string-comparison";
import * as R from "remeda";
import type { Page, pageFunctions, searchOptions } from "wikipedia";
import { ArrayKeys, Common } from "@/internals/types.js";
Expand All @@ -33,6 +32,8 @@ import Turndown from "turndown";
// @ts-expect-error missing types
import turndownPlugin from "joplin-turndown-plugin-gfm";
import { keys, mapValues } from "remeda";
import stringComparison from "string-comparison";
import { pageResult } from "wikipedia/dist/resultTypes.js";

wiki.default.setLang("en");

Expand Down Expand Up @@ -120,6 +121,11 @@ export class WikipediaToolOutput extends SearchToolOutput<WikipediaToolResult> {
}
}

interface SearchResponse {
results: Pick<pageResult, "title" | "pageid">[];
suggestion: string;
}

export class WikipediaTool extends Tool<
WikipediaToolOutput,
WikipediaToolOptions,
Expand Down Expand Up @@ -326,33 +332,27 @@ export class WikipediaTool extends Tool<
): Promise<WikipediaToolOutput> {
const runOptions = this._createRunOptions(_options);

const { results: searchRawResults, suggestion } = await wiki.default.search(input, {
suggestion: Boolean(_options?.search?.suggestion),
...runOptions.search,
});
const { results: searchRawResults, suggestion }: SearchResponse = await wiki.default.search(
input,
{
suggestion: Boolean(_options?.search?.suggestion),
...runOptions.search,
},
);

if (searchRawResults.length === 0 && suggestion && runOptions.search?.suggestion) {
return await this._run({ query: suggestion }, _options);
}

const bestCandidates = stringComparison.jaccardIndex
.sortMatch(
input,
searchRawResults.map((result) => result.title),
)
.map((result) => ({
pageId: searchRawResults[result.index].pageid,
score: result.rating,
}))
.filter((result) => result.score >= (runOptions.filters?.minPageNameSimilarity ?? 0))
.sort((a, b) => b.score - a.score);

if (bestCandidates.at(0)?.score === 1 && runOptions.filters?.excludeOthersOnExactMatch) {
bestCandidates.length = 1;
}
const bestCandidates = this.extractBestCandidates(
input,
searchRawResults,
runOptions?.filters ?? {},
);

const results = await Promise.all(
bestCandidates.map(async ({ pageId }) => {
// @ts-expect-error wrong library's typing, passing a string would lead to a classic text search instead of a concrete page retrieval
const page = await wiki.default.page(pageId, {
redirect: true,
preload: false,
Expand Down Expand Up @@ -381,6 +381,37 @@ export class WikipediaTool extends Tool<
return new WikipediaToolOutput(results, runOptions.output?.maxSerializedLength ?? Infinity);
}

protected extractBestCandidates(
query: string,
candidates: SearchResponse["results"],
options: FilterOptions,
) {
const normalize = (text: string) =>
text
.normalize("NFKD")
.replace(/[^\w| ]/g, "") // remove diacritics and special characters (except whitespace)
.replace(/\s\s+/g, " ") // collapse multiple whitespaces into one
.trim();

const bestCandidates = stringComparison.jaccardIndex
.sortMatch(
normalize(query),
candidates.map((candidate) => normalize(candidate.title)),
)
.map((result) => ({
pageId: candidates[result.index].pageid,
score: result.rating,
}))
.filter((result) => result.score >= (options.minPageNameSimilarity ?? 0))
.sort((a, b) => b.score - a.score);

if (bestCandidates.at(0)?.score === 1 && options.excludeOthersOnExactMatch) {
bestCandidates.length = 1;
}

return bestCandidates;
}

createSnapshot() {
return {
...super.createSnapshot(),
Expand Down

0 comments on commit 2529a0c

Please sign in to comment.