Skip to content

Commit

Permalink
Made changes in response to comments on #30
Browse files Browse the repository at this point in the history
  • Loading branch information
tomron-aws committed Dec 17, 2024
1 parent c9785c1 commit 40c8306
Showing 1 changed file with 58 additions and 19 deletions.
77 changes: 58 additions & 19 deletions libs/aws/langchain_aws/llms/q_business.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun
)
from langchain_core.outputs import GenerationChunk
import logging
from langchain_core.language_models import LLM
from pydantic import ConfigDict
from pydantic import ConfigDict, model_validator
import json
import asyncio
import boto3
from typing_extensions import Self

class AmazonQ(LLM):
"""Amazon Q LLM wrapper.
Expand Down Expand Up @@ -45,6 +46,13 @@ class AmazonQ(LLM):
chat_mode: str = "RETRIEVAL_MODE"
"""AWS region name. If not provided, will be extracted from environment."""

credentials_profile_name: Optional[str] = None
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
has either access keys or role information specified.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""
model_config = ConfigDict(
extra="forbid",
)
Expand Down Expand Up @@ -85,24 +93,15 @@ def _call(

# Prepare the request
request = {
'applicationId': "130f4ea4-855f-4ddf-b2a5-1e40923692d4",
'applicationId': self.application_id,
'userMessage': prompt,
'chatMode':self.chat_mode,
}
if not self.conversation_id:
request = {
'applicationId': self.application_id,
'userMessage': prompt,
'chatMode':self.chat_mode,
}
else:
request = {
'applicationId': self.application_id,
'userMessage': prompt,
'chatMode':self.chat_mode,
'conversationId':self.conversation_id,
'parentMessageId':self.parent_message_id,
}
if self.conversation_id:
request.update({
'conversationId': self.conversation_id,
'parentMessageId': self.parent_message_id,
})

# Call Amazon Q
response = self.client.chat_sync(**request)
Expand All @@ -115,6 +114,12 @@ def _call(
raise ValueError("Unexpected response format from Amazon Q")

except Exception as e:
if "Prompt Length" in str(e):
logging.info(f"Prompt Length: {len(prompt)}")
print(f"""Prompt:
{prompt}""")
raise ValueError(f"Error raised by Amazon Q service: {e}")

raise ValueError(f"Error raised by Amazon Q service: {e}")

def get_last_response(self) -> Dict:
Expand Down Expand Up @@ -143,4 +148,38 @@ def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"region_name": self.region_name,
}
}
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Dont do anything if client provided externally"""
if self.client is not None:
return self

"""Validate that AWS credentials to and python package exists in environment."""
try:
import boto3

try:
if self.credentials_profile_name is not None:
session = boto3.Session(profile_name=self.credentials_profile_name)
else:
# use default credentials
session = boto3.Session()

self.client = session.client(
"qbusiness", region_name=self.region_name
)

except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e

except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
return self

0 comments on commit 40c8306

Please sign in to comment.