From 0e3b32cc956943e5655307d12fa6f07334a6a038 Mon Sep 17 00:00:00 2001 From: Mahmoud Abughali Date: Mon, 4 Nov 2024 11:07:48 -0500 Subject: [PATCH] feat(tool): add elasticsearch tool --- .env.template | 4 + docs/tools.md | 29 ++-- examples/agents/elasticsearch.ts | 60 +++++++ package.json | 2 + src/tools/database/elasticsearch.test.ts | 128 +++++++++++++++ src/tools/database/elasticsearch.ts | 198 +++++++++++++++++++++++ tests/e2e/utils.ts | 2 + yarn.lock | 48 ++++++ 8 files changed, 457 insertions(+), 14 deletions(-) create mode 100644 examples/agents/elasticsearch.ts create mode 100644 src/tools/database/elasticsearch.test.ts create mode 100644 src/tools/database/elasticsearch.ts diff --git a/.env.template b/.env.template index a0ec02d2..8c49d093 100644 --- a/.env.template +++ b/.env.template @@ -21,3 +21,7 @@ BEE_FRAMEWORK_LOG_SINGLE_LINE="false" # For Google Search Tool # GOOGLE_API_KEY=your-google-api-key # GOOGLE_CSE_ID=your-custom-search-engine-id + +# For Elasticsearch Tool +# ELASTICSEARCH_NODE= +# ELASTICSEARCH_API_KEY= \ No newline at end of file diff --git a/docs/tools.md b/docs/tools.md index d27ffc32..79b9fcde 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -10,20 +10,21 @@ These tools extend the agent's abilities, allowing it to interact with external ## Built-in tools -| Name | Description | -| ------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------- | -| `PythonTool` | Run arbitrary Python code in the remote environment. | -| `WikipediaTool` | Search for data on Wikipedia. | -| `GoogleSearchTool` | Search for data on Google using Custom Search Engine. | -| `DuckDuckGoTool` | Search for data on DuckDuckGo. | -| [`SQLTool`](./sql-tool.md) | Execute SQL queries against relational databases. Instructions can be found [here](./sql-tool.md). | -| `CustomTool` | Run your own Python function in the remote environment. | -| `LLMTool` | Use an LLM to process input data. | -| `DynamicTool` | Construct to create dynamic tools. | -| `ArXivTool` | Retrieve research articles published on arXiv. | -| `WebCrawlerTool` | Retrieve content of an arbitrary website. | -| `OpenMeteoTool` | Retrieve current, previous, or upcoming weather for a given destination. | -| ➕ [Request](https://github.com/i-am-bee/bee-agent-framework/discussions) | | +| Name | Description | +| ------------------------------------------------------------------------- | ------------------------------------------------------------------------ | +| `PythonTool` | Run arbitrary Python code in the remote environment. | +| `WikipediaTool` | Search for data on Wikipedia. | +| `GoogleSearchTool` | Search for data on Google using Custom Search Engine. | +| `DuckDuckGoTool` | Search for data on DuckDuckGo. | +| [`SQLTool`](./sql-tool.md) | Execute SQL queries against relational databases. | +| `ElasticSearchTool` | Perform search or aggregation queries against an ElasticSearch database. | +| `CustomTool` | Run your own Python function in the remote environment. | +| `LLMTool` | Use an LLM to process input data. | +| `DynamicTool` | Construct to create dynamic tools. | +| `ArXivTool` | Retrieve research articles published on arXiv. | +| `WebCrawlerTool` | Retrieve content of an arbitrary website. | +| `OpenMeteoTool` | Retrieve current, previous, or upcoming weather for a given destination. | +| ➕ [Request](https://github.com/i-am-bee/bee-agent-framework/discussions) | | All examples can be found [here](/examples/tools). diff --git a/examples/agents/elasticsearch.ts b/examples/agents/elasticsearch.ts new file mode 100644 index 00000000..617d6599 --- /dev/null +++ b/examples/agents/elasticsearch.ts @@ -0,0 +1,60 @@ +import "dotenv/config.js"; +import { BeeAgent } from "bee-agent-framework/agents/bee/agent"; +import { OpenAIChatLLM } from "bee-agent-framework/adapters/openai/chat"; +import { ElasticSearchTool } from "bee-agent-framework/tools/database/elasticsearch"; +import { FrameworkError } from "bee-agent-framework/errors"; +import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory"; + +const llm = new OpenAIChatLLM({ + parameters: { + temperature: 0, + }, +}); + +const elasticSearchTool = new ElasticSearchTool({ + connection: { + node: process.env.ELASTICSEARCH_NODE, + auth: { + apiKey: process.env.ELASTICSEARCH_API_KEY || "", + }, + }, +}); + +const agent = new BeeAgent({ + llm, + memory: new UnconstrainedMemory(), + tools: [elasticSearchTool], +}); + +const question = "what is the average ticket price of all flights from Cape Town to Venice"; + +try { + const response = await agent + .run( + { prompt: `${question}` }, + { + execution: { + maxRetriesPerStep: 5, + totalMaxRetries: 10, + maxIterations: 15, + }, + }, + ) + .observe((emitter) => { + emitter.on("error", ({ error }) => { + console.log(`Agent 🤖 : `, FrameworkError.ensure(error).dump()); + }); + emitter.on("retry", () => { + console.log(`Agent 🤖 : `, "retrying the action..."); + }); + emitter.on("update", async ({ data, update, meta }) => { + console.log(`Agent (${update.key}) 🤖 : `, update.value); + }); + }); + + console.log(`Agent 🤖 : `, response.result.text); +} catch (error) { + console.error(FrameworkError.ensure(error).dump()); +} finally { + process.exit(0); +} diff --git a/package.json b/package.json index fd0e3885..bebb4dab 100644 --- a/package.json +++ b/package.json @@ -171,6 +171,7 @@ "zod-to-json-schema": "^3.23.3" }, "peerDependencies": { + "@elastic/elasticsearch": "^8.15.1", "@googleapis/customsearch": "^3.2.0", "@grpc/grpc-js": "^1.11.3", "@grpc/proto-loader": "^0.7.13", @@ -187,6 +188,7 @@ "devDependencies": { "@commitlint/cli": "^19.5.0", "@commitlint/config-conventional": "^19.5.0", + "@elastic/elasticsearch": "^8.15.1", "@eslint/js": "^9.13.0", "@eslint/markdown": "^6.2.1", "@googleapis/customsearch": "^3.2.0", diff --git a/src/tools/database/elasticsearch.test.ts b/src/tools/database/elasticsearch.test.ts new file mode 100644 index 00000000..03aa5d80 --- /dev/null +++ b/src/tools/database/elasticsearch.test.ts @@ -0,0 +1,128 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, it, expect, beforeEach, vi } from "vitest"; +import { ElasticSearchTool, ElasticSearchToolOptions } from "@/tools/database/elasticsearch.js"; +import { verifyDeserialization } from "@tests/e2e/utils.js"; +import { JSONToolOutput } from "@/tools/base.js"; +import { SlidingCache } from "@/cache/slidingCache.js"; +import { Task } from "promise-based-task"; + +vi.mock("@elastic/elasticsearch"); + +describe("ElasticSearchTool", () => { + let elasticSearchTool: ElasticSearchTool; + const mockClient = { + cat: { indices: vi.fn() }, + indices: { getMapping: vi.fn() }, + search: vi.fn(), + }; + + beforeEach(() => { + vi.clearAllMocks(); + elasticSearchTool = new ElasticSearchTool({ + connection: { node: "http://localhost:9200" }, + } as ElasticSearchToolOptions); + + Object.defineProperty(elasticSearchTool, "client", { + get: () => mockClient, + }); + }); + + it("lists indices correctly", async () => { + const mockIndices = [{ index: "index1" }, { index: "index2" }]; + mockClient.cat.indices.mockResolvedValueOnce(mockIndices); + + const response = await elasticSearchTool.run({ action: "LIST_INDICES" }); + expect(response.result).toEqual([{ index: "index1" }, { index: "index2" }]); + }); + + it("gets index details", async () => { + const indexName = "index1"; + const mockIndexDetails = { + [indexName]: { mappings: { properties: { field1: { type: "text" } } } }, + }; + mockClient.indices.getMapping.mockResolvedValueOnce(mockIndexDetails); + + const response = await elasticSearchTool.run({ action: "GET_INDEX_DETAILS", indexName }); + expect(response.result).toEqual(mockIndexDetails); + expect(mockClient.indices.getMapping).toHaveBeenCalledWith( + { index: indexName }, + { signal: undefined }, + ); + }); + + it("performs a search", async () => { + const indexName = "index1"; + const query = JSON.stringify({ query: { match_all: {} } }); + const mockSearchResponse = { hits: { hits: [{ _source: { field1: "value1" } }] } }; + mockClient.search.mockResolvedValueOnce(mockSearchResponse); + + const response = await elasticSearchTool.run({ + action: "SEARCH", + indexName, + query, + start: 0, + size: 1, + }); + expect(response.result).toEqual([{ field1: "value1" }]); + }); + + it("throws invalid JSON format error", async () => { + await expect(async () => { + await elasticSearchTool.run({ action: "SEARCH", indexName: "index1", query: "invalid" }); + }).rejects.toThrowError( + expect.objectContaining({ + message: expect.stringContaining("Invalid JSON format for query"), + }), + ); + }); + + it("throws missing index name error", async () => { + await expect(elasticSearchTool.run({ action: "GET_INDEX_DETAILS" })).rejects.toThrow( + "Index name is required for GET_INDEX_DETAILS action.", + ); + }); + + it("throws missing index and query error", async () => { + await expect(elasticSearchTool.run({ action: "SEARCH" })).rejects.toThrow( + "Both index name and query are required for SEARCH action.", + ); + }); + + it("serializes", async () => { + const elasticSearchTool = new ElasticSearchTool({ + connection: { node: "http://localhost:9200" }, + cache: new SlidingCache({ + size: 10, + ttl: 1000, + }), + }); + + await elasticSearchTool.cache!.set( + "connection", + Task.resolve(new JSONToolOutput([{ index: "index1", detail: "sample" }])), + ); + + const serialized = elasticSearchTool.serialize(); + const deserializedTool = ElasticSearchTool.fromSerialized(serialized); + + expect(await deserializedTool.cache.get("connection")).toStrictEqual( + await elasticSearchTool.cache.get("connection"), + ); + verifyDeserialization(elasticSearchTool, deserializedTool); + }); +}); diff --git a/src/tools/database/elasticsearch.ts b/src/tools/database/elasticsearch.ts new file mode 100644 index 00000000..54421b2c --- /dev/null +++ b/src/tools/database/elasticsearch.ts @@ -0,0 +1,198 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + Tool, + ToolInput, + ToolError, + BaseToolOptions, + BaseToolRunOptions, + JSONToolOutput, +} from "@/tools/base.js"; +import { Cache } from "@/cache/decoratorCache.js"; +import { z } from "zod"; +import { Client, ClientOptions } from "@elastic/elasticsearch"; +import { + CatIndicesResponse, + IndicesGetMappingResponse, + SearchRequest, + SearchResponse, + SearchHit, +} from "@elastic/elasticsearch/lib/api/types.js"; +import { ValidationError } from "ajv"; + +type ToolRunOptions = BaseToolRunOptions; + +export interface ElasticSearchToolOptions extends BaseToolOptions { + connection: ClientOptions; +} + +export class ElasticSearchTool extends Tool< + JSONToolOutput, + ElasticSearchToolOptions, + ToolRunOptions +> { + name = "ElasticSearchTool"; + + description = `Can query data from an ElasticSearch database. IMPORTANT: strictly follow this order of actions: + 1. LIST_INDICES - retrieve a list of available indices + 2. GET_INDEX_DETAILS - get details of index fields + 3. SEARCH - perform search or aggregation query on a specific index or pass the original user query without modifications if it's a valid JSON ElasticSearch query`; + + inputSchema() { + return z.object({ + action: z + .enum(["LIST_INDICES", "GET_INDEX_DETAILS", "SEARCH"]) + .describe( + "The action to perform. LIST_INDICES lists all indices, GET_INDEX_DETAILS fetches details for a specified index, and SEARCH executes a search or aggregation query", + ), + indexName: z + .string() + .optional() + .describe("The name of the index to query, required for GET_INDEX_DETAILS and SEARCH"), + query: z + .string() + .optional() + .describe("Valid ElasticSearch JSON search or aggregation query for SEARCH action"), + start: z.coerce + .number() + .int() + .min(0) + .default(0) + .optional() + .describe( + "The record index from which the query will start. Increase by the size of the query to get the next page of results", + ), + size: z.coerce + .number() + .int() + .min(0) + .max(10) + .default(10) + .optional() + .describe("How many records will be retrieved from the ElasticSearch query. Maximum is 10"), + }); + } + + static { + this.register(); + } + + public constructor(options: ElasticSearchToolOptions) { + super(options); + if (!options.connection.cloud && !options.connection.node && !options.connection.nodes) { + throw new ValidationError([ + { + message: "At least one of the properties must be provided", + propertyName: "connection.cloud, connection.node, connection.nodes", + }, + ]); + } + } + + @Cache() + protected get client(): Client { + try { + return new Client(this.options.connection); + } catch (error) { + throw new ToolError(`Unable to connect to ElasticSearch: ${error}`, [], { + isRetryable: false, + isFatal: true, + }); + } + } + + protected async _run( + input: ToolInput, + _options?: ToolRunOptions, + ): Promise> { + if (input.action === "LIST_INDICES") { + const indices = await this.listIndices(_options?.signal); + return new JSONToolOutput(indices); + } else if (input.action === "GET_INDEX_DETAILS") { + const indexDetails = await this.getIndexDetails(input, _options?.signal); + return new JSONToolOutput(indexDetails); + } else if (input.action === "SEARCH") { + const response = await this.search(input, _options?.signal); + if (response.aggregations) { + return new JSONToolOutput(response.aggregations); + } else { + return new JSONToolOutput(response.hits.hits.map((hit: SearchHit) => hit._source)); + } + } else { + throw new ToolError("Invalid action specified."); + } + } + + private async listIndices(signal?: AbortSignal): Promise { + const response = await this.client.cat.indices( + { + expand_wildcards: "open", + h: "index", + format: "json", + }, + { signal: signal }, + ); + return response + .filter((record) => record.index && !record.index.startsWith(".")) // Exclude system indices + .map((record) => ({ index: record.index })); + } + + private async getIndexDetails( + input: ToolInput, + signal?: AbortSignal, + ): Promise { + if (!input.indexName) { + throw new ToolError("Index name is required for GET_INDEX_DETAILS action."); + } + return await this.client.indices.getMapping( + { + index: input.indexName, + }, + { signal: signal }, + ); + } + + private async search(input: ToolInput, signal?: AbortSignal): Promise { + if (!input.indexName || !input.query) { + throw new ToolError("Both index name and query are required for SEARCH action."); + } + let parsedQuery; + try { + parsedQuery = JSON.parse(input.query); + } catch { + throw new ToolError(`Invalid JSON format for query`); + } + + const searchBody: SearchRequest = { + ...parsedQuery, + from: parsedQuery.from || input.start, + size: parsedQuery.size || input.size, + }; + + return await this.client.search( + { + index: input.indexName, + body: searchBody, + }, + { signal: signal }, + ); + } + + loadSnapshot({ ...snapshot }: ReturnType): void { + super.loadSnapshot(snapshot); + } +} diff --git a/tests/e2e/utils.ts b/tests/e2e/utils.ts index 192129ec..4a03f7b6 100644 --- a/tests/e2e/utils.ts +++ b/tests/e2e/utils.ts @@ -31,6 +31,7 @@ import { OpenAI } from "openai"; import { Groq } from "groq-sdk"; import { customsearch_v1 } from "@googleapis/customsearch"; import { LangChainTool } from "@/adapters/langchain/tools.js"; +import { Client as esClient } from "@elastic/elasticsearch"; interface CallbackOptions { required?: boolean; @@ -127,6 +128,7 @@ verifyDeserialization.ignoredClasses = [ LCBaseLLM, RunContext, Emitter, + esClient, ] as ClassConstructor[]; verifyDeserialization.isIgnored = (key: string, value: unknown, parent?: any) => { if (verifyDeserialization.ignoredKeys.has(key)) { diff --git a/yarn.lock b/yarn.lock index 24e618df..4db5411d 100644 --- a/yarn.lock +++ b/yarn.lock @@ -327,6 +327,31 @@ __metadata: languageName: node linkType: hard +"@elastic/elasticsearch@npm:^8.15.1": + version: 8.15.1 + resolution: "@elastic/elasticsearch@npm:8.15.1" + dependencies: + "@elastic/transport": "npm:^8.8.1" + tslib: "npm:^2.4.0" + checksum: 10c0/6fed56487e0bd5c2e8a54e794cd1ebe58e0ff40319b4b8d10ef3cb45534a572fd4d6d5aff5ca63707aaf4be7abbc073ba5bf414cea129f86facc5b1c576cb815 + languageName: node + linkType: hard + +"@elastic/transport@npm:^8.8.1": + version: 8.9.1 + resolution: "@elastic/transport@npm:8.9.1" + dependencies: + "@opentelemetry/api": "npm:1.x" + debug: "npm:^4.3.4" + hpagent: "npm:^1.0.0" + ms: "npm:^2.1.3" + secure-json-parse: "npm:^2.4.0" + tslib: "npm:^2.4.0" + undici: "npm:^6.12.0" + checksum: 10c0/a79fee3091dd9b9cfa70af5835ac8b362f43e783cbe28112312a3759944066613305764ce19c22941dfaf6e7157c4398eaadd1ac65b22ed8cdcf165a92f553c4 + languageName: node + linkType: hard + "@esbuild/aix-ppc64@npm:0.21.5": version: 0.21.5 resolution: "@esbuild/aix-ppc64@npm:0.21.5" @@ -1587,6 +1612,13 @@ __metadata: languageName: node linkType: hard +"@opentelemetry/api@npm:1.x": + version: 1.9.0 + resolution: "@opentelemetry/api@npm:1.9.0" + checksum: 10c0/9aae2fe6e8a3a3eeb6c1fdef78e1939cf05a0f37f8a4fae4d6bf2e09eb1e06f966ece85805626e01ba5fab48072b94f19b835449e58b6d26720ee19a58298add + languageName: node + linkType: hard + "@pkgjs/parseargs@npm:^0.11.0": version: 0.11.0 resolution: "@pkgjs/parseargs@npm:0.11.0" @@ -2907,6 +2939,7 @@ __metadata: "@commitlint/config-conventional": "npm:^19.5.0" "@connectrpc/connect": "npm:^1.6.1" "@connectrpc/connect-node": "npm:^1.6.1" + "@elastic/elasticsearch": "npm:^8.15.1" "@eslint/js": "npm:^9.13.0" "@eslint/markdown": "npm:^6.2.1" "@googleapis/customsearch": "npm:^3.2.0" @@ -2991,6 +3024,7 @@ __metadata: zod: "npm:^3.23.8" zod-to-json-schema: "npm:^3.23.3" peerDependencies: + "@elastic/elasticsearch": ^8.15.1 "@googleapis/customsearch": ^3.2.0 "@grpc/grpc-js": ^1.11.3 "@grpc/proto-loader": ^0.7.13 @@ -6035,6 +6069,13 @@ __metadata: languageName: node linkType: hard +"hpagent@npm:^1.0.0": + version: 1.2.0 + resolution: "hpagent@npm:1.2.0" + checksum: 10c0/505ef42e5e067dba701ea21e7df9fa73f6f5080e59d53680829827d34cd7040f1ecf7c3c8391abe9df4eb4682ef4a4321608836b5b70a61b88c1b3a03d77510b + languageName: node + linkType: hard + "html-entities@npm:^2.3.3, html-entities@npm:^2.5.2": version: 2.5.2 resolution: "html-entities@npm:2.5.2" @@ -11330,6 +11371,13 @@ __metadata: languageName: node linkType: hard +"undici@npm:^6.12.0": + version: 6.20.1 + resolution: "undici@npm:6.20.1" + checksum: 10c0/b2c8d5adcd226c53d02f9270e4cac277256a7147cf310af319369ec6f87651ca46b2960366cb1339a6dac84d937e01e8cdbec5cb468f1f1ce5e9490e438d7222 + languageName: node + linkType: hard + "unicorn-magic@npm:^0.1.0": version: 0.1.0 resolution: "unicorn-magic@npm:0.1.0"