Skip to content

Commit

Permalink
improved error messagin, linting and formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-farmer committed Dec 3, 2024
1 parent 322d7a7 commit 9a27307
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 21 deletions.
12 changes: 11 additions & 1 deletion libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
get_token_ids_anthropic,
)


def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
Expand Down Expand Up @@ -834,7 +835,16 @@ def _as_converse(self) -> ChatBedrockConverse:
kwargs = {
k: v
for k, v in (self.model_kwargs or {}).items()
if k in ("stop", "stop_sequences", "max_tokens", "temperature", "top_p", "additional_model_request_fields", "additional_model_response_field_paths")
if k
in (
"stop",
"stop_sequences",
"max_tokens",
"temperature",
"top_p",
"additional_model_request_fields",
"additional_model_response_field_paths",
)
}
if self.max_tokens:
kwargs["max_tokens"] = self.max_tokens
Expand Down
23 changes: 11 additions & 12 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ def set_disable_streaming(cls, values: Dict) -> Any:
if "disable_streaming" not in values:
values["disable_streaming"] = (
False
if values["provider"] in ["anthropic", "cohere"] or
(values["provider"] == "amazon" and "nova" in model_id)
if values["provider"] in ["anthropic", "cohere"]
or (values["provider"] == "amazon" and "nova" in model_id)
else "tool_calling"
)
return values
Expand Down Expand Up @@ -741,9 +741,7 @@ def _extract_response_metadata(response: Dict[str, Any]) -> Dict[str, Any]:


def _parse_response(response: Dict[str, Any]) -> AIMessage:
lc_content = _bedrock_to_lc(
response.pop("output")["message"]["content"]
)
lc_content = _bedrock_to_lc(response.pop("output")["message"]["content"])
tool_calls = _extract_tool_calls(lc_content)
usage = UsageMetadata(_camel_to_snake_keys(response.pop("usage"))) # type: ignore[misc]
return AIMessage(
Expand Down Expand Up @@ -871,9 +869,7 @@ def _lc_content_to_bedrock(
{
"video": {
"format": block["source"]["mediaType"].split("/")[1],
"source": {
"s3Location": block["source"]["data"]
},
"source": {"s3Location": block["source"]["data"]},
}
}
)
Expand Down Expand Up @@ -943,7 +939,9 @@ def _bedrock_to_lc(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"source": {
"media_type": f"video/{block['video']['format']}",
"type": "base64",
"data": _bytes_to_b64_str(block["video"]["source"]["bytes"]),
"data": _bytes_to_b64_str(
block["video"]["source"]["bytes"]
),
},
}
)
Expand All @@ -960,7 +958,7 @@ def _bedrock_to_lc(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
)
elif "document" in block:
# Request syntax assumes bedrock format; returning in same bedrock format
lc_content.append({"type": "document",**block})
lc_content.append({"type": "document", **block})
elif "tool_result" in block:
lc_content.append(
{
Expand Down Expand Up @@ -1135,14 +1133,15 @@ def _format_openai_image_url(image_url: str) -> Dict:
match = re.match(regex, image_url)
if match is None:
raise ValueError(
"Bedrock does not currently support OpenAI-format image URLs, only "
"The image URL provided is not supported. Expected image URL format is "
"base64-encoded images. Example: data:image/png;base64,'/9j/4AAQSk'..."
)
return {
"format": match.group("media_type"),
"source": {"bytes": _b64str_to_bytes(match.group("data"))},
}


def _format_openai_video_url(video_url: str) -> Dict:
"""
Formats a video of format data:video/mp4;base64,{b64_string}
Expand All @@ -1154,7 +1153,7 @@ def _format_openai_video_url(video_url: str) -> Dict:
match = re.match(regex, video_url)
if match is None:
raise ValueError(
"Bedrock does not currently support OpenAI-format video URLs, only "
"The video URL provided is not supported. Expected video URL format is "
"base64-encoded video. Example: data:video/mp4;base64,'/9j/4AAQSk'..."
)
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def standard_chat_model_params(self) -> dict:
def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None:
super().test_tool_message_histories_list_content(model)


class TestBedrockNovaStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
Expand All @@ -72,6 +73,7 @@ def test_structured_few_shot_examples(self, model: BaseChatModel) -> None:
def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None:
super().test_tool_message_histories_list_content(model)


class TestBedrockCohereStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
Expand Down
5 changes: 4 additions & 1 deletion libs/aws/tests/unit_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,9 @@ def test_beta_use_converse_api() -> None:
llm = ChatBedrock(model_id="nova.foo", region_name="us-west-2") # type: ignore[call-arg]
assert llm.beta_use_converse_api

llm = ChatBedrock(model="nova.foo", region_name="us-west-2", beta_use_converse_api=False)
llm = ChatBedrock(
model="nova.foo", region_name="us-west-2", beta_use_converse_api=False
)
assert not llm.beta_use_converse_api

llm = ChatBedrock(model="foo", region_name="us-west-2", beta_use_converse_api=True)
Expand All @@ -441,6 +443,7 @@ def test_beta_use_converse_api() -> None:
llm = ChatBedrock(model="foo", region_name="us-west-2", beta_use_converse_api=False)
assert not llm.beta_use_converse_api


@pytest.mark.parametrize(
"model_id, provider, expected_provider, expectation",
[
Expand Down
23 changes: 16 additions & 7 deletions libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,19 @@ def test_amazon_bind_tools_tool_choice() -> None:
chat_model.bind_tools(
[GetWeather], tool_choice={"tool": {"name": "GetWeather"}}
)

with pytest.raises(ValueError):
chat_model.bind_tools([GetWeather], tool_choice="GetWeather")

with pytest.raises(ValueError):
chat_model.bind_tools([GetWeather], tool_choice="any")

chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice="auto")
assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == {
"auto": {}
}


def test__messages_to_bedrock() -> None:
messages = [
SystemMessage(content="sys1"),
Expand Down Expand Up @@ -224,7 +225,7 @@ def test__messages_to_bedrock() -> None:
}
]
),
HumanMessage(
HumanMessage(
content=[
{
"type": "video",
Expand All @@ -236,7 +237,10 @@ def test__messages_to_bedrock() -> None:
content=[
{
"type": "video",
"video": {"format": "mp4", "source": { "s3Location": { "uri": "s3_url"}}},
"video": {
"format": "mp4",
"source": {"s3Location": {"uri": "s3_url"}},
},
}
]
),
Expand Down Expand Up @@ -317,7 +321,12 @@ def test__messages_to_bedrock() -> None:
},
{"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}},
{"video": {"format": "mp4", "source": {"bytes": b"video_data"}}},
{"video": {"format": "mp4", "source": { "s3Location": { "uri": "s3_url"}}}},
{
"video": {
"format": "mp4",
"source": {"s3Location": {"uri": "s3_url"}},
}
},
],
},
]
Expand Down Expand Up @@ -356,7 +365,7 @@ def test__bedrock_to_lc() -> None:
}
},
{"video": {"format": "mp4", "source": {"bytes": b"video_data"}}},
{"video": {"format": "mp4", "source": {"s3Location": { "uri": "video_data"}}}},
{"video": {"format": "mp4", "source": {"s3Location": {"uri": "video_data"}}}},
]
expected = [
{"type": "text", "text": "text1"},
Expand Down Expand Up @@ -404,7 +413,7 @@ def test__bedrock_to_lc() -> None:
"source": {
"type": "s3Location",
"media_type": "video/mp4",
"data": {"uri": "video_data"},
"data": {"uri": "video_data"},
},
},
]
Expand Down

0 comments on commit 9a27307

Please sign in to comment.