diff --git a/lib/constructs/business/api-endpoint.ts b/lib/constructs/business/api-endpoint.ts index 1af9f7bce..12323a27a 100644 --- a/lib/constructs/business/api-endpoint.ts +++ b/lib/constructs/business/api-endpoint.ts @@ -15,7 +15,12 @@ import { API_DOMAIN } from '../../configs/route53/domain'; import { GraphqlApiService } from './graphql-api-service'; import { AbstractHttpApiService } from './http-api-service'; import { RestApiService } from './rest-api-service'; -import { GraphqlApiServiceId, graphqlApiServiceMap, RestApiServiceId, restApiServiceMap } from './service'; +import { + GraphqlApiServiceId, + graphqlApiServiceMap, + RestApiServiceId, + restApiServiceMap, +} from './service'; export interface ApiEndpointProps { zone: route53.IHostedZone; @@ -23,9 +28,18 @@ export interface ApiEndpointProps { } export abstract class AbstractApiEndpoint extends Construct { - abstract readonly apiEndpoint: apigw.RestApi | apigw.LambdaRestApi | apigw.SpecRestApi | apigw2.HttpApi | appsync.GraphqlApi; + abstract readonly apiEndpoint: + | apigw.RestApi + | apigw.LambdaRestApi + | apigw.SpecRestApi + | apigw2.HttpApi + | appsync.GraphqlApi; - protected constructor(scope: Construct, id: string, props?: ApiEndpointProps) { + protected constructor( + scope: Construct, + id: string, + props?: ApiEndpointProps, + ) { super(scope, id); } } @@ -44,7 +58,8 @@ export abstract class AbstractRestApiEndpoint extends AbstractApiEndpoint { } public getDomain(): string { - const domainName: apigw.DomainName | undefined = this.apiEndpoint.domainName; + const domainName: apigw.DomainName | undefined = + this.apiEndpoint.domainName; if (typeof domainName === 'undefined') { throw RangeError('Domain not configured for this API endpoint.'); @@ -52,8 +67,12 @@ export abstract class AbstractRestApiEndpoint extends AbstractApiEndpoint { return domainName.domainName; } - public addService(name: RestApiServiceId, dataSource?: string, auth = false): this { - this.apiServices[name] = new restApiServiceMap[name](this, `${ name }-api`, { + public addService( + name: RestApiServiceId, + dataSource?: string, + auth = false, + ): this { + this.apiServices[name] = new restApiServiceMap[name](this, `${name}-api`, { dataSource: dataSource, authorizer: auth ? this.authorizer : undefined, validator: this.reqValidator, @@ -61,7 +80,7 @@ export abstract class AbstractRestApiEndpoint extends AbstractApiEndpoint { return this; } - public abstract deploy(): void + public abstract deploy(): void; } export abstract class AbstractGraphqlEndpoint extends AbstractApiEndpoint { @@ -75,16 +94,30 @@ export abstract class AbstractGraphqlEndpoint extends AbstractApiEndpoint { super(scope, id, props); } - public addService(name: GraphqlApiServiceId, dataSource: string, auth = 'apiKey'): this { - this.apiServices[name] = new graphqlApiServiceMap[name](this, `${ name }-api`, { - dataSource: dynamodb.Table.fromTableName(this, `${ name }-table`, dataSource), - auth: this.authMode[auth], - }); + public addService( + name: GraphqlApiServiceId, + dataSource: string, + auth = 'apiKey', + ): this { + this.apiServices[name] = new graphqlApiServiceMap[name]( + this, + `${name}-api`, + { + dataSource: dynamodb.Table.fromTableName( + this, + `${name}-table`, + dataSource, + ), + auth: this.authMode[auth], + }, + ); return this; } public getDomain(): string { - const domain = this.apiEndpoint.graphqlUrl.match(/https:\/\/(.*)\/graphql/g); + const domain = this.apiEndpoint.graphqlUrl.match( + /https:\/\/(.*)\/graphql/g, + ); if (domain === null) { return ''; } @@ -127,7 +160,12 @@ export class WasedaTimeRestApiEndpoint extends AbstractRestApiEndpoint { description: 'The main API endpoint for WasedaTime Web App.', endpointTypes: [apigw.EndpointType.REGIONAL], deploy: false, - binaryMediaTypes: ['application/pdf', 'image/png'], + binaryMediaTypes: [ + 'application/pdf', + 'image/png', + 'image/jpeg', + 'image/gif', + ], }); this.apiEndpoint.addGatewayResponse('4xx-resp', { type: apigw.ResponseType.DEFAULT_4XX, @@ -171,7 +209,9 @@ export class WasedaTimeRestApiEndpoint extends AbstractRestApiEndpoint { }); new route53.ARecord(this, 'alias-record', { zone: props.zone, - target: route53.RecordTarget.fromAlias(new route53_targets.ApiGatewayDomain(this.domain)), + target: route53.RecordTarget.fromAlias( + new route53_targets.ApiGatewayDomain(this.domain), + ), recordName: API_DOMAIN, }); } @@ -186,7 +226,10 @@ export class WasedaTimeRestApiEndpoint extends AbstractRestApiEndpoint { api: this.apiEndpoint, retainDeployments: false, }); - const hash = Buffer.from(flatted.stringify(this.apiServices), 'binary').toString('base64'); + const hash = Buffer.from( + flatted.stringify(this.apiServices), + 'binary', + ).toString('base64'); if (STAGE === 'dev') { devDeployment.addToLogicalId(hash); } else if (STAGE === 'prod') { @@ -232,7 +275,6 @@ export class WasedaTimeGraphqlEndpoint extends AbstractGraphqlEndpoint { readonly apiServices: { [name: string]: GraphqlApiService } = {}; constructor(scope: Construct, id: string, props: ApiEndpointProps) { - super(scope, id, props); const apiKeyAuth: appsync.AuthorizationMode = { diff --git a/lib/constructs/business/rest-api-service.ts b/lib/constructs/business/rest-api-service.ts index a4ea17539..7d336d9f7 100644 --- a/lib/constructs/business/rest-api-service.ts +++ b/lib/constructs/business/rest-api-service.ts @@ -717,6 +717,7 @@ export class ForumThreadsApiService extends RestApiService { const boardResource = root.addResource('{board_id}'); const threadResource = boardResource.addResource('{thread_id}'); const userResource = root.addResource('user'); + const testResource = root.addResource('test'); const optionsForumHome = root.addCorsPreflight({ allowOrigins: allowOrigins, @@ -766,6 +767,18 @@ export class ForumThreadsApiService extends RestApiService { ], }); + const optionsTestThreads = testResource.addCorsPreflight({ + allowOrigins: allowOrigins, + allowHeaders: allowHeaders, + allowMethods: [ + apigw2.HttpMethod.GET, + apigw2.HttpMethod.POST, + apigw2.HttpMethod.PATCH, + apigw2.HttpMethod.DELETE, + apigw2.HttpMethod.OPTIONS, + ], + }); + const getRespModel = scope.apiEndpoint.addModel('threads-get-resp-model', { schema: forumThreadGetRespSchema, contentType: 'application/json', @@ -820,6 +833,10 @@ export class ForumThreadsApiService extends RestApiService { forumThreadsFunctions.deleteFunction, { proxy: true }, ); + const testPostIntegration = new apigw.LambdaIntegration( + forumThreadsFunctions.testPostFunction, + { proxy: true }, + ); const getAllForumThreads = root.addMethod( apigw2.HttpMethod.GET, @@ -919,13 +936,28 @@ export class ForumThreadsApiService extends RestApiService { requestValidator: props.validator, }, ); + const testPostForumThreads = threadResource.addMethod( + apigw2.HttpMethod.POST, + testPostIntegration, + { + operationName: 'testThread', + methodResponses: [ + { + statusCode: '200', + responseParameters: lambdaRespParams, + }, + ], + authorizer: props.authorizer, + requestValidator: props.validator, + }, + ); this.resourceMapping = { '/forum': { [apigw2.HttpMethod.GET]: getAllForumThreads, [apigw2.HttpMethod.OPTIONS]: optionsForumHome, }, - '/forum/{uid}': { + '/forum/user': { [apigw2.HttpMethod.GET]: getUserForumThreads, [apigw2.HttpMethod.OPTIONS]: optionsUserThreads, }, @@ -939,6 +971,10 @@ export class ForumThreadsApiService extends RestApiService { [apigw2.HttpMethod.PATCH]: patchForumThreads, [apigw2.HttpMethod.DELETE]: deleteForumThreads, }, + '/forum/test': { + [apigw2.HttpMethod.POST]: testPostForumThreads, + [apigw2.HttpMethod.OPTIONS]: optionsTestThreads, + }, }; } } diff --git a/lib/constructs/common/lambda-functions.ts b/lib/constructs/common/lambda-functions.ts index f7917f010..4a27ef7b5 100644 --- a/lib/constructs/common/lambda-functions.ts +++ b/lib/constructs/common/lambda-functions.ts @@ -491,6 +491,7 @@ export class ForumThreadFunctions extends Construct { readonly postFunction: lambda.Function; readonly patchFunction: lambda.Function; readonly deleteFunction: lambda.Function; + readonly testPostFunction: lambda.Function; constructor(scope: Construct, id: string, props: FunctionsProps) { super(scope, id); @@ -642,6 +643,22 @@ export class ForumThreadFunctions extends Construct { timeout: Duration.seconds(3), environment: props.envVars, }); + + this.testPostFunction = new lambda_py.PythonFunction( + this, + 'test-post-thread', + { + entry: 'src/lambda/test-post-thread', + description: 'lambda to test forum functionalities', + functionName: 'test-forum-thread', + logRetention: logs.RetentionDays.ONE_MONTH, + memorySize: 128, + role: DBPutRole, + runtime: lambda.Runtime.PYTHON_3_9, + timeout: Duration.seconds(3), + environment: props.envVars, + }, + ); } } diff --git a/src/lambda/test-post-thread/index.py b/src/lambda/test-post-thread/index.py new file mode 100644 index 000000000..cd3be3447 --- /dev/null +++ b/src/lambda/test-post-thread/index.py @@ -0,0 +1,68 @@ +from boto3.dynamodb.conditions import Key +import json +from datetime import datetime +from utils import JsonPayloadBuilder, table, resp_handler, build_thread_id, s3_client, bucket +import uuid +import base64 + + +@resp_handler +def test_post_thread(thread, uid): + + thread_id = build_thread_id() + + text = thread["body"] + + dt_now = datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' + + object_key = None + if "image" in thread: + image_data = base64.b64decode(thread["image"]) + content_type = thread.get("contentType", "image/jpeg") + # Validate the content type + if content_type not in ["image/jpeg", "image/png", "image/gif"]: + raise ValueError("Invalid content type") + # Extracts 'jpeg', 'png', or 'gif' from the MIME type + extension = content_type.split("/")[-1] + object_key = f"{thread_id}/image.{extension}" + + s3_client.put_object(Bucket=bucket, Key=object_key, + Body=image_data, ContentType=content_type) + + thread_item = { + "board_id": thread["board_id"], + "created_at": dt_now, + "updated_at": dt_now, + "title": thread["title"], + "body": text, + "uid": uid, + "thread_id": thread_id, + "tag_id": thread["tag_id"], + "group_id": thread["group_id"], + "univ_id": thread["univ_id"], + "views": 0, + "comment_count": 0, + "new_comment": False, + "obj_key": object_key, + } + + table.put_item(Item=thread_item) + + thread_item.pop('uid', None) + thread_item["mod"] = True + + body = JsonPayloadBuilder().add_status( + True).add_data(thread_item).add_message('').compile() + return body + + +def handler(event, context): + + req = json.loads(event['body']) + params = { + # "board_id": event["pathParameters"]["board_id"], + "thread": req["data"], + "uid": event['requestContext']['authorizer']['claims']['sub'] + } + + return test_post_thread(**params) diff --git a/src/lambda/test-post-thread/utils.py b/src/lambda/test-post-thread/utils.py new file mode 100644 index 000000000..0bfe58ae6 --- /dev/null +++ b/src/lambda/test-post-thread/utils.py @@ -0,0 +1,84 @@ +import base64 +import boto3 +import json +import logging +import os +from decimal import Decimal +from datetime import datetime +import uuid + +# AWS DynamoDB Resources +db = boto3.resource("dynamodb", region_name="ap-northeast-1") +table = db.Table(os.getenv('TABLE_NAME')) + +s3_client = boto3.client('s3') +bucket = os.getenv('BUCKET_NAME') + + +class DecimalEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Decimal): + return float(obj) + return json.JSONEncoder.default(self, obj) + + +class JsonPayloadBuilder: + payload = {} + + def add_status(self, success): + self.payload['success'] = success + return self + + def add_data(self, data): + self.payload['data'] = data + return self + + def add_message(self, msg): + self.payload['message'] = msg + return self + + def compile(self): + return json.dumps(self.payload, cls=DecimalEncoder, ensure_ascii=False).encode('utf8') + + +def api_response(code, body): + return { + "isBase64Encoded": False, + "statusCode": code, + 'headers': { + "Access-Control-Allow-Origin": '*', + "Content-Type": "application/json", + "Referrer-Policy": "origin" + }, + "multiValueHeaders": {"Access-Control-Allow-Methods": ["POST", "OPTIONS", "GET", "PATCH", "DELETE"]}, + "body": body + } + + +def resp_handler(func=None, headers=None): + def handle(*args, **kwargs): + try: + resp = func(*args, **kwargs) + return api_response(200, resp) + except LookupError: + resp = JsonPayloadBuilder().add_status(False).add_data(None) \ + .add_message("Not found").compile() + return api_response(404, resp) + except Exception as e: + logging.error(str(e)) + resp = JsonPayloadBuilder().add_status(False).add_data(None) \ + .add_message("Internal error, please contact bugs@wasedatime.com.").compile() + return api_response(500, resp) + + return handle + + +def build_thread_id(): + + unique_id = str(uuid.uuid4()) + + ts = datetime.now().strftime('%Y%m%d%H%M%S') + + thread_id = f"{ts}_{unique_id}" + + return thread_id