From 6fee3db19b0927b44ecaf270e6c7ba881148811a Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Wed, 8 May 2024 23:38:28 +0900 Subject: [PATCH 01/15] add cohere model --- libs/aws/langchain_aws/chat_models/bedrock.py | 48 ++++++++++++++++++- libs/aws/langchain_aws/llms/bedrock.py | 7 +++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 9dc7e5b7..ebe73c8d 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -131,7 +131,6 @@ def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str: [_convert_one_message_to_text_mistral(message) for message in messages] ) - def _format_image(image_url: str) -> Dict: """ Formats an image of format data:image/jpeg;base64,{b64_string} @@ -235,6 +234,46 @@ def _format_anthropic_messages( ) return system, formatted_messages +def _format_cohere_messages( + messages: List[BaseMessage], +) -> Tuple[Optional[str], List[Dict]]: + """Format messages for cohere.""" + + """ + {'message': content} + """ + + system: Optional[str] = None + chat_history: List = [] + formatted_messages: Dict = {} + for i, message in enumerate(messages): + if message.type == "system": + if i != 0: + raise ValueError("System message must be at beginning of message list.") + if not isinstance(message.content, str): + raise ValueError( + "System message must be a string, " + f"instead was: {type(message.content)}" + ) + chat_history.append({'role':'CHATBOT', 'message': message.content}) + continue + + if not isinstance(message.content, str): + # populate content + content = [] + for item in message.content: + if isinstance(item, str): + content.append( + { + "type": "text", + "text": item, + } + ) + else: + content = message.content + + formatted_messages = {'message': content, 'chat_history': chat_history} + return system, formatted_messages class ChatPromptAdapter: """Adapter class to prepare the inputs from Langchain to prompt format @@ -269,7 +308,8 @@ def format_messages( ) -> Tuple[Optional[str], List[Dict]]: if provider == "anthropic": return _format_anthropic_messages(messages) - + elif provider == "cohere": + return _format_cohere_messages(messages) raise NotImplementedError( f"Provider {provider} not supported for format_messages" ) @@ -374,6 +414,10 @@ def _generate( system = self.system_prompt_with_tools + f"\n{system}" else: system = self.system_prompt_with_tools + elif provider == "cohere": + system, formatted_messages = ChatPromptAdapter.format_messages( + provider, messages + ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( provider=provider, messages=messages diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 90b860ab..6fc301cb 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -130,6 +130,10 @@ def prepare_input( input_body["prompt"] = _human_assistant_format(prompt) if "max_tokens_to_sample" not in input_body: input_body["max_tokens_to_sample"] = 1024 + elif provider == "cohere": + if messages: + input_body["chat_history"] = messages['chat_history'] + input_body["message"] = messages['message'] elif provider in ("ai21", "cohere", "meta", "mistral"): input_body["prompt"] = prompt elif provider == "amazon": @@ -151,6 +155,9 @@ def prepare_output(cls, provider: str, response: Any) -> dict: elif "content" in response_body: content = response_body.get("content") text = content[0].get("text") + elif provider == "cohere": + response_body = json.loads(response.get("body").read().decode()) + text = response_body.get("text") else: response_body = json.loads(response.get("body").read()) From c1fcbad198673e348ff0f87731c7643f1f16b965 Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Wed, 8 May 2024 23:49:18 +0900 Subject: [PATCH 02/15] fix chat_history role --- libs/aws/langchain_aws/chat_models/bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index ebe73c8d..c7177e8e 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -255,7 +255,7 @@ def _format_cohere_messages( "System message must be a string, " f"instead was: {type(message.content)}" ) - chat_history.append({'role':'CHATBOT', 'message': message.content}) + chat_history.append({'role':'USER', 'message': message.content}) continue if not isinstance(message.content, str): From 21a3fd069ba41d5236d2af9e9b6e8ddac7550291 Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Thu, 9 May 2024 08:09:25 +0900 Subject: [PATCH 03/15] fix chat_history --- libs/aws/langchain_aws/chat_models/bedrock.py | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index c7177e8e..e4e44978 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -240,7 +240,7 @@ def _format_cohere_messages( """Format messages for cohere.""" """ - {'message': content} + {'message': content, 'chat_history': [{'role': 'USER or CHATBOT', 'message': message.content}]} """ system: Optional[str] = None @@ -256,23 +256,12 @@ def _format_cohere_messages( f"instead was: {type(message.content)}" ) chat_history.append({'role':'USER', 'message': message.content}) - continue - - if not isinstance(message.content, str): - # populate content - content = [] - for item in message.content: - if isinstance(item, str): - content.append( - { - "type": "text", - "text": item, - } - ) - else: - content = message.content - - formatted_messages = {'message': content, 'chat_history': chat_history} + elif message.type == "assistant": + chat_history.append({'role':'CHATBT', 'message': message.content}) + elif message.type == "user": + chat_history.append({'role':'USER', 'message': message.content}) + content = messages[-1].content + formatted_messages = {'message': content, 'chat_history': chat_history} return system, formatted_messages class ChatPromptAdapter: From 85db9e523ce6ca495b7e1a5d58af4ba25fcb623d Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Thu, 9 May 2024 08:10:19 +0900 Subject: [PATCH 04/15] fix chat_history --- libs/aws/langchain_aws/chat_models/bedrock.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index e4e44978..f35830dc 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -256,9 +256,10 @@ def _format_cohere_messages( f"instead was: {type(message.content)}" ) chat_history.append({'role':'USER', 'message': message.content}) - elif message.type == "assistant": - chat_history.append({'role':'CHATBT', 'message': message.content}) - elif message.type == "user": + continue + elif message.type == "ai": + chat_history.append({'role':'CHATBOT', 'message': message.content}) + elif message.type == "human": chat_history.append({'role':'USER', 'message': message.content}) content = messages[-1].content formatted_messages = {'message': content, 'chat_history': chat_history} From bb7f69a5cc557f2b34c2058f04a9f8fd0831e57a Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Sat, 11 May 2024 16:18:19 +0900 Subject: [PATCH 05/15] fix use command model case --- libs/aws/langchain_aws/chat_models/bedrock.py | 15 +++++++++++++-- libs/aws/langchain_aws/llms/bedrock.py | 9 ++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index f35830dc..c87216ef 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -286,6 +286,12 @@ def convert_messages_to_prompt( human_prompt="\n\nUser:", ai_prompt="\n\nBot:", ) + elif provider == "cohere": + prompt = convert_messages_to_prompt_anthropic( + messages=messages, + human_prompt="\n\nUser:", + ai_prompt="\n\nBot:", + ) else: raise NotImplementedError( f"Provider {provider} model does not support chat." @@ -405,8 +411,13 @@ def _generate( else: system = self.system_prompt_with_tools elif provider == "cohere": - system, formatted_messages = ChatPromptAdapter.format_messages( - provider, messages + if 'command-r' in self.model_id: + system, formatted_messages = ChatPromptAdapter.format_messages( + provider, messages + ) + else: + prompt = ChatPromptAdapter.convert_messages_to_prompt( + provider=provider, messages=messages ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 6fc301cb..8c7508a3 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -157,14 +157,17 @@ def prepare_output(cls, provider: str, response: Any) -> dict: text = content[0].get("text") elif provider == "cohere": response_body = json.loads(response.get("body").read().decode()) - text = response_body.get("text") + if 'text' in response_body.keys(): + # Command-R + text = response_body.get("text") + else: + # Command + text = response_body.get("generations")[0].get("text") else: response_body = json.loads(response.get("body").read()) if provider == "ai21": text = response_body.get("completions")[0].get("data").get("text") - elif provider == "cohere": - text = response_body.get("generations")[0].get("text") elif provider == "meta": text = response_body.get("generation") elif provider == "mistral": From b6d0660251e65edb3640aaa55118f63fe03403cf Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Sat, 11 May 2024 16:23:38 +0900 Subject: [PATCH 06/15] fix use command model case --- libs/aws/langchain_aws/llms/bedrock.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 8c7508a3..502e2a11 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -131,9 +131,13 @@ def prepare_input( if "max_tokens_to_sample" not in input_body: input_body["max_tokens_to_sample"] = 1024 elif provider == "cohere": + # Command-R if messages: input_body["chat_history"] = messages['chat_history'] input_body["message"] = messages['message'] + # Command + else: + input_body["prompt"] = prompt elif provider in ("ai21", "cohere", "meta", "mistral"): input_body["prompt"] = prompt elif provider == "amazon": From 597608cc66257775c5c3015d095283631f1dc13b Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Sat, 11 May 2024 16:24:09 +0900 Subject: [PATCH 07/15] fix use command model case --- libs/aws/langchain_aws/llms/bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 502e2a11..7413c16a 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -138,7 +138,7 @@ def prepare_input( # Command else: input_body["prompt"] = prompt - elif provider in ("ai21", "cohere", "meta", "mistral"): + elif provider in ("ai21", "meta", "mistral"): input_body["prompt"] = prompt elif provider == "amazon": input_body = dict() From e7604a00e312893d57e934693c2deda0044e6a36 Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Sat, 11 May 2024 16:51:05 +0900 Subject: [PATCH 08/15] add msg: cohere.command in BedrockBase is not supported --- libs/aws/langchain_aws/llms/bedrock.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 7413c16a..aab085dc 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -760,6 +760,12 @@ def validate_environment(cls, values: Dict) -> Dict: "Please use `from langchain_community.chat_models import BedrockChat` " "instead." ) + if model_id.startswith("cohere.command-r"): + raise ValueError( + "Command R models are not supported by this LLM." + "Please use `from langchain_community.chat_models import BedrockChat` " + "instead." + ) return super().validate_environment(values) @property From 858839fbd19d15aa56d78972335ace13b9c68bdc Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Fri, 17 May 2024 21:40:28 +0900 Subject: [PATCH 09/15] fix lint E501 --- libs/aws/langchain_aws/chat_models/bedrock.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index c87216ef..65857d55 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -240,7 +240,12 @@ def _format_cohere_messages( """Format messages for cohere.""" """ - {'message': content, 'chat_history': [{'role': 'USER or CHATBOT', 'message': message.content}]} + { + 'message': content, + 'chat_history': [ + {'role': 'USER or CHATBOT', 'message': message.content} + ] + } """ system: Optional[str] = None From 26d2e6965ba56dfeda72e1951fd212db4dca246b Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Fri, 17 May 2024 23:06:31 +0900 Subject: [PATCH 10/15] fix test error --- libs/aws/langchain_aws/chat_models/bedrock.py | 47 +++++++++++-------- libs/aws/langchain_aws/function_calling.py | 2 +- libs/aws/langchain_aws/llms/bedrock.py | 9 ++-- .../langchain_aws/llms/sagemaker_endpoint.py | 1 + libs/aws/tests/callbacks.py | 1 + .../chat_models/test_bedrock.py | 1 + libs/aws/tests/unit_tests/__init__.py | 1 + 7 files changed, 38 insertions(+), 24 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 65857d55..698d6d70 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -131,6 +131,7 @@ def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str: [_convert_one_message_to_text_mistral(message) for message in messages] ) + def _format_image(image_url: str) -> Dict: """ Formats an image of format data:image/jpeg;base64,{b64_string} @@ -234,6 +235,7 @@ def _format_anthropic_messages( ) return system, formatted_messages + def _format_cohere_messages( messages: List[BaseMessage], ) -> Tuple[Optional[str], List[Dict]]: @@ -241,16 +243,14 @@ def _format_cohere_messages( """ { - 'message': content, - 'chat_history': [ - {'role': 'USER or CHATBOT', 'message': message.content} + "message": content, + "chat_history": [ + {"role": "USER or CHATBOT", "message": message.content} ] } """ - - system: Optional[str] = None - chat_history: List = [] - formatted_messages: Dict = {} + content: Optional[str] = None + chat_history: List[Dict] = [] for i, message in enumerate(messages): if message.type == "system": if i != 0: @@ -260,15 +260,15 @@ def _format_cohere_messages( "System message must be a string, " f"instead was: {type(message.content)}" ) - chat_history.append({'role':'USER', 'message': message.content}) + chat_history.append({"role": "USER", "message": message.content}) continue elif message.type == "ai": - chat_history.append({'role':'CHATBOT', 'message': message.content}) + chat_history.append({"role": "CHATBOT", "message": message.content}) elif message.type == "human": - chat_history.append({'role':'USER', 'message': message.content}) - content = messages[-1].content - formatted_messages = {'message': content, 'chat_history': chat_history} - return system, formatted_messages + chat_history.append({"role": "USER", "message": message.content}) + content = str(messages[-1].content) + return content, chat_history + class ChatPromptAdapter: """Adapter class to prepare the inputs from Langchain to prompt format @@ -309,7 +309,15 @@ def format_messages( ) -> Tuple[Optional[str], List[Dict]]: if provider == "anthropic": return _format_anthropic_messages(messages) - elif provider == "cohere": + raise NotImplementedError( + f"Provider {provider} not supported for format_messages" + ) + + @classmethod + def format_cohere_message( + cls, provider: str, messages: List[BaseMessage] + ) -> Tuple[Optional[str], List[Dict]]: + if provider == "cohere": return _format_cohere_messages(messages) raise NotImplementedError( f"Provider {provider} not supported for format_messages" @@ -403,7 +411,7 @@ def _generate( completion += chunk.text else: provider = self._get_provider() - prompt, system, formatted_messages = None, None, None + prompt, system, formatted_messages, chat_history = None, None, None, None params: Dict[str, Any] = {**kwargs} if provider == "anthropic": @@ -416,14 +424,14 @@ def _generate( else: system = self.system_prompt_with_tools elif provider == "cohere": - if 'command-r' in self.model_id: - system, formatted_messages = ChatPromptAdapter.format_messages( + if "command-r" in self.model_id: + prompt, chat_history = ChatPromptAdapter.format_cohere_message( provider, messages ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages - ) + provider=provider, messages=messages + ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( provider=provider, messages=messages @@ -438,6 +446,7 @@ def _generate( run_manager=run_manager, system=system, messages=formatted_messages, + chat_history=chat_history, **params, ) diff --git a/libs/aws/langchain_aws/function_calling.py b/libs/aws/langchain_aws/function_calling.py index 765332e2..e986dc75 100644 --- a/libs/aws/langchain_aws/function_calling.py +++ b/libs/aws/langchain_aws/function_calling.py @@ -1,4 +1,4 @@ -"""Methods for creating function specs in the style of Bedrock Functions +"""Methods for creating function specs in the style of Bedrock Functions for supported model providers""" import json diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index aab085dc..bb0b9344 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -115,6 +115,7 @@ def prepare_input( model_kwargs: Dict[str, Any], prompt: Optional[str] = None, system: Optional[str] = None, + chat_history: Optional[str] = None, messages: Optional[List[Dict]] = None, ) -> Dict[str, Any]: input_body = {**model_kwargs} @@ -132,9 +133,9 @@ def prepare_input( input_body["max_tokens_to_sample"] = 1024 elif provider == "cohere": # Command-R - if messages: - input_body["chat_history"] = messages['chat_history'] - input_body["message"] = messages['message'] + if chat_history: + input_body["chat_history"] = chat_history + input_body["message"] = prompt # Command else: input_body["prompt"] = prompt @@ -161,7 +162,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict: text = content[0].get("text") elif provider == "cohere": response_body = json.loads(response.get("body").read().decode()) - if 'text' in response_body.keys(): + if "text" in response_body.keys(): # Command-R text = response_body.get("text") else: diff --git a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py index 27879f19..aa2d511b 100644 --- a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py +++ b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py @@ -1,4 +1,5 @@ """Sagemaker InvokeEndpoint API.""" + import io import re from abc import abstractmethod diff --git a/libs/aws/tests/callbacks.py b/libs/aws/tests/callbacks.py index 66b54256..5a06a5e8 100644 --- a/libs/aws/tests/callbacks.py +++ b/libs/aws/tests/callbacks.py @@ -1,4 +1,5 @@ """A fake callback handler for testing purposes.""" + from itertools import chain from typing import Any, Dict, List, Optional, Union from uuid import UUID diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 42882149..acb9e767 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -1,4 +1,5 @@ """Test Bedrock chat model.""" + from typing import Any, cast import pytest diff --git a/libs/aws/tests/unit_tests/__init__.py b/libs/aws/tests/unit_tests/__init__.py index 800bc7f3..ec8f4e60 100644 --- a/libs/aws/tests/unit_tests/__init__.py +++ b/libs/aws/tests/unit_tests/__init__.py @@ -1,4 +1,5 @@ """All unit tests (lightweight tests).""" + from typing import Any From f5e22f971b0d633e535d738758d4edaf3603e4b8 Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Fri, 17 May 2024 23:17:36 +0900 Subject: [PATCH 11/15] tweak --- libs/aws/langchain_aws/llms/bedrock.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index bb0b9344..a6296876 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -512,6 +512,7 @@ def _prepare_input_and_invoke( prompt: Optional[str] = None, system: Optional[str] = None, messages: Optional[List[Dict]] = None, + chat_history: Optional[List[Dict]] = None, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, @@ -526,6 +527,7 @@ def _prepare_input_and_invoke( model_kwargs=params, prompt=prompt, system=system, + chat_history=chat_history, messages=messages, ) body = json.dumps(input_body) From c7440a85faa2fd036baf51bda948dae202fa02f5 Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Fri, 17 May 2024 23:22:59 +0900 Subject: [PATCH 12/15] tweak --- libs/aws/langchain_aws/llms/bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index a6296876..997335ed 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -115,7 +115,7 @@ def prepare_input( model_kwargs: Dict[str, Any], prompt: Optional[str] = None, system: Optional[str] = None, - chat_history: Optional[str] = None, + chat_history: Optional[List[Dict]] = None, messages: Optional[List[Dict]] = None, ) -> Dict[str, Any]: input_body = {**model_kwargs} From 840ea482f3c074aff00f804ba64d0c60632fce04 Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Sun, 26 May 2024 19:28:49 +0900 Subject: [PATCH 13/15] fix code --- libs/aws/langchain_aws/chat_models/bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index fdfc06a5..dc28f61f 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -484,7 +484,7 @@ def _generate( ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages + provider=provider, messages=messages, model=self._get_model() ) else: prompt = ChatPromptAdapter.convert_messages_to_prompt( From dcb5b020fc61e74147ab48f0b25458cbeb26ffec Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Sat, 10 Aug 2024 10:37:36 +0900 Subject: [PATCH 14/15] Delete unnecessary lines --- libs/aws/langchain_aws/llms/bedrock.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 2737ebc7..4b15632f 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -325,7 +325,6 @@ def prepare_output(cls, provider: str, response: Any) -> dict: elif any(block["type"] == "tool_use" for block in content): tool_calls = extract_tool_calls(content) elif provider == "cohere": - response_body = json.loads(response.get("body").read().decode()) if "text" in response_body.keys(): # Command-R text = response_body.get("text") From c001826a9c3d3a478db2b56a6637d8013ac03ba4 Mon Sep 17 00:00:00 2001 From: ksaegusa Date: Fri, 23 Aug 2024 22:09:55 +0900 Subject: [PATCH 15/15] add command-r test --- .../chat_models/test_bedrock.py | 19 +++++++++++++++++++ .../unit_tests/chat_models/test_bedrock.py | 18 ++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index d5740f7a..7237d5db 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -100,6 +100,25 @@ def test_chat_bedrock_streaming_llama3() -> None: assert response.usage_metadata +@pytest.mark.scheduled +@pytest.mark.parametrize( + "model_id", + [ + "cohere.command-text-v14", + "cohere.command-r-plus-v1:0", + ],) +def test_chat_bedrock_invoke_cohere(model_id: str) -> None: + """Test that streaming correctly streams message chunks""" + chat = ChatBedrock( # type: ignore[call-arg] + model_id=model_id + ) + system = SystemMessage(content="You are a helpful assistant.") + human = HumanMessage(content="Hello") + response = chat.invoke([system, human]) + + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + @pytest.mark.scheduled def test_chat_bedrock_streaming_generation_info() -> None: """Test that generation info is preserved when streaming.""" diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py index fd2b7cf6..5c987cb8 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py @@ -11,6 +11,7 @@ from langchain_aws import ChatBedrock from langchain_aws.chat_models.bedrock import ( _format_anthropic_messages, + _format_cohere_messages, _merge_messages, ) from langchain_aws.function_calling import convert_to_anthropic_tool @@ -256,6 +257,23 @@ def test__format_anthropic_messages_with_tool_use_blocks_and_tool_calls() -> Non actual = _format_anthropic_messages(messages) assert expected == actual +def test__format_cohere_messages() -> None: + system = SystemMessage("fuzz") # type: ignore[misc] + human = HumanMessage("foo") # type: ignore[misc] + ai = AIMessage("bar") # type: ignore[misc] + + messages = [system, human, ai] + expected = ( + "bar", + [ + {"role": "USER", "message": "fuzz"}, + {"role": "USER", "message": "foo"}, + {"role": "CHATBOT", "message": "bar"}, + ], + ) + actual = _format_cohere_messages(messages) + assert expected == actual + @pytest.fixture() def pydantic() -> Type[BaseModel]: