diff --git a/lib/constructs/common/lambda-functions.ts b/lib/constructs/common/lambda-functions.ts index c9be9cf18..affe8d4dd 100644 --- a/lib/constructs/common/lambda-functions.ts +++ b/lib/constructs/common/lambda-functions.ts @@ -1197,3 +1197,75 @@ export class ForumThreadAIFunctions extends Construct { }); } } + +export class ForumCommentAIFunctions extends Construct { + readonly injectFunction: lambda.Function; + + constructor(scope: Construct, id: string, props: FunctionsProps) { + super(scope, id); + + const latestBoto3Layer = new lambda_py.PythonLayerVersion( + this, + 'Boto3PythonLayerVersion', + { + entry: 'lib/configs/lambda/python_packages', + compatibleRuntimes: [lambda.Runtime.PYTHON_3_9], + layerVersionName: 'latest-boto3-python-layer', + description: 'Layer containing updated boto3 and botocore', + }, + ); + + const bedrockAccessPolicy = new iam.PolicyStatement({ + effect: iam.Effect.ALLOW, + actions: ['bedrock:InvokeModel'], + resources: ['*'], + }); + + const DBSyncRole: iam.LazyRole = new iam.LazyRole( + this, + 'dynamodb-s3-forum-comment-ai-role', + { + assumedBy: new iam.ServicePrincipal(AwsServicePrincipal.LAMBDA), + description: + 'Allow lambda function to perform crud operation on dynamodb and s3', + path: `/service-role/${AwsServicePrincipal.LAMBDA}/`, + roleName: 'dynamodb-s3-forum-thread-ai-role', + managedPolicies: [ + iam.ManagedPolicy.fromManagedPolicyArn( + this, + 'basic-exec1', + 'arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole', + ), + iam.ManagedPolicy.fromManagedPolicyArn( + this, + 'db-full-access', + 'arn:aws:iam::aws:policy/AmazonDynamoDBFullAccess', + ), + iam.ManagedPolicy.fromManagedPolicyArn( + this, + 's3-full-access', + 'arn:aws:iam::aws:policy/AmazonS3FullAccess', + ), + ], + }, + ); + DBSyncRole.addToPolicy(bedrockAccessPolicy); + + this.injectFunction = new lambda_py.PythonFunction( + this, + 'inject-comments', + { + entry: 'src/lambda/inject-comments', + description: 'inject ai generated comment data into the database', + functionName: 'inject-comment', + logRetention: logs.RetentionDays.ONE_MONTH, + memorySize: 128, + role: DBSyncRole, + runtime: lambda.Runtime.PYTHON_3_9, + timeout: Duration.seconds(60), + environment: props.envVars, + layers: [latestBoto3Layer], + }, + ); + } +} diff --git a/lib/constructs/persistence/data-pipeline.ts b/lib/constructs/persistence/data-pipeline.ts index 389660b4e..9e994e8dd 100644 --- a/lib/constructs/persistence/data-pipeline.ts +++ b/lib/constructs/persistence/data-pipeline.ts @@ -28,6 +28,7 @@ export enum Worker { THREADIMG, ADS, FORUMAI, + COMMENTAI, } export interface DataPipelineProps { @@ -355,3 +356,33 @@ export class ForumThreadAIDataPipeline extends AbstractDataPipeline { ).injectFunction; } } + +export class ForumCommentAIDataPipeline extends AbstractDataPipeline { + readonly dataSource?: s3.Bucket; + readonly processor: lambda.Function; + readonly dataWarehouse: dynamodb.Table; + readonly commentWarehouse: dynamodb.Table; + + constructor(scope: Construct, id: string, props: DataPipelineProps) { + super(scope, id); + + // this.dataSource = props.dataSource!; + this.dataWarehouse = props.threadWareHouse!; + this.commentWarehouse = props.commentWareHouse!; + + const UID = process.env.UID!; + + this.processor = new ForumThreadAIFunctions( + this, + 'forum-thread-ai-function', + { + envVars: { + // ['BUCKET_NAME']: this.dataSource.bucketName, + ['THREAD_TABLE_NAME']: this.dataWarehouse.tableName, + ['COMMENT_TABLE_NAME']: this.commentWarehouse.tableName, + ['UID']: UID, + }, + }, + ).injectFunction; + } +} diff --git a/lib/stacks/persistence.ts b/lib/stacks/persistence.ts index acfd5607b..61802bef5 100644 --- a/lib/stacks/persistence.ts +++ b/lib/stacks/persistence.ts @@ -10,6 +10,7 @@ import { ThreadImgDataPipeline, AdsDataPipeline, ForumThreadAIDataPipeline, + ForumCommentAIDataPipeline, Worker, } from '../constructs/persistence/data-pipeline'; import { Collection, DynamoDatabase } from '../constructs/persistence/database'; @@ -64,6 +65,17 @@ export class WasedaTimePersistenceLayer extends PersistenceLayer { ); this.dataPipelines[Worker.FORUMAI] = forumThreadAIDataPipeline; + const forumCommentAIDataPipeline = new ForumCommentAIDataPipeline( + this, + 'forum-comment-ai-data-pipeline', // Error fixed + { + // dataSource: syllabusDataPipeline.dataWarehouse, + threadWareHouse: dynamoDatabase.tables[Collection.THREAD], + commentWareHouse: dynamoDatabase.tables[Collection.COMMENT], + }, + ); + this.dataPipelines[Worker.COMMENTAI] = forumCommentAIDataPipeline; + //! New pipeline for ads const adsDataPipeline = new AdsDataPipeline(this, 'ads-data-pipeline', { dataWarehouse: dynamoDatabase.tables[Collection.ADS], diff --git a/src/lambda/get-ads/index.py b/src/lambda/get-ads/index.py index 0d95a0996..b24100738 100644 --- a/src/lambda/get-ads/index.py +++ b/src/lambda/get-ads/index.py @@ -20,7 +20,7 @@ def get_imgs_list(board_id, ad_id): # If the count propperty doesn't exist yet, set to 1, if existed increase by 1 table.update_item( - Key={ + Key={ "board_id": board_id, "ad_id": ad_id, }, diff --git a/src/lambda/inject-comments/index.py b/src/lambda/inject-comments/index.py new file mode 100644 index 000000000..5294e37ea --- /dev/null +++ b/src/lambda/inject-comments/index.py @@ -0,0 +1,34 @@ +from datetime import datetime +from utils import JsonPayloadBuilder, resp_handler, UID, table_comment, UNIV_ID, get_bedrock_response + + +@resp_handler +def inject_comment(content, thread_id): + + dt_now = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' + + comment_item = { + "thread_id": thread_id, + "created_at": dt_now, + "updated_at": dt_now, + "body": content, + "uid": UID + } + + table_comment.put_item(Item=comment_item) + + body = JsonPayloadBuilder().add_status( + True).add_data('').add_message('').compile() + return body + + +def handler(event, context): + resp = get_bedrock_response() + + content, thread_id = resp + + if resp is None: + # No threads were found; end the Lambda function. + return + + return inject_comment(content, thread_id) diff --git a/src/lambda/inject-comments/utils.py b/src/lambda/inject-comments/utils.py new file mode 100644 index 000000000..8496351c3 --- /dev/null +++ b/src/lambda/inject-comments/utils.py @@ -0,0 +1,211 @@ +from concurrent.futures import thread +import boto3 +import json +import logging +import os +from decimal import Decimal +from datetime import datetime +import uuid +from boto3.dynamodb.conditions import Key +import random +import re + +db = boto3.resource("dynamodb", region_name="ap-northeast-1") +table_thread = db.Table(os.getenv('THREAD_TABLE_NAME')) +table_comment = db.Table(os.getenv('COMMENT_TABLE_NAME')) + +UID = os.getenv('UID') + +UNIV_ID = "1" + +bedrock_client = boto3.client('bedrock-runtime', region_name='ap-northeast-1') + + +class DecimalEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Decimal): + return float(obj) + return json.JSONEncoder.default(self, obj) + + +class ExtendedEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Decimal): + return float(obj) + if isinstance(obj, set): + return list(obj) + return super(ExtendedEncoder, self).default(obj) + + +class JsonPayloadBuilder: + payload = {} + + def add_status(self, success): + self.payload['success'] = success + return self + + def add_data(self, data): + self.payload['data'] = data + return self + + def add_message(self, msg): + self.payload['message'] = msg + return self + + def compile(self): + return json.dumps(self.payload, cls=ExtendedEncoder, ensure_ascii=False).encode('utf8') + + +def api_response(code, body): + return { + "isBase64Encoded": False, + "statusCode": code, + 'headers': { + "Access-Control-Allow-Origin": '*', + "Content-Type": "application/json", + "Referrer-Policy": "origin" + }, + "multiValueHeaders": {"Access-Control-Allow-Methods": ["POST", "OPTIONS", "GET", "PATCH", "DELETE"]}, + "body": body + } + + +def resp_handler(func): + def handle(*args, **kwargs): + try: + resp = func(*args, **kwargs) + return api_response(200, resp) + except LookupError: + resp = JsonPayloadBuilder().add_status(False).add_data(None) \ + .add_message("Not found").compile() + return api_response(404, resp) + except Exception as e: + logging.error(str(e)) + resp = JsonPayloadBuilder().add_status(False).add_data(None) \ + .add_message("Internal error, please contact bugs@wasedatime.com.").compile() + return api_response(500, resp) + + return handle + + +def fetch_top_thread(): + """Fetch the latest thread in the databse and return its body and thread_id + + :return: Body and thread_id of the latest thread + """ + response = table_thread.scan( + Limit=1, + Select='ALL_ATTRIBUTES', + ScanIndexForward=False + ) + + items: list = response['Items'] + + item = items[0] + + body = item['body'] + thread_id = item['thread_id'] + + logging.info(f"Return value: {body}, {thread_id}") + + return body, thread_id + + +def fetch_comments(thread_id): + """Use the thread_id to fetch all its comments from database + + :param thread_id: thread_id of the latest thread + :return: All comments of the latest thread + """ + response = table_comment.query( + KeyConditionExpression = Key("thread_id").eq(thread_id) + ) + + items: list = response['Items'] + + all_comment_body = [] + + for item in items: + all_comment_body += item["body"] + + + logging.info(f"comments : {all_comment_body}") + + return json.dumps(all_comment_body) + + +def generate_prompt(): + body_thread_id = fetch_top_thread() + + thread_body, thread_id = body_thread_id + thread_body = json.dumps(thread_body) + + comments = fetch_comments(thread_id) + + prompt_recent_threads = f'''\n\nHuman: + Use the following example threads as your reference for topics and writing style of the students : original thread={thread_body}, comments={comments} + You are a helpful international university student who is active in an online university forum. + Generate 1 new comment for the thread after reading the original thread and comments. + Ensure: + - Do not repeat the examples. + - Do not make any offers. + - Respond strictly in format TOPIC: CONTENT + Assistant: + ''' + + logging.info(f"Chosen Prompt : {prompt_recent_threads}") + + return prompt_recent_threads, thread_id + + +def get_bedrock_response(): + + prompt, thread_id = generate_prompt() + + modelId = 'anthropic.claude-instant-v1' + accept = 'application/json' + contentType = 'application/json' + + body = json.dumps({ + "prompt": prompt, + "max_tokens_to_sample": 2000, + "temperature": 0.8 + }) + + response = bedrock_client.invoke_model( + modelId=modelId, + accept=accept, + contentType=contentType, + body=body + ) + + response_body = json.loads(response.get('body').read()) + + completion: dict = response_body.get('completion') + + return completion, thread_id + + +# def select_comment(): +# completion, thread_id = get_bedrock_response() + +# pattern = re.compile( +# r"(Academic|Job|Life|WTF):([\s\S]*?)(?=(Academic|Job|Life|WTF):|$)", re.IGNORECASE) + +# matches = pattern.findall(completion) + +# forum_posts = [{"topic": match[0], "content": match[1].strip()} +# for match in matches] + +# for post in forum_posts: +# post['topic'] = post['topic'].lower() + +# try: +# selected_thread = random.choice(forum_posts) +# logging.info(selected_thread) +# return selected_thread, thread_id +# except IndexError: +# logging.warning("LLM anomaly: No matching threads found.") +# return None + + s \ No newline at end of file