From a430cd630d15b9589c554d0e16d1cb7fac027a91 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Sat, 2 Mar 2024 01:12:59 +0100 Subject: [PATCH] Docs updates + two additional unit tests (#513) --- .../generators/amazon_bedrock/adapters.py | 125 ++++++++++++++++-- .../amazon_bedrock/chat/adapters.py | 121 ++++++++++++++++- .../amazon_bedrock/chat/chat_generator.py | 53 ++++++-- .../generators/amazon_bedrock/generator.py | 76 ++++++++--- .../tests/test_chat_generator.py | 16 +++ 5 files changed, 349 insertions(+), 42 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index 40ba0bc67..a1704ef13 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -8,6 +8,9 @@ class BedrockModelAdapter(ABC): """ Base class for Amazon Bedrock model adapters. + + Each subclass of this class is designed to address the unique specificities of a particular LLM it adapts, + focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted LLMs. """ def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> None: @@ -16,15 +19,34 @@ def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> N @abstractmethod def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: - """Prepares the body for the Amazon Bedrock request.""" + """ + Prepares the body for the Amazon Bedrock request. + Each subclass should implement this method to prepare the request body for the specific model. + + :param prompt: The prompt to be sent to the model. + :param inference_kwargs: Additional keyword arguments passed to the handler. + :return: A dictionary containing the body for the request. + """ def get_responses(self, response_body: Dict[str, Any]) -> List[str]: - """Extracts the responses from the Amazon Bedrock response.""" + """ + Extracts the responses from the Amazon Bedrock response. + + :param response_body: The response body from the Amazon Bedrock request. + :return: A list of responses. + """ completions = self._extract_completions_from_response(response_body) responses = [completion.lstrip() for completion in completions] return responses def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]: + """ + Extracts the responses from the Amazon Bedrock streaming response. + + :param stream: The streaming response from the Amazon Bedrock request. + :param stream_handler: The handler for the streaming response. + :return: A list of string responses. + """ tokens: List[str] = [] for event in stream: chunk = event.get("chunk") @@ -40,6 +62,9 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str Merges the default params with the inference kwargs and model kwargs. Includes param if it's in kwargs or its default is not None (i.e. it is actually defined). + :param inference_kwargs: The inference kwargs. + :param default_params: The default params. + :return: A dictionary containing the merged params. """ kwargs = self.model_kwargs.copy() kwargs.update(inference_kwargs) @@ -51,19 +76,34 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str @abstractmethod def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: - """Extracts the responses from the Amazon Bedrock response.""" + """ + Extracts the responses from the Amazon Bedrock response. + + :param response_body: The response body from the Amazon Bedrock request. + :return: A list of string responses. + """ @abstractmethod def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: - """Extracts the token from a streaming chunk.""" + """ + Extracts the token from a streaming chunk. + + :param chunk: The streaming chunk. + :return: A string token. + """ class AnthropicClaudeAdapter(BedrockModelAdapter): """ - Model adapter for the Anthropic's Claude model. + Adapter for the Anthropic Claude models. """ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + """ + Prepares the body for the Claude model + + :param prompt: The prompt to be sent to the model. + """ default_params = { "max_tokens_to_sample": self.max_length, "stop_sequences": ["\n\nHuman:"], @@ -77,18 +117,37 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: return body def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + """ + Extracts the responses from the Amazon Bedrock response. + + :param response_body: The response body from the Amazon Bedrock request. + :return: A list of string responses. + """ return [response_body["completion"]] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + """ + Extracts the token from a streaming chunk. + + :param chunk: The streaming chunk. + :return: A string token. + """ return chunk.get("completion", "") class CohereCommandAdapter(BedrockModelAdapter): """ - Model adapter for the Cohere's Command model. + Adapter for the Cohere Command model. """ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + """ + Prepares the body for the Command model + + :param prompt: The prompt to be sent to the model. + :param inference_kwargs: Additional keyword arguments passed to the handler. + :return: A dictionary containing the body for the request. + """ default_params = { "max_tokens": self.max_length, "stop_sequences": None, @@ -107,10 +166,22 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: return body def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + """ + Extracts the responses from the Cohere Command model response. + + :param response_body: The response body from the Amazon Bedrock request. + :return: A list of string responses. + """ responses = [generation["text"] for generation in response_body["generations"]] return responses def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + """ + Extracts the token from a streaming chunk. + + :param chunk: The streaming chunk. + :return: A string token. + """ return chunk.get("text", "") @@ -146,10 +217,17 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: class AmazonTitanAdapter(BedrockModelAdapter): """ - Model adapter for Amazon's Titan models. + Adapter for Amazon's Titan models. """ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + """ + Prepares the body for the Titan model + + :param prompt: The prompt to be sent to the model. + :param inference_kwargs: Additional keyword arguments passed to the handler. + :return: A dictionary containing the body for the request. + """ default_params = { "maxTokenCount": self.max_length, "stopSequences": None, @@ -162,19 +240,38 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: return body def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + """ + Extracts the responses from the Titan model response. + + :param response_body: The response body for Titan model response. + :return: A list of string responses. + """ responses = [result["outputText"] for result in response_body["results"]] return responses def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + """ + Extracts the token from a streaming chunk. + + :param chunk: The streaming chunk. + :return: A string token. + """ return chunk.get("outputText", "") class MetaLlama2ChatAdapter(BedrockModelAdapter): """ - Model adapter for Meta's Llama 2 Chat models. + Adapter for Meta's Llama2 models. """ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + """ + Prepares the body for the Llama2 model + + :param prompt: The prompt to be sent to the model. + :param inference_kwargs: Additional keyword arguments passed to the handler. + :return: A dictionary containing the body for the request. + """ default_params = { "max_gen_len": self.max_length, "temperature": None, @@ -186,7 +283,19 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: return body def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + """ + Extracts the responses from the Llama2 model response. + + :param response_body: The response body from the Llama2 model request. + :return: A list of string responses. + """ return [response_body["generation"]] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + """ + Extracts the token from a streaming chunk. + + :param chunk: The streaming chunk. + :return: A string token. + """ return chunk.get("generation", "") diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index a4eefe321..d5dc100f9 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -15,17 +15,35 @@ class BedrockModelChatAdapter(ABC): """ Base class for Amazon Bedrock chat model adapters. + + Each subclass of this class is designed to address the unique specificities of a particular chat LLM it adapts, + focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted chat LLMs. """ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + """ + Initializes the chat adapter with the generation kwargs. + """ self.generation_kwargs = generation_kwargs @abstractmethod def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """Prepares the body for the Amazon Bedrock request.""" + """ + Prepares the body for the Amazon Bedrock request. + Subclasses should override this method to package the chat messages into the request. + + :param messages: The chat messages to package into the request. + :param inference_kwargs: Additional inference kwargs to use. + :return: The prepared body. + """ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """Extracts the responses from the Amazon Bedrock response.""" + """ + Extracts the responses from the Amazon Bedrock response. + + :param response_body: The response body. + :return: The extracted responses. + """ return self._extract_messages_from_response(self.response_body_message_key(), response_body) def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: @@ -79,6 +97,11 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str return kwargs def _ensure_token_limit(self, prompt: str) -> str: + """ + Ensures that the prompt is within the token limit for the model. + :param prompt: The prompt to check. + :return: The resized prompt. + """ resize_info = self.check_prompt(prompt) if resize_info["prompt_length"] != resize_info["new_prompt_length"]: logger.warning( @@ -95,34 +118,56 @@ def _ensure_token_limit(self, prompt: str) -> str: @abstractmethod def check_prompt(self, prompt: str) -> Dict[str, Any]: """ - Checks the prompt length and resizes it if necessary. + Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. :param prompt: The prompt to check. :return: A dictionary containing the resized prompt and additional information. """ def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]: + """ + Extracts the messages from the response body. + + :param message_tag: The key for the message in the response body. + :param response_body: The response body. + :return: The extracted ChatMessage list. + """ metadata = {k: v for (k, v) in response_body.items() if k != message_tag} return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] @abstractmethod def response_body_message_key(self) -> str: - """Returns the key for the message in the response body.""" + """ + Returns the key for the message in the response body. + Subclasses should override this method to return the correct message key - where the response is located. + + :return: The key for the message in the response body. + """ @abstractmethod def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: - """Extracts the token from a streaming chunk.""" + """ + Extracts the token from a streaming chunk. + + :param chunk: The streaming chunk. + :return: The extracted token. + """ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): """ - Model adapter for the Anthropic Claude model. + Model adapter for the Anthropic Claude chat model. """ ANTHROPIC_USER_TOKEN = "\n\nHuman:" ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:" def __init__(self, generation_kwargs: Dict[str, Any]): + """ + Initializes the Anthropic Claude chat adapter. + + :param generation_kwargs: The generation kwargs. + """ super().__init__(generation_kwargs) # We pop the model_max_length as it is not sent to the model @@ -142,6 +187,13 @@ def __init__(self, generation_kwargs: Dict[str, Any]): ) def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + """ + Prepares the body for the Anthropic Claude request. + + :param messages: The chat messages to package into the request. + :param inference_kwargs: Additional inference kwargs to use. + :return: The prepared body. + """ default_params = { "max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512, "stop_sequences": ["\n\nHuman:"], @@ -156,6 +208,12 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ return body def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + """ + Prepares the chat messages for the Anthropic Claude request. + + :param messages: The chat messages to prepare. + :return: The prepared chat messages as a string. + """ conversation = [] for index, message in enumerate(messages): if message.is_from(ChatRole.USER): @@ -179,12 +237,29 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: return self._ensure_token_limit(prepared_prompt) def check_prompt(self, prompt: str) -> Dict[str, Any]: + """ + Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. + + :param prompt: The prompt to check. + :return: A dictionary containing the resized prompt and additional information. + """ return self.prompt_handler(prompt) def response_body_message_key(self) -> str: + """ + Returns the key for the message in the response body for Anthropic Claude i.e. "completion". + + :return: The key for the message in the response body. + """ return "completion" def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + """ + Extracts the token from a streaming chunk. + + :param chunk: The streaming chunk. + :return: The extracted token. + """ return chunk.get("completion", "") @@ -219,6 +294,10 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): ) def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + """ + Initializes the Meta Llama 2 chat adapter. + :param generation_kwargs: The generation kwargs. + """ super().__init__(generation_kwargs) # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed @@ -240,6 +319,12 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None: ) def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + """ + Prepares the body for the Meta Llama 2 request. + + :param messages: The chat messages to package into the request. + :param inference_kwargs: Additional inference kwargs to use. + """ default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} # combine stop words with default stop sequences, remove stop_words as MetaLlama2 does not support it @@ -251,16 +336,40 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ return body def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + """ + Prepares the chat messages for the Meta Llama 2 request. + + :param messages: The chat messages to prepare. + :return: The prepared chat messages as a string ready for the model. + """ prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( conversation=messages, tokenize=False, chat_template=self.chat_template ) return self._ensure_token_limit(prepared_prompt) def check_prompt(self, prompt: str) -> Dict[str, Any]: + """ + Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. + + :param prompt: The prompt to check. + :return: A dictionary containing the resized prompt and additional information. + + """ return self.prompt_handler(prompt) def response_body_message_key(self) -> str: + """ + Returns the key for the message in the response body for Meta Llama 2 i.e. "generation". + + :return: The key for the message in the response body. + """ return "generation" def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + """ + Extracts the token from a streaming chunk. + + :param chunk: The streaming chunk. + :return: The extracted token. + """ return chunk.get("generation", "") diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index e21393b60..3b5a8f6cc 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -23,9 +23,10 @@ @component class AmazonBedrockChatGenerator: """ - AmazonBedrockChatGenerator enables text generation via Amazon Bedrock chat hosted models. For example, to use - the Anthropic Claude model, simply initialize the AmazonBedrockChatGenerator with the 'anthropic.claude-v2' - model name. + `AmazonBedrockChatGenerator` enables text generation via Amazon Bedrock hosted chat LLMs. + + For example, to use the Anthropic Claude model, simply initialize the `AmazonBedrockChatGenerator` with the + 'anthropic.claude-v2' model name. ```python from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator @@ -42,7 +43,7 @@ class AmazonBedrockChatGenerator: ``` If you prefer non-streaming mode, simply remove the `streaming_callback` parameter, capture the return value of the - component's run method and the AmazonBedrockChatGenerator will return the response in a non-streaming mode. + component's run method and the `AmazonBedrockChatGenerator` will return the response in a non-streaming mode. """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { @@ -65,14 +66,14 @@ def __init__( streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ - Initializes the AmazonBedrockChatGenerator with the provided parameters. The parameters are passed to the + Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the Amazon Bedrock client. Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded automatically from the environment or the AWS configuration file and do not need to be provided explicitly via the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, - and `aws_region_name`. + and `aws_region_name`. :param model: The model to use for generation. The model must be available in Amazon Bedrock. The model has to be specified in the format outlined in the Amazon Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). @@ -133,6 +134,15 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: self.streaming_callback = streaming_callback def invoke(self, *args, **kwargs): + """ + Invokes the Amazon Bedrock LLM with the given parameters. The parameters are passed to the Amazon Bedrock + client. + + :param args: The positional arguments passed to the generator. + :param kwargs: The keyword arguments passed to the generator. + :return: List of `ChatMessage` generated by LLM. + """ + kwargs = kwargs.copy() messages: List[ChatMessage] = kwargs.pop("messages", []) # check if the prompt is a list of ChatMessage objects @@ -166,12 +176,26 @@ def invoke(self, *args, **kwargs): return responses - @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) + @component.output_types(replies=List[ChatMessage]) def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. + + :param messages: The messages to generate a response to. + :param generation_kwargs: Additional generation keyword arguments passed to the model. + :return: A dictionary with the following keys: + - `replies`: The generated List of `ChatMessage` objects. + """ return {"replies": self.invoke(messages=messages, **(generation_kwargs or {}))} @classmethod def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]: + """ + Returns the model adapter for the given model. + + :param model: The model to get the adapter for. + :return: The model adapter for the given model, or None if the model is not supported. + """ for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): if re.fullmatch(pattern, model): return adapter @@ -179,8 +203,10 @@ def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter] def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary. - :return: The serialized component as a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ return default_to_dict( self, @@ -198,9 +224,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": """ - Deserialize this component from a dictionary. - :param data: The dictionary representation of this component. - :return: The deserialized component instance. + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 9c3e157cb..706d29c98 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -33,21 +33,19 @@ @component class AmazonBedrockGenerator: """ - Generator based on a Hugging Face model. - This component provides an interface to generate text using a Hugging Face model that runs locally. + `AmazonBedrockGenerator` enables text generation via Amazon Bedrock hosted LLMs. + + For example, to use the Anthropic Claude model, simply initialize the `AmazonBedrockGenerator` with the + 'anthropic.claude-v2' model name. Provide AWS credentials either via local AWS profile or directly via + `aws_access_key_id`, `aws_secret_access_key`, `aws_session_token`, and `aws_region_name` parameters. Usage example: ```python - from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator + from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator generator = AmazonBedrockGenerator( - model="anthropic.claude-v2", - max_length=99, - aws_access_key_id="...", - aws_secret_access_key="...", - aws_session_token="...", - aws_profile_name="...", - aws_region_name="..." + model="anthropic.claude-v2", + max_length=99 ) print(generator.run("Who is the best American actor?")) @@ -75,6 +73,19 @@ def __init__( max_length: Optional[int] = 100, **kwargs, ): + """ + Create a new `AmazonBedrockGenerator` instance. + + :param model: The name of the model to use. + :param aws_access_key_id: The AWS access key ID. + :param aws_secret_access_key: The AWS secret access key. + :param aws_session_token: The AWS session token. + :param aws_region_name: The AWS region name. + :param aws_profile_name: The AWS profile name. + :param max_length: The maximum length of the generated text. + :param kwargs: Additional keyword arguments to be passed to the model. + + """ if not model: msg = "'model' cannot be None or empty string" raise ValueError(msg) @@ -126,6 +137,13 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: + """ + Ensures that the prompt and answer token lengths together are within the model_max_length specified during + the initialization of the component. + + :param prompt: The prompt to be sent to the model. + :return: The resized prompt. + """ # the prompt for this model will be of the type str if isinstance(prompt, List): msg = ( @@ -148,6 +166,13 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union return str(resize_info["resized_prompt"]) def invoke(self, *args, **kwargs): + """ + Invokes the model with the given prompt. + + :param args: Additional positional arguments passed to the generator. + :param kwargs: Additional keyword arguments passed to the generator. + :return: A list of generated responses (strings). + """ kwargs = kwargs.copy() prompt: str = kwargs.pop("prompt", None) stream: bool = kwargs.get("stream", self.model_adapter.model_kwargs.get("stream", False)) @@ -193,12 +218,26 @@ def invoke(self, *args, **kwargs): return responses - @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) + @component.output_types(replies=List[str]) def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Generates a list of string response to the given prompt. + + :param prompt: The prompt to generate a response for. + :param generation_kwargs: Additional keyword arguments passed to the generator. + :return: A dictionary with the following keys: + - `replies`: A list of generated responses (strings). + """ return {"replies": self.invoke(prompt=prompt, **(generation_kwargs or {}))} @classmethod def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]: + """ + Gets the model adapter for the given model. + + :param model: The model name. + :return: The model adapter class, or None if no adapter is found. + """ for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): if re.fullmatch(pattern, model): return adapter @@ -206,8 +245,10 @@ def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]: def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary. - :return: The serialized component as a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ return default_to_dict( self, @@ -223,9 +264,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator": """ - Deserialize this component from a dictionary. - :param data: The dictionary representation of this component. - :return: The deserialized component instance. + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ deserialize_secrets_inplace( data["init_parameters"], diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 0f7bced89..9ba4d5534 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -173,6 +173,18 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body + @pytest.mark.integration + def test_get_responses(self) -> None: + adapter = AnthropicClaudeChatAdapter(generation_kwargs={}) + response_body = {"completion": "This is a single response."} + expected_response = "This is a single response." + response_message = adapter.get_responses(response_body) + # assert that the type of each item in the list is a ChatMessage + for message in response_message: + assert isinstance(message, ChatMessage) + + assert response_message == [ChatMessage.from_assistant(expected_response)] + class TestMetaLlama2ChatAdapter: @pytest.mark.integration @@ -221,4 +233,8 @@ def test_get_responses(self) -> None: response_body = {"generation": "This is a single response."} expected_response = "This is a single response." response_message = adapter.get_responses(response_body) + # assert that the type of each item in the list is a ChatMessage + for message in response_message: + assert isinstance(message, ChatMessage) + assert response_message == [ChatMessage.from_assistant(expected_response)]