Skip to content

Commit

Permalink
Merge pull request #37 from bigbernnn/main
Browse files Browse the repository at this point in the history
Adds support for function calling with Anthropic models on Bedrock
  • Loading branch information
3coins authored May 3, 2024
2 parents 2331bf6 + 219f764 commit 123c720
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 5 deletions.
2 changes: 1 addition & 1 deletion libs/aws/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ retriever = AmazonKendraRetriever(
retriever.get_relevant_documents(query="What is the meaning of life?")
```

`AmazonKnowlegeBasesRetriever` class provides a retriever to connect with Amazon Knowlege Bases.
`AmazonKnowlegeBasesRetriever` class provides a retriever to connect with Amazon Knowledge Bases.

```python
from langchain_aws import AmazonKnowledgeBasesRetriever
Expand Down
69 changes: 67 additions & 2 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import re
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -16,8 +30,11 @@
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool

from langchain_aws.function_calling import convert_to_anthropic_tool, get_system_message
from langchain_aws.llms.bedrock import BedrockBase
from langchain_aws.utils import (
get_num_tokens_anthropic,
Expand Down Expand Up @@ -264,6 +281,8 @@ def format_messages(
class ChatBedrock(BaseChatModel, BedrockBase):
"""A chat model that uses the Bedrock API."""

system_prompt_with_tools: str = ""

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
Expand Down Expand Up @@ -307,6 +326,11 @@ def _stream(
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
if self.system_prompt_with_tools:
if system:
system = self.system_prompt_with_tools + f"\n{system}"
else:
system = self.system_prompt_with_tools
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
Expand Down Expand Up @@ -345,6 +369,11 @@ def _generate(
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
if self.system_prompt_with_tools:
if system:
system = self.system_prompt_with_tools + f"\n{system}"
else:
system = self.system_prompt_with_tools
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
Expand Down Expand Up @@ -399,6 +428,42 @@ def get_token_ids(self, text: str) -> List[int]:
else:
return super().get_token_ids(text)

def set_system_prompt_with_tools(self, xml_tools_system_prompt: str) -> None:
"""Workaround to bind. Sets the system prompt with tools"""
self.system_prompt_with_tools = xml_tools_system_prompt

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model has a tool calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any), or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
provider = self._get_provider()

if provider == "anthropic":
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
system_formatted_tools = get_system_message(formatted_tools)
self.set_system_prompt_with_tools(system_formatted_tools)
return self


@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock")
class BedrockChat(ChatBedrock):
Expand Down
139 changes: 139 additions & 0 deletions libs/aws/langchain_aws/function_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Methods for creating function specs in the style of Bedrock Functions
for supported model providers"""

import json
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Type,
Union,
)

from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from typing_extensions import TypedDict

PYTHON_TO_JSON_TYPES = {
"str": "string",
"int": "integer",
"float": "number",
"bool": "boolean",
}

SYSTEM_PROMPT_FORMAT = """In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
{formatted_tools}
</tools>""" # noqa: E501

TOOL_FORMAT = """<tool_description>
<tool_name>{tool_name}</tool_name>
<description>{tool_description}</description>
<parameters>
{formatted_parameters}
</parameters>
</tool_description>"""

TOOL_PARAMETER_FORMAT = """<parameter>
<name>{parameter_name}</name>
<type>{parameter_type}</type>
<description>{parameter_description}</description>
</parameter>"""


class AnthropicTool(TypedDict):
name: str
description: str
input_schema: Dict[str, Any]


def _get_type(parameter: Dict[str, Any]) -> str:
if "type" in parameter:
return parameter["type"]
if "anyOf" in parameter:
return json.dumps({"anyOf": parameter["anyOf"]})
if "allOf" in parameter:
return json.dumps({"allOf": parameter["allOf"]})
return json.dumps(parameter)


def get_system_message(tools: List[AnthropicTool]) -> str:
tools_data: List[Dict] = [
{
"tool_name": tool["name"],
"tool_description": tool["description"],
"formatted_parameters": "\n".join(
[
TOOL_PARAMETER_FORMAT.format(
parameter_name=name,
parameter_type=_get_type(parameter),
parameter_description=parameter.get("description"),
)
for name, parameter in tool["input_schema"]["properties"].items()
]
),
}
for tool in tools
]
tools_formatted = "\n".join(
[
TOOL_FORMAT.format(
tool_name=tool["tool_name"],
tool_description=tool["tool_description"],
formatted_parameters=tool["formatted_parameters"],
)
for tool in tools_data
]
)
return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted)


class FunctionDescription(TypedDict):
"""Representation of a callable function to send to an LLM."""

name: str
"""The name of the function."""
description: str
"""A description of the function."""
parameters: dict
"""The parameters of the function."""


class ToolDescription(TypedDict):
"""Representation of a callable function to the OpenAI API."""

type: Literal["function"]
function: FunctionDescription


def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> AnthropicTool:
# already in Anthropic tool format
if isinstance(tool, dict) and all(
k in tool for k in ("name", "description", "input_schema")
):
return AnthropicTool(tool) # type: ignore
else:
formatted = convert_to_openai_tool(tool)["function"]
return AnthropicTool(
name=formatted["name"],
description=formatted["description"],
input_schema=formatted["parameters"],
)
65 changes: 63 additions & 2 deletions libs/aws/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field

from langchain_aws.chat_models.bedrock import ChatBedrock
from tests.callbacks import FakeCallbackHandler
Expand Down Expand Up @@ -156,5 +157,65 @@ def test_bedrock_invoke(chat: ChatBedrock) -> None:
"""Test invoke tokens from BedrockChat."""
result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
assert "usage" in result.additional_kwargs # type: ignore[attr-defined]
assert result.additional_kwargs["usage"]["prompt_tokens"] == 13 # type: ignore[attr-defined]
assert "usage" in result.additional_kwargs
assert result.additional_kwargs["usage"]["prompt_tokens"] == 13


@pytest.mark.scheduled
def test_function_call_invoke_with_system(chat: ChatBedrock) -> None:
class GetWeather(BaseModel):
location: str = Field(..., description="The city and state")

llm_with_tools = chat.bind_tools([GetWeather])

messages = [
SystemMessage(content="anwser only in french"),
HumanMessage(content="what is the weather like in San Francisco"),
]

response = llm_with_tools.invoke(messages)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)


@pytest.mark.scheduled
def test_function_call_invoke_without_system(chat: ChatBedrock) -> None:
class GetWeather(BaseModel):
location: str = Field(..., description="The city and state")

llm_with_tools = chat.bind_tools([GetWeather])

messages = [HumanMessage(content="what is the weather like in San Francisco")]

response = llm_with_tools.invoke(messages)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)


@pytest.mark.scheduled
async def test_function_call_invoke_with_system_astream(chat: ChatBedrock) -> None:
class GetWeather(BaseModel):
location: str = Field(..., description="The city and state")

llm_with_tools = chat.bind_tools([GetWeather])

messages = [
SystemMessage(content="anwser only in french"),
HumanMessage(content="what is the weather like in San Francisco"),
]

for chunk in llm_with_tools.stream(messages):
assert isinstance(chunk.content, str)


@pytest.mark.scheduled
async def test_function_call_invoke_without_system_astream(chat: ChatBedrock) -> None:
class GetWeather(BaseModel):
location: str = Field(..., description="The city and state")

llm_with_tools = chat.bind_tools([GetWeather])

messages = [HumanMessage(content="what is the weather like in San Francisco")]

for chunk in llm_with_tools.stream(messages):
assert isinstance(chunk.content, str)

0 comments on commit 123c720

Please sign in to comment.