From 396233d8259a68b8a422c93abd19463079b19d38 Mon Sep 17 00:00:00 2001 From: Nicolas Bernier Date: Tue, 30 Apr 2024 11:18:44 -0400 Subject: [PATCH 1/3] Support for function calling with Anthropic on Bedrock --- libs/aws/README.md | 2 +- libs/aws/langchain_aws/chat_models/bedrock.py | 70 ++++++++- libs/aws/langchain_aws/function_calling.py | 139 ++++++++++++++++++ 3 files changed, 208 insertions(+), 3 deletions(-) create mode 100644 libs/aws/langchain_aws/function_calling.py diff --git a/libs/aws/README.md b/libs/aws/README.md index 4e391286..62d0bb59 100644 --- a/libs/aws/README.md +++ b/libs/aws/README.md @@ -54,7 +54,7 @@ retriever = AmazonKendraRetriever( retriever.get_relevant_documents(query="What is the meaning of life?") ``` -`AmazonKnowlegeBasesRetriever` class provides a retriever to connect with Amazon Knowlege Bases. +`AmazonKnowlegeBasesRetriever` class provides a retriever to connect with Amazon Knowledge Bases. ```python from langchain_aws import AmazonKnowledgeBasesRetriever diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 5fa7182e..2b9d52da 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -1,11 +1,25 @@ import re from collections import defaultdict -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) from langchain_core._api.deprecation import deprecated from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, @@ -16,8 +30,11 @@ SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Extra +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_aws.function_calling import convert_to_anthropic_tool, get_system_message from langchain_aws.llms.bedrock import BedrockBase from langchain_aws.utils import ( get_num_tokens_anthropic, @@ -264,6 +281,8 @@ def format_messages( class ChatBedrock(BaseChatModel, BedrockBase): """A chat model that uses the Bedrock API.""" + system_prompt_with_tools: str = "" + @property def _llm_type(self) -> str: """Return type of chat model.""" @@ -307,6 +326,11 @@ def _stream( system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) + if self.system_prompt_with_tools: + if system: + system = self.system_prompt_with_tools + f"\n{system}" + else: + system = self.system_prompt_with_tools else: prompt = ChatPromptAdapter.convert_messages_to_prompt( provider=provider, messages=messages @@ -345,6 +369,11 @@ def _generate( system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) + if self.system_prompt_with_tools: + if system: + system = self.system_prompt_with_tools + f"\n{system}" + else: + system = self.system_prompt_with_tools else: prompt = ChatPromptAdapter.convert_messages_to_prompt( provider=provider, messages=messages @@ -399,6 +428,43 @@ def get_token_ids(self, text: str) -> List[int]: else: return super().get_token_ids(text) + def set_system_prompt_with_tools(self, xml_tools_system_prompt: str) -> None: + """Workaround to bind. Sets the system prompt with tools""" + self.system_prompt_with_tools = xml_tools_system_prompt + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Assumes model has a tool calling API. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + tool_choice: Which tool to require the model to call. + Must be the name of the single provided function or + "auto" to automatically determine which function to call + (if any), or a dict of the form: + {"type": "function", "function": {"name": <>}}. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + provider = self._get_provider() + + if provider == "anthropic": + formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools] + system_formatted_tools = get_system_message(formatted_tools) + self.set_system_prompt_with_tools(system_formatted_tools) + # llm_with_tools = self.bind(tools=system_formatted_tools, **kwargs) + return self + @deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock") class BedrockChat(ChatBedrock): diff --git a/libs/aws/langchain_aws/function_calling.py b/libs/aws/langchain_aws/function_calling.py new file mode 100644 index 00000000..765332e2 --- /dev/null +++ b/libs/aws/langchain_aws/function_calling.py @@ -0,0 +1,139 @@ +"""Methods for creating function specs in the style of Bedrock Functions +for supported model providers""" + +import json +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Type, + Union, +) + +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool +from typing_extensions import TypedDict + +PYTHON_TO_JSON_TYPES = { + "str": "string", + "int": "integer", + "float": "number", + "bool": "boolean", +} + +SYSTEM_PROMPT_FORMAT = """In this environment you have access to a set of tools you can use to answer the user's question. + +You may call them like this: + + +$TOOL_NAME + +<$PARAMETER_NAME>$PARAMETER_VALUE +... + + + + +Here are the tools available: + +{formatted_tools} +""" # noqa: E501 + +TOOL_FORMAT = """ +{tool_name} +{tool_description} + +{formatted_parameters} + +""" + +TOOL_PARAMETER_FORMAT = """ +{parameter_name} +{parameter_type} +{parameter_description} +""" + + +class AnthropicTool(TypedDict): + name: str + description: str + input_schema: Dict[str, Any] + + +def _get_type(parameter: Dict[str, Any]) -> str: + if "type" in parameter: + return parameter["type"] + if "anyOf" in parameter: + return json.dumps({"anyOf": parameter["anyOf"]}) + if "allOf" in parameter: + return json.dumps({"allOf": parameter["allOf"]}) + return json.dumps(parameter) + + +def get_system_message(tools: List[AnthropicTool]) -> str: + tools_data: List[Dict] = [ + { + "tool_name": tool["name"], + "tool_description": tool["description"], + "formatted_parameters": "\n".join( + [ + TOOL_PARAMETER_FORMAT.format( + parameter_name=name, + parameter_type=_get_type(parameter), + parameter_description=parameter.get("description"), + ) + for name, parameter in tool["input_schema"]["properties"].items() + ] + ), + } + for tool in tools + ] + tools_formatted = "\n".join( + [ + TOOL_FORMAT.format( + tool_name=tool["tool_name"], + tool_description=tool["tool_description"], + formatted_parameters=tool["formatted_parameters"], + ) + for tool in tools_data + ] + ) + return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted) + + +class FunctionDescription(TypedDict): + """Representation of a callable function to send to an LLM.""" + + name: str + """The name of the function.""" + description: str + """A description of the function.""" + parameters: dict + """The parameters of the function.""" + + +class ToolDescription(TypedDict): + """Representation of a callable function to the OpenAI API.""" + + type: Literal["function"] + function: FunctionDescription + + +def convert_to_anthropic_tool( + tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], +) -> AnthropicTool: + # already in Anthropic tool format + if isinstance(tool, dict) and all( + k in tool for k in ("name", "description", "input_schema") + ): + return AnthropicTool(tool) # type: ignore + else: + formatted = convert_to_openai_tool(tool)["function"] + return AnthropicTool( + name=formatted["name"], + description=formatted["description"], + input_schema=formatted["parameters"], + ) From f6ae870a7a33109ab5587b8913ef5e3efead720b Mon Sep 17 00:00:00 2001 From: Nicolas Bernier Date: Fri, 3 May 2024 11:28:49 -0400 Subject: [PATCH 2/3] Adding tests for function calling --- libs/aws/langchain_aws/chat_models/bedrock.py | 1 - .../chat_models/test_bedrock.py | 80 ++++++++++++++++++- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 2b9d52da..9dc7e5b7 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -462,7 +462,6 @@ def bind_tools( formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools] system_formatted_tools = get_system_message(formatted_tools) self.set_system_prompt_with_tools(system_formatted_tools) - # llm_with_tools = self.bind(tools=system_formatted_tools, **kwargs) return self diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 31f00caa..ced4f3b4 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -12,6 +12,7 @@ from langchain_aws.chat_models.bedrock import ChatBedrock from tests.callbacks import FakeCallbackHandler +from langchain_core.pydantic_v1 import BaseModel, Field @pytest.fixture @@ -156,5 +157,80 @@ def test_bedrock_invoke(chat: ChatBedrock) -> None: """Test invoke tokens from BedrockChat.""" result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) - assert "usage" in result.additional_kwargs # type: ignore[attr-defined] - assert result.additional_kwargs["usage"]["prompt_tokens"] == 13 # type: ignore[attr-defined] + assert "usage" in result.additional_kwargs + assert result.additional_kwargs["usage"]["prompt_tokens"] == 13 + +@pytest.mark.scheduled +def test_bedrock_anthropic_function_call_invoke_with_system(chat: ChatBedrock) -> None: + class GetWeather(BaseModel): + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + llm_with_tools = chat.bind_tools([GetWeather]) + + messages = [ + SystemMessage( + content="anwser only in french" + ), + HumanMessage( + content="what is the weather like in San Francisco" + ) + ] + + response = llm_with_tools.invoke(messages) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + +@pytest.mark.scheduled +def test_bedrock_anthropic_function_call_invoke_without_system(chat: ChatBedrock) -> None: + class GetWeather(BaseModel): + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + llm_with_tools = chat.bind_tools([GetWeather]) + + messages = [ + HumanMessage( + content="what is the weather like in San Francisco" + ) + ] + + response = llm_with_tools.invoke(messages) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + +@pytest.mark.scheduled +async def test_bedrock_anthropic_function_call_invoke_with_system_astream(chat: ChatBedrock) -> None: + class GetWeather(BaseModel): + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + llm_with_tools = chat.bind_tools([GetWeather]) + + messages = [ + SystemMessage( + content="anwser only in french" + ), + HumanMessage( + content="what is the weather like in San Francisco" + ) + ] + + for chunk in llm_with_tools.stream(messages): + assert isinstance(chunk.content, str) + +@pytest.mark.scheduled +async def test_bedrock_anthropic_function_call_invoke_without_system_astream(chat: ChatBedrock) -> None: + class GetWeather(BaseModel): + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + llm_with_tools = chat.bind_tools([GetWeather]) + + messages = [ + HumanMessage( + content="what is the weather like in San Francisco" + ) + ] + + full = None + for token in llm_with_tools.stream("I'm Pickle Rick"): + full = token if full is None else full + token + assert isinstance(token.content, str) + assert isinstance(cast(AIMessageChunk, full).content, str) \ No newline at end of file From 219f764acb5522670bea330733634394eca0663b Mon Sep 17 00:00:00 2001 From: Nicolas Bernier Date: Fri, 3 May 2024 12:02:55 -0400 Subject: [PATCH 3/3] fix linting errors after adding tests --- .../chat_models/test_bedrock.py | 63 +++++++------------ 1 file changed, 24 insertions(+), 39 deletions(-) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index ced4f3b4..42882149 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -9,10 +9,10 @@ SystemMessage, ) from langchain_core.outputs import ChatGeneration, LLMResult +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_aws.chat_models.bedrock import ChatBedrock from tests.callbacks import FakeCallbackHandler -from langchain_core.pydantic_v1 import BaseModel, Field @pytest.fixture @@ -157,80 +157,65 @@ def test_bedrock_invoke(chat: ChatBedrock) -> None: """Test invoke tokens from BedrockChat.""" result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) - assert "usage" in result.additional_kwargs - assert result.additional_kwargs["usage"]["prompt_tokens"] == 13 + assert "usage" in result.additional_kwargs + assert result.additional_kwargs["usage"]["prompt_tokens"] == 13 + @pytest.mark.scheduled -def test_bedrock_anthropic_function_call_invoke_with_system(chat: ChatBedrock) -> None: +def test_function_call_invoke_with_system(chat: ChatBedrock) -> None: class GetWeather(BaseModel): - location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + location: str = Field(..., description="The city and state") llm_with_tools = chat.bind_tools([GetWeather]) messages = [ - SystemMessage( - content="anwser only in french" - ), - HumanMessage( - content="what is the weather like in San Francisco" - ) + SystemMessage(content="anwser only in french"), + HumanMessage(content="what is the weather like in San Francisco"), ] response = llm_with_tools.invoke(messages) assert isinstance(response, BaseMessage) assert isinstance(response.content, str) - + + @pytest.mark.scheduled -def test_bedrock_anthropic_function_call_invoke_without_system(chat: ChatBedrock) -> None: +def test_function_call_invoke_without_system(chat: ChatBedrock) -> None: class GetWeather(BaseModel): - location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + location: str = Field(..., description="The city and state") llm_with_tools = chat.bind_tools([GetWeather]) - messages = [ - HumanMessage( - content="what is the weather like in San Francisco" - ) - ] + messages = [HumanMessage(content="what is the weather like in San Francisco")] response = llm_with_tools.invoke(messages) assert isinstance(response, BaseMessage) assert isinstance(response.content, str) + @pytest.mark.scheduled -async def test_bedrock_anthropic_function_call_invoke_with_system_astream(chat: ChatBedrock) -> None: +async def test_function_call_invoke_with_system_astream(chat: ChatBedrock) -> None: class GetWeather(BaseModel): - location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + location: str = Field(..., description="The city and state") llm_with_tools = chat.bind_tools([GetWeather]) messages = [ - SystemMessage( - content="anwser only in french" - ), - HumanMessage( - content="what is the weather like in San Francisco" - ) + SystemMessage(content="anwser only in french"), + HumanMessage(content="what is the weather like in San Francisco"), ] for chunk in llm_with_tools.stream(messages): assert isinstance(chunk.content, str) + @pytest.mark.scheduled -async def test_bedrock_anthropic_function_call_invoke_without_system_astream(chat: ChatBedrock) -> None: +async def test_function_call_invoke_without_system_astream(chat: ChatBedrock) -> None: class GetWeather(BaseModel): - location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + location: str = Field(..., description="The city and state") llm_with_tools = chat.bind_tools([GetWeather]) - messages = [ - HumanMessage( - content="what is the weather like in San Francisco" - ) - ] + messages = [HumanMessage(content="what is the weather like in San Francisco")] - full = None - for token in llm_with_tools.stream("I'm Pickle Rick"): - full = token if full is None else full + token - assert isinstance(token.content, str) - assert isinstance(cast(AIMessageChunk, full).content, str) \ No newline at end of file + for chunk in llm_with_tools.stream(messages): + assert isinstance(chunk.content, str)