From afda911cb981540f07b95cc1a4d5728387bc40a4 Mon Sep 17 00:00:00 2001 From: NAPTlME Date: Mon, 22 Apr 2024 01:45:41 -0500 Subject: [PATCH] Linting/Formatting/codespell changes --- libs/aws/README.md | 2 +- libs/aws/langchain_aws/chat_models/bedrock.py | 23 +++-- libs/aws/langchain_aws/llms/bedrock.py | 84 ++++++++++++------- 3 files changed, 71 insertions(+), 38 deletions(-) diff --git a/libs/aws/README.md b/libs/aws/README.md index 4e391286..baaee44c 100644 --- a/libs/aws/README.md +++ b/libs/aws/README.md @@ -54,7 +54,7 @@ retriever = AmazonKendraRetriever( retriever.get_relevant_documents(query="What is the meaning of life?") ``` -`AmazonKnowlegeBasesRetriever` class provides a retriever to connect with Amazon Knowlege Bases. +`AmazonKnowledgeBasesRetriever` class provides a retriever to connect with Amazon Knowledge Bases. ```python from langchain_aws import AmazonKnowledgeBasesRetriever diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 5c82481c..ff66665b 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -18,7 +18,10 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import Extra -from langchain_aws.llms.bedrock import BedrockBase, _combine_generation_info_for_llm_result +from langchain_aws.llms.bedrock import ( + BedrockBase, + _combine_generation_info_for_llm_result, +) from langchain_aws.utils import ( get_num_tokens_anthropic, get_token_ids_anthropic, @@ -321,7 +324,11 @@ def _stream( **kwargs, ): delta = chunk.text - yield ChatGenerationChunk(message=AIMessageChunk(content=delta, response_metadata=chunk.generation_info)) + yield ChatGenerationChunk( + message=AIMessageChunk( + content=delta, response_metadata=chunk.generation_info + ) + ) def _generate( self, @@ -332,13 +339,17 @@ def _generate( ) -> ChatResult: completion = "" llm_output: Dict[str, Any] = {} - provider_stop_reason_code = self.provider_stop_reason_key_map.get(self._get_provider(), "stop_reason") + provider_stop_reason_code = self.provider_stop_reason_key_map.get( + self._get_provider(), "stop_reason" + ) if self.streaming: response_metadata: list[Dict[str, Any]] = [] for chunk in self._stream(messages, stop, run_manager, **kwargs): completion += chunk.text response_metadata.append(chunk.message.response_metadata) - llm_output = _combine_generation_info_for_llm_result(response_metadata, provider_stop_reason_code) + llm_output = _combine_generation_info_for_llm_result( + response_metadata, provider_stop_reason_code + ) else: provider = self._get_provider() prompt, system, formatted_messages = None, None, None @@ -369,9 +380,7 @@ def _generate( return ChatResult( generations=[ ChatGeneration( - message=AIMessage( - content=completion, additional_kwargs=llm_output - ) + message=AIMessage(content=completion, additional_kwargs=llm_output) ) ], llm_output=llm_output, diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index d0a13216..0f5e493e 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -80,10 +80,7 @@ def _human_assistant_format(input_text: str) -> str: def _stream_response_to_generation_chunk( - stream_response: Dict[str, Any], - provider, - output_key, - messages_api + stream_response: Dict[str, Any], provider, output_key, messages_api ) -> GenerationChunk: """Convert a stream response to a generation chunk.""" if messages_api: @@ -92,7 +89,7 @@ def _stream_response_to_generation_chunk( usage_info = stream_response.get("message", {}).get("usage", None) generation_info = {"usage": usage_info} return GenerationChunk(text="", generation_info=generation_info) - case "content_block_delta": + case "content_block_delta": if not stream_response["delta"]: return GenerationChunk(text="") return GenerationChunk( @@ -110,7 +107,7 @@ def _stream_response_to_generation_chunk( return None else: # chunk obj format varies with provider - generation_info = {k:v for k, v in stream_response.items() if k != output_key} + generation_info = {k: v for k, v in stream_response.items() if k != output_key} return GenerationChunk( text=( stream_response[output_key] @@ -119,8 +116,11 @@ def _stream_response_to_generation_chunk( ), generation_info=generation_info, ) - -def _combine_generation_info_for_llm_result(chunks_generation_info: list[Dict[str, Any]], provider_stop_code) -> Dict[str, Any]: + + +def _combine_generation_info_for_llm_result( + chunks_generation_info: list[Dict[str, Any]], provider_stop_code +) -> Dict[str, Any]: """ Returns usage and stop reason information with the intent to pack into an LLMResult Takes a list of GenerationChunks @@ -147,7 +147,9 @@ def _combine_generation_info_for_llm_result(chunks_generation_info: list[Dict[st # uses the last stop reason stop_reason = generation_info[provider_stop_code] - total_usage_info["total_tokens"] = total_usage_info["prompt_tokens"] + total_usage_info["completion_tokens"] + total_usage_info["total_tokens"] = ( + total_usage_info["prompt_tokens"] + total_usage_info["completion_tokens"] + ) return {"usage": total_usage_info, "stop_reason": stop_reason} @@ -235,7 +237,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict: "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, - "stop_reason": response_body["stop_reason"] + "stop_reason": response_body["stop_reason"], } @classmethod @@ -282,8 +284,12 @@ def prepare_output_stream( elif messages_api and (chunk_obj.get("type") == "message_stop"): return - - generation_chunk = _stream_response_to_generation_chunk(chunk_obj, provider=provider, output_key=output_key, messages_api=messages_api) + generation_chunk = _stream_response_to_generation_chunk( + chunk_obj, + provider=provider, + output_key=output_key, + messages_api=messages_api, + ) if generation_chunk: yield generation_chunk else: @@ -291,7 +297,11 @@ def prepare_output_stream( @classmethod async def aprepare_output_stream( - cls, provider: str, response: Any, stop: Optional[List[str]] = None, messages_api: bool = False + cls, + provider: str, + response: Any, + stop: Optional[List[str]] = None, + messages_api: bool = False, ) -> AsyncIterator[GenerationChunk]: stream = response.get("body") @@ -323,7 +333,12 @@ async def aprepare_output_stream( ): return - generation_chunk = _stream_response_to_generation_chunk(chunk_obj, provider=provider, output_key=output_key, messages_api=messages_api) + generation_chunk = _stream_response_to_generation_chunk( + chunk_obj, + provider=provider, + output_key=output_key, + messages_api=messages_api, + ) if generation_chunk: yield generation_chunk else: @@ -385,7 +400,7 @@ class BedrockBase(BaseLanguageModel, ABC): "amazon": "completionReason", "ai21": "finishReason", "cohere": "finish_reason", - "mistral": "stop_reason" + "mistral": "stop_reason", } guardrails: Optional[Mapping[str, Any]] = { @@ -603,10 +618,7 @@ def _prepare_input_and_invoke( if stop is not None: text = enforce_stop_tokens(text, stop) - llm_output = { - "usage": usage_info, - "stop_reason": stop_reason - } + llm_output = {"usage": usage_info, "stop_reason": stop_reason} # Verify and raise a callback error if any intervention occurs or a signal is # sent from a Bedrock service, @@ -620,8 +632,6 @@ def _prepare_input_and_invoke( ), **services_trace, ) - - return text, llm_output @@ -894,7 +904,9 @@ def _call( """ provider = self._get_provider() - provider_stop_reason_code = self.provider_stop_reason_key_map.get(provider, "stop_reason") + provider_stop_reason_code = self.provider_stop_reason_key_map.get( + provider, "stop_reason" + ) if self.streaming: all_chunks = [] @@ -907,16 +919,22 @@ def _call( if run_manager is not None: chunks_generation_info = [x.generation_info for x in all_chunks] - llm_output = _combine_generation_info_for_llm_result(chunks_generation_info, provider_stop_code=provider_stop_reason_code) - run_manager.on_llm_end(LLMResult(generations=[all_chunks], llm_output=llm_output)) - + llm_output = _combine_generation_info_for_llm_result( + chunks_generation_info, provider_stop_code=provider_stop_reason_code + ) + run_manager.on_llm_end( + LLMResult(generations=[all_chunks], llm_output=llm_output) + ) + return completion text, llm_output = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ) if run_manager is not None: - run_manager.on_llm_end(LLMResult(generations=[[Generation(text=text)]], llm_output=llm_output)) + run_manager.on_llm_end( + LLMResult(generations=[[Generation(text=text)]], llm_output=llm_output) + ) return text @@ -972,7 +990,9 @@ async def _acall( raise ValueError("Streaming must be set to True for async operations. ") provider = self._get_provider() - provider_stop_reason_code = self.provider_stop_reason_key_map.get(provider, "stop_reason") + provider_stop_reason_code = self.provider_stop_reason_key_map.get( + provider, "stop_reason" + ) chunks = [ chunk @@ -982,9 +1002,13 @@ async def _acall( ] if run_manager is not None: - chunks_generation_info = [x.generation_info for x in chunks] - llm_output = _combine_generation_info_for_llm_result(chunks_generation_info, provider_stop_code=provider_stop_reason_code) - run_manager.on_llm_end(LLMResult(generations=[chunks], llm_output=llm_output)) + chunks_generation_info = [x.generation_info for x in chunks] + llm_output = _combine_generation_info_for_llm_result( + chunks_generation_info, provider_stop_code=provider_stop_reason_code + ) + run_manager.on_llm_end( + LLMResult(generations=[chunks], llm_output=llm_output) + ) return "".join([chunk.text for chunk in chunks])