Skip to content

Commit

Permalink
feat(tool): add elasticsearch tool
Browse files Browse the repository at this point in the history
  • Loading branch information
Mahmoud Abughali committed Nov 4, 2024
1 parent aabfe04 commit 0e3b32c
Show file tree
Hide file tree
Showing 8 changed files with 457 additions and 14 deletions.
4 changes: 4 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ BEE_FRAMEWORK_LOG_SINGLE_LINE="false"
# For Google Search Tool
# GOOGLE_API_KEY=your-google-api-key
# GOOGLE_CSE_ID=your-custom-search-engine-id

# For Elasticsearch Tool
# ELASTICSEARCH_NODE=
# ELASTICSEARCH_API_KEY=
29 changes: 15 additions & 14 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@ These tools extend the agent's abilities, allowing it to interact with external

## Built-in tools

| Name | Description |
| ------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------- |
| `PythonTool` | Run arbitrary Python code in the remote environment. |
| `WikipediaTool` | Search for data on Wikipedia. |
| `GoogleSearchTool` | Search for data on Google using Custom Search Engine. |
| `DuckDuckGoTool` | Search for data on DuckDuckGo. |
| [`SQLTool`](./sql-tool.md) | Execute SQL queries against relational databases. Instructions can be found [here](./sql-tool.md). |
| `CustomTool` | Run your own Python function in the remote environment. |
| `LLMTool` | Use an LLM to process input data. |
| `DynamicTool` | Construct to create dynamic tools. |
| `ArXivTool` | Retrieve research articles published on arXiv. |
| `WebCrawlerTool` | Retrieve content of an arbitrary website. |
| `OpenMeteoTool` | Retrieve current, previous, or upcoming weather for a given destination. |
|[Request](https://github.com/i-am-bee/bee-agent-framework/discussions) | |
| Name | Description |
| ------------------------------------------------------------------------- | ------------------------------------------------------------------------ |
| `PythonTool` | Run arbitrary Python code in the remote environment. |
| `WikipediaTool` | Search for data on Wikipedia. |
| `GoogleSearchTool` | Search for data on Google using Custom Search Engine. |
| `DuckDuckGoTool` | Search for data on DuckDuckGo. |
| [`SQLTool`](./sql-tool.md) | Execute SQL queries against relational databases. |
| `ElasticSearchTool` | Perform search or aggregation queries against an ElasticSearch database. |
| `CustomTool` | Run your own Python function in the remote environment. |
| `LLMTool` | Use an LLM to process input data. |
| `DynamicTool` | Construct to create dynamic tools. |
| `ArXivTool` | Retrieve research articles published on arXiv. |
| `WebCrawlerTool` | Retrieve content of an arbitrary website. |
| `OpenMeteoTool` | Retrieve current, previous, or upcoming weather for a given destination. |
|[Request](https://github.com/i-am-bee/bee-agent-framework/discussions) | |

All examples can be found [here](/examples/tools).

Expand Down
60 changes: 60 additions & 0 deletions examples/agents/elasticsearch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import "dotenv/config.js";
import { BeeAgent } from "bee-agent-framework/agents/bee/agent";
import { OpenAIChatLLM } from "bee-agent-framework/adapters/openai/chat";
import { ElasticSearchTool } from "bee-agent-framework/tools/database/elasticsearch";
import { FrameworkError } from "bee-agent-framework/errors";
import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory";

const llm = new OpenAIChatLLM({
parameters: {
temperature: 0,
},
});

const elasticSearchTool = new ElasticSearchTool({
connection: {
node: process.env.ELASTICSEARCH_NODE,
auth: {
apiKey: process.env.ELASTICSEARCH_API_KEY || "",
},
},
});

const agent = new BeeAgent({
llm,
memory: new UnconstrainedMemory(),
tools: [elasticSearchTool],
});

const question = "what is the average ticket price of all flights from Cape Town to Venice";

try {
const response = await agent
.run(
{ prompt: `${question}` },
{
execution: {
maxRetriesPerStep: 5,
totalMaxRetries: 10,
maxIterations: 15,
},
},
)
.observe((emitter) => {
emitter.on("error", ({ error }) => {
console.log(`Agent 🤖 : `, FrameworkError.ensure(error).dump());
});
emitter.on("retry", () => {
console.log(`Agent 🤖 : `, "retrying the action...");
});
emitter.on("update", async ({ data, update, meta }) => {
console.log(`Agent (${update.key}) 🤖 : `, update.value);
});
});

console.log(`Agent 🤖 : `, response.result.text);
} catch (error) {
console.error(FrameworkError.ensure(error).dump());
} finally {
process.exit(0);
}
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
"zod-to-json-schema": "^3.23.3"
},
"peerDependencies": {
"@elastic/elasticsearch": "^8.15.1",
"@googleapis/customsearch": "^3.2.0",
"@grpc/grpc-js": "^1.11.3",
"@grpc/proto-loader": "^0.7.13",
Expand All @@ -187,6 +188,7 @@
"devDependencies": {
"@commitlint/cli": "^19.5.0",
"@commitlint/config-conventional": "^19.5.0",
"@elastic/elasticsearch": "^8.15.1",
"@eslint/js": "^9.13.0",
"@eslint/markdown": "^6.2.1",
"@googleapis/customsearch": "^3.2.0",
Expand Down
128 changes: 128 additions & 0 deletions src/tools/database/elasticsearch.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { describe, it, expect, beforeEach, vi } from "vitest";
import { ElasticSearchTool, ElasticSearchToolOptions } from "@/tools/database/elasticsearch.js";
import { verifyDeserialization } from "@tests/e2e/utils.js";
import { JSONToolOutput } from "@/tools/base.js";
import { SlidingCache } from "@/cache/slidingCache.js";
import { Task } from "promise-based-task";

vi.mock("@elastic/elasticsearch");

describe("ElasticSearchTool", () => {
let elasticSearchTool: ElasticSearchTool;
const mockClient = {
cat: { indices: vi.fn() },
indices: { getMapping: vi.fn() },
search: vi.fn(),
};

beforeEach(() => {
vi.clearAllMocks();
elasticSearchTool = new ElasticSearchTool({
connection: { node: "http://localhost:9200" },
} as ElasticSearchToolOptions);

Object.defineProperty(elasticSearchTool, "client", {
get: () => mockClient,
});
});

it("lists indices correctly", async () => {
const mockIndices = [{ index: "index1" }, { index: "index2" }];
mockClient.cat.indices.mockResolvedValueOnce(mockIndices);

const response = await elasticSearchTool.run({ action: "LIST_INDICES" });
expect(response.result).toEqual([{ index: "index1" }, { index: "index2" }]);
});

it("gets index details", async () => {
const indexName = "index1";
const mockIndexDetails = {
[indexName]: { mappings: { properties: { field1: { type: "text" } } } },
};
mockClient.indices.getMapping.mockResolvedValueOnce(mockIndexDetails);

const response = await elasticSearchTool.run({ action: "GET_INDEX_DETAILS", indexName });
expect(response.result).toEqual(mockIndexDetails);
expect(mockClient.indices.getMapping).toHaveBeenCalledWith(
{ index: indexName },
{ signal: undefined },
);
});

it("performs a search", async () => {
const indexName = "index1";
const query = JSON.stringify({ query: { match_all: {} } });
const mockSearchResponse = { hits: { hits: [{ _source: { field1: "value1" } }] } };
mockClient.search.mockResolvedValueOnce(mockSearchResponse);

const response = await elasticSearchTool.run({
action: "SEARCH",
indexName,
query,
start: 0,
size: 1,
});
expect(response.result).toEqual([{ field1: "value1" }]);
});

it("throws invalid JSON format error", async () => {
await expect(async () => {
await elasticSearchTool.run({ action: "SEARCH", indexName: "index1", query: "invalid" });
}).rejects.toThrowError(
expect.objectContaining({
message: expect.stringContaining("Invalid JSON format for query"),
}),
);
});

it("throws missing index name error", async () => {
await expect(elasticSearchTool.run({ action: "GET_INDEX_DETAILS" })).rejects.toThrow(
"Index name is required for GET_INDEX_DETAILS action.",
);
});

it("throws missing index and query error", async () => {
await expect(elasticSearchTool.run({ action: "SEARCH" })).rejects.toThrow(
"Both index name and query are required for SEARCH action.",
);
});

it("serializes", async () => {
const elasticSearchTool = new ElasticSearchTool({
connection: { node: "http://localhost:9200" },
cache: new SlidingCache({
size: 10,
ttl: 1000,
}),
});

await elasticSearchTool.cache!.set(
"connection",
Task.resolve(new JSONToolOutput([{ index: "index1", detail: "sample" }])),
);

const serialized = elasticSearchTool.serialize();
const deserializedTool = ElasticSearchTool.fromSerialized(serialized);

expect(await deserializedTool.cache.get("connection")).toStrictEqual(
await elasticSearchTool.cache.get("connection"),
);
verifyDeserialization(elasticSearchTool, deserializedTool);
});
});
Loading

0 comments on commit 0e3b32c

Please sign in to comment.