Skip to content

Commit

Permalink
Merge pull request #194 from spraakbanken/fix-sql-entries
Browse files Browse the repository at this point in the history
Avoid name clashes in SqlEntryRepository
  • Loading branch information
kod-kristoff authored Mar 28, 2022
2 parents 0447b17 + f35f158 commit c52e123
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 15 deletions.
8 changes: 5 additions & 3 deletions karp/lex_infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
)
from karp.lex_infrastructure.repositories import (
SqlEntryUowRepositoryUnitOfWork,
SqlEntryUowCreator,
SqlEntryUowV1Creator,
SqlEntryUowV2Creator,
SqlResourceUnitOfWork,
)

Expand Down Expand Up @@ -102,8 +103,9 @@ def resources_uow(
@injector.multiprovider
def entry_uow_creator_map(self) -> Dict[str, EntryUnitOfWorkCreator]:
return {
'default': SqlEntryUowCreator,
SqlEntryUowCreator.repository_type: SqlEntryUowCreator,
'default': SqlEntryUowV2Creator,
SqlEntryUowV1Creator.repository_type: SqlEntryUowV1Creator,
SqlEntryUowV2Creator.repository_type: SqlEntryUowV2Creator,
}


Expand Down
2 changes: 1 addition & 1 deletion karp/lex_infrastructure/repositories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .sql_entry_uows import SqlEntryUowRepository, SqlEntryUowRepositoryUnitOfWork
from .sql_entries import SqlEntryUowCreator
from .sql_entries import SqlEntryUowV1Creator, SqlEntryUowV2Creator
from .sql_resources import SqlResourceRepository, SqlResourceUnitOfWork
57 changes: 46 additions & 11 deletions karp/lex_infrastructure/repositories/sql_entries.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
"""SQL repositories for entries."""
import inspect
import logging
import typing
from typing import Dict, List, Optional, Tuple
from uuid import UUID
from typing import Dict, List, Optional, Generic, TypeVar

import injector
import regex
import sqlalchemy as sa
from sqlalchemy import sql
from sqlalchemy.orm import sessionmaker
import logging
import ulid

from karp.foundation.value_objects import UniqueId
from karp.foundation.events import EventBus
Expand Down Expand Up @@ -190,7 +188,7 @@ def _save(self, entry: Entry):
{
'entry_by_entry_id': entry_by_entry_id,
'entry_by_entity_id': entry_by_entity_id,
'entry': entry.dict(),
'entry': entry.dict(),
}
)
raise RuntimeError(f'entry = {entry.dict()}')
Expand Down Expand Up @@ -564,7 +562,7 @@ class SqlEntryUnitOfWork(
SqlUnitOfWork,
repositories.EntryUnitOfWork,
):
repository_type: str = 'sql_entries_v1'
repository_type: str = 'sql_entries_base'

def __init__(
self,
Expand All @@ -585,12 +583,15 @@ def _begin(self):
self._session = self.session_factory()
if self._entries is None:
self._entries = SqlEntryRepository.from_dict(
name=self.name,
name=self.table_name(),
resource_config=self.config,
session=self._session
)
return self

def table_name(self) -> str:
return self.name

@property
def repo(self) -> SqlEntryRepository:
if self._entries is None:
Expand All @@ -606,6 +607,20 @@ def collect_new_events(self) -> typing.Iterable:
return super().collect_new_events()
else:
return []


class SqlEntryUnitOfWorkV1(SqlEntryUnitOfWork):
repository_type: str = 'sql_entries_v1'


class SqlEntryUnitOfWorkV2(SqlEntryUnitOfWork):
repository_type: str = 'sql_entries_v2'

def table_name(self) -> str:
u = ulid.from_uuid(self.entity_id)
random_part = u.randomness().str
return f"{self.name}_{random_part}"

# ===== Value objects =====
# class SqlEntryRepositorySettings(EntryRepositorySettings):
# def __init__(self, *, table_name: str, config: Dict):
Expand All @@ -623,9 +638,11 @@ def collect_new_events(self) -> typing.Iterable:
# runtime_table_name, history_model, settings.config
# )
# return SqlEntryRepository(history_model, runtime_model, settings.config)
SqlEntryUowType = TypeVar('SqlEntryUowType', bound=SqlEntryUnitOfWork)

class SqlEntryUowCreator:
repository_type: str = SqlEntryUnitOfWork.repository_type

class SqlEntryUowCreator(Generic[SqlEntryUowType]):
repository_type: str = "repository_type"

@injector.inject
def __init__(
Expand All @@ -646,9 +663,9 @@ def __call__(
user: str,
message: str,
timestamp: float,
) -> SqlEntryUnitOfWork:
) -> SqlEntryUowType:
if entity_id not in self.cache:
self.cache[entity_id] = SqlEntryUnitOfWork(
self.cache[entity_id] = self._create_uow(
entity_id=entity_id,
name=name,
config=config,
Expand All @@ -660,3 +677,21 @@ def __call__(
event_bus=self.event_bus,
)
return self.cache[entity_id]


class SqlEntryUowV1Creator(
SqlEntryUowCreator[SqlEntryUnitOfWorkV1]
):
repository_type: str = "sql_entries_v1"

def _create_uow(self, **kwargs) -> SqlEntryUnitOfWorkV1:
return SqlEntryUnitOfWorkV1(**kwargs)


class SqlEntryUowV2Creator(
SqlEntryUowCreator[SqlEntryUnitOfWorkV2]
):
repository_type: str = "sql_entries_v2"

def _create_uow(self, **kwargs) -> SqlEntryUnitOfWorkV2:
return SqlEntryUnitOfWorkV2(**kwargs)
89 changes: 89 additions & 0 deletions karp/tests/integration/test_sql_entries_uow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from unittest import mock

import pytest
import ulid

from karp.foundation.events import EventBus
from karp import lex
from karp.lex_infrastructure import SqlEntryUowV1Creator, SqlEntryUowV2Creator
from karp.tests.unit.lex import factories


@pytest.fixture
def example_uow() -> lex.CreateEntryRepository:
return factories.CreateEntryRepositoryFactory()


@pytest.fixture
def sql_entry_uow_v1_creator(sqlite_session_factory) -> SqlEntryUowV1Creator:
return SqlEntryUowV1Creator(
event_bus=mock.Mock(spec=EventBus),
session_factory=sqlite_session_factory,
)


@pytest.fixture
def sql_entry_uow_v2_creator(sqlite_session_factory) -> SqlEntryUowV2Creator:
return SqlEntryUowV2Creator(
event_bus=mock.Mock(spec=EventBus),
session_factory=sqlite_session_factory,
)


class TestSqlEntryUowV1:
def test_creator_repository_type(
self,
sql_entry_uow_v1_creator: SqlEntryUowV1Creator,
):
assert sql_entry_uow_v1_creator.repository_type == 'sql_entries_v1'

def test_uow_repository_type(
self,
sql_entry_uow_v1_creator: SqlEntryUowV1Creator,
example_uow: lex.CreateEntryRepository,
):
entry_uow = sql_entry_uow_v1_creator(
**example_uow.dict(exclude={'repository_type'})
)
assert entry_uow.repository_type == 'sql_entries_v1'

def test_repo_table_name(
self,
sql_entry_uow_v1_creator: SqlEntryUowV1Creator,
example_uow: lex.CreateEntryRepository,
):
entry_uow = sql_entry_uow_v1_creator(
**example_uow.dict(exclude={'repository_type'})
)
with entry_uow as uw:
assert uw.repo.history_model.__tablename__ == example_uow.name


class TestSqlEntryUowV2:
def test_creator_repository_type(
self,
sql_entry_uow_v2_creator: SqlEntryUowV2Creator,
):
assert sql_entry_uow_v2_creator.repository_type == 'sql_entries_v2'

def test_uow_repository_type(
self,
sql_entry_uow_v2_creator: SqlEntryUowV2Creator,
example_uow: lex.CreateEntryRepository,
):
entry_uow = sql_entry_uow_v2_creator(
**example_uow.dict(exclude={'repository_type'})
)
assert entry_uow.repository_type == 'sql_entries_v2'

def test_repo_table_name(
self,
sql_entry_uow_v2_creator: SqlEntryUowV2Creator,
example_uow: lex.CreateEntryRepository,
):
entry_uow = sql_entry_uow_v2_creator(
**example_uow.dict(exclude={'repository_type'})
)
random_part = ulid.from_uuid(entry_uow.entity_id).randomness().str
with entry_uow as uw:
assert uw.repo.history_model.__tablename__ == f'{example_uow.name}_{random_part}'

0 comments on commit c52e123

Please sign in to comment.