diff --git a/src/sbl_filing_api/services/request_action_validator.py b/src/sbl_filing_api/services/request_action_validator.py index 6acb066a..a0493245 100644 --- a/src/sbl_filing_api/services/request_action_validator.py +++ b/src/sbl_filing_api/services/request_action_validator.py @@ -3,6 +3,7 @@ import logging from abc import ABC, abstractmethod from enum import StrEnum +from http import HTTPStatus from typing import Any, Dict, List, Set import httpx @@ -44,12 +45,22 @@ def __eq__(self, other: "FiRequest"): @alru_cache(ttl=60 * 60) async def get_institution_data(fi_request: FiRequest): - async with httpx.AsyncClient() as client: - res = await client.get( - settings.user_fi_api_url + fi_request.lei, - headers={"authorization": fi_request.request.headers["authorization"]}, - ) - return res.json() + try: + async with httpx.AsyncClient() as client: + res = await client.get( + settings.user_fi_api_url + fi_request.lei, + headers={"authorization": fi_request.request.headers["authorization"]}, + ) + if res.status_code == HTTPStatus.OK: + return res.json() + except Exception: + log.exception("Failed to retrieve fi data for %s", fi_request.lei) + + """ + `alru_cache` seems to cache `None` results, even though documentation for normal `lru_cache` seems to indicate it doesn't by cache `None` by default. + So manually invalidate the cache if no returnable result found + """ + get_institution_data.cache_invalidate(fi_request) class ActionValidator(ABC): @@ -86,7 +97,7 @@ def __init__(self): super().__init__("check_lei_tin") def __call__(self, institution: Dict[str, Any], **kwargs): - if not institution["tax_id"]: + if not (institution and institution.get("tax_id")): return "Cannot sign filing. TIN is required to file." @@ -130,7 +141,9 @@ def __call__(self, filing: FilingDAO, **kwargs): return f"Cannot sign filing. Filing for {filing.lei} for period {filing.filing_period} does not have contact info defined." -validation_registry = {Validator() for Validator in ActionValidator.__subclasses__()} +validation_registry = { + validator.name: validator for validator in {Validator() for Validator in ActionValidator.__subclasses__()} +} def set_context(requirements: Set[UserActionContext]): @@ -168,12 +181,16 @@ def validate_user_action(validator_names: Set[str], exception_name: str): async def _run_validations(request: Request): res = [] - validators = set(filter(lambda validator: validator.name in validator_names, validation_registry)) - for validator in validators: + for validator_name in validator_names: + validator = validation_registry.get(validator_name) + if not validator: + log.warning("Action validator [%s] not found.", validator_name) + continue if inspect.iscoroutinefunction(validator.__call__): res.append(await validator(**request.state.context)) else: res.append(validator(**request.state.context)) + res = [r for r in res if r] if len(res): raise RegTechHttpException( diff --git a/tests/services/test_request_action_validator.py b/tests/services/test_request_action_validator.py new file mode 100644 index 00000000..41cd8d0d --- /dev/null +++ b/tests/services/test_request_action_validator.py @@ -0,0 +1,157 @@ +from http import HTTPStatus +from logging import Logger + +import pytest +from fastapi import Request +from pytest_mock import MockerFixture +from regtech_api_commons.api.exceptions import RegTechHttpException + +from sbl_filing_api.entities.models.dao import ContactInfoDAO, FilingDAO, SubmissionDAO +from sbl_filing_api.entities.models.model_enums import SubmissionState +from sbl_filing_api.services.request_action_validator import UserActionContext, set_context, validate_user_action + + +@pytest.fixture +def httpx_unauthed_mock(mocker: MockerFixture) -> None: + mock_client_get = mocker.patch("httpx.AsyncClient.get") + mock_response = mocker.patch("httpx.Response") + mock_response.status_code = HTTPStatus.FORBIDDEN + mock_client_get.return_value = mock_response + + +@pytest.fixture +def httpx_authed_mock(mocker: MockerFixture) -> None: + mock_client_get = mocker.patch("httpx.AsyncClient.get") + mock_response = mocker.patch("httpx.Response") + mock_response.status_code = HTTPStatus.OK + mock_response.json.return_value = { + "tax_id": "12-3456789", + "lei_status_code": "LAPSED", + "lei_status": {"name": "Lapsed", "code": "LAPSED", "can_file": False}, + } + mock_client_get.return_value = mock_response + + +@pytest.fixture +async def filing_mock(mocker: MockerFixture) -> FilingDAO: + sub_mock = mocker.patch("sbl_filing_api.entities.models.dao.SubmissionDAO") + sub_mock.state = SubmissionState.UPLOAD_FAILED + filing = FilingDAO(lei="1234567890ABCDEFGH00", filing_period="2024", submissions=[sub_mock]) + return filing + + +@pytest.fixture +def request_mock(mocker: MockerFixture) -> Request: + mock = mocker.patch("fastapi.Request") + mock.path_params = {"lei": "1234567890ABCDEFGH00", "period_code": "2024"} + return mock + + +@pytest.fixture +def request_mock_valid_context(mocker: MockerFixture, request_mock: Request, filing_mock: FilingDAO) -> Request: + filing_mock.is_voluntary = True + filing_mock.submissions = [SubmissionDAO(state=SubmissionState.SUBMISSION_ACCEPTED)] + filing_mock.contact_info = ContactInfoDAO() + + request_mock.state.context = { + "lei": "1234567890ABCDEFGH00", + "period": "2024", + UserActionContext.INSTITUTION: { + "tax_id": "12-3456789", + "lei_status_code": "ISSUED", + "lei_status": {"name": "Issued", "code": "ISSUED", "can_file": True}, + }, + UserActionContext.FILING: filing_mock, + } + return request_mock + + +@pytest.fixture +def request_mock_invalid_context(mocker: MockerFixture, request_mock: Request, filing_mock: FilingDAO) -> Request: + request_mock.state.context = { + "lei": "1234567890ABCDEFGH00", + "period": "2024", + UserActionContext.INSTITUTION: { + "lei_status_code": "LAPSED", + "lei_status": {"name": "Lapsed", "code": "LAPSED", "can_file": False}, + }, + UserActionContext.FILING: filing_mock, + } + return request_mock + + +@pytest.fixture +def log_mock(mocker: MockerFixture) -> Logger: + return mocker.patch("sbl_filing_api.services.request_action_validator.log") + + +async def test_validations_with_errors(request_mock_invalid_context: Request): + run_validations = validate_user_action( + { + "check_lei_status", + "check_lei_tin", + "check_filing_exists", + "check_sub_accepted", + "check_voluntary_filer", + "check_contact_info", + }, + "Test Exception", + ) + with pytest.raises(RegTechHttpException) as e: + await run_validations(request_mock_invalid_context) + assert e.value.name == "Test Exception" + errors = e.value.detail + assert ( + "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have a latest submission in the SUBMISSION_ACCEPTED state." + in errors + ) + assert ( + "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have a selection of is_voluntary defined." + in errors + ) + assert ( + "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have contact info defined." + in errors + ) + assert "Cannot sign filing. TIN is required to file." in errors + assert "Cannot sign filing. LEI status of LAPSED cannot file." in errors + + +async def test_validations_no_errors(request_mock_valid_context: Request): + run_validations = validate_user_action( + { + "check_lei_status", + "check_lei_tin", + "check_filing_exists", + "check_sub_accepted", + "check_voluntary_filer", + "check_contact_info", + }, + "Test Exception", + ) + await run_validations(request_mock_valid_context) + + +async def test_lei_status_bad_api_res(request_mock: Request, httpx_unauthed_mock): + run_validations = validate_user_action({"check_lei_status"}, "Test Exception") + context_setter = set_context({UserActionContext.INSTITUTION}) + await context_setter(request_mock) + + with pytest.raises(RegTechHttpException) as e: + await run_validations(request_mock) + assert "Unable to determine LEI status." in e.value.detail + + +async def test_lei_status_good_api_res(request_mock: Request, httpx_authed_mock): + run_validations = validate_user_action({"check_lei_status"}, "Test Exception") + context_setter = set_context({UserActionContext.INSTITUTION}) + await context_setter(request_mock) + with pytest.raises(RegTechHttpException) as e: + await run_validations(request_mock) + assert "Cannot sign filing. LEI status of LAPSED cannot file." in e.value.detail + + +async def test_invalid_validation(request_mock_invalid_context: Request, log_mock: Logger): + run_validations = validate_user_action({"fake_validation"}, "Test Exception") + await run_validations(request_mock_invalid_context) + log_mock.warning.assert_called_with("Action validator [%s] not found.", "fake_validation")