Skip to content

Commit

Permalink
Docs updates + two additional unit tests (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje authored Mar 2, 2024
1 parent b347b30 commit a430cd6
Show file tree
Hide file tree
Showing 5 changed files with 349 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
class BedrockModelAdapter(ABC):
"""
Base class for Amazon Bedrock model adapters.
Each subclass of this class is designed to address the unique specificities of a particular LLM it adapts,
focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted LLMs.
"""

def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> None:
Expand All @@ -16,15 +19,34 @@ def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> N

@abstractmethod
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
"""Prepares the body for the Amazon Bedrock request."""
"""
Prepares the body for the Amazon Bedrock request.
Each subclass should implement this method to prepare the request body for the specific model.
:param prompt: The prompt to be sent to the model.
:param inference_kwargs: Additional keyword arguments passed to the handler.
:return: A dictionary containing the body for the request.
"""

def get_responses(self, response_body: Dict[str, Any]) -> List[str]:
"""Extracts the responses from the Amazon Bedrock response."""
"""
Extracts the responses from the Amazon Bedrock response.
:param response_body: The response body from the Amazon Bedrock request.
:return: A list of responses.
"""
completions = self._extract_completions_from_response(response_body)
responses = [completion.lstrip() for completion in completions]
return responses

def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]:
"""
Extracts the responses from the Amazon Bedrock streaming response.
:param stream: The streaming response from the Amazon Bedrock request.
:param stream_handler: The handler for the streaming response.
:return: A list of string responses.
"""
tokens: List[str] = []
for event in stream:
chunk = event.get("chunk")
Expand All @@ -40,6 +62,9 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str
Merges the default params with the inference kwargs and model kwargs.
Includes param if it's in kwargs or its default is not None (i.e. it is actually defined).
:param inference_kwargs: The inference kwargs.
:param default_params: The default params.
:return: A dictionary containing the merged params.
"""
kwargs = self.model_kwargs.copy()
kwargs.update(inference_kwargs)
Expand All @@ -51,19 +76,34 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str

@abstractmethod
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
"""Extracts the responses from the Amazon Bedrock response."""
"""
Extracts the responses from the Amazon Bedrock response.
:param response_body: The response body from the Amazon Bedrock request.
:return: A list of string responses.
"""

@abstractmethod
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""Extracts the token from a streaming chunk."""
"""
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: A string token.
"""


class AnthropicClaudeAdapter(BedrockModelAdapter):
"""
Model adapter for the Anthropic's Claude model.
Adapter for the Anthropic Claude models.
"""

def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
"""
Prepares the body for the Claude model
:param prompt: The prompt to be sent to the model.
"""
default_params = {
"max_tokens_to_sample": self.max_length,
"stop_sequences": ["\n\nHuman:"],
Expand All @@ -77,18 +117,37 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
return body

def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
"""
Extracts the responses from the Amazon Bedrock response.
:param response_body: The response body from the Amazon Bedrock request.
:return: A list of string responses.
"""
return [response_body["completion"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: A string token.
"""
return chunk.get("completion", "")


class CohereCommandAdapter(BedrockModelAdapter):
"""
Model adapter for the Cohere's Command model.
Adapter for the Cohere Command model.
"""

def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
"""
Prepares the body for the Command model
:param prompt: The prompt to be sent to the model.
:param inference_kwargs: Additional keyword arguments passed to the handler.
:return: A dictionary containing the body for the request.
"""
default_params = {
"max_tokens": self.max_length,
"stop_sequences": None,
Expand All @@ -107,10 +166,22 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
return body

def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
"""
Extracts the responses from the Cohere Command model response.
:param response_body: The response body from the Amazon Bedrock request.
:return: A list of string responses.
"""
responses = [generation["text"] for generation in response_body["generations"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: A string token.
"""
return chunk.get("text", "")


Expand Down Expand Up @@ -146,10 +217,17 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:

class AmazonTitanAdapter(BedrockModelAdapter):
"""
Model adapter for Amazon's Titan models.
Adapter for Amazon's Titan models.
"""

def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
"""
Prepares the body for the Titan model
:param prompt: The prompt to be sent to the model.
:param inference_kwargs: Additional keyword arguments passed to the handler.
:return: A dictionary containing the body for the request.
"""
default_params = {
"maxTokenCount": self.max_length,
"stopSequences": None,
Expand All @@ -162,19 +240,38 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
return body

def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
"""
Extracts the responses from the Titan model response.
:param response_body: The response body for Titan model response.
:return: A list of string responses.
"""
responses = [result["outputText"] for result in response_body["results"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: A string token.
"""
return chunk.get("outputText", "")


class MetaLlama2ChatAdapter(BedrockModelAdapter):
"""
Model adapter for Meta's Llama 2 Chat models.
Adapter for Meta's Llama2 models.
"""

def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
"""
Prepares the body for the Llama2 model
:param prompt: The prompt to be sent to the model.
:param inference_kwargs: Additional keyword arguments passed to the handler.
:return: A dictionary containing the body for the request.
"""
default_params = {
"max_gen_len": self.max_length,
"temperature": None,
Expand All @@ -186,7 +283,19 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
return body

def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
"""
Extracts the responses from the Llama2 model response.
:param response_body: The response body from the Llama2 model request.
:return: A list of string responses.
"""
return [response_body["generation"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:return: A string token.
"""
return chunk.get("generation", "")
Loading

0 comments on commit a430cd6

Please sign in to comment.