Skip to content

Commit

Permalink
Refactor processing of tags
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Oct 19, 2022
1 parent 77940b2 commit 5768147
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions tools/import_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from server.application.organizations.queries import GetOrganizationBySiret
from server.application.tags.queries import GetAllTags
from server.config.di import bootstrap, resolve
from server.domain.common.types import id_factory
from server.domain.common.types import ID, id_factory
from server.domain.datasets.entities import DataFormat, UpdateFrequency
from server.domain.organizations.types import Siret
from server.seedwork.application.messages import MessageBus
Expand Down Expand Up @@ -126,25 +126,29 @@ def _map_last_updated_at(value: Optional[str], config: Config) -> Optional[dt.da


def _map_tag_ids(
value: Optional[str], tags_by_name: Mapping[str, dict], tags_to_create: List[dict]
value: Optional[str],
existing_tag_ids_by_name: Mapping[str, ID],
tags_to_create: List[dict],
) -> List[str]:
if not value:
return []

tag_ids = []

# Split and normalize tag names. For example:
# "périmètre délimité des abords (PDA), urbanisme; géolocalisation"
# -> {"périmètre délimité des abords (PDA)", "urbanisme", "géolocalisation"}
# -> {"périmètre délimité des abords (PDA)", "urbanisme", "géolocalisation"}
cleaned_names = set(name.strip() for name in value.replace(";", ",").split(","))

for name in cleaned_names:
try:
tag = tags_by_name[name]
tag_id = existing_tag_ids_by_name[name]
except KeyError:
tag = {"id": str(id_factory()), "params": {"name": name}}
tag_id = id_factory()
tag = {"id": str(tag_id), "params": {"name": name}}
tags_to_create.append(tag)

tag_ids.append(str(tag["id"]))
tag_ids.append(str(tag_id))

return tag_ids

Expand Down Expand Up @@ -187,7 +191,7 @@ async def main(config_path: Path, out_path: Path) -> int:
expected_extra_fields = {field.name for field in catalog.extra_fields}

tags = await bus.execute(GetAllTags())
tags_by_name = {tag.name: tag.dict() for tag in tags}
existing_tag_ids_by_name = {tag.name: tag.id for tag in tags}

with config.input_csv.path.open(encoding=config.input_csv.encoding) as f:
reader = csv.DictReader(f, delimiter=config.input_csv.delimiter)
Expand Down Expand Up @@ -258,7 +262,7 @@ async def main(config_path: Path, out_path: Path) -> int:
params["url"] = row["url"] or None
params["license"] = row["licence"] or None
params["tag_ids"] = _map_tag_ids(
row["mots_cles"] or None, tags_by_name, tags_to_create
row["mots_cles"] or None, existing_tag_ids_by_name, tags_to_create
)

params["extra_field_values"] = _map_extra_field_values(
Expand Down

0 comments on commit 5768147

Please sign in to comment.