Skip to content

Commit

Permalink
Linting/Formatting/codespell changes
Browse files Browse the repository at this point in the history
  • Loading branch information
NAPTlME committed Apr 22, 2024
1 parent 1eeb60f commit afda911
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 38 deletions.
2 changes: 1 addition & 1 deletion libs/aws/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ retriever = AmazonKendraRetriever(
retriever.get_relevant_documents(query="What is the meaning of life?")
```

`AmazonKnowlegeBasesRetriever` class provides a retriever to connect with Amazon Knowlege Bases.
`AmazonKnowledgeBasesRetriever` class provides a retriever to connect with Amazon Knowledge Bases.

```python
from langchain_aws import AmazonKnowledgeBasesRetriever
Expand Down
23 changes: 16 additions & 7 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra

from langchain_aws.llms.bedrock import BedrockBase, _combine_generation_info_for_llm_result
from langchain_aws.llms.bedrock import (
BedrockBase,
_combine_generation_info_for_llm_result,
)
from langchain_aws.utils import (
get_num_tokens_anthropic,
get_token_ids_anthropic,
Expand Down Expand Up @@ -321,7 +324,11 @@ def _stream(
**kwargs,
):
delta = chunk.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta, response_metadata=chunk.generation_info))
yield ChatGenerationChunk(
message=AIMessageChunk(
content=delta, response_metadata=chunk.generation_info
)
)

def _generate(
self,
Expand All @@ -332,13 +339,17 @@ def _generate(
) -> ChatResult:
completion = ""
llm_output: Dict[str, Any] = {}
provider_stop_reason_code = self.provider_stop_reason_key_map.get(self._get_provider(), "stop_reason")
provider_stop_reason_code = self.provider_stop_reason_key_map.get(
self._get_provider(), "stop_reason"
)
if self.streaming:
response_metadata: list[Dict[str, Any]] = []
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
response_metadata.append(chunk.message.response_metadata)
llm_output = _combine_generation_info_for_llm_result(response_metadata, provider_stop_reason_code)
llm_output = _combine_generation_info_for_llm_result(
response_metadata, provider_stop_reason_code
)
else:
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None
Expand Down Expand Up @@ -369,9 +380,7 @@ def _generate(
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=completion, additional_kwargs=llm_output
)
message=AIMessage(content=completion, additional_kwargs=llm_output)
)
],
llm_output=llm_output,
Expand Down
84 changes: 54 additions & 30 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,7 @@ def _human_assistant_format(input_text: str) -> str:


def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any],
provider,
output_key,
messages_api
stream_response: Dict[str, Any], provider, output_key, messages_api
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
if messages_api:
Expand All @@ -92,7 +89,7 @@ def _stream_response_to_generation_chunk(
usage_info = stream_response.get("message", {}).get("usage", None)
generation_info = {"usage": usage_info}
return GenerationChunk(text="", generation_info=generation_info)
case "content_block_delta":
case "content_block_delta":
if not stream_response["delta"]:
return GenerationChunk(text="")
return GenerationChunk(
Expand All @@ -110,7 +107,7 @@ def _stream_response_to_generation_chunk(
return None
else:
# chunk obj format varies with provider
generation_info = {k:v for k, v in stream_response.items() if k != output_key}
generation_info = {k: v for k, v in stream_response.items() if k != output_key}
return GenerationChunk(
text=(
stream_response[output_key]
Expand All @@ -119,8 +116,11 @@ def _stream_response_to_generation_chunk(
),
generation_info=generation_info,
)

def _combine_generation_info_for_llm_result(chunks_generation_info: list[Dict[str, Any]], provider_stop_code) -> Dict[str, Any]:


def _combine_generation_info_for_llm_result(
chunks_generation_info: list[Dict[str, Any]], provider_stop_code
) -> Dict[str, Any]:
"""
Returns usage and stop reason information with the intent to pack into an LLMResult
Takes a list of GenerationChunks
Expand All @@ -147,7 +147,9 @@ def _combine_generation_info_for_llm_result(chunks_generation_info: list[Dict[st
# uses the last stop reason
stop_reason = generation_info[provider_stop_code]

total_usage_info["total_tokens"] = total_usage_info["prompt_tokens"] + total_usage_info["completion_tokens"]
total_usage_info["total_tokens"] = (
total_usage_info["prompt_tokens"] + total_usage_info["completion_tokens"]
)

return {"usage": total_usage_info, "stop_reason": stop_reason}

Expand Down Expand Up @@ -235,7 +237,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"stop_reason": response_body["stop_reason"]
"stop_reason": response_body["stop_reason"],
}

@classmethod
Expand Down Expand Up @@ -282,16 +284,24 @@ def prepare_output_stream(
elif messages_api and (chunk_obj.get("type") == "message_stop"):
return


generation_chunk = _stream_response_to_generation_chunk(chunk_obj, provider=provider, output_key=output_key, messages_api=messages_api)
generation_chunk = _stream_response_to_generation_chunk(
chunk_obj,
provider=provider,
output_key=output_key,
messages_api=messages_api,
)
if generation_chunk:
yield generation_chunk
else:
continue

@classmethod
async def aprepare_output_stream(
cls, provider: str, response: Any, stop: Optional[List[str]] = None, messages_api: bool = False
cls,
provider: str,
response: Any,
stop: Optional[List[str]] = None,
messages_api: bool = False,
) -> AsyncIterator[GenerationChunk]:
stream = response.get("body")

Expand Down Expand Up @@ -323,7 +333,12 @@ async def aprepare_output_stream(
):
return

generation_chunk = _stream_response_to_generation_chunk(chunk_obj, provider=provider, output_key=output_key, messages_api=messages_api)
generation_chunk = _stream_response_to_generation_chunk(
chunk_obj,
provider=provider,
output_key=output_key,
messages_api=messages_api,
)
if generation_chunk:
yield generation_chunk
else:
Expand Down Expand Up @@ -385,7 +400,7 @@ class BedrockBase(BaseLanguageModel, ABC):
"amazon": "completionReason",
"ai21": "finishReason",
"cohere": "finish_reason",
"mistral": "stop_reason"
"mistral": "stop_reason",
}

guardrails: Optional[Mapping[str, Any]] = {
Expand Down Expand Up @@ -603,10 +618,7 @@ def _prepare_input_and_invoke(
if stop is not None:
text = enforce_stop_tokens(text, stop)

llm_output = {
"usage": usage_info,
"stop_reason": stop_reason
}
llm_output = {"usage": usage_info, "stop_reason": stop_reason}

# Verify and raise a callback error if any intervention occurs or a signal is
# sent from a Bedrock service,
Expand All @@ -620,8 +632,6 @@ def _prepare_input_and_invoke(
),
**services_trace,
)



return text, llm_output

Expand Down Expand Up @@ -894,7 +904,9 @@ def _call(
"""

provider = self._get_provider()
provider_stop_reason_code = self.provider_stop_reason_key_map.get(provider, "stop_reason")
provider_stop_reason_code = self.provider_stop_reason_key_map.get(
provider, "stop_reason"
)

if self.streaming:
all_chunks = []
Expand All @@ -907,16 +919,22 @@ def _call(

if run_manager is not None:
chunks_generation_info = [x.generation_info for x in all_chunks]
llm_output = _combine_generation_info_for_llm_result(chunks_generation_info, provider_stop_code=provider_stop_reason_code)
run_manager.on_llm_end(LLMResult(generations=[all_chunks], llm_output=llm_output))

llm_output = _combine_generation_info_for_llm_result(
chunks_generation_info, provider_stop_code=provider_stop_reason_code
)
run_manager.on_llm_end(
LLMResult(generations=[all_chunks], llm_output=llm_output)
)

return completion

text, llm_output = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
if run_manager is not None:
run_manager.on_llm_end(LLMResult(generations=[[Generation(text=text)]], llm_output=llm_output))
run_manager.on_llm_end(
LLMResult(generations=[[Generation(text=text)]], llm_output=llm_output)
)

return text

Expand Down Expand Up @@ -972,7 +990,9 @@ async def _acall(
raise ValueError("Streaming must be set to True for async operations. ")

provider = self._get_provider()
provider_stop_reason_code = self.provider_stop_reason_key_map.get(provider, "stop_reason")
provider_stop_reason_code = self.provider_stop_reason_key_map.get(
provider, "stop_reason"
)

chunks = [
chunk
Expand All @@ -982,9 +1002,13 @@ async def _acall(
]

if run_manager is not None:
chunks_generation_info = [x.generation_info for x in chunks]
llm_output = _combine_generation_info_for_llm_result(chunks_generation_info, provider_stop_code=provider_stop_reason_code)
run_manager.on_llm_end(LLMResult(generations=[chunks], llm_output=llm_output))
chunks_generation_info = [x.generation_info for x in chunks]
llm_output = _combine_generation_info_for_llm_result(
chunks_generation_info, provider_stop_code=provider_stop_reason_code
)
run_manager.on_llm_end(
LLMResult(generations=[chunks], llm_output=llm_output)
)

return "".join([chunk.text for chunk in chunks])

Expand Down

0 comments on commit afda911

Please sign in to comment.