Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added /v1/filing/{period_name} endpoint, corresponding repo and pytests #63

Merged
merged 8 commits into from
Feb 14, 2024
149 changes: 80 additions & 69 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ regtech-data-validator = {git = "https://github.com/cfpb/regtech-data-validator.
python-multipart = "^0.0.6"
boto3 = "^1.33.12"
alembic = "^1.12.0"
async-lru = "^2.0.4"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
Expand Down
2 changes: 1 addition & 1 deletion src/entities/models/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class FilingTaskStateDTO(BaseModel):
task: FilingTaskDTO
user: str | None = None
state: FilingTaskState
change_timestamp: datetime
change_timestamp: datetime | None = None


class FilingDTO(BaseModel):
Expand Down
102 changes: 68 additions & 34 deletions src/entities/repos/submission_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Any, List, TypeVar
from entities.engine import get_session
from regtech_api_commons.models.auth import AuthenticatedUser

from copy import deepcopy

from async_lru import alru_cache

from entities.models import (
SubmissionDAO,
Expand All @@ -14,20 +19,21 @@
FilingDTO,
FilingDAO,
FilingTaskDAO,
FilingTaskStateDAO,
FilingTaskState,
)

logger = logging.getLogger(__name__)

T = TypeVar("T")


class NoFilingPeriodException(Exception):
pass


async def get_submissions(session: AsyncSession, filing_id: int = None) -> List[SubmissionDAO]:
async with session.begin():
stmt = select(SubmissionDAO)
if filing_id:
stmt = stmt.filter(SubmissionDAO.filing == filing_id)
results = await session.scalars(stmt)
return results.all()
return await query_helper(session, SubmissionDAO, "filing", filing_id)


async def get_latest_submission(session: AsyncSession, filing_id: int) -> List[SubmissionDAO]:
Expand All @@ -42,29 +48,44 @@ async def get_latest_submission(session: AsyncSession, filing_id: int) -> List[S


async def get_filing_periods(session: AsyncSession) -> List[FilingPeriodDAO]:
async with session.begin():
stmt = select(FilingPeriodDAO)
results = await session.scalars(stmt)
return results.all()
return await query_helper(session, FilingPeriodDAO)


async def get_submission(session: AsyncSession, submission_id: int) -> SubmissionDAO:
return await query_helper(session, submission_id, SubmissionDAO)
result = await query_helper(session, SubmissionDAO, "id", submission_id)
return result[0] if result else None


async def get_filing(session: AsyncSession, filing_id: int) -> FilingDAO:
return await query_helper(session, filing_id, FilingDAO)
result = await query_helper(session, FilingDAO, "id", filing_id)
if result:
result = await populate_missing_tasks(session, result)
return result[0] if result else None


async def get_period_filings_for_user(
session: AsyncSession, user: AuthenticatedUser, period_name: str
) -> List[FilingDAO]:
filing_period = await query_helper(session, FilingPeriodDAO, "name", period_name)
if filing_period:
filings = await query_helper(session, FilingDAO, "filing_period", filing_period[0].id)
filings = [f for f in filings if f.lei in user.institutions]
if filings:
filings = await populate_missing_tasks(session, filings)

return filings
else:
raise NoFilingPeriodException(f"There is no Filing Period with name {period_name} defined in the database.")


async def get_filing_period(session: AsyncSession, filing_period_id: int) -> FilingPeriodDAO:
return await query_helper(session, filing_period_id, FilingPeriodDAO)
result = await query_helper(session, FilingPeriodDAO, "id", filing_period_id)
return result[0] if result else None


@alru_cache(maxsize=128)
async def get_filing_tasks(session: AsyncSession) -> List[FilingTaskDAO]:
async with session.begin():
stmt = select(FilingTaskDAO)
results = await session.scalars(stmt)
return results.all()
return await query_helper(session, FilingTaskDAO)


async def add_submission(session: AsyncSession, submission: SubmissionDTO) -> SubmissionDAO:
Expand Down Expand Up @@ -101,20 +122,33 @@ async def upsert_filing(session: AsyncSession, filing: FilingDTO) -> FilingDAO:
return await upsert_helper(session, filing, FilingDAO)


async def upsert_helper(session: AsyncSession, original_data: Any, type: T) -> T:
async with session.begin():
copy_data = original_data.__dict__.copy()
# this is only for if a DAO is passed in
# Should be DTOs, but hey, it's python
if copy_data["id"] is not None and "_sa_instance_state" in copy_data:
del copy_data["_sa_instance_state"]
new_dao = type(**copy_data)
new_dao = await session.merge(new_dao)
await session.commit()
return new_dao


async def query_helper(session: AsyncSession, id: int, type: T) -> T:
async with session.begin():
stmt = select(type).filter(type.id == id)
return await session.scalar(stmt)
async def upsert_helper(session: AsyncSession, original_data: Any, table_obj: T) -> T:
copy_data = original_data.__dict__.copy()
# this is only for if a DAO is passed in
# Should be DTOs, but hey, it's python
if copy_data["id"] is not None and "_sa_instance_state" in copy_data:
del copy_data["_sa_instance_state"]
new_dao = table_obj(**copy_data)
new_dao = await session.merge(new_dao)
await session.commit()
return new_dao


async def query_helper(session: AsyncSession, table_obj: T, column_name: str = None, value: Any = None) -> List[T]:
stmt = select(table_obj)
if column_name and value:
stmt = stmt.filter(getattr(table_obj, column_name) == value)
return (await session.scalars(stmt)).all()


async def populate_missing_tasks(session: AsyncSession, filings: List[FilingDAO]) -> List[FilingDAO]:
filing_tasks = await get_filing_tasks(session)
filings_copy = deepcopy(filings)
for f in filings_copy:
tasks = [t.task for t in f.tasks]
missing_tasks = [t for t in filing_tasks if t not in tasks]
for mt in missing_tasks:
f.tasks.append(
FilingTaskStateDAO(filing=f.id, task_name=mt.name, state=FilingTaskState.NOT_STARTED, user="")
)
return filings_copy
13 changes: 11 additions & 2 deletions src/routers/filing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from http import HTTPStatus
from fastapi import Depends, Request, UploadFile, BackgroundTasks, status
from fastapi import Depends, Request, UploadFile, BackgroundTasks, status, HTTPException
from fastapi.responses import JSONResponse
from regtech_api_commons.api import Router
from services import submission_processor
from typing import Annotated, List

from entities.engine import get_session
from entities.models import FilingPeriodDTO, SubmissionDTO
from entities.models import FilingPeriodDTO, SubmissionDTO, FilingDTO
from entities.repos import submission_repo as repo

from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -27,6 +27,15 @@ async def get_filing_periods(request: Request):
return await repo.get_filing_periods(request.state.db_session)


# This has to come after the /periods endpoint
@router.get("/{period_name}", response_model=List[FilingDTO])
async def get_filings(request: Request, period_name: str):
try:
return await repo.get_period_filings_for_user(request.state.db_session, period_name)
except repo.NoFilingPeriodException as nfpe:
raise HTTPException(status_code=500, detail=str(nfpe))


@router.post("/{lei}/submissions/{submission_id}", status_code=HTTPStatus.ACCEPTED)
async def upload_file(
request: Request, lei: str, submission_id: str, file: UploadFile, background_tasks: BackgroundTasks
Expand Down
43 changes: 42 additions & 1 deletion tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from pytest_mock import MockerFixture
from unittest.mock import Mock

from entities.models import FilingPeriodDAO, FilingType
from entities.models import FilingPeriodDAO, FilingType, FilingDAO, FilingTaskStateDAO, FilingTaskState, FilingTaskDAO
from entities.repos import submission_repo as repo

from regtech_api_commons.models.auth import AuthenticatedUser
from starlette.authentication import AuthCredentials, UnauthenticatedUser
Expand Down Expand Up @@ -57,3 +58,43 @@ def get_filing_period_mock(mocker: MockerFixture) -> Mock:
)
]
return mock


@pytest.fixture
def get_filings_mock(mocker: MockerFixture) -> Mock:
mock = mocker.patch("entities.repos.submission_repo.get_period_filings_for_user")
mock.return_value = [
FilingDAO(
id=1,
lei="12345678",
tasks=[
FilingTaskStateDAO(
filing=1,
task=FilingTaskDAO(name="Task-1", task_order=1),
state=FilingTaskState.NOT_STARTED,
user="",
),
FilingTaskStateDAO(
filing=1,
task=FilingTaskDAO(name="Task-2", task_order=2),
state=FilingTaskState.NOT_STARTED,
user="",
),
],
filing_period=1,
institution_snapshot_id="v1",
contact_info="[email protected]",
)
]
return mock


@pytest.fixture
def get_filings_error_mock(mocker: MockerFixture) -> Mock:
mock = mocker.patch(
"entities.repos.submission_repo.get_period_filings_for_user",
side_effect=repo.NoFilingPeriodException(
"There is no Filing Period with name FilingPeriod2025 defined in the database."
),
)
return mock
16 changes: 16 additions & 0 deletions tests/api/routers/test_filing_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ def test_unauthed_get_submissions(
res = client.get("/v1/filing/123456790/filings/1/submissions")
assert res.status_code == 403

def test_get_filings(self, app_fixture: FastAPI, get_filings_mock: Mock):
client = TestClient(app_fixture)
res = client.get("/v1/filing/FilingPeriod2024")
get_filings_mock.assert_called_with(ANY, "FilingPeriod2024")
assert res.status_code == 200
assert len(res.json()) == 1
assert res.json()[0]["lei"] == "12345678"

def test_get_filings_with_error(self, app_fixture: FastAPI, get_filings_error_mock: Mock):
client = TestClient(app_fixture)
response = client.get("/v1/filing/FilingPeriod2025")
assert response.status_code == 500
assert response.json() == {
"detail": "There is no Filing Period with name FilingPeriod2025 defined in the database."
}

async def test_get_submissions(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.submission_repo.get_submissions")
mock.return_value = [
Expand Down
23 changes: 23 additions & 0 deletions tests/entities/repos/test_submission_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

from entities.engine import engine as entities_engine

from regtech_api_commons.models.auth import AuthenticatedUser


class TestSubmissionRepo:
@pytest.fixture(scope="function", autouse=True)
Expand Down Expand Up @@ -155,10 +157,31 @@ async def test_get_filing(self, query_session: AsyncSession):
res = await repo.get_filing(query_session, filing_id=1)
assert res.id == 1
assert res.lei == "1234567890"
assert len(res.tasks) == 2
assert FilingTaskState.NOT_STARTED in set([t.state for t in res.tasks])

res = await repo.get_filing(query_session, filing_id=2)
assert res.id == 2
assert res.lei == "ABCDEFGHIJ"
assert len(res.tasks) == 2
assert FilingTaskState.NOT_STARTED in set([t.state for t in res.tasks])

async def test_get_period_filings_for_user(self, query_session: AsyncSession, mocker: MockerFixture):
user = AuthenticatedUser.from_claim({"institutions": ["ZYXWVUTSRQP"]})
results = await repo.get_period_filings_for_user(query_session, user, period_name="FilingPeriod2024")
assert len(results) == 0

user = AuthenticatedUser.from_claim({"institutions": ["1234567890", "0987654321"]})
results = await repo.get_period_filings_for_user(query_session, user, period_name="FilingPeriod2024")
assert len(results) == 1
assert results[0].id == 1
assert results[0].lei == "1234567890"
assert len(results[0].tasks) == 2

try:
await repo.get_period_filings_for_user(query_session, user, period_name="FilingPeriod2025")
except repo.NoFilingPeriodException as nfpe:
assert str(nfpe) == "There is no Filing Period with name FilingPeriod2025 defined in the database."

async def test_get_latest_submission(self, query_session: AsyncSession):
res = await repo.get_latest_submission(query_session, filing_id=2)
Expand Down
Loading