Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding for Cohere models in BedrockChat #42

Closed
wants to merge 18 commits into from
Closed
62 changes: 61 additions & 1 deletion libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,40 @@ def _format_anthropic_messages(
return system, formatted_messages


def _format_cohere_messages(
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict]]:
"""Format messages for cohere."""

"""
{
"message": content,
"chat_history": [
{"role": "USER or CHATBOT", "message": message.content}
]
}
"""
content: Optional[str] = None
chat_history: List[Dict] = []
for i, message in enumerate(messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
if not isinstance(message.content, str):
raise ValueError(
"System message must be a string, "
f"instead was: {type(message.content)}"
)
chat_history.append({"role": "USER", "message": message.content})
continue
elif message.type == "ai":
chat_history.append({"role": "CHATBOT", "message": message.content})
elif message.type == "human":
chat_history.append({"role": "USER", "message": message.content})
content = str(messages[-1].content)
return content, chat_history


class ChatPromptAdapter:
"""Adapter class to prepare the inputs from Langchain to prompt format
that Chat model expects.
Expand All @@ -355,6 +389,12 @@ def convert_messages_to_prompt(
human_prompt="\n\nUser:",
ai_prompt="\n\nBot:",
)
elif provider == "cohere":
prompt = convert_messages_to_prompt_anthropic(
messages=messages,
human_prompt="\n\nUser:",
ai_prompt="\n\nBot:",
)
else:
raise NotImplementedError(
f"Provider {provider} model does not support chat."
Expand All @@ -367,7 +407,16 @@ def format_messages(
) -> Tuple[Optional[str], List[Dict]]:
if provider == "anthropic":
return _format_anthropic_messages(messages)
raise NotImplementedError(
f"Provider {provider} not supported for format_messages"
)

@classmethod
def format_cohere_message(
cls, provider: str, messages: List[BaseMessage]
) -> Tuple[Optional[str], List[Dict]]:
if provider == "cohere":
return _format_cohere_messages(messages)
raise NotImplementedError(
f"Provider {provider} not supported for format_messages"
)
Expand Down Expand Up @@ -506,7 +555,8 @@ def _generate(
response_metadata, provider_stop_reason_code
)
else:
prompt, system, formatted_messages = None, None, None
provider = self._get_provider()
prompt, system, formatted_messages, chat_history = None, None, None, None
params: Dict[str, Any] = {**kwargs}

if provider == "anthropic":
Expand All @@ -519,6 +569,15 @@ def _generate(
system = self.system_prompt_with_tools + f"\n{system}"
else:
system = self.system_prompt_with_tools
elif provider == "cohere":
if "command-r" in self.model_id:
prompt, chat_history = ChatPromptAdapter.format_cohere_message(
provider, messages
)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages, model=self._get_model()
)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages, model=self._get_model()
Expand All @@ -533,6 +592,7 @@ def _generate(
run_manager=run_manager,
system=system,
messages=formatted_messages,
chat_history=chat_history,
**params,
)
# usage metadata
Expand Down
31 changes: 27 additions & 4 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def prepare_input(
model_kwargs: Dict[str, Any],
prompt: Optional[str] = None,
system: Optional[str] = None,
chat_history: Optional[List[Dict]] = None,
messages: Optional[List[Dict]] = None,
tools: Optional[List[AnthropicTool]] = None,
) -> Dict[str, Any]:
Expand All @@ -289,7 +290,15 @@ def prepare_input(
input_body["prompt"] = _human_assistant_format(prompt)
if "max_tokens_to_sample" not in input_body:
input_body["max_tokens_to_sample"] = 1024
elif provider in ("ai21", "cohere", "meta", "mistral"):
elif provider == "cohere":
# Command-R
if chat_history:
input_body["chat_history"] = chat_history
input_body["message"] = prompt
# Command
else:
input_body["prompt"] = prompt
elif provider in ("ai21", "meta", "mistral"):
input_body["prompt"] = prompt
elif provider == "amazon":
input_body = dict()
Expand All @@ -315,12 +324,18 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
text = content[0]["text"]
elif any(block["type"] == "tool_use" for block in content):
tool_calls = extract_tool_calls(content)

elif provider == "cohere":
if "text" in response_body.keys():
# Command-R
text = response_body.get("text")
else:
# Command
text = response_body.get("generations")[0].get("text")
else:
response_body = json.loads(response.get("body").read())

if provider == "ai21":
text = response_body.get("completions")[0].get("data").get("text")
elif provider == "cohere":
text = response_body.get("generations")[0].get("text")
elif provider == "meta":
text = response_body.get("generation")
elif provider == "mistral":
Expand Down Expand Up @@ -674,6 +689,7 @@ def _prepare_input_and_invoke(
prompt: Optional[str] = None,
system: Optional[str] = None,
messages: Optional[List[Dict]] = None,
chat_history: Optional[List[Dict]] = None,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
Expand All @@ -691,6 +707,7 @@ def _prepare_input_and_invoke(
model_kwargs=params,
prompt=prompt,
system=system,
chat_history=chat_history,
messages=messages,
)
if "claude-3" in self._get_model():
Expand Down Expand Up @@ -969,6 +986,12 @@ def validate_environment(cls, values: Dict) -> Dict:
"Please use `from langchain_community.chat_models import BedrockChat` "
"instead."
)
if model_id.startswith("cohere.command-r"):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I check why it's not supported here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

command-r has been excluded from BedrockLLM for the same reason as claude-3. Specifically, command-r differs from the standard command model in terms of output format and behavior. Due to these differences, it requires specialized handling, which is why it has been excluded from the general BedrockLLM class.

raise ValueError(
"Command R models are not supported by this LLM."
"Please use `from langchain_community.chat_models import BedrockChat` "
"instead."
)
return super().validate_environment(values)

@property
Expand Down
19 changes: 19 additions & 0 deletions libs/aws/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,25 @@ def test_chat_bedrock_streaming_llama3() -> None:
assert response.usage_metadata


@pytest.mark.scheduled
@pytest.mark.parametrize(
"model_id",
[
"cohere.command-text-v14",
"cohere.command-r-plus-v1:0",
],)
def test_chat_bedrock_invoke_cohere(model_id: str) -> None:
"""Test that streaming correctly streams message chunks"""
chat = ChatBedrock( # type: ignore[call-arg]
model_id=model_id
)
system = SystemMessage(content="You are a helpful assistant.")
human = HumanMessage(content="Hello")
response = chat.invoke([system, human])

assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)

@pytest.mark.scheduled
def test_chat_bedrock_streaming_generation_info() -> None:
"""Test that generation info is preserved when streaming."""
Expand Down
18 changes: 18 additions & 0 deletions libs/aws/tests/unit_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from langchain_aws import ChatBedrock
from langchain_aws.chat_models.bedrock import (
_format_anthropic_messages,
_format_cohere_messages,
_merge_messages,
)
from langchain_aws.function_calling import convert_to_anthropic_tool
Expand Down Expand Up @@ -256,6 +257,23 @@ def test__format_anthropic_messages_with_tool_use_blocks_and_tool_calls() -> Non
actual = _format_anthropic_messages(messages)
assert expected == actual

def test__format_cohere_messages() -> None:
system = SystemMessage("fuzz") # type: ignore[misc]
human = HumanMessage("foo") # type: ignore[misc]
ai = AIMessage("bar") # type: ignore[misc]

messages = [system, human, ai]
expected = (
"bar",
[
{"role": "USER", "message": "fuzz"},
{"role": "USER", "message": "foo"},
{"role": "CHATBOT", "message": "bar"},
],
)
actual = _format_cohere_messages(messages)
assert expected == actual


@pytest.fixture()
def pydantic() -> Type[BaseModel]:
Expand Down