diff --git a/docs/custom_tool_guides/tool_guide.md b/docs/custom_tool_guides/tool_guide.md index 5a6b968162..d05663a3ff 100644 --- a/docs/custom_tool_guides/tool_guide.md +++ b/docs/custom_tool_guides/tool_guide.md @@ -1,4 +1,4 @@ -# Custom tools and retrieval sources +# Custom Tools Follow these instructions to create your own custom tools. Custom tools will need to be built in the `community` folder. Make sure you've enabled the `INSTALL_COMMUNITY_DEPS` build arg in the `docker-compose.yml` file by setting it to `true`. @@ -27,115 +27,49 @@ There are three types of tools: ## Step 3: Implement the Tool -Add your tool implementation [here](https://github.com/cohere-ai/toolkit/tree/main/src/community/tools) (please note that this link might change). The specific subfolder used will depend on the type of tool you're implementing. +Add your tool implementation [here](https://github.com/cohere-ai/toolkit/tree/main/src/community/tools) (please note that this link is subject to change). -If you need to install a new module to run your tool, execute the following command and run `make dev` again. +If you need to install a new library to run your tool, execute the following command and run `make dev` again. ```bash poetry add --group community ``` +### Implementing a Tool -If you're working on a File or Data Loader, follow the steps outlined in [Implementing a Retriever](#implementing-a-retriever). +Add the implementation inside a tool class that inherits from `BaseTool`. This class will need to implement the `call()` method, which should return a list of dictionary results. -If you're implementing a Function Tool, refer to the steps in [Implementing a Function Tool](#implementing-a-function-tool). +Note: To enable citations, each result in the list should contain a "text" field. -### Implementing a Retriever - -Add the implementation inside a tool class that inherits `BaseRetrieval` and needs to implement the function `def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:` - -You can define custom configurations for your tool within the `__init__` function. Set the exact values for these variables during [Step 4](#step-4-making-your-tool-available). - -You can also develop a tool that requires a token or authentication. To do this, simply set your variable in the .env file. - -For example, for Wikipedia we have a custom configuration: - -```python -class LangChainWikiRetriever(BaseRetrieval): - """ - This class retrieves documents from Wikipedia using the langchain package. - This requires wikipedia package to be installed. - """ - - def __init__(self, chunk_size: int = 300, chunk_overlap: int = 0): - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - - def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: - wiki_retriever = WikipediaRetriever() - docs = wiki_retriever.get_relevant_documents(query) - text_splitter = CharacterTextSplitter( - chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap - ) - documents = text_splitter.split_documents(docs) - return [ - { - "text": doc.page_content, - "title": doc.metadata.get("title", None), - "url": doc.metadata.get("source", None), - } - for doc in documents - ] -``` - -And for internet search, we need an API key +For example, let's look at the community-implemented `ArxivRetriever`: ```python -class TavilyInternetSearch(BaseRetrieval): - def __init__(self): - if "TAVILY_API_KEY" not in os.environ: - raise ValueError("Please set the TAVILY_API_KEY environment variable.") +from typing import Any, Dict, List - self.api_key = os.environ["TAVILY_API_KEY"] - self.client = TavilyClient(api_key=self.api_key) +from langchain_community.utilities import ArxivAPIWrapper - def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: - content = self.client.search(query=query, search_depth="advanced") - - if "results" not in content: - return [] - - return [ - { - "url": result["url"], - "text": result["content"], - } - for result in content["results"] -``` +from community.tools import BaseTool -Note that all Retrievers should return a list of Dicts, and each Dict should contain at least a `text` key. -### Implementing a Function Tool - -Add the implementation inside a tool class that inherits `BaseFunctionTool` and needs to implement the function `def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]:` - -For example, for calculator - -```python -from typing import Any -from py_expression_eval import Parser -from typing import List, Dict - -from backend.tools.function_tools.base import BaseFunctionTool +class ArxivRetriever(BaseTool): + def __init__(self): + self.client = ArxivAPIWrapper() -class CalculatorFunctionTool(BaseFunctionTool): - """ - Function Tool that evaluates mathematical expressions. - """ + @classmethod + # If your tool requires any environment variables such as API keys, + # you will need to assert that they're not None here + def is_available(cls) -> bool: + return True + # Your tool needs to implement this call() method def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]: - math_parser = Parser() - to_evaluate = parameters.get("code", "").replace("pi", "PI").replace("e", "E") - result = [] - try: - result = {"result": math_parser.parse(to_evaluate).evaluate({})} - except Exception: - result = {"result": "Parsing error - syntax not allowed."} - return result + result = self.client.run(parameters) + + return [{"text": result}] # <- Return list of results, in this case there is only one ``` ## Step 4: Making Your Tool Available -To make your tool available, add its definition to the tools config [here](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py). +To make your tool available, add its definition to the community tools [config.py](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py). Start by adding the tool name to the `ToolName` enum found at the top of the file. @@ -151,27 +85,6 @@ Next, include the tool configurations in the `AVAILABLE_TOOLS` list. The definit - Description: A brief description of the tool. - Env_vars: A list of secrets required by the tool. -Function tool with custom parameter definitions: - -```python -ToolName.Python_Interpreter: ManagedTool( - name=ToolName.Python_Interpreter, - implementation=PythonInterpreterFunctionTool, - parameter_definitions={ - "code": { - "description": "Python code to execute using an interpreter", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=PythonInterpreterFunctionTool.is_available(), - error_message="PythonInterpreterFunctionTool not available, please make sure to set the PYTHON_INTERPRETER_URL environment variable.", - category=Category.Function, - description="Runs python code in a sandbox.", -) -``` - ## Step 5: Test Your Tool! Now, when you run the toolkit, all the visible tools, including the one you just added, should be available! @@ -209,4 +122,4 @@ curl --location 'http://localhost:8000/chat-stream' \ ## Step 6 (extra): Add Unit tests -If you would like to go above and beyond, it would be helpful to add some unit tests to ensure that your tool is working as expected. Create a file [here](https://github.com/cohere-ai/cohere-toolkit/tree/main/src/community/tests/tools) and add a few cases. +If you would like to go above and beyond, it would be helpful to add some unit tests to ensure that your tool is working as expected. Create a file [here](https://github.com/cohere-ai/cohere-toolkit/tree/main/src/community/tests/tools) and add a few test cases. diff --git a/src/backend/tools/retrieval/collate.py b/src/backend/chat/collate.py similarity index 100% rename from src/backend/tools/retrieval/collate.py rename to src/backend/chat/collate.py diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index 1b97f926de..1a2ab0fddf 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -4,13 +4,13 @@ from fastapi import HTTPException from backend.chat.base import BaseChat +from backend.chat.collate import combine_documents from backend.chat.custom.utils import get_deployment from backend.config.tools import AVAILABLE_TOOLS, ToolName from backend.model_deployments.base import BaseDeployment from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.tool import Category, Tool from backend.services.logger import get_logger -from backend.tools.retrieval.collate import combine_documents class CustomChat(BaseChat): @@ -84,10 +84,12 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: all_documents = {} # TODO: call in parallel and error handling + # TODO: merge with regular function tools after multihop implemented for retriever in retrievers: for query in queries: + parameters = {"query": query} all_documents.setdefault(query, []).extend( - retriever.retrieve_documents(query) + retriever.call(parameters) ) # Collate Documents diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 9e17606d51..4541d77e39 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -4,13 +4,11 @@ from enum import StrEnum from backend.schemas.tool import Category, ManagedTool -from backend.tools.function_tools import ( - CalculatorFunctionTool, - PythonInterpreterFunctionTool, -) -from backend.tools.retrieval import ( +from backend.tools import ( + Calculator, LangChainVectorDBRetriever, LangChainWikiRetriever, + PythonInterpreter, TavilyInternetSearch, ) @@ -38,6 +36,13 @@ class ToolName(StrEnum): ToolName.Wiki_Retriever_LangChain: ManagedTool( name=ToolName.Wiki_Retriever_LangChain, implementation=LangChainWikiRetriever, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, kwargs={"chunk_size": 300, "chunk_overlap": 0}, is_visible=True, is_available=LangChainWikiRetriever.is_available(), @@ -48,6 +53,13 @@ class ToolName(StrEnum): ToolName.File_Upload_Langchain: ManagedTool( name=ToolName.File_Upload_Langchain, implementation=LangChainVectorDBRetriever, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, is_visible=True, is_available=LangChainVectorDBRetriever.is_available(), error_message="LangChainVectorDBRetriever not available, please make sure to set the COHERE_API_KEY environment variable.", @@ -56,7 +68,7 @@ class ToolName(StrEnum): ), ToolName.Python_Interpreter: ManagedTool( name=ToolName.Python_Interpreter, - implementation=PythonInterpreterFunctionTool, + implementation=PythonInterpreter, parameter_definitions={ "code": { "description": "Python code to execute using an interpreter", @@ -65,14 +77,14 @@ class ToolName(StrEnum): } }, is_visible=True, - is_available=PythonInterpreterFunctionTool.is_available(), + is_available=PythonInterpreter.is_available(), error_message="PythonInterpreterFunctionTool not available, please make sure to set the PYTHON_INTERPRETER_URL environment variable.", category=Category.Function, description="Runs python code in a sandbox.", ), ToolName.Calculator: ManagedTool( name=ToolName.Calculator, - implementation=CalculatorFunctionTool, + implementation=Calculator, parameter_definitions={ "code": { "description": "Arithmetic expression to evaluate", @@ -81,14 +93,21 @@ class ToolName(StrEnum): } }, is_visible=True, - is_available=CalculatorFunctionTool.is_available(), - error_message="CalculatorFunctionTool not available.", + is_available=Calculator.is_available(), + error_message="Calculator tool not available.", category=Category.Function, description="Evaluate arithmetic expressions.", ), ToolName.Tavily_Internet_Search: ManagedTool( name=ToolName.Tavily_Internet_Search, implementation=TavilyInternetSearch, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, is_visible=True, is_available=TavilyInternetSearch.is_available(), error_message="TavilyInternetSearch not available, please make sure to set the TAVILY_API_KEY environment variable.", diff --git a/src/backend/tests/tools/function_tools/test_calculator.py b/src/backend/tests/tools/test_calculator.py similarity index 65% rename from src/backend/tests/tools/function_tools/test_calculator.py rename to src/backend/tests/tools/test_calculator.py index 69fb6432ab..8fdc706e13 100644 --- a/src/backend/tests/tools/function_tools/test_calculator.py +++ b/src/backend/tests/tools/test_calculator.py @@ -1,13 +1,13 @@ -from backend.tools.function_tools import CalculatorFunctionTool +from backend.tools import Calculator def test_calculator() -> None: - calculator = CalculatorFunctionTool() + calculator = Calculator() result = calculator.call({"code": "2+2"}) assert result == {"result": 4} def test_calculator_invalid_syntax() -> None: - calculator = CalculatorFunctionTool() + calculator = Calculator() result = calculator.call({"code": "2+"}) assert result == {"result": "Parsing error - syntax not allowed."} diff --git a/src/backend/tests/tools/retrieval/test_collate.py b/src/backend/tests/tools/test_collate.py similarity index 96% rename from src/backend/tests/tools/retrieval/test_collate.py rename to src/backend/tests/tools/test_collate.py index 479a6f082a..c79d0a6bfe 100644 --- a/src/backend/tests/tools/retrieval/test_collate.py +++ b/src/backend/tests/tools/test_collate.py @@ -2,8 +2,8 @@ import pytest +from backend.chat import collate from backend.model_deployments import CohereDeployment -from backend.tools.retrieval import collate is_cohere_env_set = ( os.environ.get("COHERE_API_KEY") is not None diff --git a/src/backend/tests/tools/retrieval/test_lang_chain.py b/src/backend/tests/tools/test_lang_chain.py similarity index 91% rename from src/backend/tests/tools/retrieval/test_lang_chain.py rename to src/backend/tests/tools/test_lang_chain.py index 1f3dc93952..80a37992c5 100644 --- a/src/backend/tests/tools/retrieval/test_lang_chain.py +++ b/src/backend/tests/tools/test_lang_chain.py @@ -4,10 +4,7 @@ import pytest from langchain_core.documents.base import Document -from backend.tools.retrieval.lang_chain import ( - LangChainVectorDBRetriever, - LangChainWikiRetriever, -) +from backend.tools import LangChainVectorDBRetriever, LangChainWikiRetriever is_cohere_env_set = ( os.environ.get("COHERE_API_KEY") is not None @@ -53,10 +50,10 @@ def test_wiki_retriever() -> None: wiki_retriever_mock.get_relevant_documents.return_value = mock_docs with patch( - "backend.tools.retrieval.lang_chain.WikipediaRetriever", + "backend.tools.lang_chain.WikipediaRetriever", return_value=wiki_retriever_mock, ): - result = retriever.retrieve_documents(query) + result = retriever.call({"query": query}) assert result == expected_docs @@ -71,10 +68,10 @@ def test_wiki_retriever_no_docs() -> None: wiki_retriever_mock.get_relevant_documents.return_value = mock_docs with patch( - "backend.tools.retrieval.lang_chain.WikipediaRetriever", + "backend.tools.lang_chain.WikipediaRetriever", return_value=wiki_retriever_mock, ): - result = retriever.retrieve_documents(query) + result = retriever.call({"query": query}) assert result == [] @@ -134,7 +131,7 @@ def test_vector_db_retriever() -> None: mock_db = MagicMock() mock_from_documents.return_value = mock_db mock_db.as_retriever().get_relevant_documents.return_value = mock_docs - result = retriever.retrieve_documents(query) + result = retriever.call({"query": query}) assert result == expected_docs @@ -155,6 +152,6 @@ def test_vector_db_retriever_no_docs() -> None: mock_db = MagicMock() mock_from_documents.return_value = mock_db mock_db.as_retriever().get_relevant_documents.return_value = mock_docs - result = retriever.retrieve_documents(query) + result = retriever.call({"query": query}) assert result == [] diff --git a/src/backend/tools/__init__.py b/src/backend/tools/__init__.py index e69de29bb2..d3b83617e1 100644 --- a/src/backend/tools/__init__.py +++ b/src/backend/tools/__init__.py @@ -0,0 +1,12 @@ +from backend.tools.calculator import Calculator +from backend.tools.lang_chain import LangChainVectorDBRetriever, LangChainWikiRetriever +from backend.tools.python_interpreter import PythonInterpreter +from backend.tools.tavily import TavilyInternetSearch + +__all__ = [ + "Calculator", + "PythonInterpreter", + "LangChainVectorDBRetriever", + "LangChainWikiRetriever", + "TavilyInternetSearch", +] diff --git a/src/backend/tools/function_tools/base.py b/src/backend/tools/base.py similarity index 52% rename from src/backend/tools/function_tools/base.py rename to src/backend/tools/base.py index 0d475b23dc..b8aa6e33e8 100644 --- a/src/backend/tools/function_tools/base.py +++ b/src/backend/tools/base.py @@ -2,12 +2,14 @@ from typing import Any, Dict, List -class BaseFunctionTool: - """Base for all retrieval options.""" +class BaseTool: + """ + Abstract base class for all Tools. + """ @classmethod @abstractmethod def is_available(cls) -> bool: ... @abstractmethod - def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]: ... + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: ... diff --git a/src/backend/tools/function_tools/calculator.py b/src/backend/tools/calculator.py similarity index 75% rename from src/backend/tools/function_tools/calculator.py rename to src/backend/tools/calculator.py index 7e490566fa..b006091be0 100644 --- a/src/backend/tools/function_tools/calculator.py +++ b/src/backend/tools/calculator.py @@ -2,10 +2,10 @@ from py_expression_eval import Parser -from backend.tools.function_tools.base import BaseFunctionTool +from backend.tools.base import BaseTool -class CalculatorFunctionTool(BaseFunctionTool): +class Calculator(BaseTool): """ Function Tool that evaluates mathematical expressions. """ @@ -14,9 +14,10 @@ class CalculatorFunctionTool(BaseFunctionTool): def is_available(cls) -> bool: return True - def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]: + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: math_parser = Parser() to_evaluate = parameters.get("code", "").replace("pi", "PI").replace("e", "E") + result = [] try: result = {"result": math_parser.parse(to_evaluate).evaluate({})} diff --git a/src/backend/tools/function_tools/__init__.py b/src/backend/tools/function_tools/__init__.py deleted file mode 100644 index 7ea7a88ad6..0000000000 --- a/src/backend/tools/function_tools/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from backend.tools.function_tools.calculator import CalculatorFunctionTool -from backend.tools.function_tools.python_interpreter import ( - PythonInterpreterFunctionTool, -) - -__all__ = [ - "CalculatorFunctionTool", - "PythonInterpreterFunctionTool", -] diff --git a/src/backend/tools/retrieval/lang_chain.py b/src/backend/tools/lang_chain.py similarity index 86% rename from src/backend/tools/retrieval/lang_chain.py rename to src/backend/tools/lang_chain.py index 21fbde882e..ff0c229f34 100644 --- a/src/backend/tools/retrieval/lang_chain.py +++ b/src/backend/tools/lang_chain.py @@ -7,7 +7,7 @@ from langchain_community.retrievers import WikipediaRetriever from langchain_community.vectorstores import Chroma -from backend.tools.retrieval.base import BaseRetrieval +from backend.tools.base import BaseTool """ Plug in your lang chain retrieval implementation here. @@ -17,7 +17,7 @@ """ -class LangChainWikiRetriever(BaseRetrieval): +class LangChainWikiRetriever(BaseTool): """ This class retrieves documents from Wikipedia using the langchain package. This requires wikipedia package to be installed. @@ -31,13 +31,15 @@ def __init__(self, chunk_size: int = 300, chunk_overlap: int = 0): def is_available(cls) -> bool: return True - def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: wiki_retriever = WikipediaRetriever() + query = parameters.get("query", "") docs = wiki_retriever.get_relevant_documents(query) text_splitter = CharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap ) documents = text_splitter.split_documents(docs) + return [ { "text": doc.page_content, @@ -48,7 +50,7 @@ def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: ] -class LangChainVectorDBRetriever(BaseRetrieval): +class LangChainVectorDBRetriever(BaseTool): """ This class retrieves documents from a vector database using the langchain package. """ @@ -62,13 +64,17 @@ def __init__(self, filepath: str): def is_available(cls) -> bool: return cls.cohere_api_key is not None - def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: cohere_embeddings = CohereEmbeddings(cohere_api_key=self.cohere_api_key) + # Load text files and split into chunks loader = PyPDFLoader(self.filepath) text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0) pages = loader.load_and_split(text_splitter) + # Create a vector store from the documents db = Chroma.from_documents(documents=pages, embedding=cohere_embeddings) + query = parameters.get("query", "") input_docs = db.as_retriever().get_relevant_documents(query) + return [dict({"text": doc.page_content}) for doc in input_docs] diff --git a/src/backend/tools/function_tools/python_interpreter.py b/src/backend/tools/python_interpreter.py similarity index 96% rename from src/backend/tools/function_tools/python_interpreter.py rename to src/backend/tools/python_interpreter.py index d00b43705a..3fc2e0514f 100644 --- a/src/backend/tools/function_tools/python_interpreter.py +++ b/src/backend/tools/python_interpreter.py @@ -6,14 +6,14 @@ from langchain_core.tools import Tool as LangchainTool from pydantic.v1 import BaseModel, Field -from backend.tools.function_tools.base import BaseFunctionTool +from backend.tools.base import BaseTool class LangchainPythonInterpreterToolInput(BaseModel): code: str = Field(description="Python code to execute.") -class PythonInterpreterFunctionTool(BaseFunctionTool): +class PythonInterpreter(BaseTool): """ This class calls arbitrary code against a Python interpreter. It requires a URL at which the interpreter lives diff --git a/src/backend/tools/retrieval/__init__.py b/src/backend/tools/retrieval/__init__.py deleted file mode 100644 index 5643374089..0000000000 --- a/src/backend/tools/retrieval/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from backend.tools.retrieval.lang_chain import ( - LangChainVectorDBRetriever, - LangChainWikiRetriever, -) -from backend.tools.retrieval.tavily import TavilyInternetSearch - -__all__ = [ - "LangChainVectorDBRetriever", - "LangChainWikiRetriever", - "TavilyInternetSearch", -] diff --git a/src/backend/tools/retrieval/base.py b/src/backend/tools/retrieval/base.py deleted file mode 100644 index fc4cd7adaa..0000000000 --- a/src/backend/tools/retrieval/base.py +++ /dev/null @@ -1,20 +0,0 @@ -from abc import abstractmethod -from typing import Any, Dict, List - - -class BaseRetrieval: - """Base for all retrieval options.""" - - @classmethod - @abstractmethod - def is_available(cls) -> bool: ... - - @abstractmethod - def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: ... - - def validate_documents(self, documents: List[Dict[str, Any]]) -> bool: - """Validate the documents retrieved. A valid document should have a text field.""" - for document in documents: - if "text" not in document: - return False - return True diff --git a/src/backend/tools/retrieval/tavily.py b/src/backend/tools/tavily.py similarity index 86% rename from src/backend/tools/retrieval/tavily.py rename to src/backend/tools/tavily.py index a4b33dbe74..2914ecc5a0 100644 --- a/src/backend/tools/retrieval/tavily.py +++ b/src/backend/tools/tavily.py @@ -4,10 +4,10 @@ from langchain_community.tools.tavily_search import TavilySearchResults from tavily import TavilyClient -from backend.tools.retrieval.base import BaseRetrieval +from backend.tools.base import BaseTool -class TavilyInternetSearch(BaseRetrieval): +class TavilyInternetSearch(BaseTool): tavily_api_key = os.environ.get("TAVILY_API_KEY") def __init__(self): @@ -17,7 +17,8 @@ def __init__(self): def is_available(cls) -> bool: return cls.tavily_api_key is not None - def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + query = parameters.get("query", "") content = self.client.search(query=query, search_depth="advanced") if "results" not in content: diff --git a/src/community/config/tools.py b/src/community/config/tools.py index 70adb964a5..15f05e5920 100644 --- a/src/community/config/tools.py +++ b/src/community/config/tools.py @@ -1,12 +1,14 @@ from enum import StrEnum -from community.tools import Category, ManagedTool -from community.tools.function_tools import ClinicalTrialsTool, WolframAlphaFunctionTool -from community.tools.retrieval import ( +from community.tools import ( ArxivRetriever, + Category, + ClinicalTrials, ConnectorRetriever, LlamaIndexUploadPDFRetriever, + ManagedTool, PubMedRetriever, + WolframAlpha, ) @@ -23,6 +25,13 @@ class CommunityToolName(StrEnum): CommunityToolName.Arxiv: ManagedTool( name=CommunityToolName.Arxiv, implementation=ArxivRetriever, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, is_visible=True, is_available=ArxivRetriever.is_available(), error_message="ArxivRetriever is not available.", @@ -41,6 +50,13 @@ class CommunityToolName(StrEnum): CommunityToolName.Pub_Med: ManagedTool( name=CommunityToolName.Pub_Med, implementation=PubMedRetriever, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, is_visible=True, is_available=PubMedRetriever.is_available(), error_message="PubMedRetriever is not available.", @@ -58,18 +74,18 @@ class CommunityToolName(StrEnum): ), CommunityToolName.Wolfram_Alpha: ManagedTool( name=CommunityToolName.Wolfram_Alpha, - implementation=WolframAlphaFunctionTool, + implementation=WolframAlpha, is_visible=False, - is_available=WolframAlphaFunctionTool.is_available(), + is_available=WolframAlpha.is_available(), error_message="WolframAlphaFunctionTool is not available, please set the WOLFRAM_APP_ID environment variable.", category=Category.Function, description="Evaluate arithmetic expressions.", ), CommunityToolName.ClinicalTrials: ManagedTool( name=CommunityToolName.ClinicalTrials, - implementation=ClinicalTrialsTool, + implementation=ClinicalTrials, is_visible=True, - is_available=ClinicalTrialsTool.is_available(), + is_available=ClinicalTrials.is_available(), error_message="ClinicalTrialsTool is not available.", category=Category.Function, description="Retrieves clinical studies from ClinicalTrials.gov.", diff --git a/src/community/tests/tools/retrieval/test_pub_med.py b/src/community/tests/tools/retrieval/test_pub_med.py deleted file mode 100644 index b88cc985a2..0000000000 --- a/src/community/tests/tools/retrieval/test_pub_med.py +++ /dev/null @@ -1,8 +0,0 @@ -from community.tools.retrieval.pub_med import PubMedRetriever - - -def test_pub_med_retriever(): - retriever = PubMedRetriever() - result = retriever.retrieve_documents("What causes lung cancer?") - assert len(result) > 0 - assert "text" in result[0] diff --git a/src/community/tests/tools/retrieval/test_arxiv.py b/src/community/tests/tools/test_arxiv.py similarity index 60% rename from src/community/tests/tools/retrieval/test_arxiv.py rename to src/community/tests/tools/test_arxiv.py index 3913ec2a29..0c44a2d1e6 100644 --- a/src/community/tests/tools/retrieval/test_arxiv.py +++ b/src/community/tests/tools/test_arxiv.py @@ -1,9 +1,9 @@ -from community.tools.retrieval.arxiv import ArxivRetriever +from community.tools import ArxivRetriever def test_arxiv_retriever(): retriever = ArxivRetriever() - result = retriever.retrieve_documents("quantum") + result = retriever.call({"query": "quantum"}) assert len(result) > 0 assert "text" in result[0] assert "quantum" in result[0]["text"].lower() diff --git a/src/community/tests/tools/functions/test_clinicaltrials.py b/src/community/tests/tools/test_clinicaltrials.py similarity index 69% rename from src/community/tests/tools/functions/test_clinicaltrials.py rename to src/community/tests/tools/test_clinicaltrials.py index d6b6cddd4f..802878d8c3 100644 --- a/src/community/tests/tools/functions/test_clinicaltrials.py +++ b/src/community/tests/tools/test_clinicaltrials.py @@ -1,8 +1,8 @@ -from community.tools.function_tools.clinicaltrials import ClinicalTrialsTool +from community.tools import ClinicalTrials def test_clinicaltrials_tool(): - retriever = ClinicalTrialsTool() + retriever = ClinicalTrials() result = retriever.call( parameters={ "condition": "lung cancer", diff --git a/src/community/tests/tools/retrieval/test_llama_index.py b/src/community/tests/tools/test_llama_index.py similarity index 99% rename from src/community/tests/tools/retrieval/test_llama_index.py rename to src/community/tests/tools/test_llama_index.py index cef017b2f0..c328b2ffc2 100644 --- a/src/community/tests/tools/retrieval/test_llama_index.py +++ b/src/community/tests/tools/test_llama_index.py @@ -1,4 +1,4 @@ -from community.tools.retrieval.llama_index import LlamaIndexUploadPDFRetriever +from community.tools import LlamaIndexUploadPDFRetriever def test_pdf_retriever() -> None: @@ -34,6 +34,6 @@ def test_pdf_retriever() -> None: "text": '49. Robbins, Gary (5 September 2019). "UCSD discovers surge in plastics pollution off Santa\nBarbara" (https://www.latimes.com/california/story/2019-09-04/uc-san-diego-discovers-explo\nsion-in-plastics-products-in-seafloor-off-santa-barbara). Los Angeles Times. Retrieved\n5 September 2019.\n50. Street, Francesca (13 May 2019). "Deepest ocean dive recorded: How Victor Vescovo did it"\n(https://www.cnn.com/travel/article/victor-vescovo-deepest-dive-pacific/index.html). CNN\nTravel. CNN. Retrieved 13 May 2019.\n51. Levy, Adam (15 May 2019). ""Bomb Carbon" Has Been Found in Deep-Ocean Creatures" (h\nttps://www.scientificamerican.com/article/bomb-carbon-has-been-found-in-deep-ocean-creat\nures/). Scientific American.\n52. Hafemeister, David W. (2007). Physics of societal issues: calculations on national security,\nenvironment, and energy (https://books.google.com/books?id=LT4MSqv9QUIC&pg=PA187).\nBerlin: Springer. p. 187. ISBN 978-0-387-95560-5.\n53. Kingsley, Marvin G.; Rogers, Kenneth H. (2007). Calculated risks: highly radioactive waste\nand homeland security (https://books.google.com/books?id=bOP4-BpYXrEC&pg=PA75).\nAldershot, Hants, England: Ashgate. pp. 75–76. ISBN 978-0-7546-7133-6.\n54. "Dumping and Loss overview" (https://web.archive.org/web/20110605190619/http://www.la\nw.berkeley.edu/centers/ilr/ona/pages/dumping2.htm). Oceans in the Nuclear Age. Archived\nfrom the original (http://www.law.berkeley.edu/centers/ilr/ona/pages/dumping2.htm) on 5\nJune 2011. Retrieved 18 September 2010.\nMariana Trench Dive (25 March 2012) (https://web.archive.org/web/20140625050833/http://d\neepseachallenge.com/) – Deepsea Challenger\nMariana Trench Dive (23 January 1960) (http://www.britishpathe.com/video/they-dived-7-mil\nes/query/mariana+trench) – Trieste (Newsreel)\nMariana Trench Dive (50th Anniv) (http://www.vvdailypress.com/articles/walsh-18116-regret-\nmiles.html) Archived (https://web.archive.org/web/20130603064615/http://www.vvdailypress.\ncom/articles/walsh-18116-regret-miles.html) 3 June 2013 at the Wayback Machine – Trieste\n– Capt Don Walsh\nMariana Trench – Maps (Google) (https://maps.google.com/maps?q=11.317,+142.25(Marian\na+Trench)&z=6)\nNOAA – Ocean Explorer (http://oceanexplorer.noaa.gov) (Ofc Ocean Exploration & Rsch)\nNOAA – Ocean Explorer – Multimedia (http://oceanexplorer.noaa.gov/explorations/06fire/bac\nkground/marianaarc/marianaarc.html) – Mariana Arc (podcast (http://oceanexplorer.noaa.go\nv/explorations/podcast/oceanexplorer_podcast.xml))\nNOAA – Ocean Explorer – Video Playlist (https://www.youtube.com/view_play_list?p=94B79\n5FD631011E0) – Ring of Fire (2004–2006)\nRetrieved from "https://en.wikipedia.org/w/index.php?title=Mariana_Trench&oldid=1187694887"External links\n' }, ] - result = retriever.retrieve_documents(query) + result = retriever.call({"query": query}) assert expected_docs == result diff --git a/src/community/tests/tools/test_pub_med.py b/src/community/tests/tools/test_pub_med.py new file mode 100644 index 0000000000..b4aea9210b --- /dev/null +++ b/src/community/tests/tools/test_pub_med.py @@ -0,0 +1,8 @@ +from community.tools import PubMedRetriever + + +def test_pub_med_retriever(): + retriever = PubMedRetriever() + result = retriever.call({"query": "What causes lung cancer?"}) + assert len(result) > 0 + assert "text" in result[0] diff --git a/src/community/tools/__init__.py b/src/community/tools/__init__.py index f2a3eac724..d120aefeca 100644 --- a/src/community/tools/__init__.py +++ b/src/community/tools/__init__.py @@ -1,3 +1,17 @@ from backend.schemas.tool import Category, ManagedTool -from backend.tools.function_tools.base import BaseFunctionTool -from backend.tools.retrieval.base import BaseRetrieval +from backend.tools.base import BaseTool +from community.tools.arxiv import ArxivRetriever +from community.tools.clinicaltrials import ClinicalTrials +from community.tools.connector import ConnectorRetriever +from community.tools.llama_index import LlamaIndexUploadPDFRetriever +from community.tools.pub_med import PubMedRetriever +from community.tools.wolfram import WolframAlpha + +__all__ = [ + "WolframAlpha", + "ClinicalTrials", + "ArxivRetriever", + "ConnectorRetriever", + "LlamaIndexUploadPDFRetriever", + "PubMedRetriever", +] diff --git a/src/community/tools/retrieval/arxiv.py b/src/community/tools/arxiv.py similarity index 61% rename from src/community/tools/retrieval/arxiv.py rename to src/community/tools/arxiv.py index e9699b4c97..c5ae0cb3bb 100644 --- a/src/community/tools/retrieval/arxiv.py +++ b/src/community/tools/arxiv.py @@ -2,10 +2,10 @@ from langchain_community.utilities import ArxivAPIWrapper -from community.tools import BaseRetrieval +from community.tools import BaseTool -class ArxivRetriever(BaseRetrieval): +class ArxivRetriever(BaseTool): def __init__(self): self.client = ArxivAPIWrapper() @@ -13,6 +13,7 @@ def __init__(self): def is_available(cls) -> bool: return True - def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + query = parameters.get("query", "") result = self.client.run(query) return [{"text": result}] diff --git a/src/community/tools/function_tools/clinicaltrials.py b/src/community/tools/clinicaltrials.py similarity index 95% rename from src/community/tools/function_tools/clinicaltrials.py rename to src/community/tools/clinicaltrials.py index 20270a0b9c..e217d17340 100644 --- a/src/community/tools/function_tools/clinicaltrials.py +++ b/src/community/tools/clinicaltrials.py @@ -1,14 +1,16 @@ -# https://clinicaltrials.gov/data-api/api - from typing import Any, Dict, List import requests -from community.tools import BaseFunctionTool +from community.tools import BaseTool + +class ClinicalTrials(BaseTool): + """ + Retrieves clinical studies from ClinicalTrials.gov. -class ClinicalTrialsTool(BaseFunctionTool): - """Retrieves clinical studies from ClinicalTrials.gov""" + See: https://clinicaltrials.gov/data-api/api + """ def __init__(self, url="https://clinicaltrials.gov/api/v2/studies"): self._url = url diff --git a/src/community/tools/retrieval/connector.py b/src/community/tools/connector.py similarity index 70% rename from src/community/tools/retrieval/connector.py rename to src/community/tools/connector.py index eecb37b041..b1ff81b89f 100644 --- a/src/community/tools/retrieval/connector.py +++ b/src/community/tools/connector.py @@ -2,10 +2,10 @@ import requests -from community.tools import BaseRetrieval +from community.tools import BaseTool """ -Plug in your connector configuration here. For example: +Plug in your Connector configuration here. For example: Url: http://example_connector.com/search Auth: Bearer token for the connector @@ -14,7 +14,8 @@ """ -class ConnectorRetriever(BaseRetrieval): +class ConnectorRetriever(BaseTool): + def __init__(self, url: str, auth: str): self.url = url self.auth = auth @@ -23,8 +24,8 @@ def __init__(self, url: str, auth: str): def is_available(cls) -> bool: return True - def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: - body = {"query": query} + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + body = {"query": parameters} headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.auth}", diff --git a/src/community/tools/function_tools/__init__.py b/src/community/tools/function_tools/__init__.py deleted file mode 100644 index 8d812bd26a..0000000000 --- a/src/community/tools/function_tools/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from community.tools.function_tools.clinicaltrials import ClinicalTrialsTool -from community.tools.function_tools.wolfram import WolframAlphaFunctionTool - -__all__ = [ - "WolframAlphaFunctionTool", - "ClinicalTrialsTool", -] diff --git a/src/community/tools/retrieval/llama_index.py b/src/community/tools/llama_index.py similarity index 80% rename from src/community/tools/retrieval/llama_index.py rename to src/community/tools/llama_index.py index 7d5496554e..0b6163092c 100644 --- a/src/community/tools/retrieval/llama_index.py +++ b/src/community/tools/llama_index.py @@ -2,7 +2,7 @@ from llama_index.core import SimpleDirectoryReader -from community.tools import BaseRetrieval +from community.tools import BaseTool """ Plug in your llama index retrieval implementation here. @@ -14,7 +14,7 @@ """ -class LlamaIndexUploadPDFRetriever(BaseRetrieval): +class LlamaIndexUploadPDFRetriever(BaseTool): """ This class retrieves documents from a PDF using the llama_index package. This requires llama_index package to be installed. @@ -27,6 +27,6 @@ def __init__(self, filepath: str): def is_available(cls) -> bool: return True - def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: docs = SimpleDirectoryReader(input_files=[self.filepath]).load_data() return [dict({"text": doc.text}) for doc in docs] diff --git a/src/community/tools/retrieval/pub_med.py b/src/community/tools/pub_med.py similarity index 62% rename from src/community/tools/retrieval/pub_med.py rename to src/community/tools/pub_med.py index 96e6aec3df..6b962ee485 100644 --- a/src/community/tools/retrieval/pub_med.py +++ b/src/community/tools/pub_med.py @@ -2,10 +2,10 @@ from langchain_community.tools.pubmed.tool import PubmedQueryRun -from community.tools import BaseRetrieval +from community.tools import BaseTool -class PubMedRetriever(BaseRetrieval): +class PubMedRetriever(BaseTool): def __init__(self): self.client = PubmedQueryRun() @@ -13,6 +13,7 @@ def __init__(self): def is_available(cls) -> bool: return True - def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + query = parameters.get("query", "") result = self.client.invoke(query) return [{"text": result}] diff --git a/src/community/tools/retrieval/__init__.py b/src/community/tools/retrieval/__init__.py deleted file mode 100644 index 5efec48fec..0000000000 --- a/src/community/tools/retrieval/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from community.tools.retrieval.arxiv import ArxivRetriever -from community.tools.retrieval.connector import ConnectorRetriever -from community.tools.retrieval.llama_index import LlamaIndexUploadPDFRetriever -from community.tools.retrieval.pub_med import PubMedRetriever - -__all__ = [ - "ArxivRetriever", - "ConnectorRetriever", - "LlamaIndexUploadPDFRetriever", - "PubMedRetriever", -] diff --git a/src/community/tools/function_tools/wolfram.py b/src/community/tools/wolfram.py similarity index 68% rename from src/community/tools/function_tools/wolfram.py rename to src/community/tools/wolfram.py index 7468884215..61e06ad8df 100644 --- a/src/community/tools/function_tools/wolfram.py +++ b/src/community/tools/wolfram.py @@ -1,14 +1,18 @@ -# https://python.langchain.com/docs/integrations/tools/wolfram_alpha/ - import os from typing import Any, Dict, List from langchain_community.utilities.wolfram_alpha import WolframAlphaAPIWrapper -from community.tools import BaseFunctionTool +from community.tools import BaseTool + + +class WolframAlpha(BaseTool): + """ + Wolfram Alpha tool. + See: https://python.langchain.com/docs/integrations/tools/wolfram_alpha/ + """ -class WolframAlphaFunctionTool(BaseFunctionTool): wolfram_app_id = os.environ.get("WOLFRAM_APP_ID") def __init__(self): @@ -19,7 +23,7 @@ def __init__(self): def is_available(cls) -> bool: return cls.wolfram_app_id is not None - def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]: + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: to_evaluate = parameters.get("expression", "") result = self.tool.run(to_evaluate) return {"result": result, "text": result}