Skip to content

Commit

Permalink
Merge pull request #112 from aws-samples/spy_dev
Browse files Browse the repository at this point in the history
add llm support ak/sk connect
  • Loading branch information
supinyu authored Jun 24, 2024
2 parents e0ce07a + e6d9d74 commit 31c5376
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
4 changes: 4 additions & 0 deletions application/.env.cntemplate
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ SAGEMAKER_ENDPOINT_SQL=sql-sqlcoder-7b-2-7e5b6
SAGEMAKER_ENDPOINT_EXPLAIN=llm-internlm2-chat-7b-3ab71

EMBEDDING_DIMENSION=1024

BEDROCK_SECRETS_AK_SK=bedrock-ak-sk
OPENSEARCH_SECRETS_URL_HOST=opensearch-host-url
OPENSEARCH_SECRETS_USERNAME_PASSWORD=opensearch-master-user
4 changes: 4 additions & 0 deletions application/.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ DYNAMODB_AWS_REGION=us-west-2

EMBEDDING_DIMENSION=1536
BEDROCK_EMBEDDING_MODEL=amazon.titan-embed-text-v1

BEDROCK_SECRETS_AK_SK=bedrock-ak-sk
OPENSEARCH_SECRETS_URL_HOST=opensearch-host-url
OPENSEARCH_SECRETS_USERNAME_PASSWORD=opensearch-master-user
33 changes: 30 additions & 3 deletions application/utils/env_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
OPENSEARCH_REGION = os.getenv('AOS_AWS_REGION')

AOS_HOST = os.getenv('AOS_HOST')
AOS_PORT = os.getenv('AOS_PORT')
AOS_PORT = int(os.getenv('AOS_PORT'))
AOS_USER = os.getenv('AOS_USER')
AOS_PASSWORD = os.getenv('AOS_PASSWORD')
AOS_DOMAIN = os.getenv('AOS_DOMAIN')
Expand All @@ -35,19 +35,25 @@

OPENSEARCH_TYPE = os.getenv('OPENSEARCH_TYPE')

OPENSEARCH_SECRETS_URL_HOST = os.getenv('OPENSEARCH_SECRETS_URL_HOST', 'opensearch-host-url')

OPENSEARCH_SECRETS_USERNAME_PASSWORD = os.getenv('OPENSEARCH_SECRETS_USERNAME_PASSWORD', 'opensearch-master-user')

BEDROCK_SECRETS_AK_SK = os.getenv('BEDROCK_SECRETS_AK_SK')


def get_opensearch_parameter():
try:
session = boto3.session.Session()
sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION)
master_user = sm_client.get_secret_value(SecretId='opensearch-host-url')['SecretString']
master_user = sm_client.get_secret_value(SecretId=OPENSEARCH_SECRETS_URL_HOST)['SecretString']
data = json.loads(master_user)
es_host_name = data.get('host')
# cluster endpoint, for example: my-test-domain.us-east-1.es.amazonaws.com/
host = es_host_name + '/' if es_host_name[-1] != '/' else es_host_name

sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION)
master_user = sm_client.get_secret_value(SecretId='opensearch-master-user')['SecretString']
master_user = sm_client.get_secret_value(SecretId=OPENSEARCH_SECRETS_USERNAME_PASSWORD)['SecretString']
data = json.loads(master_user)
username = data.get('username')
password = data.get('password')
Expand All @@ -59,6 +65,25 @@ def get_opensearch_parameter():
raise e


def get_bedrock_parameter():
bedrock_ak_sk_info = {}
try:
session = boto3.session.Session()
sm_client = session.client(service_name='secretsmanager', region_name=AWS_DEFAULT_REGION)
if BEDROCK_SECRETS_AK_SK is not None and BEDROCK_SECRETS_AK_SK != "":
bedrock_info = sm_client.get_secret_value(SecretId=BEDROCK_SECRETS_AK_SK)['SecretString']
data = json.loads(bedrock_info)
access_key = data.get('access_key_id')
secret_key = data.get('secret_access_key')
bedrock_ak_sk_info['access_key_id'] = access_key
bedrock_ak_sk_info['secret_access_key'] = secret_key
return bedrock_ak_sk_info
except ClientError as e:
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
return bedrock_ak_sk_info


if OPENSEARCH_TYPE == "service":
opensearch_host, opensearch_port, opensearch_username, opensearch_password = get_opensearch_parameter()
AOS_HOST = opensearch_host
Expand All @@ -78,3 +103,5 @@ def get_opensearch_parameter():
'agent_index': AOS_INDEX_AGENT,
'embedding_dimension': EMBEDDING_DIMENSION
}

bedrock_ak_sk_info = get_bedrock_parameter()
9 changes: 8 additions & 1 deletion application/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
generate_agent_analyse_prompt, generate_data_summary_prompt, generate_suggest_question_prompt, \
generate_query_rewrite_prompt

from utils.env_var import bedrock_ak_sk_info
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

Expand All @@ -38,7 +39,13 @@
def get_bedrock_client():
global bedrock
if not bedrock:
bedrock = boto3.client(service_name='bedrock-runtime', config=config)
if len(bedrock_ak_sk_info) == 0:
bedrock = boto3.client(service_name='bedrock-runtime', config=config)
else:
bedrock = boto3.client(
service_name='bedrock-runtime', config=config,
aws_access_key_id=bedrock_ak_sk_info['access_key_id'],
aws_secret_access_key=bedrock_ak_sk_info['secret_access_key'])
return bedrock


Expand Down

0 comments on commit 31c5376

Please sign in to comment.