From cf435cf0d1bf3b6b6f592ad199e2dc16355a7dab Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 19 Sep 2024 10:12:47 -0700 Subject: [PATCH] aws[patch]: fix ChatBedrockConverse init (#200) Make sure supports_tool_choice_values is set whether or not client is passed in --- .../chat_models/bedrock_converse.py | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index c5e80bae..322bbb8b 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -376,46 +376,6 @@ def set_disable_streaming(cls, values: Dict) -> Any: @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that AWS credentials to and python package exists in environment.""" - if self.client is not None: - return self - - try: - if self.credentials_profile_name is not None: - session = boto3.Session(profile_name=self.credentials_profile_name) - else: - session = boto3.Session() - except ValueError as e: - raise ValueError(f"Error raised by bedrock service: {e}") - except Exception as e: - raise ValueError( - "Could not load credentials to authenticate with AWS client. " - "Please check that credentials in the specified " - f"profile name are valid. Bedrock error: {e}" - ) from e - - self.region_name = ( - self.region_name or os.getenv("AWS_DEFAULT_REGION") or session.region_name - ) - - client_params = {} - if self.region_name: - client_params["region_name"] = self.region_name - if self.endpoint_url: - client_params["endpoint_url"] = self.endpoint_url - if self.config: - client_params["config"] = self.config - - try: - self.client = session.client("bedrock-runtime", **client_params) - except ValueError as e: - raise ValueError(f"Error raised by bedrock service: {e}") - except Exception as e: - raise ValueError( - "Could not load credentials to authenticate with AWS client. " - "Please check that credentials in the specified " - f"profile name are valid. Bedrock error: {e}" - ) from e - # As of 08/05/24 only claude-3 and mistral-large models support tool choice: # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html if self.supports_tool_choice_values is None: @@ -426,6 +386,46 @@ def validate_environment(self) -> Self: else: self.supports_tool_choice_values = () + if self.client is None: + try: + if self.credentials_profile_name is not None: + session = boto3.Session(profile_name=self.credentials_profile_name) + else: + session = boto3.Session() + except ValueError as e: + raise ValueError(f"Error raised by bedrock service: {e}") + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + f"profile name are valid. Bedrock error: {e}" + ) from e + + self.region_name = ( + self.region_name + or os.getenv("AWS_DEFAULT_REGION") + or session.region_name + ) + + client_params = {} + if self.region_name: + client_params["region_name"] = self.region_name + if self.endpoint_url: + client_params["endpoint_url"] = self.endpoint_url + if self.config: + client_params["config"] = self.config + + try: + self.client = session.client("bedrock-runtime", **client_params) + except ValueError as e: + raise ValueError(f"Error raised by bedrock service: {e}") + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + f"profile name are valid. Bedrock error: {e}" + ) from e + return self def _generate(