Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added filing_task and filing_task_state tables #50

Merged
merged 8 commits into from
Feb 7, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""update filing table for filing tasks

Revision ID: 078cbbc69fe5
Revises: 4e8ae26c1a22
Create Date: 2024-01-30 13:15:44.323900

"""
from typing import Sequence, Union

from alembic import op, context
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "078cbbc69fe5"
down_revision: Union[str, None] = "4e8ae26c1a22"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.drop_column("filing", "state")
if "sqlite" not in context.get_context().dialect.name:
op.execute(sa.DDL("DROP TYPE filingstate"))


def downgrade() -> None:
op.add_column(
"filing",
sa.Column(
"state",
sa.Enum(
"FILING_STARTED",
"FILING_INSTITUTION_APPROVED",
"FILING_IN_PROGRESS",
"FILING_COMPLETE",
name="filingstate",
),
),
)
30 changes: 30 additions & 0 deletions db_revisions/versions/4ca961a003e1_create_filing_task_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""create filing task table

Revision ID: 4ca961a003e1
Revises: f30c5c3c7a42
Create Date: 2024-01-30 12:59:15.720135

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "4ca961a003e1"
down_revision: Union[str, None] = "f30c5c3c7a42"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.create_table(
"filing_task",
sa.Column("name", sa.String, primary_key=True),
sa.Column("task_order", sa.INTEGER, nullable=False),
)


def downgrade() -> None:
op.drop_table("filing_task")
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""create filing task state table

Revision ID: 4e8ae26c1a22
Revises: 4ca961a003e1
Create Date: 2024-01-30 13:02:52.041229

"""
from typing import Sequence, Union

from alembic import op, context
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "4e8ae26c1a22"
down_revision: Union[str, None] = "4ca961a003e1"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.create_table(
"filing_task_state",
sa.Column("filing", sa.INTEGER, primary_key=True),
sa.Column("task_name", sa.String, primary_key=True),
sa.Column(
"state",
sa.Enum(
"NOT_STARTED",
"IN_PROGRESS",
"COMPLETED",
name="filingtaskstate",
),
),
sa.Column("user", sa.String, nullable=False),
sa.Column("change_timestamp", sa.DateTime, nullable=False),
sa.ForeignKeyConstraint(
["filing"],
["filing.id"],
),
sa.ForeignKeyConstraint(
["task_name"],
["filing_task.name"],
),
)


def downgrade() -> None:
op.drop_table("filing_task_state")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the downgrade remove the custom type just by dropping the table, or do we need to do the drop type thing as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, updated the downgrade to have a drop of the filingtaskstate TYPE. Note that sqlite doesn't have that concept so it checks for the dialect.

if "sqlite" not in context.get_context().dialect.name:
op.execute(sa.DDL("DROP TYPE filingtaskstate"))
12 changes: 8 additions & 4 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
"SubmissionState",
"FilingDAO",
"FilingDTO",
"FilingTaskStateDAO",
"FilingTaskStateDTO",
"FilingTaskDAO",
"FilingTaskDTO",
"FilingPeriodDAO",
"FilingPeriodDTO",
"FilingType",
"FilingState",
"FilingTaskState",
]

from .dao import Base, SubmissionDAO, FilingPeriodDAO, FilingDAO
from .dto import SubmissionDTO, FilingDTO, FilingPeriodDTO
from .model_enums import FilingType, FilingState, SubmissionState
from .dao import Base, SubmissionDAO, FilingPeriodDAO, FilingDAO, FilingTaskStateDAO, FilingTaskDAO
from .dto import SubmissionDTO, FilingDTO, FilingPeriodDTO, FilingTaskStateDTO, FilingTaskDTO
from .model_enums import FilingType, FilingTaskState, SubmissionState
33 changes: 27 additions & 6 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from .model_enums import FilingType, FilingState, SubmissionState
from .model_enums import FilingType, FilingTaskState, SubmissionState
from datetime import datetime
from typing import Any
from typing import Any, List
from sqlalchemy import Enum as SAEnum
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import ForeignKey, func
from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase, relationship
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.types import JSON

Expand Down Expand Up @@ -37,11 +36,33 @@ class FilingPeriodDAO(Base):
filing_type: Mapped[FilingType] = mapped_column(SAEnum(FilingType))


class FilingTaskDAO(Base):
__tablename__ = "filing_task"
name: Mapped[str] = mapped_column(primary_key=True)
task_order: Mapped[int]

def __str__(self):
return f"Name: {self.name}, Order: {self.task_order}"


class FilingTaskStateDAO(Base):
__tablename__ = "filing_task_state"
filing: Mapped[int] = mapped_column(ForeignKey("filing.id"), primary_key=True)
task_name: Mapped[str] = mapped_column(ForeignKey("filing_task.name"), primary_key=True)
task: Mapped[FilingTaskDAO] = relationship(lazy="selectin")
user: Mapped[str]
state: Mapped[FilingTaskState] = mapped_column(SAEnum(FilingTaskState))
change_timestamp: Mapped[datetime] = mapped_column(server_default=func.now(), onupdate=func.now())

def __str__(self):
return f"Filing ID: {self.filing}, Task: {self.task}, User: {self.user}, state: {self.state}, Timestamp: {self.change_timestamp}"


class FilingDAO(Base):
__tablename__ = "filing"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
lei: Mapped[str]
state: Mapped[FilingState] = mapped_column(SAEnum(FilingState))
tasks: Mapped[List[FilingTaskStateDAO]] = relationship(lazy="selectin", cascade="all, delete-orphan")
filing_period: Mapped[int] = mapped_column(ForeignKey("filing_period.id"))
institution_snapshot_id: Mapped[str]
contact_info: Mapped[str] = mapped_column(nullable=True)
Expand Down
23 changes: 20 additions & 3 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Dict, Any
from typing import Dict, Any, List
from pydantic import BaseModel, ConfigDict
from .model_enums import FilingType, FilingState, SubmissionState
from .model_enums import FilingType, FilingTaskState, SubmissionState


class SubmissionDTO(BaseModel):
Expand All @@ -16,12 +16,29 @@ class SubmissionDTO(BaseModel):
confirmation_id: str | None = None


class FilingTaskDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

name: str
task_order: int


class FilingTaskStateDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

filing: int
task: FilingTaskDTO
user: str | None = None
state: FilingTaskState
change_timestamp: datetime


class FilingDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: int | None = None
lei: str
state: FilingState
tasks: List[FilingTaskStateDTO]
filing_period: int
institution_snapshot_id: str
contact_info: str | None = None
Expand Down
9 changes: 4 additions & 5 deletions src/entities/models/model_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ class SubmissionState(str, Enum):
SUBMISSION_SIGNED = "SUBMISSION_SIGNED"


class FilingState(str, Enum):
FILING_STARTED = "FILING_STARTED"
FILING_INSTITUTION_APPROVED = "FILING_INSTITUTION_APPROVED"
FILING_IN_PROGRESS = "FILING_IN_PROGRESS"
FILING_COMPLETE = "FILING_COMPLETE"
class FilingTaskState(str, Enum):
NOT_STARTED = "NOT_STARTED"
IN_PROGRESS = "IN_PROGRESS"
COMPLETED = "COMPLETED"


class FilingType(str, Enum):
Expand Down
8 changes: 8 additions & 0 deletions src/entities/repos/submission_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FilingPeriodDTO,
FilingDTO,
FilingDAO,
FilingTaskDAO,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,6 +49,13 @@ async def get_filing_period(session: AsyncSession, filing_period_id: int) -> Fil
return await query_helper(session, filing_period_id, FilingPeriodDAO)


async def get_filing_tasks(session: AsyncSession) -> List[FilingTaskDAO]:
async with session.begin():
stmt = select(FilingTaskDAO)
results = await session.scalars(stmt)
return results.all()


async def add_submission(session: AsyncSession, submission: SubmissionDTO) -> SubmissionDAO:
async with session.begin():
new_sub = SubmissionDAO(
Expand Down
57 changes: 39 additions & 18 deletions tests/entities/repos/test_submission_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
FilingPeriodDTO,
FilingDAO,
FilingDTO,
FilingTaskStateDAO,
FilingTaskDAO,
FilingType,
FilingState,
FilingTaskState,
SubmissionState,
)
from entities.repos import submission_repo as repo
Expand All @@ -29,6 +31,11 @@ async def setup(
):
mocker.patch.object(entities_engine, "SessionLocal", return_value=session_generator)

filing_task_1 = FilingTaskDAO(name="Task-1", task_order=1)
filing_task_2 = FilingTaskDAO(name="Task-2", task_order=2)
transaction_session.add(filing_task_1)
transaction_session.add(filing_task_2)

filing_period = FilingPeriodDAO(
name="FilingPeriod2024",
start_period=datetime.now(),
Expand All @@ -40,13 +47,11 @@ async def setup(

filing1 = FilingDAO(
lei="1234567890",
state=FilingState.FILING_STARTED,
institution_snapshot_id="Snapshot-1",
filing_period=1,
)
filing2 = FilingDAO(
lei="ABCDEFGHIJ",
state=FilingState.FILING_STARTED,
institution_snapshot_id="Snapshot-1",
filing_period=1,
)
Expand Down Expand Up @@ -101,28 +106,44 @@ async def test_get_filing_period(self, query_session: AsyncSession):
assert res.filing_type == FilingType.MANUAL

async def test_add_and_modify_filing(self, transaction_session: AsyncSession):
new_filing = FilingDTO(
lei="12345ABCDE",
state=FilingState.FILING_IN_PROGRESS,
institution_snapshot_id="Snapshot-1",
filing_period=1,
)
new_filing = FilingDTO(lei="12345ABCDE", institution_snapshot_id="Snapshot-1", filing_period=1, tasks=[])
res = await repo.upsert_filing(transaction_session, new_filing)
assert res.id == 3
assert res.lei == "12345ABCDE"
assert res.state == FilingState.FILING_IN_PROGRESS
assert res.institution_snapshot_id == "Snapshot-1"

mod_filing = FilingDTO(
id=3,
lei="12345ABCDE",
state=FilingState.FILING_COMPLETE,
institution_snapshot_id="Snapshot-1",
filing_period=1,
)
mod_filing = FilingDTO(id=3, lei="12345ABCDE", institution_snapshot_id="Snapshot-2", filing_period=1, tasks=[])
res = await repo.upsert_filing(transaction_session, mod_filing)
assert res.id == 3
assert res.lei == "12345ABCDE"
assert res.state == FilingState.FILING_COMPLETE
assert res.institution_snapshot_id == "Snapshot-2"

async def test_get_filing_tasks(self, transaction_session: AsyncSession):
tasks = await repo.get_filing_tasks(transaction_session)
assert len(tasks) == 2
assert tasks[0].name == "Task-1"
assert tasks[1].name == "Task-2"

async def test_add_task_to_filing(self, query_session: AsyncSession, transaction_session: AsyncSession):
filing = await repo.get_filing(query_session, filing_id=1)
task = await query_session.scalar(select(FilingTaskDAO).where(FilingTaskDAO.name == "Task-1"))
filing_task = FilingTaskStateDAO(
filing=filing.id, task=task, user="[email protected]", state=FilingTaskState.IN_PROGRESS
)
filing.tasks = [filing_task]
seconds_now = datetime.utcnow().timestamp()
await repo.upsert_filing(transaction_session, filing)

filing_task_states = (await transaction_session.scalars(select(FilingTaskStateDAO))).all()

assert len(filing_task_states) == 1
assert filing_task_states[0].task.name == "Task-1"
assert filing_task_states[0].filing == 1
assert filing_task_states[0].state == FilingTaskState.IN_PROGRESS
assert filing_task_states[0].user == "[email protected]"
assert filing_task_states[0].change_timestamp.timestamp() == pytest.approx(
seconds_now, abs=1.0
) # allow for possible 1 second difference

async def test_get_filing(self, query_session: AsyncSession):
res = await repo.get_filing_period(query_session, filing_period_id=1)
Expand Down
Loading
Loading