Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: Refactor Retriever/Function tools into single Tool concept #154

Merged
merged 6 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 24 additions & 111 deletions docs/custom_tool_guides/tool_guide.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Custom tools and retrieval sources
# Custom Tools
Follow these instructions to create your own custom tools.

Custom tools will need to be built in the `community` folder. Make sure you've enabled the `INSTALL_COMMUNITY_DEPS` build arg in the `docker-compose.yml` file by setting it to `true`.
Expand Down Expand Up @@ -27,115 +27,49 @@ There are three types of tools:

## Step 3: Implement the Tool

Add your tool implementation [here](https://github.com/cohere-ai/toolkit/tree/main/src/community/tools) (please note that this link might change). The specific subfolder used will depend on the type of tool you're implementing.
Add your tool implementation [here](https://github.com/cohere-ai/toolkit/tree/main/src/community/tools) (please note that this link is subject to change).

If you need to install a new module to run your tool, execute the following command and run `make dev` again.
If you need to install a new library to run your tool, execute the following command and run `make dev` again.

```bash
poetry add <MODULE> --group community
```
### Implementing a Tool

If you're working on a File or Data Loader, follow the steps outlined in [Implementing a Retriever](#implementing-a-retriever).
Add the implementation inside a tool class that inherits from `BaseTool`. This class will need to implement the `call()` method, which should return a list of dictionary results.

If you're implementing a Function Tool, refer to the steps in [Implementing a Function Tool](#implementing-a-function-tool).
Note: To enable citations, each result in the list should contain a "text" field.

### Implementing a Retriever

Add the implementation inside a tool class that inherits `BaseRetrieval` and needs to implement the function `def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:`

You can define custom configurations for your tool within the `__init__` function. Set the exact values for these variables during [Step 4](#step-4-making-your-tool-available).

You can also develop a tool that requires a token or authentication. To do this, simply set your variable in the .env file.

For example, for Wikipedia we have a custom configuration:

```python
class LangChainWikiRetriever(BaseRetrieval):
"""
This class retrieves documents from Wikipedia using the langchain package.
This requires wikipedia package to be installed.
"""

def __init__(self, chunk_size: int = 300, chunk_overlap: int = 0):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap

def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
wiki_retriever = WikipediaRetriever()
docs = wiki_retriever.get_relevant_documents(query)
text_splitter = CharacterTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
documents = text_splitter.split_documents(docs)
return [
{
"text": doc.page_content,
"title": doc.metadata.get("title", None),
"url": doc.metadata.get("source", None),
}
for doc in documents
]
```

And for internet search, we need an API key
For example, let's look at the community-implemented `ArxivRetriever`:

```python
class TavilyInternetSearch(BaseRetrieval):
def __init__(self):
if "TAVILY_API_KEY" not in os.environ:
raise ValueError("Please set the TAVILY_API_KEY environment variable.")
from typing import Any, Dict, List

self.api_key = os.environ["TAVILY_API_KEY"]
self.client = TavilyClient(api_key=self.api_key)
from langchain_community.utilities import ArxivAPIWrapper

def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
content = self.client.search(query=query, search_depth="advanced")

if "results" not in content:
return []

return [
{
"url": result["url"],
"text": result["content"],
}
for result in content["results"]
```
from community.tools import BaseTool

Note that all Retrievers should return a list of Dicts, and each Dict should contain at least a `text` key.

### Implementing a Function Tool

Add the implementation inside a tool class that inherits `BaseFunctionTool` and needs to implement the function `def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]:`

For example, for calculator

```python
from typing import Any
from py_expression_eval import Parser
from typing import List, Dict

from backend.tools.function_tools.base import BaseFunctionTool
class ArxivRetriever(BaseTool):
def __init__(self):
self.client = ArxivAPIWrapper()

class CalculatorFunctionTool(BaseFunctionTool):
"""
Function Tool that evaluates mathematical expressions.
"""
@classmethod
# If your tool requires any environment variables such as API keys,
# you will need to assert that they're not None here
def is_available(cls) -> bool:
return True

# Your tool needs to implement this call() method
def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]:
math_parser = Parser()
to_evaluate = parameters.get("code", "").replace("pi", "PI").replace("e", "E")
result = []
try:
result = {"result": math_parser.parse(to_evaluate).evaluate({})}
except Exception:
result = {"result": "Parsing error - syntax not allowed."}
return result
result = self.client.run(parameters)

return [{"text": result}] # <- Return list of results, in this case there is only one
```

## Step 4: Making Your Tool Available

To make your tool available, add its definition to the tools config [here](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py).
To make your tool available, add its definition to the community tools [config.py](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py).

Start by adding the tool name to the `ToolName` enum found at the top of the file.

Expand All @@ -151,27 +85,6 @@ Next, include the tool configurations in the `AVAILABLE_TOOLS` list. The definit
- Description: A brief description of the tool.
- Env_vars: A list of secrets required by the tool.

Function tool with custom parameter definitions:

```python
ToolName.Python_Interpreter: ManagedTool(
name=ToolName.Python_Interpreter,
implementation=PythonInterpreterFunctionTool,
parameter_definitions={
"code": {
"description": "Python code to execute using an interpreter",
"type": "str",
"required": True,
}
},
is_visible=True,
is_available=PythonInterpreterFunctionTool.is_available(),
error_message="PythonInterpreterFunctionTool not available, please make sure to set the PYTHON_INTERPRETER_URL environment variable.",
category=Category.Function,
description="Runs python code in a sandbox.",
)
```

## Step 5: Test Your Tool!

Now, when you run the toolkit, all the visible tools, including the one you just added, should be available!
Expand Down Expand Up @@ -209,4 +122,4 @@ curl --location 'http://localhost:8000/chat-stream' \

## Step 6 (extra): Add Unit tests

If you would like to go above and beyond, it would be helpful to add some unit tests to ensure that your tool is working as expected. Create a file [here](https://github.com/cohere-ai/cohere-toolkit/tree/main/src/community/tests/tools) and add a few cases.
If you would like to go above and beyond, it would be helpful to add some unit tests to ensure that your tool is working as expected. Create a file [here](https://github.com/cohere-ai/cohere-toolkit/tree/main/src/community/tests/tools) and add a few test cases.
File renamed without changes.
6 changes: 4 additions & 2 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from fastapi import HTTPException

from backend.chat.base import BaseChat
from backend.chat.collate import combine_documents
from backend.chat.custom.utils import get_deployment
from backend.config.tools import AVAILABLE_TOOLS, ToolName
from backend.model_deployments.base import BaseDeployment
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.tool import Category, Tool
from backend.services.logger import get_logger
from backend.tools.retrieval.collate import combine_documents


class CustomChat(BaseChat):
Expand Down Expand Up @@ -84,10 +84,12 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:

all_documents = {}
# TODO: call in parallel and error handling
# TODO: merge with regular function tools after multihop implemented
for retriever in retrievers:
for query in queries:
parameters = {"query": query}
all_documents.setdefault(query, []).extend(
retriever.retrieve_documents(query)
retriever.call(parameters)
)

# Collate Documents
Expand Down
39 changes: 29 additions & 10 deletions src/backend/config/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from enum import StrEnum

from backend.schemas.tool import Category, ManagedTool
from backend.tools.function_tools import (
CalculatorFunctionTool,
PythonInterpreterFunctionTool,
)
from backend.tools.retrieval import (
from backend.tools import (
Calculator,
LangChainVectorDBRetriever,
LangChainWikiRetriever,
PythonInterpreter,
TavilyInternetSearch,
)

tianjing-li marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -38,6 +36,13 @@ class ToolName(StrEnum):
ToolName.Wiki_Retriever_LangChain: ManagedTool(
name=ToolName.Wiki_Retriever_LangChain,
implementation=LangChainWikiRetriever,
parameter_definitions={
"query": {
"description": "Query for retrieval.",
"type": "str",
"required": True,
}
},
kwargs={"chunk_size": 300, "chunk_overlap": 0},
is_visible=True,
is_available=LangChainWikiRetriever.is_available(),
Expand All @@ -48,6 +53,13 @@ class ToolName(StrEnum):
ToolName.File_Upload_Langchain: ManagedTool(
name=ToolName.File_Upload_Langchain,
implementation=LangChainVectorDBRetriever,
parameter_definitions={
"query": {
"description": "Query for retrieval.",
"type": "str",
"required": True,
}
},
is_visible=True,
is_available=LangChainVectorDBRetriever.is_available(),
error_message="LangChainVectorDBRetriever not available, please make sure to set the COHERE_API_KEY environment variable.",
Expand All @@ -56,7 +68,7 @@ class ToolName(StrEnum):
),
ToolName.Python_Interpreter: ManagedTool(
name=ToolName.Python_Interpreter,
implementation=PythonInterpreterFunctionTool,
implementation=PythonInterpreter,
parameter_definitions={
"code": {
"description": "Python code to execute using an interpreter",
Expand All @@ -65,14 +77,14 @@ class ToolName(StrEnum):
}
},
is_visible=True,
is_available=PythonInterpreterFunctionTool.is_available(),
is_available=PythonInterpreter.is_available(),
error_message="PythonInterpreterFunctionTool not available, please make sure to set the PYTHON_INTERPRETER_URL environment variable.",
category=Category.Function,
description="Runs python code in a sandbox.",
),
ToolName.Calculator: ManagedTool(
name=ToolName.Calculator,
implementation=CalculatorFunctionTool,
implementation=Calculator,
parameter_definitions={
"code": {
"description": "Arithmetic expression to evaluate",
Expand All @@ -81,14 +93,21 @@ class ToolName(StrEnum):
}
},
is_visible=True,
is_available=CalculatorFunctionTool.is_available(),
error_message="CalculatorFunctionTool not available.",
is_available=Calculator.is_available(),
error_message="Calculator tool not available.",
category=Category.Function,
description="Evaluate arithmetic expressions.",
),
ToolName.Tavily_Internet_Search: ManagedTool(
name=ToolName.Tavily_Internet_Search,
implementation=TavilyInternetSearch,
parameter_definitions={
"query": {
"description": "Query for retrieval.",
"type": "str",
"required": True,
}
},
is_visible=True,
is_available=TavilyInternetSearch.is_available(),
error_message="TavilyInternetSearch not available, please make sure to set the TAVILY_API_KEY environment variable.",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from backend.tools.function_tools import CalculatorFunctionTool
from backend.tools import Calculator


def test_calculator() -> None:
calculator = CalculatorFunctionTool()
calculator = Calculator()
result = calculator.call({"code": "2+2"})
assert result == {"result": 4}


def test_calculator_invalid_syntax() -> None:
calculator = CalculatorFunctionTool()
calculator = Calculator()
result = calculator.call({"code": "2+"})
assert result == {"result": "Parsing error - syntax not allowed."}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import pytest

from backend.chat import collate
from backend.model_deployments import CohereDeployment
from backend.tools.retrieval import collate

is_cohere_env_set = (
os.environ.get("COHERE_API_KEY") is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
import pytest
from langchain_core.documents.base import Document

from backend.tools.retrieval.lang_chain import (
LangChainVectorDBRetriever,
LangChainWikiRetriever,
)
from backend.tools import LangChainVectorDBRetriever, LangChainWikiRetriever

is_cohere_env_set = (
os.environ.get("COHERE_API_KEY") is not None
Expand Down Expand Up @@ -53,10 +50,10 @@ def test_wiki_retriever() -> None:
wiki_retriever_mock.get_relevant_documents.return_value = mock_docs

with patch(
"backend.tools.retrieval.lang_chain.WikipediaRetriever",
"backend.tools.lang_chain.WikipediaRetriever",
return_value=wiki_retriever_mock,
):
result = retriever.retrieve_documents(query)
result = retriever.call({"query": query})

assert result == expected_docs

Expand All @@ -71,10 +68,10 @@ def test_wiki_retriever_no_docs() -> None:
wiki_retriever_mock.get_relevant_documents.return_value = mock_docs

with patch(
"backend.tools.retrieval.lang_chain.WikipediaRetriever",
"backend.tools.lang_chain.WikipediaRetriever",
return_value=wiki_retriever_mock,
):
result = retriever.retrieve_documents(query)
result = retriever.call({"query": query})

assert result == []

Expand Down Expand Up @@ -134,7 +131,7 @@ def test_vector_db_retriever() -> None:
mock_db = MagicMock()
mock_from_documents.return_value = mock_db
mock_db.as_retriever().get_relevant_documents.return_value = mock_docs
result = retriever.retrieve_documents(query)
result = retriever.call({"query": query})

assert result == expected_docs

Expand All @@ -155,6 +152,6 @@ def test_vector_db_retriever_no_docs() -> None:
mock_db = MagicMock()
mock_from_documents.return_value = mock_db
mock_db.as_retriever().get_relevant_documents.return_value = mock_docs
result = retriever.retrieve_documents(query)
result = retriever.call({"query": query})

assert result == []
12 changes: 12 additions & 0 deletions src/backend/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from backend.tools.calculator import Calculator
from backend.tools.lang_chain import LangChainVectorDBRetriever, LangChainWikiRetriever
from backend.tools.python_interpreter import PythonInterpreter
from backend.tools.tavily import TavilyInternetSearch

__all__ = [
"Calculator",
"PythonInterpreter",
"LangChainVectorDBRetriever",
"LangChainWikiRetriever",
"TavilyInternetSearch",
]
Loading
Loading