Skip to content

Commit

Permalink
feat: adding base function for bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonNotJson committed Oct 22, 2023
1 parent 0aaeaca commit b4370a6
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 4 deletions.
56 changes: 56 additions & 0 deletions lib/constructs/common/lambda-functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1129,3 +1129,59 @@ export class CareerRestFunctions extends Construct {
});
}
}

export class ForumThreadAIFunctions extends Construct {
readonly injectFunction: lambda.Function;

constructor(scope: Construct, id: string, props: FunctionsProps) {
super(scope, id);

const bedrockAccessPolicy = new iam.PolicyStatement({
effect: iam.Effect.ALLOW,
actions: ['bedrock:InvokeModel'],
resources: ['*'],
});

const DBSyncRole: iam.LazyRole = new iam.LazyRole(
this,
'dynamodb-s3-forum-thread-ai-role',
{
assumedBy: new iam.ServicePrincipal(AwsServicePrincipal.LAMBDA),
description:
'Allow lambda function to perform crud operation on dynamodb and s3',
path: `/service-role/${AwsServicePrincipal.LAMBDA}/`,
roleName: 'dynamodb-s3-career-sync-role',
managedPolicies: [
iam.ManagedPolicy.fromManagedPolicyArn(
this,
'basic-exec1',
'arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole',
),
iam.ManagedPolicy.fromManagedPolicyArn(
this,
'db-full-access',
'arn:aws:iam::aws:policy/AmazonDynamoDBFullAccess',
),
iam.ManagedPolicy.fromManagedPolicyArn(
this,
's3-full-access',
'arn:aws:iam::aws:policy/AmazonS3FullAccess',
),
],
},
);
DBSyncRole.addToPolicy(bedrockAccessPolicy);

this.injectFunction = new lambda_py.PythonFunction(this, 'inject-threads', {
entry: 'src/lambda/inject-threads',
description: 'inject ai generated thread data into the database',
functionName: 'inject-thread',
logRetention: logs.RetentionDays.ONE_MONTH,
memorySize: 128,
role: DBSyncRole,
runtime: lambda.Runtime.PYTHON_3_9,
timeout: Duration.seconds(3),
environment: props.envVars,
});
}
}
31 changes: 29 additions & 2 deletions lib/constructs/persistence/data-pipeline.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,29 @@ import {
ThreadImageProcessFunctions,
AdsImageProcessFunctionsPipeline,
CareerDBSyncFunction,
ForumThreadAIFunctions,
} from '../common/lambda-functions';

export enum Worker {
SYLLABUS,
CAREER,
FEEDS,
THREADIMG,
ADS, //! New ADS value
ADS,
FORUMAI,
}

export interface DataPipelineProps {
dataSource?: s3.Bucket;
dataWarehouse?: dynamodb.Table;
threadWareHouse?: dynamodb.Table;
commentWareHouse?: dynamodb.Table;
}

export abstract class AbstractDataPipeline extends Construct {
abstract readonly dataSource?: s3.Bucket;
abstract readonly processor: lambda.Function | sfn.StateMachine;
abstract readonly dataWarehouse: s3.Bucket | dynamodb.Table;
abstract readonly dataWarehouse?: s3.Bucket | dynamodb.Table;
}

export class SyllabusDataPipeline extends AbstractDataPipeline {
Expand Down Expand Up @@ -320,3 +324,26 @@ export class AdsDataPipeline extends AbstractDataPipeline {
);
}
}

export class ForumThreadAIDataPipeline extends AbstractDataPipeline {
readonly dataSource?: s3.Bucket;
readonly processor: lambda.Function;
readonly dataWarehouse: dynamodb.Table;
readonly commentWarehouse: dynamodb.Table;

constructor(scope: Construct, id: string, props: DataPipelineProps) {
super(scope, id);

this.dataSource = props.dataSource!;
this.dataWarehouse = props.threadWareHouse!;
this.commentWarehouse = props.commentWareHouse!;

this.processor = new ForumThreadAIFunctions(this, 'career-sync-function', {
envVars: {
['BUCKET_NAME']: this.dataSource.bucketName,
['THREAD_TABLE_NAME']: this.dataWarehouse.tableName,
['COMMENT_TABLE_NAME']: this.commentWarehouse.tableName,
},
}).injectFunction;
}
}
15 changes: 13 additions & 2 deletions lib/stacks/persistence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import {
SyllabusDataPipeline,
SyllabusSyncPipeline,
ThreadImgDataPipeline,
AdsDataPipeline, //! New value
AdsDataPipeline,
ForumThreadAIDataPipeline,
Worker,
} from '../constructs/persistence/data-pipeline';
import { Collection, DynamoDatabase } from '../constructs/persistence/database';
Expand Down Expand Up @@ -52,6 +53,17 @@ export class WasedaTimePersistenceLayer extends PersistenceLayer {
);
this.dataPipelines[Worker.THREADIMG] = threadImgDataPipeline;

const forumThreadAIDataPipeline = new ForumThreadAIDataPipeline(
this,
'forum-thread-ai-data-pipeline',
{
dataSource: syllabusDataPipeline.dataWarehouse,
threadWareHouse: dynamoDatabase.tables[Collection.THREAD],
commentWareHouse: dynamoDatabase.tables[Collection.COMMENT],
},
);
this.dataPipelines[Worker.FORUMAI] = forumThreadAIDataPipeline;

//! New pipeline for ads
const adsDataPipeline = new AdsDataPipeline(this, 'ads-data-pipeline', {
dataWarehouse: dynamoDatabase.tables[Collection.ADS],
Expand Down Expand Up @@ -83,7 +95,6 @@ export class WasedaTimePersistenceLayer extends PersistenceLayer {
dynamoDatabase.tables[Collection.COMMENT].tableName,
);

//! new endpoint for ads
this.dataInterface.setEndpoint(
DataEndpoint.ADS,
dynamoDatabase.tables[Collection.ADS].tableName,
Expand Down
2 changes: 2 additions & 0 deletions src/lambda/inject-threads/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def handler(event, context):
pass
93 changes: 93 additions & 0 deletions src/lambda/inject-threads/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import boto3
import json
import logging
import os
from decimal import Decimal

db = boto3.resource("dynamodb", region_name="ap-northeast-1")
table = db.Table(os.getenv('TABLE_NAME'))

s3_client = boto3.client('s3')
bucket = os.getenv('BUCKET_NAME')

bedrock_client = boto3.client('bedrock-runtime', region_name='ap-northeast-1')


class DecimalEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Decimal):
return float(obj)
return json.JSONEncoder.default(self, obj)


class ExtendedEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Decimal):
return float(obj)
if isinstance(obj, set):
return list(obj)
return super(ExtendedEncoder, self).default(obj)


class JsonPayloadBuilder:
payload = {}

def add_status(self, success):
self.payload['success'] = success
return self

def add_data(self, data):
self.payload['data'] = data
return self

def add_message(self, msg):
self.payload['message'] = msg
return self

def compile(self):
return json.dumps(self.payload, cls=ExtendedEncoder, ensure_ascii=False).encode('utf8')


def api_response(code, body):
return {
"isBase64Encoded": False,
"statusCode": code,
'headers': {
"Access-Control-Allow-Origin": '*',
"Content-Type": "application/json",
"Referrer-Policy": "origin"
},
"multiValueHeaders": {"Access-Control-Allow-Methods": ["POST", "OPTIONS", "GET", "PATCH", "DELETE"]},
"body": body
}


def resp_handler(func):
def handle(*args, **kwargs):
try:
resp = func(*args, **kwargs)
return api_response(200, resp)
except LookupError:
resp = JsonPayloadBuilder().add_status(False).add_data(None) \
.add_message("Not found").compile()
return api_response(404, resp)
except Exception as e:
logging.error(str(e))
resp = JsonPayloadBuilder().add_status(False).add_data(None) \
.add_message("Internal error, please contact [email protected].").compile()
return api_response(500, resp)

return handle


def generate_url(bucket_name, object_key, expiration=3600):
try:
response = s3_client.generate_presigned_url('get_object',
Params={'Bucket': bucket_name,
'Key': object_key},
ExpiresIn=expiration)
except Exception as e:
logging.error(str(e))
return None

return response

0 comments on commit b4370a6

Please sign in to comment.