From 653a8b3cf8e8b104783a6e4215d3bd84cfb82b61 Mon Sep 17 00:00:00 2001 From: Marek Pavelka Date: Thu, 19 Sep 2024 20:39:05 +0200 Subject: [PATCH] fix: raise botocore exceptions (#193) Motivation: Old implementation is swallowing botocore exceptions and translating it into `ValueError` exception which complicates handling different errors, some botocore exceptions could be just retried in the user code, whereas for others the user wants to just trigger alert and re-raise or raise some specific exception. Raising `ValueError` for all botocore exceptions just make it impossible. --- .../aws/langchain_aws/chat_models/bedrock_converse.py | 2 +- libs/aws/langchain_aws/embeddings/bedrock.py | 4 +++- libs/aws/langchain_aws/llms/bedrock.py | 11 +++++++++-- libs/aws/langchain_aws/llms/sagemaker_endpoint.py | 11 +++++++++-- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index adbf0b09..9cc7e989 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -705,7 +705,7 @@ def _messages_to_bedrock( ) bedrock_messages.append(curr) else: - raise ValueError() + raise ValueError(f"Unsupported message type {type(msg)}") return bedrock_messages, bedrock_system diff --git a/libs/aws/langchain_aws/embeddings/bedrock.py b/libs/aws/langchain_aws/embeddings/bedrock.py index fe6fced1..34e57f5b 100644 --- a/libs/aws/langchain_aws/embeddings/bedrock.py +++ b/libs/aws/langchain_aws/embeddings/bedrock.py @@ -1,5 +1,6 @@ import asyncio import json +import logging import os from typing import Any, Dict, List, Optional @@ -154,7 +155,8 @@ def _embedding_func(self, text: str) -> List[float]: return response_body.get("embedding") except Exception as e: - raise ValueError(f"Error raised by inference endpoint: {e}") + logging.error(f"Error raised by inference endpoint: {e}") + raise e def _normalize_vector(self, embeddings: List[float]) -> List[float]: """Normalize the embedding to a unit vector.""" diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 48cceb33..84487002 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -1,5 +1,6 @@ import asyncio import json +import logging import os import warnings from abc import ABC @@ -788,7 +789,10 @@ def _prepare_input_and_invoke( ) = LLMInputOutputAdapter.prepare_output(provider, response).values() except Exception as e: - raise ValueError(f"Error raised by bedrock service: {e}") + logging.error(f"Error raised by bedrock service: {e}") + if run_manager is not None: + run_manager.on_llm_error(e) + raise e if stop is not None: text = enforce_stop_tokens(text, stop) @@ -908,7 +912,10 @@ def _prepare_input_and_invoke_stream( response = self.client.invoke_model_with_response_stream(**request_options) except Exception as e: - raise ValueError(f"Error raised by bedrock service: {e}") + logging.error(f"Error raised by bedrock service: {e}") + if run_manager is not None: + run_manager.on_llm_error(e) + raise e for chunk in LLMInputOutputAdapter.prepare_output_stream( provider, diff --git a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py index 5f47d64e..bc0a480d 100644 --- a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py +++ b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py @@ -1,6 +1,7 @@ """Sagemaker InvokeEndpoint API.""" import io +import logging import re from abc import abstractmethod from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union @@ -336,7 +337,10 @@ def _stream( run_manager.on_llm_new_token(chunk.text) except Exception as e: - raise ValueError(f"Error raised by streaming inference endpoint: {e}") + logging.error(f"Error raised by streaming inference endpoint: {e}") + if run_manager is not None: + run_manager.on_llm_error(e) + raise e def _call( self, @@ -382,7 +386,10 @@ def _call( **_endpoint_kwargs, ) except Exception as e: - raise ValueError(f"Error raised by inference endpoint: {e}") + logging.error(f"Error raised by inference endpoint: {e}") + if run_manager is not None: + run_manager.on_llm_error(e) + raise e text = self.content_handler.transform_output(response["Body"]) if stop is not None: