diff --git a/db_revisions/env.py b/db_revisions/env.py index 722aa94f..6b8f4d77 100644 --- a/db_revisions/env.py +++ b/db_revisions/env.py @@ -1,3 +1,5 @@ +import os + from logging.config import fileConfig from sqlalchemy import engine_from_config diff --git a/pyproject.toml b/pyproject.toml index da691ae5..d976444d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,9 @@ addopts = [ "--strict-markers", "-rfE", ] +env = [ + "INST_DB_SCHEMA=main" +] testpaths = ["tests"] [tool.black] diff --git a/src/entities/engine/__init__.py b/src/entities/engine/__init__.py new file mode 100644 index 00000000..1fa81932 --- /dev/null +++ b/src/entities/engine/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["get_session"] + +from .engine import get_session diff --git a/src/entities/engine/engine.py b/src/entities/engine/engine.py new file mode 100644 index 00000000..9a436825 --- /dev/null +++ b/src/entities/engine/engine.py @@ -0,0 +1,20 @@ +from sqlalchemy.ext.asyncio import ( + create_async_engine, + async_sessionmaker, + async_scoped_session, +) +from asyncio import current_task +from config import settings + +engine = create_async_engine(settings.inst_conn.unicode_string(), echo=True).execution_options( + schema_translate_map={None: settings.inst_db_schema} +) +SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task) + + +async def get_session(): + session = SessionLocal() + try: + yield session + finally: + await session.close() diff --git a/src/entities/models/__init__.py b/src/entities/models/__init__.py new file mode 100644 index 00000000..27994b25 --- /dev/null +++ b/src/entities/models/__init__.py @@ -0,0 +1,12 @@ +__all__ = [ + "Base", + "SubmissionDAO", + "ValidationResultDAO", + "RecordDAO", + "RecordDTO", + "ValidationResultDTO", + "SubmissionDTO", +] + +from .dao import Base, SubmissionDAO, ValidationResultDAO, RecordDAO +from .dto import RecordDTO, ValidationResultDTO, SubmissionDTO diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py new file mode 100644 index 00000000..d2ef7875 --- /dev/null +++ b/src/entities/models/dao.py @@ -0,0 +1,49 @@ +from datetime import datetime +from typing import get_args, List, Any, Literal +from sqlalchemy import ForeignKey, func, Enum +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.types import JSON + +Severity = Literal["error", "warning"] + + +class Base(AsyncAttrs, DeclarativeBase): + pass + + +class AuditMixin(object): + event_time: Mapped[datetime] = mapped_column(server_default=func.now()) + + +class SubmissionDAO(AuditMixin, Base): + __tablename__ = "submission" + submission_id: Mapped[str] = mapped_column(index=True, primary_key=True) + submitter: Mapped[str] + lei: Mapped[str] + results: Mapped[List["ValidationResultDAO"]] = relationship(back_populates="submission") + json_dump: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=True) + + def __str__(self): + return f"Submission ID: {self.submission_id}, Submitter: {self.submitter}, LEI: {self.lei}" + + +class ValidationResultDAO(AuditMixin, Base): + __tablename__ = "validation_result" + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + submission_id: Mapped[str] = mapped_column(ForeignKey("submission.submission_id")) + submission: Mapped["SubmissionDAO"] = relationship(back_populates="results") # if we care about bidirectional + validation_id: Mapped[str] + field_name: Mapped[str] + severity: Mapped[Severity] = mapped_column(Enum(*get_args(Severity))) + records: Mapped[List["RecordDAO"]] = relationship(back_populates="result") + + +class RecordDAO(AuditMixin, Base): + __tablename__ = "validation_result_record" + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + result_id: Mapped[str] = mapped_column(ForeignKey("validation_result.id")) + result: Mapped["ValidationResultDAO"] = relationship(back_populates="records") # if we care about bidirectional + record: Mapped[int] + data: Mapped[str] diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py new file mode 100644 index 00000000..f66bb8cc --- /dev/null +++ b/src/entities/models/dto.py @@ -0,0 +1,27 @@ +from typing import List +from pydantic import BaseModel, ConfigDict + + +class RecordDTO(BaseModel): + model_config = ConfigDict(from_attributes=True) + + record: int + data: str + + +class ValidationResultDTO(BaseModel): + model_config = ConfigDict(from_attributes=True) + + validation_id: str + field_name: str + severity: str + records: List[RecordDTO] = [] + + +class SubmissionDTO(BaseModel): + model_config = ConfigDict(from_attributes=True) + + submission_id: str + lei: str + submitter: str + results: List[ValidationResultDTO] = [] diff --git a/src/entities/repos/__init__.py b/src/entities/repos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/entities/repos/submission_repo.py b/src/entities/repos/submission_repo.py new file mode 100644 index 00000000..2bb45d7a --- /dev/null +++ b/src/entities/repos/submission_repo.py @@ -0,0 +1,44 @@ +from sqlalchemy import select +from sqlalchemy.orm import joinedload +from sqlalchemy.ext.asyncio import AsyncSession + +import pandas as pd +from entities.models import SubmissionDAO, ValidationResultDAO, RecordDAO + + +async def get_submission(session: AsyncSession, submission_id: str) -> SubmissionDAO: + async with session.begin(): + stmt = ( + select(SubmissionDAO) + .options(joinedload(SubmissionDAO.results).joinedload(ValidationResultDAO.records)) + .filter(SubmissionDAO.submission_id == submission_id) + ) + return await session.scalar(stmt) + +# I was thinking this would be called after calling data_validator.create_schemas.validate() +# which returns a boolean, DataFrame tuple. The DataFrame represents the results of validation. +# Not sure if we'll already have the submission info in a DTO at this time (from the endpoint call) +# so we may be able to change the submission_id, submitter, and lei into an object versus individual +# data fields. +async def add_submission( + session: AsyncSession, submission_id: str, submitter: str, lei: str, results: pd.DataFrame +) -> SubmissionDAO: + async with session.begin(): + findings_by_v_id_df = results.reset_index().set_index(["validation_id"]) + submission = SubmissionDAO(submission_id=submission_id, submitter=submitter, lei=lei) + validation_results = [] + for v_id_idx, v_id_df in findings_by_v_id_df.groupby(by="validation_id"): + v_head = v_id_df.iloc[0] + result = ValidationResultDAO( + validation_id=v_id_idx, field_name=v_head.at["field_name"], severity=v_head.at["validation_severity"] + ) + records = [] + for rec_no, rec_df in v_id_df.iterrows(): + record = RecordDAO(record=rec_df.at["record_no"], data=rec_df.at["field_value"]) + records.append(record) + result.records = records + validation_results.append(result) + submission.results = validation_results + session.add(submission) + + return submission diff --git a/tests/entities/conftest.py b/tests/entities/conftest.py new file mode 100644 index 00000000..fc0633d0 --- /dev/null +++ b/tests/entities/conftest.py @@ -0,0 +1,61 @@ +import asyncio +import pytest + +from asyncio import current_task +from sqlalchemy.ext.asyncio import ( + create_async_engine, + AsyncEngine, + async_scoped_session, + async_sessionmaker, +) +from entities.models import Base + + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + try: + yield loop + finally: + loop.close() + + +@pytest.fixture(scope="session") +def engine(): + return create_async_engine("sqlite+aiosqlite://") + + +@pytest.fixture(scope="function", autouse=True) +async def setup_db( + request: pytest.FixtureRequest, + engine: AsyncEngine, + event_loop: asyncio.AbstractEventLoop, +): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + def teardown(): + async def td(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + event_loop.run_until_complete(td()) + + request.addfinalizer(teardown) + + +@pytest.fixture(scope="function") +async def transaction_session(session_generator: async_scoped_session): + async with session_generator() as session: + yield session + + +@pytest.fixture(scope="function") +async def query_session(session_generator: async_scoped_session): + async with session_generator() as session: + yield session + + +@pytest.fixture(scope="function") +def session_generator(engine: AsyncEngine): + return async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task) diff --git a/tests/entities/repos/test_submission_repo.py b/tests/entities/repos/test_submission_repo.py new file mode 100644 index 00000000..48c56fec --- /dev/null +++ b/tests/entities/repos/test_submission_repo.py @@ -0,0 +1,95 @@ +import pandas as pd +import pytest + +from sqlalchemy.ext.asyncio import AsyncSession + +from entities.models import SubmissionDAO, ValidationResultDAO, RecordDAO +from entities.repos import submission_repo as repo + + +class TestSubmissionRepo: + @pytest.fixture(scope="function", autouse=True) + async def setup( + self, + transaction_session: AsyncSession, + ): + submission = SubmissionDAO(submission_id="12345", submitter="test@cfpb.gov", lei="1234567890ABCDEFGHIJ") + results = [] + result1 = ValidationResultDAO(validation_id="E0123", field_name="uid", severity="error") + records = [] + record1a = RecordDAO(record=1, data="empty") + records.append(record1a) + result1.records = records + results.append(result1) + submission.results = results + + transaction_session.add(submission) + await transaction_session.commit() + + async def test_get_submission(self, query_session: AsyncSession): + res = await repo.get_submission(query_session, submission_id="12345") + assert res.submission_id == "12345" + assert res.submitter == "test@cfpb.gov" + assert res.lei == "1234567890ABCDEFGHIJ" + assert len(res.results) == 1 + assert len(res.results[0].records) == 1 + assert res.results[0].validation_id == "E0123" + assert res.results[0].records[0].data == "empty" + + async def test_add_submission(self, transaction_session: AsyncSession): + df_columns = [ + "record_no", + "field_name", + "field_value", + "validation_severity", + "validation_id", + "validation_name", + "validation_desc", + ] + df_data = [ + [ + 0, + "uid", + "BADUID0", + "error", + "E0001", + "id.invalid_text_length", + "'Unique identifier' must be at least 21 characters in length.", + ], + [ + 0, + "uid", + "BADTEXTLENGTH", + "error", + "E0100", + "ct_credit_product_ff.invalid_text_length", + "'Free-form text field for other credit products' must not exceed 300 characters in length.", + ], + [ + 1, + "uid", + "BADUID1", + "error", + "E0001", + "id.invalid_text_length", + "'Unique identifier' must be at least 21 characters in length.", + ], + ] + error_df = pd.DataFrame(df_data, columns=df_columns) + print(f"Data Frame: {error_df}") + res = await repo.add_submission( + transaction_session, + submission_id="12346", + submitter="test@cfpb.gov", + lei="1234567890ABCDEFGHIJ", + results=error_df, + ) + assert res.submission_id == "12346" + assert res.submitter == "test@cfpb.gov" + assert res.lei == "1234567890ABCDEFGHIJ" + assert len(res.results) == 2 # Two error codes, 3 records total + assert len(res.results[0].records) == 2 + assert len(res.results[1].records) == 1 + assert res.results[0].validation_id == "E0001" + assert res.results[1].validation_id == "E0100" + assert res.results[0].records[0].data == "BADUID0" diff --git a/tests/migrations/test_migrations.py b/tests/migrations/test_migrations.py new file mode 100644 index 00000000..29682fb4 --- /dev/null +++ b/tests/migrations/test_migrations.py @@ -0,0 +1,29 @@ +from pytest_alembic.tests import ( + test_single_head_revision, + test_up_down_consistency, + test_upgrade, +) + +import sqlalchemy +from sqlalchemy.engine import Engine + +from pytest_alembic import MigrationContext + + +def test_migrations(alembic_runner: MigrationContext, alembic_engine: Engine): + alembic_runner.migrate_up_to("af1ba24f831a") + + inspector = sqlalchemy.inspect(alembic_engine) + tables = inspector.get_table_names() + assert "submission" in tables + assert {"submission_id", "submitter", "lei", "json_dump", "event_time"} == set([c["name"] for c in inspector.get_columns("submission")]) + + assert "validation_result" in tables + assert {"id", "submission_id", "validation_id", "field_name", "severity", "event_time"} == set([c["name"] for c in inspector.get_columns("validation_result")]) + vr_fk = inspector.get_foreign_keys("validation_result")[0] + assert "submission_id" in vr_fk["constrained_columns"] and "submission" == vr_fk["referred_table"] and "submission_id" in vr_fk["referred_columns"] + + assert "validation_result_record" in tables + assert {"id", "result_id", "record", "data", "event_time"} == set([c["name"] for c in inspector.get_columns("validation_result_record")]) + vrr_fk = inspector.get_foreign_keys("validation_result_record")[0] + assert "result_id" in vrr_fk["constrained_columns"] and "validation_result" == vrr_fk["referred_table"] and "id" in vrr_fk["referred_columns"] \ No newline at end of file