diff --git a/db_revisions/versions/078cbbc69fe5_update_filing_table_for_filing_tasks.py b/db_revisions/versions/078cbbc69fe5_update_filing_table_for_filing_tasks.py new file mode 100644 index 00000000..fd6df75c --- /dev/null +++ b/db_revisions/versions/078cbbc69fe5_update_filing_table_for_filing_tasks.py @@ -0,0 +1,40 @@ +"""update filing table for filing tasks + +Revision ID: 078cbbc69fe5 +Revises: 4e8ae26c1a22 +Create Date: 2024-01-30 13:15:44.323900 + +""" +from typing import Sequence, Union + +from alembic import op, context +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "078cbbc69fe5" +down_revision: Union[str, None] = "4e8ae26c1a22" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.drop_column("filing", "state") + if "sqlite" not in context.get_context().dialect.name: + op.execute(sa.DDL("DROP TYPE filingstate")) + + +def downgrade() -> None: + op.add_column( + "filing", + sa.Column( + "state", + sa.Enum( + "FILING_STARTED", + "FILING_INSTITUTION_APPROVED", + "FILING_IN_PROGRESS", + "FILING_COMPLETE", + name="filingstate", + ), + ), + ) diff --git a/db_revisions/versions/4ca961a003e1_create_filing_task_table.py b/db_revisions/versions/4ca961a003e1_create_filing_task_table.py new file mode 100644 index 00000000..7173902b --- /dev/null +++ b/db_revisions/versions/4ca961a003e1_create_filing_task_table.py @@ -0,0 +1,30 @@ +"""create filing task table + +Revision ID: 4ca961a003e1 +Revises: f30c5c3c7a42 +Create Date: 2024-01-30 12:59:15.720135 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "4ca961a003e1" +down_revision: Union[str, None] = "f30c5c3c7a42" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "filing_task", + sa.Column("name", sa.String, primary_key=True), + sa.Column("task_order", sa.INTEGER, nullable=False), + ) + + +def downgrade() -> None: + op.drop_table("filing_task") diff --git a/db_revisions/versions/4e8ae26c1a22_create_filing_task_state_table.py b/db_revisions/versions/4e8ae26c1a22_create_filing_task_state_table.py new file mode 100644 index 00000000..c7a4c616 --- /dev/null +++ b/db_revisions/versions/4e8ae26c1a22_create_filing_task_state_table.py @@ -0,0 +1,51 @@ +"""create filing task state table + +Revision ID: 4e8ae26c1a22 +Revises: 4ca961a003e1 +Create Date: 2024-01-30 13:02:52.041229 + +""" +from typing import Sequence, Union + +from alembic import op, context +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "4e8ae26c1a22" +down_revision: Union[str, None] = "4ca961a003e1" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "filing_task_state", + sa.Column("filing", sa.INTEGER, primary_key=True), + sa.Column("task_name", sa.String, primary_key=True), + sa.Column( + "state", + sa.Enum( + "NOT_STARTED", + "IN_PROGRESS", + "COMPLETED", + name="filingtaskstate", + ), + ), + sa.Column("user", sa.String, nullable=False), + sa.Column("change_timestamp", sa.DateTime, nullable=False), + sa.ForeignKeyConstraint( + ["filing"], + ["filing.id"], + ), + sa.ForeignKeyConstraint( + ["task_name"], + ["filing_task.name"], + ), + ) + + +def downgrade() -> None: + op.drop_table("filing_task_state") + if "sqlite" not in context.get_context().dialect.name: + op.execute(sa.DDL("DROP TYPE filingtaskstate")) diff --git a/src/entities/models/__init__.py b/src/entities/models/__init__.py index 1dec549a..f7fd6163 100644 --- a/src/entities/models/__init__.py +++ b/src/entities/models/__init__.py @@ -5,12 +5,16 @@ "SubmissionState", "FilingDAO", "FilingDTO", + "FilingTaskStateDAO", + "FilingTaskStateDTO", + "FilingTaskDAO", + "FilingTaskDTO", "FilingPeriodDAO", "FilingPeriodDTO", "FilingType", - "FilingState", + "FilingTaskState", ] -from .dao import Base, SubmissionDAO, FilingPeriodDAO, FilingDAO -from .dto import SubmissionDTO, FilingDTO, FilingPeriodDTO -from .model_enums import FilingType, FilingState, SubmissionState +from .dao import Base, SubmissionDAO, FilingPeriodDAO, FilingDAO, FilingTaskStateDAO, FilingTaskDAO +from .dto import SubmissionDTO, FilingDTO, FilingPeriodDTO, FilingTaskStateDTO, FilingTaskDTO +from .model_enums import FilingType, FilingTaskState, SubmissionState diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index fb75ad22..35e1658a 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -1,10 +1,9 @@ -from .model_enums import FilingType, FilingState, SubmissionState +from .model_enums import FilingType, FilingTaskState, SubmissionState from datetime import datetime -from typing import Any +from typing import Any, List from sqlalchemy import Enum as SAEnum -from sqlalchemy import ForeignKey -from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy import ForeignKey, func +from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase, relationship from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.types import JSON @@ -37,11 +36,33 @@ class FilingPeriodDAO(Base): filing_type: Mapped[FilingType] = mapped_column(SAEnum(FilingType)) +class FilingTaskDAO(Base): + __tablename__ = "filing_task" + name: Mapped[str] = mapped_column(primary_key=True) + task_order: Mapped[int] + + def __str__(self): + return f"Name: {self.name}, Order: {self.task_order}" + + +class FilingTaskStateDAO(Base): + __tablename__ = "filing_task_state" + filing: Mapped[int] = mapped_column(ForeignKey("filing.id"), primary_key=True) + task_name: Mapped[str] = mapped_column(ForeignKey("filing_task.name"), primary_key=True) + task: Mapped[FilingTaskDAO] = relationship(lazy="selectin") + user: Mapped[str] + state: Mapped[FilingTaskState] = mapped_column(SAEnum(FilingTaskState)) + change_timestamp: Mapped[datetime] = mapped_column(server_default=func.now(), onupdate=func.now()) + + def __str__(self): + return f"Filing ID: {self.filing}, Task: {self.task}, User: {self.user}, state: {self.state}, Timestamp: {self.change_timestamp}" + + class FilingDAO(Base): __tablename__ = "filing" id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) lei: Mapped[str] - state: Mapped[FilingState] = mapped_column(SAEnum(FilingState)) + tasks: Mapped[List[FilingTaskStateDAO]] = relationship(lazy="selectin", cascade="all, delete-orphan") filing_period: Mapped[int] = mapped_column(ForeignKey("filing_period.id")) institution_snapshot_id: Mapped[str] contact_info: Mapped[str] = mapped_column(nullable=True) diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py index 664ecd72..9aab035c 100644 --- a/src/entities/models/dto.py +++ b/src/entities/models/dto.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import Dict, Any +from typing import Dict, Any, List from pydantic import BaseModel, ConfigDict -from .model_enums import FilingType, FilingState, SubmissionState +from .model_enums import FilingType, FilingTaskState, SubmissionState class SubmissionDTO(BaseModel): @@ -16,12 +16,29 @@ class SubmissionDTO(BaseModel): confirmation_id: str | None = None +class FilingTaskDTO(BaseModel): + model_config = ConfigDict(from_attributes=True) + + name: str + task_order: int + + +class FilingTaskStateDTO(BaseModel): + model_config = ConfigDict(from_attributes=True) + + filing: int + task: FilingTaskDTO + user: str | None = None + state: FilingTaskState + change_timestamp: datetime + + class FilingDTO(BaseModel): model_config = ConfigDict(from_attributes=True) id: int | None = None lei: str - state: FilingState + tasks: List[FilingTaskStateDTO] filing_period: int institution_snapshot_id: str contact_info: str | None = None diff --git a/src/entities/models/model_enums.py b/src/entities/models/model_enums.py index a198480f..e8ad5b71 100644 --- a/src/entities/models/model_enums.py +++ b/src/entities/models/model_enums.py @@ -10,11 +10,10 @@ class SubmissionState(str, Enum): SUBMISSION_SIGNED = "SUBMISSION_SIGNED" -class FilingState(str, Enum): - FILING_STARTED = "FILING_STARTED" - FILING_INSTITUTION_APPROVED = "FILING_INSTITUTION_APPROVED" - FILING_IN_PROGRESS = "FILING_IN_PROGRESS" - FILING_COMPLETE = "FILING_COMPLETE" +class FilingTaskState(str, Enum): + NOT_STARTED = "NOT_STARTED" + IN_PROGRESS = "IN_PROGRESS" + COMPLETED = "COMPLETED" class FilingType(str, Enum): diff --git a/src/entities/repos/submission_repo.py b/src/entities/repos/submission_repo.py index b3d351ea..a8647188 100644 --- a/src/entities/repos/submission_repo.py +++ b/src/entities/repos/submission_repo.py @@ -13,6 +13,7 @@ FilingPeriodDTO, FilingDTO, FilingDAO, + FilingTaskDAO, ) logger = logging.getLogger(__name__) @@ -48,6 +49,13 @@ async def get_filing_period(session: AsyncSession, filing_period_id: int) -> Fil return await query_helper(session, filing_period_id, FilingPeriodDAO) +async def get_filing_tasks(session: AsyncSession) -> List[FilingTaskDAO]: + async with session.begin(): + stmt = select(FilingTaskDAO) + results = await session.scalars(stmt) + return results.all() + + async def add_submission(session: AsyncSession, submission: SubmissionDTO) -> SubmissionDAO: async with session.begin(): new_sub = SubmissionDAO( diff --git a/tests/entities/repos/test_submission_repo.py b/tests/entities/repos/test_submission_repo.py index 2361447e..cedd02f3 100644 --- a/tests/entities/repos/test_submission_repo.py +++ b/tests/entities/repos/test_submission_repo.py @@ -12,8 +12,10 @@ FilingPeriodDTO, FilingDAO, FilingDTO, + FilingTaskStateDAO, + FilingTaskDAO, FilingType, - FilingState, + FilingTaskState, SubmissionState, ) from entities.repos import submission_repo as repo @@ -29,6 +31,11 @@ async def setup( ): mocker.patch.object(entities_engine, "SessionLocal", return_value=session_generator) + filing_task_1 = FilingTaskDAO(name="Task-1", task_order=1) + filing_task_2 = FilingTaskDAO(name="Task-2", task_order=2) + transaction_session.add(filing_task_1) + transaction_session.add(filing_task_2) + filing_period = FilingPeriodDAO( name="FilingPeriod2024", start_period=datetime.now(), @@ -40,13 +47,11 @@ async def setup( filing1 = FilingDAO( lei="1234567890", - state=FilingState.FILING_STARTED, institution_snapshot_id="Snapshot-1", filing_period=1, ) filing2 = FilingDAO( lei="ABCDEFGHIJ", - state=FilingState.FILING_STARTED, institution_snapshot_id="Snapshot-1", filing_period=1, ) @@ -101,28 +106,44 @@ async def test_get_filing_period(self, query_session: AsyncSession): assert res.filing_type == FilingType.MANUAL async def test_add_and_modify_filing(self, transaction_session: AsyncSession): - new_filing = FilingDTO( - lei="12345ABCDE", - state=FilingState.FILING_IN_PROGRESS, - institution_snapshot_id="Snapshot-1", - filing_period=1, - ) + new_filing = FilingDTO(lei="12345ABCDE", institution_snapshot_id="Snapshot-1", filing_period=1, tasks=[]) res = await repo.upsert_filing(transaction_session, new_filing) assert res.id == 3 assert res.lei == "12345ABCDE" - assert res.state == FilingState.FILING_IN_PROGRESS + assert res.institution_snapshot_id == "Snapshot-1" - mod_filing = FilingDTO( - id=3, - lei="12345ABCDE", - state=FilingState.FILING_COMPLETE, - institution_snapshot_id="Snapshot-1", - filing_period=1, - ) + mod_filing = FilingDTO(id=3, lei="12345ABCDE", institution_snapshot_id="Snapshot-2", filing_period=1, tasks=[]) res = await repo.upsert_filing(transaction_session, mod_filing) assert res.id == 3 assert res.lei == "12345ABCDE" - assert res.state == FilingState.FILING_COMPLETE + assert res.institution_snapshot_id == "Snapshot-2" + + async def test_get_filing_tasks(self, transaction_session: AsyncSession): + tasks = await repo.get_filing_tasks(transaction_session) + assert len(tasks) == 2 + assert tasks[0].name == "Task-1" + assert tasks[1].name == "Task-2" + + async def test_add_task_to_filing(self, query_session: AsyncSession, transaction_session: AsyncSession): + filing = await repo.get_filing(query_session, filing_id=1) + task = await query_session.scalar(select(FilingTaskDAO).where(FilingTaskDAO.name == "Task-1")) + filing_task = FilingTaskStateDAO( + filing=filing.id, task=task, user="test@cfpb.gov", state=FilingTaskState.IN_PROGRESS + ) + filing.tasks = [filing_task] + seconds_now = datetime.utcnow().timestamp() + await repo.upsert_filing(transaction_session, filing) + + filing_task_states = (await transaction_session.scalars(select(FilingTaskStateDAO))).all() + + assert len(filing_task_states) == 1 + assert filing_task_states[0].task.name == "Task-1" + assert filing_task_states[0].filing == 1 + assert filing_task_states[0].state == FilingTaskState.IN_PROGRESS + assert filing_task_states[0].user == "test@cfpb.gov" + assert filing_task_states[0].change_timestamp.timestamp() == pytest.approx( + seconds_now, abs=1.0 + ) # allow for possible 1 second difference async def test_get_filing(self, query_session: AsyncSession): res = await repo.get_filing_period(query_session, filing_period_id=1) diff --git a/tests/migrations/test_migrations.py b/tests/migrations/test_migrations.py index d6147071..94a9de71 100644 --- a/tests/migrations/test_migrations.py +++ b/tests/migrations/test_migrations.py @@ -10,6 +10,37 @@ from pytest_alembic import MigrationContext +def test_migrations_up_to_078cbbc69fe5(alembic_runner: MigrationContext, alembic_engine: Engine): + alembic_runner.migrate_up_to("078cbbc69fe5") + + inspector = sqlalchemy.inspect(alembic_engine) + tables = inspector.get_table_names() + + assert "filing_task" in tables + assert {"name", "task_order"} == set([c["name"] for c in inspector.get_columns("filing_task")]) + + assert "filing_task_state" in tables + assert {"filing", "task_name", "state", "user", "change_timestamp"} == set( + [c["name"] for c in inspector.get_columns("filing_task_state")] + ) + + filing_state_fk1 = inspector.get_foreign_keys("filing_task_state")[0] + assert ( + "filing" in filing_state_fk1["constrained_columns"] + and "filing" == filing_state_fk1["referred_table"] + and "id" in filing_state_fk1["referred_columns"] + ) + + filing_state_fk2 = inspector.get_foreign_keys("filing_task_state")[1] + assert ( + "task_name" in filing_state_fk2["constrained_columns"] + and "filing_task" == filing_state_fk2["referred_table"] + and "name" in filing_state_fk2["referred_columns"] + ) + + assert "state" not in set([c["name"] for c in inspector.get_columns("filing")]) + + def test_migrations(alembic_runner: MigrationContext, alembic_engine: Engine): alembic_runner.migrate_up_to("f30c5c3c7a42")