diff --git a/alembic/versions/0019_add_entity_counters.py b/alembic/versions/0019_add_entity_counters.py new file mode 100644 index 00000000..5d1a5fe2 --- /dev/null +++ b/alembic/versions/0019_add_entity_counters.py @@ -0,0 +1,37 @@ +"""add entity counters + +Revision ID: 0019 +Revises: 0018 +Create Date: 2023-10-02 11:32:43.825217 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '0019' +down_revision = '0018' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('entity_counter', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('description', sa.String(), nullable=False), + sa.Column('prefix', sa.String(), nullable=False), + sa.Column('counter', sa.Integer(), nullable=True), + sa.CheckConstraint("prefix IN ('CCLW','UNFCCC')", name=op.f('ck_entity_counter__prefix_allowed_orgs')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_entity_counter')), + sa.UniqueConstraint('prefix', name=op.f('uq_entity_counter__prefix')) + ) + # ### end Alembic commands ### + + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('entity_counter') + # ### end Alembic commands ### diff --git a/app/api/api_v1/routers/cclw_ingest.py b/app/api/api_v1/routers/cclw_ingest.py index 04391037..2f8467f3 100644 --- a/app/api/api_v1/routers/cclw_ingest.py +++ b/app/api/api_v1/routers/cclw_ingest.py @@ -46,6 +46,7 @@ write_documents_to_s3, write_ingest_results_to_s3, ) +from app.db.models.app import ORGANISATION_CCLW from app.db.session import get_db _LOGGER = logging.getLogger(__name__) @@ -63,7 +64,7 @@ def _start_ingest( context = None # TODO: add a way for a user to monitor progress of the ingest try: - context = initialise_context(db, "CCLW") + context = initialise_context(db, ORGANISATION_CCLW) document_ingestor = get_cclw_document_ingestor(db, context) read(documents_file_contents, context, CCLWDocumentIngestRow, document_ingestor) event_ingestor = get_event_ingestor(db) @@ -135,7 +136,7 @@ def validate_law_policy( ) try: - context = initialise_context(db, "CCLW") + context = initialise_context(db, ORGANISATION_CCLW) except Exception as e: _LOGGER.exception( "Failed to create ingest context", extra={"props": {"errors": str(e)}} @@ -206,7 +207,7 @@ def ingest_law_policy( ) try: - context = initialise_context(db, "CCLW") + context = initialise_context(db, ORGANISATION_CCLW) except Exception as e: _LOGGER.exception( "Failed to create ingest context", extra={"props": {"errors": str(e)}} diff --git a/app/api/api_v1/routers/unfccc_ingest.py b/app/api/api_v1/routers/unfccc_ingest.py index 755b3d7d..c598fa23 100644 --- a/app/api/api_v1/routers/unfccc_ingest.py +++ b/app/api/api_v1/routers/unfccc_ingest.py @@ -46,6 +46,7 @@ write_documents_to_s3, write_ingest_results_to_s3, ) +from app.db.models.app import ORGANISATION_UNFCCC from app.db.session import get_db _LOGGER = logging.getLogger(__name__) @@ -63,7 +64,7 @@ def start_unfccc_ingest( context = None # TODO: add a way for a user to monitor progress of the ingest try: - context = initialise_context(db, "UNFCCC") + context = initialise_context(db, ORGANISATION_UNFCCC) # First the collections.... collection_ingestor = get_collection_ingestor(db) read( @@ -149,7 +150,7 @@ def validate_unfccc_law_policy( ) try: - context = initialise_context(db, "UNFCCC") + context = initialise_context(db, ORGANISATION_UNFCCC) except Exception as e: _LOGGER.exception( "Failed to create ingest context", extra={"props": {"errors": str(e)}} @@ -230,7 +231,7 @@ def ingest_unfccc_law_policy( ) try: - context = initialise_context(db, "UNFCCC") + context = initialise_context(db, ORGANISATION_UNFCCC) except Exception as e: _LOGGER.exception( "Failed to create ingest context", extra={"props": {"errors": str(e)}} diff --git a/app/core/ingestion/processor.py b/app/core/ingestion/processor.py index 251a6d21..a8990945 100644 --- a/app/core/ingestion/processor.py +++ b/app/core/ingestion/processor.py @@ -35,6 +35,7 @@ validate_cclw_document_row, validate_unfccc_document_row, ) +from app.db.models.app import ORGANISATION_CCLW, ORGANISATION_UNFCCC from app.db.models.app.users import Organisation from app.db.models.law_policy.geography import GEO_INTERNATIONAL, GEO_NONE @@ -235,11 +236,11 @@ def initialise_context(db: Session, org_name: str) -> IngestContext: """ with db.begin(): organisation = db.query(Organisation).filter_by(name=org_name).one() - if org_name == "CCLW": + if org_name == ORGANISATION_CCLW: return CCLWIngestContext( org_name=org_name, org_id=cast(int, organisation.id), results=[] ) - if org_name == "UNFCCC": + if org_name == ORGANISATION_UNFCCC: return UNFCCCIngestContext( org_name=org_name, org_id=cast(int, organisation.id), results=[] ) @@ -366,9 +367,9 @@ def unfccc_process(context: IngestContext, row: UNFCCCDocumentIngestRow) -> None row=row, ) - if context.org_name == "CCLW": + if context.org_name == ORGANISATION_CCLW: return cclw_process - elif context.org_name == "UNFCCC": + elif context.org_name == ORGANISATION_UNFCCC: return unfccc_process raise ValueError(f"Unknown org {context.org_name} for validation.") diff --git a/app/core/ingestion/utils.py b/app/core/ingestion/utils.py index 48dd5bd6..697f910d 100644 --- a/app/core/ingestion/utils.py +++ b/app/core/ingestion/utils.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import enum from typing import Any, Callable, Optional, TypeVar, cast +from app.db.models.app import ORGANISATION_CCLW, ORGANISATION_UNFCCC from app.db.session import AnyModel from sqlalchemy.orm import Session @@ -229,7 +230,7 @@ class UNFCCCIngestContext(IngestContext): consistency_validator: ConsistencyValidator download_urls: dict[str, str] # import_id -> url - def __init__(self, org_name="UNFCCC", org_id=2, results=None): + def __init__(self, org_name=ORGANISATION_UNFCCC, org_id=2, results=None): self.collection_ids_defined = [] self.collection_ids_referenced = [] self.consistency_validator = ConsistencyValidator() @@ -245,7 +246,7 @@ class CCLWIngestContext(IngestContext): consistency_validator: ConsistencyValidator - def __init__(self, org_name="CCLW", org_id=1, results=None): + def __init__(self, org_name=ORGANISATION_CCLW, org_id=1, results=None): self.consistency_validator = ConsistencyValidator() self.org_name = org_name self.org_id = org_id diff --git a/app/data_migrations/__init__.py b/app/data_migrations/__init__.py index cbda1411..b04ef6ed 100644 --- a/app/data_migrations/__init__.py +++ b/app/data_migrations/__init__.py @@ -7,3 +7,4 @@ from .populate_geography import populate_geography from .populate_language import populate_language from .populate_taxonomy import populate_taxonomy +from .populate_counters import populate_counters diff --git a/app/data_migrations/populate_counters.py b/app/data_migrations/populate_counters.py new file mode 100644 index 00000000..9a8cace7 --- /dev/null +++ b/app/data_migrations/populate_counters.py @@ -0,0 +1,23 @@ +from sqlalchemy.orm import Session + +from app.db.models.app.counters import ( + ORGANISATION_CCLW, + ORGANISATION_UNFCCC, + EntityCounter, +) + + +def populate_counters(db: Session): + n_rows = db.query(EntityCounter).count() + if n_rows == 0: + db.add( + EntityCounter( + prefix=ORGANISATION_CCLW, description="Counter for CCLW entities" + ) + ) + db.add( + EntityCounter( + prefix=ORGANISATION_UNFCCC, description="Counter for UNFCCC entities" + ) + ) + db.commit() diff --git a/app/data_migrations/populate_taxonomy.py b/app/data_migrations/populate_taxonomy.py index 680a884c..f0d05b8d 100644 --- a/app/data_migrations/populate_taxonomy.py +++ b/app/data_migrations/populate_taxonomy.py @@ -2,6 +2,7 @@ from sqlalchemy.orm import Session from app.data_migrations.taxonomy_cclw import get_cclw_taxonomy from app.data_migrations.taxonomy_unf3c import get_unf3c_taxonomy +from app.db.models.app import ORGANISATION_CCLW, ORGANISATION_UNFCCC from app.db.models.app.users import Organisation from app.db.models.law_policy.metadata import MetadataOrganisation, MetadataTaxonomy @@ -53,14 +54,14 @@ def populate_org_taxonomy( def populate_taxonomy(db: Session) -> None: populate_org_taxonomy( db, - org_name="CCLW", + org_name=ORGANISATION_CCLW, org_type="Academic", description="Climate Change Laws of the World", fn_get_taxonomy=get_cclw_taxonomy, ) populate_org_taxonomy( db, - org_name="UNFCCC", + org_name=ORGANISATION_UNFCCC, org_type="UN", description="United Nations Framework Convention on Climate Change", fn_get_taxonomy=get_unf3c_taxonomy, diff --git a/app/db/models/app/__init__.py b/app/db/models/app/__init__.py index a8f2ebed..530c0803 100644 --- a/app/db/models/app/__init__.py +++ b/app/db/models/app/__init__.py @@ -1 +1,2 @@ from .users import AppUser, OrganisationUser, Organisation +from .counters import EntityCounter, ORGANISATION_CCLW, ORGANISATION_UNFCCC diff --git a/app/db/models/app/counters.py b/app/db/models/app/counters.py new file mode 100644 index 00000000..7ce59704 --- /dev/null +++ b/app/db/models/app/counters.py @@ -0,0 +1,107 @@ +""" +Schema for counters. + +The following section includes the necessary schema for maintaining the counts +of different entity types. These are scoped per "data source" - however the +concept of "data source" is not yet implemented, see PDCT-431. +""" +import logging +from enum import Enum +import sqlalchemy as sa +from sqlalchemy.sql import text +from app.db.session import Base +from sqlalchemy.orm.session import object_session + + +_LOGGER = logging.getLogger(__name__) + +# +# DO NOT ADD TO THIS LIST BELOW +# +# NOTE: These need to change when we introduce "Data source" (PDCT-431) +ORGANISATION_CCLW = "CCLW" +ORGANISATION_UNFCCC = "UNFCCC" + + +class CountedEntity(str, Enum): + """Entities that are to be counted.""" + + Collection = "collection" + Family = "family" + Document = "document" + Event = "event" + + +class EntityCounter(Base): + """ + A list of entity counters per organisation name. + + NOTE: There is no foreign key, as this is expected to change + when we introduce data sources (PDCT-431). So at this time a + FK to the new datasource table should be introduced. + + This is used for generating import_ids in the following format: + + ... + + """ + + __tablename__ = "entity_counter" + __table_args__ = ( + sa.CheckConstraint( + "prefix IN ('CCLW','UNFCCC')", + name="prefix_allowed_orgs", + ), + ) + + _get_and_increment = text( + """ + WITH updated AS ( + UPDATE entity_counter SET counter = counter + 1 + WHERE id = :id RETURNING counter + ) + SELECT counter FROM updated; + """ + ) + + id = sa.Column(sa.Integer, primary_key=True) + description = sa.Column(sa.String, nullable=False, default="") + prefix = sa.Column(sa.String, unique=True, nullable=False) # Organisation.name + counter = sa.Column(sa.Integer, default=0) + + def get_next_count(self) -> str: + """ + Gets the next counter value and updates the row. + + :return str: The next counter value. + """ + try: + db = object_session(self) + cmd = self._get_and_increment.bindparams(id=self.id) + value = db.execute(cmd).scalar() + db.commit() + return value + except: + _LOGGER.exception(f"When generating counter for {self.prefix}") + raise + + def create_import_id(self, entity: CountedEntity) -> str: + """ + Creates a unique import id. + + This uses the n-value of zero to conform to existing format. + + :param CountedEntity entity: The entity you want counted + :raises RuntimeError: raised when the prefix is not an organisation. + :return str: The fully formatted import_id + """ + # Validation + prefix_ok = ( + self.prefix == ORGANISATION_CCLW or self.prefix == ORGANISATION_UNFCCC + ) + if not prefix_ok: + raise RuntimeError("Prefix is not a known organisation!") + n = 0 # The fourth quad is historical + i_value = str(self.get_next_count()).zfill(8) + n_value = str(n).zfill(4) + return f"{self.prefix}.{entity.value}.i{i_value}.n{n_value}" diff --git a/app/initial_data.py b/app/initial_data.py index 64344184..255793f7 100644 --- a/app/initial_data.py +++ b/app/initial_data.py @@ -13,6 +13,7 @@ from app.db.session import SessionLocal from app.data_migrations import ( + populate_counters, populate_document_type, populate_document_role, populate_document_variant, @@ -33,6 +34,7 @@ def run_data_migrations(db): populate_geography(db) populate_language(db) populate_taxonomy(db) + populate_counters(db) db.flush() # Geography data is used by geo-stats so flush diff --git a/tests/unit/app/models/test_counters.py b/tests/unit/app/models/test_counters.py new file mode 100644 index 00000000..1f555276 --- /dev/null +++ b/tests/unit/app/models/test_counters.py @@ -0,0 +1,24 @@ +from app.data_migrations import populate_counters +from app.db.models.app.counters import CountedEntity, EntityCounter + + +def test_import_id_generation(test_db): + populate_counters(test_db) + rows = test_db.query(EntityCounter).count() + assert rows > 0 + + row: EntityCounter = ( + test_db.query(EntityCounter).filter(EntityCounter.prefix == "CCLW").one() + ) + assert row is not None + + assert row.prefix == "CCLW" + assert row.counter == 0 + + import_id = row.create_import_id(CountedEntity.Family) + assert import_id == "CCLW.family.i00000001.n0000" + + row: EntityCounter = ( + test_db.query(EntityCounter).filter(EntityCounter.prefix == "CCLW").one() + ) + assert row.counter == 1