diff --git a/lambda/orc_ingestion.py b/lambda/orc_ingestion.py index 49b9ff0d..ea089726 100644 --- a/lambda/orc_ingestion.py +++ b/lambda/orc_ingestion.py @@ -1,8 +1,11 @@ import json import logging import os +import re import uuid + from datetime import datetime, timezone +from typing import List import boto3 @@ -13,83 +16,88 @@ s3_client = boto3.client("s3") -def lambda_handler(event, context): - """ - When using this lambda, please include previous_ods_code in the event json. - Example: - {"previous_ods_code": "ods_code"} - """ - try: - previous_ods_code = event["previous_ods_code"] - except KeyError as e: - return {"statusCode": 400, "error": "missing param 'previous_ods_code'"} - - file_to_ingest = os.environ["INGEST_FILE_NAME"] # this env var is set as "Patient-List-Test" in terraform +class DuplicateNhsNumberException(Exception): + pass - ingestion_bucket_name = os.environ["S3_BUCKET_NAME"] - suspension_queue_url = os.environ["SUSPENSION_QUEUE_URL"] - nhs_number_list: list[str] = get_nhs_number_list_from_s3(filename=file_to_ingest, bucket_name=ingestion_bucket_name) - all_trace_ids = [] - - for nhs_number in nhs_number_list: - trace_id = new_uuid() - all_trace_ids.append(trace_id) - logger.info(f"Assigned a new traceId: {trace_id}") - - suspension_message_json = build_suspension_message( - nhs_number=nhs_number, - previous_ods_code=previous_ods_code, - nems_message_id=trace_id, - ) - send_message_with_trace_id( - message_body=suspension_message_json, - queue_url=suspension_queue_url, - trace_id=trace_id, - ) - logger.info( - f"sent message for a patient to queue with traceId: {trace_id}" - ) +class InvalidFileFormatException(Exception): + pass - logger.info("Here are all the trace ids that related to this ingest:") - logger.info(str(all_trace_ids)) - return {"statusCode": 200, "body": str(all_trace_ids)} +class InvalidNhsNumberException(Exception): + pass -def get_nhs_number_list_from_s3(filename: str, bucket_name: str) -> list: - local_file_path = f"/tmp/{filename}" - s3_client.download_file( - bucket_name, filename, local_file_path - ) - with open(local_file_path, "r") as f: - nhs_number_list = f.read() - os.remove(local_file_path) - return nhs_number_list.split(',') +class NHSNumberValidator: + @staticmethod + def validate(nhs_numbers: str) -> List[str]: + if not bool(re.match('^[0-9,]+$', nhs_numbers)): + raise InvalidFileFormatException('File should only contain numbers and commas.') + + nhs_number_list = nhs_numbers.split(',') + if not all(len(x) == 10 for x in nhs_number_list): + raise InvalidNhsNumberException('All NHS numbers must be 10 digits long and there must be no trailing commas.') + + if len(nhs_number_list) != len(set(nhs_number_list)): + raise DuplicateNhsNumberException('Duplicate NHS numbers found.') + + return nhs_number_list -def new_uuid() -> str: - return str(uuid.uuid4()) +class S3FileReader: + @staticmethod + def read(filename, bucket_name) -> str: + response = s3_client.get_object(Bucket=bucket_name, Key=filename) + file_lines = response["Body"].readlines() + if len(file_lines) > 1: + raise InvalidFileFormatException('All NHS numbers must be contained on a single line, separated by commas.') -def get_timestamp(): - return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + return file_lines[0].decode('utf-8') -def build_suspension_message( - nhs_number: str, previous_ods_code: str, nems_message_id: str -) -> str: - return json.dumps({ +def process_nhs_number(nhs_number, previous_ods_code, suspension_queue_url) -> str: + trace_id = str(uuid.uuid4()) + logger.info(f"Assigned a new traceId: {trace_id}") + + suspension_message_json = json.dumps({ "nhsNumber": nhs_number, - "lastUpdated": get_timestamp(), + "lastUpdated": datetime.now(timezone.utc).replace(microsecond=0).isoformat(), "previousOdsCode": previous_ods_code, - "nemsMessageId": nems_message_id, + "nemsMessageId": trace_id, }) - -def send_message_with_trace_id(message_body: str, queue_url: str, trace_id: str): - return sqs_client.send_message( - QueueUrl=queue_url, - MessageAttributes={"traceId": {"DataType": "String", "StringValue": trace_id}}, - MessageBody=message_body, + sqs_client.send_message( + QueueUrl=suspension_queue_url, + MessageAttributes={ + "traceId": { + "DataType": "String", + "StringValue": trace_id + } + }, + MessageBody=suspension_message_json, ) + logger.info(f"Sent message for a patient to queue with traceId: {trace_id}") + + return trace_id + + +def lambda_handler(event, context) -> dict: + try: + previous_ods_code = event["previous_ods_code"] + except KeyError: + return {"statusCode": 400, "error": "missing attribute 'previous_ods_code'"} + + file_to_ingest = os.environ["INGEST_FILE_NAME"] + ingestion_bucket_name = os.environ["S3_BUCKET_NAME"] + suspension_queue_url = os.environ["SUSPENSION_QUEUE_URL"] + + nhs_numbers = S3FileReader.read(filename=file_to_ingest, bucket_name=ingestion_bucket_name) + nhs_number_list = NHSNumberValidator.validate(nhs_numbers) + + all_trace_ids = [process_nhs_number(nhs_number, previous_ods_code, suspension_queue_url) for nhs_number in nhs_number_list] + + logger.info("Here are all the trace ids that related to this ingest:") + logger.info(str(all_trace_ids)) + + return {"statusCode": 200, "body": str(all_trace_ids)}