From d3e01f06ac12f391cbdf60eec976275a92b35ea1 Mon Sep 17 00:00:00 2001 From: diversemix Date: Tue, 3 Oct 2023 10:59:38 +0100 Subject: [PATCH] add test and tidy --- app/db/models/app/counters.py | 18 +++++++++++++++--- tests/unit/app/models/test_counters.py | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) create mode 100644 tests/unit/app/models/test_counters.py diff --git a/app/db/models/app/counters.py b/app/db/models/app/counters.py index 8c00c06c..7ce59704 100644 --- a/app/db/models/app/counters.py +++ b/app/db/models/app/counters.py @@ -70,7 +70,11 @@ class EntityCounter(Base): counter = sa.Column(sa.Integer, default=0) def get_next_count(self) -> str: - """Gets the next counter value""" + """ + Gets the next counter value and updates the row. + + :return str: The next counter value. + """ try: db = object_session(self) cmd = self._get_and_increment.bindparams(id=self.id) @@ -81,8 +85,16 @@ def get_next_count(self) -> str: _LOGGER.exception(f"When generating counter for {self.prefix}") raise - def get_import_id(self, entity: CountedEntity) -> str: - """gets an import id""" + def create_import_id(self, entity: CountedEntity) -> str: + """ + Creates a unique import id. + + This uses the n-value of zero to conform to existing format. + + :param CountedEntity entity: The entity you want counted + :raises RuntimeError: raised when the prefix is not an organisation. + :return str: The fully formatted import_id + """ # Validation prefix_ok = ( self.prefix == ORGANISATION_CCLW or self.prefix == ORGANISATION_UNFCCC diff --git a/tests/unit/app/models/test_counters.py b/tests/unit/app/models/test_counters.py new file mode 100644 index 00000000..9893b6b7 --- /dev/null +++ b/tests/unit/app/models/test_counters.py @@ -0,0 +1,24 @@ +from app.data_migrations import populate_counters +from app.db.models.app.counters import CountedEntity, EntityCounter + + +def test_import_id_generation(test_db): + populate_counters(test_db) + rows = test_db.query(EntityCounter).count() + assert rows > 0 + + row: EntityCounter = ( + test_db.query(EntityCounter).filter(EntityCounter.prefix == "CCLW").one() + ) + assert row is not None + + assert row.prefix == "CCLW" + assert row.counter == 0 + + import_id = row.create_import_id(CountedEntity.Family) + assert import_id == "CCLW.family.i00000001.n0001" + + row: EntityCounter = ( + test_db.query(EntityCounter).filter(EntityCounter.prefix == "CCLW").one() + ) + assert row.counter == 1