Skip to content

Commit

Permalink
feat(tools): update Wikipedia tool, remove links, extend interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D committed Sep 27, 2024
1 parent 0392ff0 commit ee651c3
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/tools/search/wikipedia.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ describe("Wikipedia", () => {
it("Serializes", async () => {
const instance = new WikipediaTool({
extraction: {
fields: ["infobox"],
fields: { infobox: {} },
},
});
await instance.run({ query: "Prague" });
Expand Down
116 changes: 79 additions & 37 deletions src/tools/search/wikipedia.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import { Tool, ToolInput } from "@/tools/base.js";
import Turndown from "turndown";
// @ts-expect-error missing types
import turndownPlugin from "joplin-turndown-plugin-gfm";
import { keys, mapValues } from "remeda";

wiki.default.setLang("en");

Expand All @@ -42,10 +43,20 @@ export interface FilterOptions {
minPageNameSimilarity?: number;
}

export type PageFunctions = pageFunctions | "markdown";
export type PageFunctions = Record<
pageFunctions,
{
transform?: <T>(output: T) => T;
}
> & {
markdown: {
transform?: <T>(output: T) => T;
filter?: (node: HTMLElement) => boolean;
};
};

export interface ExtractionOptions {
fields?: PageFunctions[];
fields?: Partial<PageFunctions>;
}

export interface OutputOptions {
Expand All @@ -68,7 +79,7 @@ export interface WikipediaToolRunOptions extends SearchToolRunOptions {
}

export interface WikipediaToolResult extends SearchToolResult {
fields: Record<string, unknown>;
fields: Partial<Record<keyof PageFunctions, unknown>>;
}

export class WikipediaToolOutput extends SearchToolOutput<WikipediaToolResult> {
Expand Down Expand Up @@ -133,14 +144,14 @@ export class WikipediaTool extends Tool<

@Cache()
protected get _mappers(): Record<
PageFunctions,
keyof PageFunctions,
(page: Page, runOptions: WikipediaToolRunOptions) => Promise<any>
> {
return {
categories: (page) => page.categories(),
content: (page) => page.content(),
html: (page) => page.html(),
markdown: async (page) => {
markdown: async (page, runOptions) => {
const html = await page.html().then((result) => {
const url = new URL(page.fullurl);
const base = `${url.protocol}//${[url.hostname, url.port].filter(Boolean).join(":")}`;
Expand All @@ -159,30 +170,7 @@ export class WikipediaTool extends Tool<
const service = new Turndown();
service.use(turndownPlugin.gfm);
return service
.remove((node): boolean => {
switch (node.tagName.toLowerCase()) {
case "link":
case "style":
return true;
default:
return (
[
"toc",
"reflist",
"mw-references-wrap",
"navbox",
"navbox-styles",
"mw-editsection",
"sistersitebox",
"navbox-inner",
"refbegin",
"notpageimage",
"mw-file-element",
].some((cls) => node.className.includes(cls)) ||
["navigation"].some((role) => node.role === role)
);
}
})
.remove((node) => runOptions.extraction?.fields?.markdown?.filter?.(node) === false)
.turndown(html);
},
images: (page) => page.images(),
Expand All @@ -200,9 +188,64 @@ export class WikipediaTool extends Tool<

@Cache()
protected get _defaultRunOptions(): WikipediaToolRunOptions {
const ignoredTags = new Set([
"a",
"img",
"link",
"style",
"abbr",
"cite",
"input",
"sup",
"bdi",
"q",
"figure",
"audio",
"track",
"figcaption",
"small",
]);
const ignoredTagsSelector = Array.from(ignoredTags.values()).join(",");

return {
extraction: {
fields: ["markdown"],
fields: {
markdown: {
filter: (node) => {
const tagName = node.tagName.toLowerCase();
if (ignoredTags.has(tagName)) {
return false;
}

if (ignoredTagsSelector) {
for (const childNode of node.querySelectorAll(ignoredTagsSelector)) {
childNode.remove();
}
}

if (node.children.length === 0) {
return false;
}

return (
[
"toc",
"reflist",
"mw-references-wrap",
"navbox",
"navbox-styles",
"mw-editsection",
"sistersitebox",
"navbox-inner",
"refbegin",
"notpageimage",
"mw-file-element",
].every((cls) => !node.className.includes(cls)) &&
["navigation"].every((role) => node.role !== role)
);
},
},
},
},
filters: {
minPageNameSimilarity: 0.5,
Expand Down Expand Up @@ -275,9 +318,7 @@ export class WikipediaTool extends Tool<
const page = await wiki.default.page(pageId, {
redirect: true,
preload: false,
fields: (runOptions?.extraction?.fields ?? []).filter(
(field): field is pageFunctions => field !== "markdown",
),
fields: keys(runOptions.extraction?.fields ?? {}).filter((key) => key !== "markdown"),
});

return asyncProperties({
Expand All @@ -290,10 +331,11 @@ export class WikipediaTool extends Tool<
})(),
url: page.fullurl,
fields: asyncProperties(
R.mapToObj(runOptions?.extraction?.fields || [], (key) => [
key,
this._mappers[key](page, runOptions).catch(() => null),
]),
mapValues(runOptions?.extraction?.fields ?? {}, (value, key) =>
this._mappers[key](page, runOptions)
.then((response) => (value.transform ? value.transform(response) : response))
.catch(() => null),
),
),
});
}),
Expand Down
40 changes: 40 additions & 0 deletions tests/e2e/tools/wikipedia.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { WikipediaTool } from "@/tools/search/wikipedia.js";
import { expect } from "vitest";

describe("Wikipedia", () => {
it("Retrieves data", async () => {
const instance = new WikipediaTool();
const response = await instance.run({ query: "Molecule" });

expect(response.results).toHaveLength(1);
const result = response.results[0];
expect(result).toBeTruthy();
expect(result).toMatchObject({
title: expect.any(String),
description: expect.any(String),
url: expect.any(String),
fields: expect.any(Object),
});

const markdown = response.results[0].fields!.markdown;
expect(markdown).toBeTruthy();
expect(markdown).not.match(/<a\s+[^>]*href=["'][^"']*["'][^>]*>/gim);
expect(markdown).not.match(/<img\s+[^>]*src=["'][^"']*["'][^>]*>/gim);
});
});

0 comments on commit ee651c3

Please sign in to comment.