Skip to content

Commit

Permalink
feat: Add Mistral Amazon Bedrock support (#632)
Browse files Browse the repository at this point in the history
* Add mistral model support
  • Loading branch information
vblagoje authored Apr 16, 2024
1 parent b116f4d commit 0307c49
Show file tree
Hide file tree
Showing 6 changed files with 422 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,55 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
return chunk.get("completion", "")


class MistralAdapter(BedrockModelAdapter):
"""
Adapter for the Mistral models.
"""

def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
"""
Prepares the body for the Mistral model
:param prompt: The prompt to be sent to the model.
:param inference_kwargs: Additional keyword arguments passed to the handler.
:returns: A dictionary with the following keys:
- `prompt`: The prompt to be sent to the model.
- specified inference parameters.
"""
default_params: Dict[str, Any] = {
"max_tokens": self.max_length,
"stop": [],
"temperature": None,
"top_p": None,
"top_k": None,
}
params = self._get_params(inference_kwargs, default_params)
# Add the instruction tag to the prompt if it's not already there
formatted_prompt = f"<s>[INST] {prompt} [/INST]" if "INST" not in prompt else prompt
return {"prompt": formatted_prompt, **params}

def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
"""
Extracts the responses from the Amazon Bedrock response.
:param response_body: The response body from the Amazon Bedrock request.
:returns: A list of string responses.
"""
return [output.get("text", "") for output in response_body.get("outputs", [])]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:returns: A string token.
"""
chunk_list = chunk.get("outputs", [])
if chunk_list:
return chunk_list[0].get("text", "")
return ""


class CohereCommandAdapter(BedrockModelAdapter):
"""
Adapter for the Cohere Command model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,160 @@ def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]:
return {"content": [{"type": "text", "text": m.content}], "role": m.role.value}


class MistralChatAdapter(BedrockModelChatAdapter):
"""
Model adapter for the Mistral chat model.
"""

chat_template = """
{% if messages[0]['role'] == 'system' %}
{% set loop_messages = messages[1:] %}
{% set system_message = messages[0]['content'] %}
{% else %}
{% set loop_messages = messages %}
{% set system_message = false %}
{% endif %}
{{bos_token}}
{% for message in loop_messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 and system_message != false %}
{% set content = system_message + '\n' + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{ '[INST] ' + content.strip() + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ content.strip() + eos_token }}
{% endif %}
{% endfor %}
"""
chat_template = "".join(line.strip() for line in chat_template.splitlines())

# the above template was designed to match https://docs.mistral.ai/models/#chat-template
# and to support system messages, otherwise we could use the default mistral chat template
# available on HF infrastructure

# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
ALLOWED_PARAMS: ClassVar[List[str]] = [
"max_tokens",
"safe_prompt",
"random_seed",
"temperature",
"top_p",
]

def __init__(self, generation_kwargs: Dict[str, Any]):
"""
Initializes the Mistral chat adapter.
:param generation_kwargs: The generation kwargs.
"""
super().__init__(generation_kwargs)

# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
# Mistral has a limit of at least 32000 tokens
model_max_length = self.generation_kwargs.pop("model_max_length", 32000)

# Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer
# a) we should get good estimates for the prompt length
# b) we can use apply_chat_template with the template above to delineate ChatMessages
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
self.prompt_handler = DefaultPromptHandler(
tokenizer=tokenizer,
model_max_length=model_max_length,
max_length=self.generation_kwargs.get("max_gen_len") or 512,
)

def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]:
"""
Prepares the body for the Mistral request.
:param messages: The chat messages to package into the request.
:param inference_kwargs: Additional inference kwargs to use.
:returns: The prepared body.
"""
default_params = {
"max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required
}
# replace stop_words from inference_kwargs with stop, as this is Mistral specific parameter
stop_words = inference_kwargs.pop("stop_words", [])
if stop_words:
inference_kwargs["stop"] = stop_words
params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS)
body = {"prompt": self.prepare_chat_messages(messages=messages), **params}
return body

def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
"""
Prepares the chat messages for the Mistral request.
:param messages: The chat messages to prepare.
:returns: The prepared chat messages as a string.
"""
# it would be great to use the default mistral chat template, but it doesn't support system messages
# the class variable defined chat_template is a workaround to support system messages
# default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
# but we'll use our custom chat template
prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
conversation=[self.to_openai_format(m) for m in messages], tokenize=False, chat_template=self.chat_template
)
return self._ensure_token_limit(prepared_prompt)

def to_openai_format(self, m: ChatMessage) -> Dict[str, Any]:
"""
Convert the message to the format expected by OpenAI's Chat API.
See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details.
:returns: A dictionary with the following key:
- `role`
- `content`
- `name` (optional)
"""
msg = {"role": m.role.value, "content": m.content}
if m.name:
msg["name"] = m.name
return msg

def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated.
:param prompt: The prompt to check.
:returns: A dictionary containing the resized prompt and additional information.
"""
return self.prompt_handler(prompt)

def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Extracts the messages from the response body.
:param response_body: The response body.
:return: The extracted ChatMessage list.
"""
messages: List[ChatMessage] = []
responses = response_body.get("outputs", [])
for response in responses:
meta = {k: v for k, v in response.items() if k not in ["text"]}
messages.append(ChatMessage.from_assistant(response["text"], meta=meta))
return messages

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:returns: The extracted token.
"""
response_chunk = chunk.get("outputs", [])
if response_chunk:
return response_chunk[0].get("text", "")
return ""


class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
"""
Model adapter for the Meta Llama 2 models.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from haystack_integrations.common.amazon_bedrock.utils import get_aws_session

from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter
from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter, MistralChatAdapter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,6 +50,7 @@ class AmazonBedrockChatGenerator:
SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = {
r"anthropic.claude.*": AnthropicClaudeChatAdapter,
r"meta.llama2.*": MetaLlama2ChatAdapter,
r"mistral.*": MistralChatAdapter,
}

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BedrockModelAdapter,
CohereCommandAdapter,
MetaLlama2ChatAdapter,
MistralAdapter,
)
from .handlers import (
DefaultPromptHandler,
Expand Down Expand Up @@ -58,6 +59,7 @@ class AmazonBedrockGenerator:
r"cohere.command.*": CohereCommandAdapter,
r"anthropic.claude.*": AnthropicClaudeAdapter,
r"meta.llama2.*": MetaLlama2ChatAdapter,
r"mistral.*": MistralAdapter,
}

def __init__(
Expand Down Expand Up @@ -124,8 +126,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:

# Truncate prompt if prompt tokens > model_max_length-max_length
# (max_length is the length of the generated text)
# It is hard to determine which tokenizer to use for the SageMaker model
# so we use GPT2 tokenizer which will likely provide good token count approximation
# we use GPT2 tokenizer which will likely provide good token count approximation
self.prompt_handler = DefaultPromptHandler(
tokenizer="gpt2",
model_max_length=model_max_length,
Expand Down
80 changes: 80 additions & 0 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
AnthropicClaudeChatAdapter,
BedrockModelChatAdapter,
MetaLlama2ChatAdapter,
MistralChatAdapter,
)

KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"
MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"]
MISTRAL_MODELS = [
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
]


def test_to_dict(mock_boto3_session):
Expand Down Expand Up @@ -176,6 +182,80 @@ def test_prepare_body_with_custom_inference_params(self) -> None:
assert body == expected_body


class TestMistralAdapter:
def test_prepare_body_with_default_params(self) -> None:
layer = MistralChatAdapter(generation_kwargs={})
prompt = "Hello, how are you?"
expected_body = {
"max_tokens": 512,
"prompt": "<s>[INST] Hello, how are you? [/INST]",
}

body = layer.prepare_body([ChatMessage.from_user(prompt)])

assert body == expected_body

def test_prepare_body_with_custom_inference_params(self) -> None:
layer = MistralChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4})
prompt = "Hello, how are you?"
expected_body = {
"prompt": "<s>[INST] Hello, how are you? [/INST]",
"max_tokens": 512,
"temperature": 0.7,
"top_p": 0.8,
}

body = layer.prepare_body([ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69)

assert body == expected_body

def test_mistral_chat_template_correct_order(self):
layer = MistralChatAdapter(generation_kwargs={})
layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_assistant("B"), ChatMessage.from_user("C")])
layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_user("B"), ChatMessage.from_assistant("C")])

def test_mistral_chat_template_incorrect_order(self):
layer = MistralChatAdapter(generation_kwargs={})
try:
layer.prepare_body([ChatMessage.from_assistant("B"), ChatMessage.from_assistant("C")])
msg = "Expected TemplateError"
raise AssertionError(msg)
except Exception as e:
assert "Conversation roles must alternate user/assistant/" in str(e)

try:
layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_user("B")])
msg = "Expected TemplateError"
raise AssertionError(msg)
except Exception as e:
assert "Conversation roles must alternate user/assistant/" in str(e)

try:
layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_system("B")])
msg = "Expected TemplateError"
raise AssertionError(msg)
except Exception as e:
assert "Conversation roles must alternate user/assistant/" in str(e)

@pytest.mark.parametrize("model_name", MISTRAL_MODELS)
@pytest.mark.integration
def test_default_inference_params(self, model_name, chat_messages):
client = AmazonBedrockChatGenerator(model=model_name)
response = client.run(chat_messages)

assert "replies" in response, "Response does not contain 'replies' key"
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"


@pytest.fixture
def chat_messages():
messages = [
Expand Down
Loading

0 comments on commit 0307c49

Please sign in to comment.