Skip to content

Commit

Permalink
add knowledge_base_handler
Browse files Browse the repository at this point in the history
  • Loading branch information
kris committed Mar 22, 2024
1 parent 04eda23 commit 8b0198f
Show file tree
Hide file tree
Showing 3 changed files with 367 additions and 2 deletions.
2 changes: 1 addition & 1 deletion deployment/cdk.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"aws",
"aws-cn"
],
"vpc_deployment": false,
"vpc_deployment": true,
"vpc_name": "smart_search-vpc",
"subnet_name": "smart_search-subnet-private1-us-east-1a",
"subnet_id": "subnet-1234",
Expand Down
367 changes: 366 additions & 1 deletion deployment/lib/ss_lambdavpcstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def create_langchain_qa_func(self, search_engine_key, vpc=None, vpc_subnets=None
timeout=Duration.minutes(10),
vpc=vpc,
vpc_subnets=vpc_subnets,
reserved_concurrent_executions=100
reserved_concurrent_executions=10
)
langchain_processor_qa_function.add_environment("host", search_engine_key)
langchain_processor_qa_function.add_environment("index", index)
Expand Down Expand Up @@ -441,6 +441,371 @@ def create_apigw_resource_method_for_langchain_qa(self, api, langchain_processor
langchain_processor_qa_function.add_environment("dynamodb_table_name", chat_table.table_name)
cdk.CfnOutput(self, 'chat_table_name', value=chat_table.table_name, export_name='ChatTableName')

def create_apigw_resource_method_for_endpoint_list(self, api, endpoint_list_function):

endpoint_list_resource = api.root.add_resource(
'endpoint_list',
default_cors_preflight_options=apigw.CorsOptions(
allow_methods=['GET', 'OPTIONS'],
allow_origins=apigw.Cors.ALL_ORIGINS)
)

endpoint_list_integration = apigw.LambdaIntegration(
endpoint_list_function,
proxy=True,
integration_responses=[
apigw.IntegrationResponse(
status_code="200",
response_parameters={
'method.response.header.Access-Control-Allow-Origin': "'*'"
}
)
]
)

endpoint_list_resource.add_method(
'GET',
endpoint_list_integration,
method_responses=[
apigw.MethodResponse(
status_code="200",
response_parameters={
'method.response.header.Access-Control-Allow-Origin': True
}
)
]
)

def create_file_upload_prerequisites(self, api, search_engine_key, vpc=None, vpc_subnets=None):
# Now hardcode for testing first
ACCOUNT = os.getenv('AWS_ACCOUNT_ID', '')
REGION = os.getenv('AWS_REGION', '')
bucket_for_uploaded_files = "intelligent-search-data-bucket" + "-" + ACCOUNT + "-" + REGION
execution_role_name = self.node.try_get_context("execution_role_name") + REGION
index = self.node.try_get_context("index")
language = self.node.try_get_context("language")
embedding_endpoint_name = self.node.try_get_context("embedding_endpoint_name")
search_engine_opensearch = self.node.try_get_context("search_engine_opensearch")
search_engine_zilliz = self.node.try_get_context("search_engine_zilliz")
zilliz_endpoint = self.node.try_get_context("zilliz_endpoint")
zilliz_token = self.node.try_get_context("zilliz_token")
CN_SUFFIX = "-cn" if "cn-" in REGION else ""

"""
1. Create S3 bucket for storing uploaded files
"""

_bucket_name = bucket_for_uploaded_files

_bucket = s3.Bucket(self,
id=_bucket_name,
bucket_name=_bucket_name,
block_public_access=s3.BlockPublicAccess.BLOCK_ALL,
encryption=s3.BucketEncryption.S3_MANAGED,
enforce_ssl=True,
versioned=False,
removal_policy=RemovalPolicy.DESTROY
)

_bucket.add_cors_rule(
allowed_headers=["*"],
allowed_methods=[
s3.HttpMethods.GET,
s3.HttpMethods.PUT,
s3.HttpMethods.POST
],
allowed_origins=["*"]
)

self.bucket = _bucket

"""
2. Create Execution Role for Uploading file to S3
IAM RoleName: custom-role-document-ai-upload-to-s3
"""
_role_name = execution_role_name
_inline_policies = {
"AllowS3UploadPermission": _iam.PolicyDocument(
statements=[
_iam.PolicyStatement(
actions=[
"s3:PutObject",
],
resources=[f"arn:aws{CN_SUFFIX}:s3:::{_bucket_name}/*"]),
]
),
"AllowLogCreation": _iam.PolicyDocument(
statements=[
_iam.PolicyStatement(
actions=[
"logs:CreateLogGroup",
"logs:CreateLogStream",
"logs:DescribeLogGroups",
"logs:DescribeLogStreams",
"logs:PutLogEvents",
"logs:GetLogEvents",
"logs:FilterLogEvents"
],
resources=["*"]),
])
}

_execution_role = _iam.Role(self,
id=_role_name,
role_name=_role_name,
assumed_by=_iam.ServicePrincipal("apigateway.amazonaws.com"),
description="Execution role for uploading file from APIGW to S3 directly.",
inline_policies=_inline_policies
)

"""
3. Create Lambda for processing file and save to ES
"""
function_name = 'langchain_processor_dataload'

_data_load_role_policy = _iam.PolicyStatement(
actions=[
'sagemaker:InvokeEndpointAsync',
'sagemaker:InvokeEndpoint',
's3:AmazonS3FullAccess',
'lambda:AWSLambdaBasicExecutionRole',
'secretsmanager:SecretsManagerReadWrite',
'bedrock:*'
],
resources=['*'] # 可根据需求进行更改
)
data_load_role = _iam.Role(
self, 'data_load_role',
assumed_by=_iam.ServicePrincipal('lambda.amazonaws.com')
)
data_load_role.add_to_policy(_data_load_role_policy)

data_load_role.add_managed_policy(
_iam.ManagedPolicy.from_aws_managed_policy_name("service-role/AWSLambdaBasicExecutionRole")
)

data_load_role.add_managed_policy(
_iam.ManagedPolicy.from_aws_managed_policy_name("SecretsManagerReadWrite")
)

data_load_role.add_managed_policy(
_iam.ManagedPolicy.from_aws_managed_policy_name("AmazonS3FullAccess")
)

data_load_function = _lambda.Function(
self, function_name,
function_name=function_name,
runtime=_lambda.Runtime.PYTHON_3_9,
role=data_load_role,
layers=[self.langchain_processor_qa_layer],
code=_lambda.Code.from_asset('../lambda/' + function_name),
handler='lambda_function' + '.lambda_handler',
timeout=Duration.minutes(10),
vpc=vpc,
vpc_subnets=vpc_subnets,
reserved_concurrent_executions=10
)
data_load_function.add_environment("host", search_engine_key)
data_load_function.add_environment("index", index)
data_load_function.add_environment("language", language)
data_load_function.add_environment("embedding_endpoint_name", embedding_endpoint_name)
data_load_function.add_environment("search_engine_opensearch", str(search_engine_opensearch))
data_load_function.add_environment("search_engine_zilliz", str(search_engine_zilliz))
data_load_function.add_environment("zilliz_endpoint", str(zilliz_endpoint))
data_load_function.add_environment("zilliz_token", str(zilliz_token))

"""
4. Update S3 file notification with Lambda
prefix is {BUCKET_NAME}/source_data/
"""
_bucket.add_event_notification(
s3.EventType.OBJECT_CREATED,
s3n.LambdaDestination(data_load_function),
s3.NotificationKeyFilter(
prefix="source_data/",
),
)

"""
5. Create S3-based API Gateway
"""
# Create Resources in below structure
# /{bucket}/{prefix}/{sub_prefix}/{filename}
file_upload_root = api.root.add_resource(path_part="file_upload")
bucket_resource = file_upload_root.add_resource(path_part="{bucket}")
prefix_resource = bucket_resource.add_resource(path_part="{prefix}")
sub_prefix_resource = prefix_resource.add_resource(path_part="{sub_prefix}")
filename_resource = sub_prefix_resource.add_resource(
path_part="{filename}",
default_cors_preflight_options=apigw.CorsOptions(
allow_methods=['PUT', 'POST', 'OPTIONS'],
allow_origins=apigw.Cors.ALL_ORIGINS
)
)

# Create S3 Integration for APIGw
request_parameters = {
"method.request.path.bucket": True, # True if param is mandatory
"method.request.path.filename": True, # True if param is mandatory
"method.request.path.prefix": True, # True if param is mandatory
"method.request.path.sub_prefix": True # True if param is mandatory
}

request_parameters_in_integration_options = {
"integration.request.path.bucket": "method.request.path.bucket",
"integration.request.path.key": "method.request.path.filename",
"integration.request.path.prefix": "method.request.path.prefix",
"integration.request.path.sub_prefix": "method.request.path.sub_prefix",
}

"""
6. Create Lambda for list all sagemaker endpoint for front-end
"""
function_name = 'endpoint_list'

_endpoint_list_role_policy = _iam.PolicyStatement(
actions=[
's3:AmazonS3FullAccess',
'lambda:AWSLambdaBasicExecutionRole',
'secretsmanager:SecretsManagerReadWrite'
],
resources=['*'] # 可根据需求进行更改
)
endpoint_list_role = _iam.Role(
self, 'endpoint_list_role',
assumed_by=_iam.ServicePrincipal('lambda.amazonaws.com')
)
endpoint_list_role.add_to_policy(_data_load_role_policy)

endpoint_list_role.add_managed_policy(
_iam.ManagedPolicy.from_aws_managed_policy_name("service-role/AWSLambdaBasicExecutionRole")
)

endpoint_list_role.add_managed_policy(
_iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSageMakerReadOnly")
)


endpoint_list_function = _lambda.Function(
self, function_name,
function_name=function_name,
runtime=_lambda.Runtime.PYTHON_3_9,
role=endpoint_list_role,
#layers=[self.langchain_processor_qa_layer],
code=_lambda.Code.from_asset('../lambda/' + function_name),
handler='lambda_function' + '.lambda_handler',
timeout=Duration.minutes(10),
vpc=vpc,
vpc_subnets=vpc_subnets,
reserved_concurrent_executions=10
)

self.create_apigw_resource_method_for_endpoint_list(
api=api,
endpoint_list_function=endpoint_list_function
)

# Create Integration Options
"""
Covering:
1. Content Handling : Default passthrough, if not specify
2. URL Path Parameters
3. Credential Role
"""
_s3_apigw_put_integration_options = apigw.IntegrationOptions(
request_parameters=request_parameters_in_integration_options,
credentials_role=_execution_role,
integration_responses=[
apigw.IntegrationResponse(
status_code="200",
response_parameters={
"method.response.header.Access-Control-Allow-Headers": "'Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token'",
"method.response.header.Access-Control-Allow-Origin": "'*'",
"method.response.header.Access-Control-Allow-Methods": "'PUT,POST,OPTIONS'",
},
response_templates={
"application/json": ""

}
)
]
)

s3_apigw_integration = apigw.AwsIntegration(
service="s3",
path="{bucket}/{prefix}/{sub_prefix}/{key}",
region=os.getenv('region'),
integration_http_method="PUT",
options=_s3_apigw_put_integration_options
)

filename_resource.add_method(
http_method='PUT',
integration=s3_apigw_integration,
request_parameters=request_parameters,
method_responses=[
apigw.MethodResponse(
status_code="200",
response_parameters={
'method.response.header.Access-Control-Allow-Origin': True,
'method.response.header.Access-Control-Allow-Headers': True,
'method.response.header.Access-Control-Allow-Methods': True,
},
)
]
)

def create_content_moderation_func(self, vpc=None, vpc_subnets=None):
_api_url_suffix = "_cn" if 'cn-' in os.getenv('AWS_REGION', '') else ""
content_moderation_api = self.node.try_get_context(f"content_moderation_api{_api_url_suffix}")
# content_moderation_token = self.node.try_get_context("content_moderation_account_token_in_base64")
content_moderation_result_table = self.node.try_get_context("content_moderation_result_table")

_content_moderation_role_policy = _iam.PolicyStatement(
actions=[
'lambda:AWSLambdaBasicExecutionRole',
'secretsmanager:GetSecretValue',
"logs:CreateLogGroup",
"logs:CreateLogStream",
"logs:DescribeLogGroups",
"logs:DescribeLogStreams",
"logs:PutLogEvents",
"logs:GetLogEvents",
"logs:FilterLogEvents"
],
resources=['*']
)

content_moderation_role = _iam.Role(
self, 'content_moderation_role',
assumed_by=_iam.ServicePrincipal('lambda.amazonaws.com')
)
content_moderation_role.add_to_policy(_content_moderation_role_policy)

content_moderation_role.add_managed_policy(
_iam.ManagedPolicy.from_aws_managed_policy_name("service-role/AWSLambdaBasicExecutionRole")
)

# add langchain processor for smart query and answer
function_name_content_moderation = 'content_moderation'
content_moderation_func = _lambda.Function(
self, function_name_content_moderation,
function_name=function_name_content_moderation,
runtime=_lambda.Runtime.PYTHON_3_9,
role=content_moderation_role,
layers=[self.langchain_processor_qa_layer],
code=_lambda.Code.from_asset('../lambda/' + function_name_content_moderation),
handler='lambda_function' + '.lambda_handler',
memory_size=256,
timeout=Duration.minutes(10),
vpc=vpc,
vpc_subnets=vpc_subnets,
reserved_concurrent_executions=10
)
content_moderation_func.add_environment("content_moderation_api", content_moderation_api)

return content_moderation_func

def create_apigw_resource_method_for_knowledge_base_handler(self, api, knowledge_base_handler_function):

knowledge_base_handler_resource = api.root.add_resource(
Expand Down
Binary file modified lambda/.DS_Store
Binary file not shown.

0 comments on commit 8b0198f

Please sign in to comment.