diff --git a/application/fhir_logging_client/service.py b/application/fhir_logging_client/service.py index ecf9a7b..b370e08 100644 --- a/application/fhir_logging_client/service.py +++ b/application/fhir_logging_client/service.py @@ -1,4 +1,3 @@ -import json import logging from datetime import datetime @@ -6,7 +5,8 @@ from flask import current_app from fhir.resources.auditevent import AuditEvent -from application.oauth_server.service import TokenService, get_timestamp_now +from application.oauth_server.service import TokenService +from application.utils import new_trace_headers logger = logging.getLogger('fhir_logging_service') logger.setLevel(logging.DEBUG) @@ -15,18 +15,20 @@ class FhirLoggingService: @staticmethod - def register_idp_interaction(entity_what_reference: str): + def register_idp_interaction(entity_what_reference: str, trace_headers: dict): logger.info(f"Registering idp interaction for entity: [{entity_what_reference}]") - audit_event = FhirLoggingService._get_audit_event(entity_what_reference) + audit_event = FhirLoggingService._get_audit_event(entity_what_reference, trace_headers) access_token = token_service.get_system_access_token() endpoint = f'{current_app.config["FHIR_CLIENT_SERVERURL"]}/AuditEvent' logger.info(f"About to submit AuditEvent to endpoint [{endpoint}]") - audit_event_json = json.dumps(audit_event) logger.info(f"generated audit event json: {audit_event}") - response = requests.post(endpoint, json=audit_event, headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/fhir+json;charset=utf-8"}) + headers = new_trace_headers(trace_headers, + {"Authorization": f"Bearer {access_token}", + "Content-Type": "application/fhir+json;charset=utf-8"}) + response = requests.post(endpoint, json=audit_event, headers=headers) if response.ok: logger.info(f"Audit event created successfully with code [{response.status_code}]") @@ -36,12 +38,28 @@ def register_idp_interaction(entity_what_reference: str): return response @staticmethod - def _get_audit_event(entity_what_reference: str): + def _get_audit_event(entity_what_reference: str, trace_headers: dict): entity_type = entity_what_reference.split("/")[0] if entity_type != "Patient" and entity_type != "Practitioner": raise Exception(f"Cannot log IDP interaction - Entity type must be Patient or Practitioner. Got [{entity_type}] instead.") + extension_ = [] + if 'X-Request-Id' in trace_headers: + extension_.append({ + "url": "http://koppeltaal.nl/fhir/StructureDefinition/request-id", + "valueId": trace_headers['X-Request-Id'] + }) + if 'X-Correlation-Id' in trace_headers: + extension_.append({ + "url": "http://koppeltaal.nl/fhir/StructureDefinition/correlation-id", + "valueId": trace_headers['X-Correlation-Id'] + }) + if 'X-Trace-Id' in trace_headers: + extension_.append({ + "url": "http://koppeltaal.nl/fhir/StructureDefinition/trace-id", + "valueId": trace_headers['X-Trace-Id'] + }) data = { "resourceType": "AuditEvent", "meta": { @@ -49,6 +67,7 @@ def _get_audit_event(entity_what_reference: str): "http://koppeltaal.nl/fhir/StructureDefinition/KT2AuditEvent" ] }, + "extension": extension_, "type": { "system": "http://dicom.nema.org/resources/ontology/DCM", "code": "110114", diff --git a/application/idp_client/service.py b/application/idp_client/service.py index a3aa96b..4b8f9b8 100644 --- a/application/idp_client/service.py +++ b/application/idp_client/service.py @@ -1,6 +1,7 @@ import logging from typing import Tuple from urllib.parse import urlencode +from uuid import uuid4 import jwt as pyjwt import requests @@ -9,6 +10,7 @@ from application.fhir_logging_client.service import fhir_logging_service from application.oauth_server.model import Oauth2Session, IdentityProvider from application.oauth_server.service import token_service +from application.utils import new_trace_headers logger = logging.getLogger('idp_service') logger.setLevel(logging.DEBUG) @@ -24,6 +26,8 @@ def consume_idp_code(self) -> Tuple[str, int]: user_claim = "email" state = request.values.get('state') + trace_headers = self._get_trace_headers() + if not state: logger.error('No state found on the authentication response') return 'Bad request, no state found on the authentication response', 400 @@ -36,6 +40,9 @@ def consume_idp_code(self) -> Tuple[str, int]: hti_launch_token = pyjwt.decode(oauth2_session.launch, options={"verify_signature": False}) logger.info(f'[{oauth2_session.id}] Consuming idp oidc code for user {hti_launch_token["sub"]}') + if 'X-Trace-Id' not in trace_headers: + trace_headers['X-Trace-Id'] = hti_launch_token['jti'] # The JTI token is the trace id if not set + code = request.values.get('code') if not code: logger.error(f'[{oauth2_session.id}] no code parameter found') @@ -68,7 +75,9 @@ def consume_idp_code(self) -> Tuple[str, int]: # get the user from the FHIR server, to verify if the Patient has this email set as an identifier access_token = token_service.get_system_access_token() - user_response = requests.get(f'{current_app.config["FHIR_CLIENT_SERVERURL"]}/{hti_launch_token["sub"]}', headers={"Authorization": "Bearer " + access_token}) + headers = new_trace_headers(trace_headers, {"Authorization": "Bearer " + access_token}) + + user_response = requests.get(f'{current_app.config["FHIR_CLIENT_SERVERURL"]}/{hti_launch_token["sub"]}', headers=headers) if not user_response.ok: logger.error(f'Failed to fetch user {hti_launch_token["sub"]} with error code [{user_response.status_code}] and message: \n{user_response.reason}') return 'Bad request, user could not be fetched from store', 400 @@ -84,11 +93,22 @@ def consume_idp_code(self) -> Tuple[str, int]: logger.info(f'[{oauth2_session.id}] user id matched between HTI and IDP by user_identifier [{user_identifier}]') - fhir_logging_service.register_idp_interaction(f'Patient/{launching_user_resource["id"]}') + fhir_logging_service.register_idp_interaction(f'Patient/{launching_user_resource["id"]}', trace_headers) # As the user has been verified, finish the initial OAuth launch flow by responding with the code return f'{oauth2_session.redirect_uri}?{urlencode({"code": oauth2_session.code, "state": oauth2_session.state})}', 302 + def _get_trace_headers(self): + trace_headers = { + 'X-Request-Id': request.headers.get('X-Request-Id', str(uuid4())) + } + if 'X-Correlation-Id' in request.headers: + trace_headers['X-Correlation-Id'] = request.headers['X-Correlation-Id'] + if 'X-Trace-Id' in request.headers: + trace_headers['X-Trace-Id'] = request.headers['X-Trace-Id'] + + return trace_headers + @staticmethod def exchange_idp_code(code, oauth2_session: Oauth2Session): diff --git a/application/utils.py b/application/utils.py index 0ad5e70..dd9df51 100644 --- a/application/utils.py +++ b/application/utils.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta from functools import wraps +from uuid import uuid4 from authlib.jose import Key from cryptography.hazmat.primitives import serialization @@ -18,6 +19,17 @@ def get_private_key_as_pem(key: Key): return private_key_bytes +def new_trace_headers(trace_headers: dict, headers: dict = None): + rv = {"X-Request-Id": str(uuid4())} + if headers: + rv.update(headers) + if "X-Request-Id" in trace_headers: + # No, this is not a typo, the X-Request-Id goes into X-Correlation-Id + rv["X-Correlation-Id"] = trace_headers['X-Request-Id'] + if "X-Trace-Id" in trace_headers: + rv["X-Trace-Id"] = trace_headers['X-Trace-Id'] + return rv + def oidc_smart_config_cached(): """ Flask decorator that allow to set Expire and Cache headers. """ diff --git a/test/test_fhir_logging_service.py b/test/test_fhir_logging_service.py index dc0a8fc..8338702 100644 --- a/test/test_fhir_logging_service.py +++ b/test/test_fhir_logging_service.py @@ -65,7 +65,7 @@ def testing_app(server_key: Key): def test_happy(mock1, testing_app: FlaskClient): testing_app.get("test") # TODO: Ugly fix to initialize app context - mocking the flask.request would be nicer - resp = fhir_logging_service.register_idp_interaction("Patient/123") + resp = fhir_logging_service.register_idp_interaction("Patient/123", {}) json_content = resp.json()['json'] resp_audit_event = AuditEvent(**json_content) @@ -75,4 +75,32 @@ def test_happy(mock1, testing_app: FlaskClient): assert resp_audit_event.source.observer.reference == "Device/my-unit-test-device-id" assert resp_audit_event.outcome == "0" assert 'Authorization' in resp.json()['headers'] + assert 'X-Request-Id' in resp.json()['headers'] +@mock.patch('requests.post', side_effect=_test_fhir_logging_happy_post) +def test_happy_headers(mock1, testing_app: FlaskClient): + + testing_app.get("test") # TODO: Ugly fix to initialize app context - mocking the flask.request would be nicer + trace_headers = { + 'X-Request-Id': str(uuid4()), + 'X-Correlation-Id': str(uuid4()), + 'X-Trace-Id': str(uuid4()) + } + resp = fhir_logging_service.register_idp_interaction("Patient/123", trace_headers) + + json_content = resp.json()['json'] + resp_audit_event = AuditEvent(**json_content) + + assert resp_audit_event.entity[0].what.reference == "Patient/123" + assert resp_audit_event.agent[0].who.reference == "Device/my-unit-test-device-id" + assert resp_audit_event.source.observer.reference == "Device/my-unit-test-device-id" + assert resp_audit_event.extension[0].valueId == trace_headers['X-Request-Id'] + assert resp_audit_event.extension[1].valueId == trace_headers['X-Correlation-Id'] + assert resp_audit_event.extension[2].valueId == trace_headers['X-Trace-Id'] + assert resp_audit_event.outcome == "0" + assert 'Authorization' in resp.json()['headers'] + assert 'X-Request-Id' in resp.json()['headers'] + # Correlation ID should be the original Request ID + assert trace_headers['X-Request-Id'] == resp.json()['headers']['X-Correlation-Id'] + # Trace ID should remain the same + assert trace_headers['X-Trace-Id'] == resp.json()['headers']['X-Trace-Id']