Skip to content

Commit

Permalink
aws[patch]: fix ChatBedrockConverse init (#200)
Browse files Browse the repository at this point in the history
Make sure supports_tool_choice_values is set whether or not client is
passed in
  • Loading branch information
baskaryan authored Sep 19, 2024
1 parent e38050c commit cf435cf
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,46 +376,6 @@ def set_disable_streaming(cls, values: Dict) -> Any:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""
if self.client is not None:
return self

try:
if self.credentials_profile_name is not None:
session = boto3.Session(profile_name=self.credentials_profile_name)
else:
session = boto3.Session()
except ValueError as e:
raise ValueError(f"Error raised by bedrock service: {e}")
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error: {e}"
) from e

self.region_name = (
self.region_name or os.getenv("AWS_DEFAULT_REGION") or session.region_name
)

client_params = {}
if self.region_name:
client_params["region_name"] = self.region_name
if self.endpoint_url:
client_params["endpoint_url"] = self.endpoint_url
if self.config:
client_params["config"] = self.config

try:
self.client = session.client("bedrock-runtime", **client_params)
except ValueError as e:
raise ValueError(f"Error raised by bedrock service: {e}")
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error: {e}"
) from e

# As of 08/05/24 only claude-3 and mistral-large models support tool choice:
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
if self.supports_tool_choice_values is None:
Expand All @@ -426,6 +386,46 @@ def validate_environment(self) -> Self:
else:
self.supports_tool_choice_values = ()

if self.client is None:
try:
if self.credentials_profile_name is not None:
session = boto3.Session(profile_name=self.credentials_profile_name)
else:
session = boto3.Session()
except ValueError as e:
raise ValueError(f"Error raised by bedrock service: {e}")
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error: {e}"
) from e

self.region_name = (
self.region_name
or os.getenv("AWS_DEFAULT_REGION")
or session.region_name
)

client_params = {}
if self.region_name:
client_params["region_name"] = self.region_name
if self.endpoint_url:
client_params["endpoint_url"] = self.endpoint_url
if self.config:
client_params["config"] = self.config

try:
self.client = session.client("bedrock-runtime", **client_params)
except ValueError as e:
raise ValueError(f"Error raised by bedrock service: {e}")
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error: {e}"
) from e

return self

def _generate(
Expand Down

0 comments on commit cf435cf

Please sign in to comment.