diff --git a/app/api/api_v1/routers/search.py b/app/api/api_v1/routers/search.py index 6882f1d5..bb026e55 100644 --- a/app/api/api_v1/routers/search.py +++ b/app/api/api_v1/routers/search.py @@ -56,9 +56,9 @@ def _search_request(db: Session, search_body: SearchRequestBody) -> SearchRespon search_body.documents_only = True search_body.exact_match = False try: - data_access_search_params = create_vespa_search_params(db, search_body) - data_access_search_response = _VESPA_CONNECTION.search( - parameters=data_access_search_params + cpr_sdk_search_params = create_vespa_search_params(db, search_body) + cpr_sdk_search_response = _VESPA_CONNECTION.search( + parameters=cpr_sdk_search_params ) except QueryError as e: raise HTTPException( @@ -67,7 +67,7 @@ def _search_request(db: Session, search_body: SearchRequestBody) -> SearchRespon ) return process_vespa_search_response( db, - data_access_search_response, + cpr_sdk_search_response, limit=search_body.page_size, offset=search_body.offset, ).increment_pages() diff --git a/app/api/api_v1/schemas/search.py b/app/api/api_v1/schemas/search.py index bef5c6a8..530ccaac 100644 --- a/app/api/api_v1/schemas/search.py +++ b/app/api/api_v1/schemas/search.py @@ -251,6 +251,11 @@ class SearchResponseFamily(BaseModel): The geographical location of the family in ISO 3166-1 alpha-3 """ + family_geographies: List[str] + """ + The geographical locations of the family in ISO 3166-1 alpha-3 + """ + family_metadata: dict """ An object if metadata for the family, the schema will change given the family_source diff --git a/app/core/browse.py b/app/core/browse.py index 065edf93..e0f7288a 100644 --- a/app/core/browse.py +++ b/app/core/browse.py @@ -57,6 +57,7 @@ def to_search_response_family( family_last_updated_date=family_last_updated_date, family_source=cast(str, organisation.name), family_geography=geography_value, + family_geographies=[row.value for row in family.geographies], family_title_match=False, family_description_match=False, # ↓ Stuff we don't currently use for browse ↓ diff --git a/app/core/search.py b/app/core/search.py index c495c23c..6ba68d84 100644 --- a/app/core/search.py +++ b/app/core/search.py @@ -273,9 +273,9 @@ def _convert_filter_field(filter_field: str) -> Optional[str]: if filter_field == FilterField.CATEGORY: return filter_fields["category"] if filter_field == FilterField.COUNTRY: - return filter_fields["geography"] + return filter_fields["geographies"] if filter_field == FilterField.REGION: - return filter_fields["geography"] + return filter_fields["geographies"] if filter_field == FilterField.LANGUAGE: return filter_fields["language"] if filter_field == FilterField.SOURCE: @@ -310,7 +310,7 @@ def _convert_filters( new_keyword_filters[new_field] = new_values # Regions and countries filters should only include the overlap - geo_field = filter_fields["geography"] + geo_field = filter_fields["geographies"] if regions and countries: values = list(set(countries).intersection(regions)) if values: @@ -326,6 +326,7 @@ def _convert_filters( return None +# TODO: Add a test for this function def _process_vespa_search_response_families( db: Session, vespa_families: Sequence[CprSdkResponseFamily], @@ -341,6 +342,7 @@ def _process_vespa_search_response_families( vespa_families_to_process = vespa_families[offset : limit + offset] all_response_family_ids = [vf.id for vf in vespa_families_to_process] + # TODO: Potential disparity between what's in postgres and vespa family_and_family_metadata: Sequence[tuple[Family, FamilyMetadata]] = ( db.query(Family, FamilyMetadata) .filter(Family.import_id.in_(all_response_family_ids)) @@ -391,6 +393,7 @@ def _process_vespa_search_response_families( or hit.family_category is None or hit.family_source is None or hit.family_geography is None + or hit.family_geographies is None ): _LOGGER.error( "Skipping hit with empty required family info for import " @@ -424,6 +427,7 @@ def _process_vespa_search_response_families( prev_continuation_token=vespa_family.prev_continuation_token, family_documents=[], family_geography=hit.family_geography, + family_geographies=hit.family_geographies, family_metadata=cast(dict, db_family_metadata.value), ) response_family_lookup[family_import_id] = response_family @@ -478,7 +482,6 @@ def _process_vespa_search_response_families( response_families.append(response_family) response_family = None - return response_families diff --git a/poetry.lock b/poetry.lock index d6b315f0..b1ba25c2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -687,35 +687,35 @@ files = [ [[package]] name = "cpr-sdk" -version = "1.3.11" +version = "1.4.2" description = "" optional = false -python-versions = "<4.0,>=3.9" +python-versions = "<4.0,>=3.10" files = [ - {file = "cpr_sdk-1.3.11-py3-none-any.whl", hash = "sha256:4bcf0a172d17daaac37bd877af55c36db249c5fe6a439b840b4433326e6a3e2e"}, - {file = "cpr_sdk-1.3.11.tar.gz", hash = "sha256:df662c5947fae6af3cb5899cab3545747e7ac342abc64746ece564e0f4d16bac"}, + {file = "cpr_sdk-1.4.2-py3-none-any.whl", hash = "sha256:2dc20393c9623bb0f5d8fdde388a4c12b969df6115f083378e6628e1bed44ec4"}, + {file = "cpr_sdk-1.4.2.tar.gz", hash = "sha256:bfa7742c3e57c860b89da1c1406712b8aa4b17d6ae0f9b96030ecbd8d1cc5e32"}, ] [package.dependencies] aws-error-utils = ">=2.7.0,<3.0.0" -boto3 = ">=1.34.153,<2.0.0" -datasets = ">=2.19.2,<3.0.0" +boto3 = ">=1.35.14,<2.0.0" +datasets = ">=2.19.1,<3.0.0" deprecation = ">=2.1.0,<3.0.0" flatten-dict = ">=0.4.2,<0.5.0" langdetect = ">=1.0.9,<2.0.0" -numpy = ">=1.23.5" +numpy = ">=1.26.4,<2.0.0" pandas = ">=2.2.2,<3.0.0" poetry = ">=1.8.3,<2.0.0" -pydantic = ">=2.8.2,<3.0.0" +pydantic = ">=2.9.1,<3.0.0" pyvespa = {version = ">=0.45.0,<0.46.0", optional = true, markers = "extra == \"vespa\""} pyyaml = {version = ">=6.0.2,<7.0.0", optional = true, markers = "extra == \"vespa\""} sentence-transformers = {version = ">=2.2.2,<3.0.0", optional = true, markers = "extra == \"vespa\""} -torch = {version = ">=2.0.0,<3.0.0", optional = true, markers = "extra == \"vespa\""} -tqdm = ">=4.64.1,<5.0.0" +torch = {version = ">=2.0.0,<=2.2.2", optional = true, markers = "extra == \"vespa\""} +tqdm = ">=4.66.5,<5.0.0" [package.extras] spacy = ["spacy (>=3.5.1,<4.0.0)"] -vespa = ["pyvespa (>=0.45.0,<0.46.0)", "pyyaml (>=6.0.2,<7.0.0)", "sentence-transformers (>=2.2.2,<3.0.0)", "torch (>=2.0.0,<3.0.0)"] +vespa = ["pyvespa (>=0.45.0,<0.46.0)", "pyyaml (>=6.0.2,<7.0.0)", "sentence-transformers (>=2.2.2,<3.0.0)", "torch (>=2.0.0,<=2.2.2)"] [[package]] name = "crashtest" @@ -4853,4 +4853,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e577074f10b94b494aa291891e3b3fc4d29174e2cda7a09520b8094f06b5e210" +content-hash = "42467c6a6c180bb77ac30af1a721b5cf56cb8b777a32b5e63884322c10e131c0" diff --git a/pyproject.toml b/pyproject.toml index 11a21afc..9408bbec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "navigator_backend" -version = "1.14.20" +version = "1.15.0" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] @@ -10,7 +10,7 @@ python = "^3.10" Authlib = "^0.15.5" bcrypt = "^3.2.0" boto3 = "^1.26" -cpr_sdk = { version = "1.3.11", extras = ["vespa"] } +cpr_sdk = { version = "1.4.2", extras = ["vespa"] } fastapi = "^0.104.1" fastapi-health = "^0.4.0" fastapi-pagination = { extras = ["sqlalchemy"], version = "^0.12.19" } diff --git a/tests/search/test_search.py b/tests/search/test_search.py index 332d5cdb..2a539c08 100644 --- a/tests/search/test_search.py +++ b/tests/search/test_search.py @@ -8,6 +8,7 @@ from cpr_sdk.models.search import Family as CprSdkFamily from cpr_sdk.models.search import Filters as CprSdkFilters from cpr_sdk.models.search import Hit as CprSdkHit +from cpr_sdk.models.search import MetadataFilter from cpr_sdk.models.search import Passage as CprSdkPassage from cpr_sdk.models.search import SearchResponse as CprSdkSearchResponse from cpr_sdk.models.search import filter_fields @@ -39,7 +40,7 @@ ( "query_string,exact_match,year_range,sort_field,sort_order," "keyword_filters,max_passages,page_size,offset,continuation_tokens," - "family_ids,document_ids" + "family_ids,document_ids,metadata_filters,corpus_type_names,corpus_import_ids" ), [ ( @@ -55,6 +56,9 @@ None, None, None, + None, + None, + None, ), ( "world", @@ -69,6 +73,9 @@ ["ABC"], None, None, + None, + None, + None, ), ( "hello", @@ -83,6 +90,9 @@ None, None, None, + None, + None, + None, ), ( "world", @@ -100,6 +110,9 @@ ["ABC"], None, None, + None, + None, + None, ), ( "hello", @@ -114,6 +127,9 @@ None, None, None, + None, + None, + None, ), ( "world", @@ -128,6 +144,9 @@ ["ABC", "ADDD"], None, None, + None, + None, + None, ), ( "hello", @@ -142,6 +161,9 @@ None, None, ["CCLW.document.1.0"], + None, + None, + None, ), ( "world", @@ -156,6 +178,9 @@ ["ABC"], ["CCLW.executive.1.0"], None, + None, + None, + None, ), ( "hello", @@ -170,6 +195,9 @@ None, ["CCLW.executive.1.0"], ["CCLW.document.1.0", "CCLW.document.2.0"], + None, + None, + None, ), ( "world", @@ -184,6 +212,63 @@ ["ABC"], ["CCLW.executive.1.0", "CCLW.executive.2.0"], ["CCLW.document.1.0", "CCLW.document.2.0"], + None, + None, + None, + ), + ( + "world", + True, + None, + None, + "desc", + None, + 10, + 10, + 0, + None, + None, + None, + [ + {"name": "family.sector", "value": "Price"}, + {"name": "family.topic", "value": "Mitigation"}, + ], + None, + None, + ), + ( + "world", + True, + None, + None, + "desc", + None, + 10, + 10, + 0, + None, + None, + None, + None, + ["UNFCCC Submissions", "Laws and Policies"], + None, + ), + ( + "world", + True, + None, + None, + "desc", + None, + 10, + 10, + 0, + None, + None, + None, + None, + None, + ["CCLW.corpus.1.0", "CCLW.corpus.2.0"], ), ], ) @@ -201,6 +286,9 @@ def test_create_vespa_search_params( continuation_tokens, family_ids, document_ids, + metadata_filters, + corpus_type_names, + corpus_import_ids, ): search_request_body = SearchRequestBody( query_string=query_string, @@ -215,6 +303,13 @@ def test_create_vespa_search_params( page_size=page_size, offset=offset, continuation_tokens=continuation_tokens, + corpus_type_names=corpus_type_names, + corpus_import_ids=corpus_import_ids, + metadata=( + [MetadataFilter.model_validate(mdata) for mdata in metadata_filters] + if metadata_filters is not None + else [] + ), ) # First step, just make sure we can create a validated pydantic model @@ -256,7 +351,13 @@ def test_create_vespa_search_params( if "family_source" in converted_keyword_filters.keys() else [] ), + family_geographies=( + converted_keyword_filters["family_geographies"] + if "family_geographies" in converted_keyword_filters.keys() + else [] + ), ) + else: assert not produced_search_parameters.keyword_filters assert not produced_search_parameters.filters @@ -264,13 +365,24 @@ def test_create_vespa_search_params( assert produced_search_parameters.sort_by == sort_field assert produced_search_parameters.sort_order == sort_order + assert produced_search_parameters.metadata == ( + [ + MetadataFilter.model_validate({"name": mdata.name, "value": mdata.value}) + for mdata in produced_search_parameters.metadata + ] + if produced_search_parameters.metadata is not None + else [] + ) + assert corpus_type_names == produced_search_parameters.corpus_type_names + assert corpus_import_ids == produced_search_parameters.corpus_import_ids + @pytest.mark.search @pytest.mark.parametrize( ( "exact_match,year_range,sort_field,sort_order," "keyword_filters,max_passages,page_size,offset,continuation_tokens," - "family_ids,document_ids" + "family_ids,document_ids,metadata_filters,corpus_type_names,corpus_import_ids" ), [ ( @@ -285,6 +397,12 @@ def test_create_vespa_search_params( ["ABC"], None, None, + [ + {"name": "family.sector", "value": "Price"}, + {"name": "family.topic", "value": "Mitigation"}, + ], + None, + None, ), ( False, @@ -301,6 +419,9 @@ def test_create_vespa_search_params( ["ABC"], ["CCLW.document.1.0", "CCLW.document.2.0"], None, + None, + ["UNFCCC Submissions", "Laws and Policies"], + None, ), ( False, @@ -317,6 +438,9 @@ def test_create_vespa_search_params( ["ABC"], ["CCLW.executive.1.0", "CCLW.executive.2.0"], ["CCLW.document.1.0", "CCLW.document.2.0"], + None, + None, + ["CCLW.corpus.1.0", "CCLW.corpus.2.0"], ), ], ) @@ -332,6 +456,9 @@ def test_create_browse_request_params( continuation_tokens, family_ids, document_ids, + metadata_filters, + corpus_type_names, + corpus_import_ids, ): SearchRequestBody( query_string="", @@ -346,6 +473,13 @@ def test_create_browse_request_params( page_size=page_size, offset=offset, continuation_tokens=continuation_tokens, + corpus_type_names=corpus_type_names, + corpus_import_ids=corpus_import_ids, + metadata=( + [MetadataFilter.model_validate(mdata) for mdata in metadata_filters] + if metadata_filters is not None + else [] + ), ) @@ -368,50 +502,50 @@ def test_create_browse_request_params( }, None, ), - ({"regions": ["north-america"]}, {"family_geography": ["CAN", "USA"]}), + ({"regions": ["north-america"]}, {"family_geographies": ["CAN", "USA"]}), ( { "regions": ["north-america"], "countries": ["not-a-country"], }, - {"family_geography": ["CAN", "USA"]}, + {"family_geographies": ["CAN", "USA"]}, ), ( {"regions": ["north-america"], "countries": ["canada"]}, - {"family_geography": ["CAN"]}, + {"family_geographies": ["CAN"]}, ), - ({"countries": ["cambodia"]}, {"family_geography": ["KHM"]}), + ({"countries": ["cambodia"]}, {"family_geographies": ["KHM"]}), ({"countries": ["this-is-not-valid"]}, None), ( {"countries": ["france", "germany"]}, - {"family_geography": ["FRA", "DEU"]}, + {"family_geographies": ["FRA", "DEU"]}, ), ( {"countries": ["cambodia"], "categories": ["Executive"]}, - {"family_category": ["Executive"], "family_geography": ["KHM"]}, + {"family_category": ["Executive"], "family_geographies": ["KHM"]}, ), ( {"countries": ["cambodia"], "languages": ["english"]}, - {"document_languages": ["english"], "family_geography": ["KHM"]}, + {"document_languages": ["english"], "family_geographies": ["KHM"]}, ), ( {"countries": ["cambodia"], "sources": ["CCLW"]}, - {"family_source": ["CCLW"], "family_geography": ["KHM"]}, + {"family_source": ["CCLW"], "family_geographies": ["KHM"]}, ), ( { "regions": ["north-america"], "categories": ["Executive"], }, - {"family_category": ["Executive"], "family_geography": ["CAN", "USA"]}, + {"family_category": ["Executive"], "family_geographies": ["CAN", "USA"]}, ), ( {"regions": ["north-america"], "languages": ["english"]}, - {"document_languages": ["english"], "family_geography": ["CAN", "USA"]}, + {"document_languages": ["english"], "family_geographies": ["CAN", "USA"]}, ), ( {"regions": ["north-america"], "sources": ["CCLW"]}, - {"family_source": ["CCLW"], "family_geography": ["CAN", "USA"]}, + {"family_source": ["CCLW"], "family_geographies": ["CAN", "USA"]}, ), ({"categories": ["Executive"]}, {"family_category": ["Executive"]}), ({"languages": ["english"]}, {"document_languages": ["english"]}), @@ -448,6 +582,7 @@ class FamSpec: family_category: str family_ts: str family_geo: str + family_geos: list[str] family_metadata: dict[str, list[str]] description_hit: bool @@ -489,6 +624,7 @@ def _generate_search_response_hits(spec: FamSpec) -> Sequence[CprSdkHit]: family_category=spec.family_category, family_publication_ts=datetime.fromisoformat(spec.family_ts), family_geography=spec.family_geo, + family_geographies=spec.family_geos, document_cdn_object=( f"{spec.family_import_id}/{slugify(spec.family_name)}" f"_{document_number}" @@ -520,6 +656,7 @@ def _generate_search_response_hits(spec: FamSpec) -> Sequence[CprSdkHit]: family_category=spec.family_category, family_publication_ts=datetime.fromisoformat(spec.family_ts), family_geography=spec.family_geo, + family_geographies=spec.family_geos, document_cdn_object=( f"{spec.family_import_id}/{slugify(spec.family_name)}" f"_{document_number}" @@ -585,6 +722,7 @@ def _generate_search_response(specs: Sequence[FamSpec]) -> CprSdkSearchResponse: family_category="Executive", family_ts="2023-12-12", family_geo="france", + family_geos=["france"], family_metadata={"keyword": ["Spacial Planning"]}, description_hit=True, family_document_count=1, @@ -599,6 +737,7 @@ def _generate_search_response(specs: Sequence[FamSpec]) -> CprSdkSearchResponse: family_category="Legislative", family_ts="2022-12-25", family_geo="spain", + family_geos=["spain"], family_metadata={"sector": ["Urban", "Transportation"], "keyword": ["Hydrogen"]}, description_hit=False, family_document_count=3, @@ -613,6 +752,7 @@ def _generate_search_response(specs: Sequence[FamSpec]) -> CprSdkSearchResponse: family_category="UNFCCC", family_ts="2019-01-01", family_geo="ukraine", + family_geos=["ukraine"], family_metadata={"author_type": ["Non-Party"], "author": ["Anyone"]}, description_hit=True, family_document_count=5, @@ -627,6 +767,7 @@ def _generate_search_response(specs: Sequence[FamSpec]) -> CprSdkSearchResponse: family_category="UNFCCC", family_ts="2010-03-14", family_geo="norway", + family_geos=["norway"], family_metadata={"author_type": ["Party"], "author": ["Anyone Else"]}, description_hit=False, family_document_count=2, @@ -645,17 +786,15 @@ def populate_data_db(db: Session, fam_specs: Sequence[FamSpec]) -> None: family_category=FamilyCategory(fam_spec.family_category), ) db.add(family) - db.add( - FamilyGeography( - family_import_id=fam_spec.family_import_id, - geography_id=( - db.query(Geography) - .filter(Geography.slug == fam_spec.family_geo) - .one() - .id - ), + for fam_geo in fam_spec.family_geos: + db.add( + FamilyGeography( + family_import_id=fam_spec.family_import_id, + geography_id=( + db.query(Geography).filter(Geography.slug == fam_geo).one().id + ), + ) ) - ) family_event = FamilyEvent( import_id=f"{fam_spec.family_source}.event.{fam_spec.family_import_id.split('.')[2]}.0", title="Published", diff --git a/tests/search/vespa/setup_search_tests.py b/tests/search/vespa/setup_search_tests.py index 93a308d4..6cbcd458 100644 --- a/tests/search/vespa/setup_search_tests.py +++ b/tests/search/vespa/setup_search_tests.py @@ -1,5 +1,6 @@ import json import random +from collections import defaultdict from datetime import datetime from pathlib import Path from typing import Iterable, Mapping, Optional, Sequence @@ -72,7 +73,9 @@ def _fixture_docs() -> Iterable[tuple[VespaFixture, VespaFixture]]: yield doc, family -def _populate_db_families(db: Session, max_docs: int = VESPA_FIXTURE_COUNT) -> None: +def _populate_db_families( + db: Session, max_docs: int = VESPA_FIXTURE_COUNT, deterministic_metadata=False +) -> None: """ Sets up the database using fixtures @@ -84,7 +87,10 @@ def _populate_db_families(db: Session, max_docs: int = VESPA_FIXTURE_COUNT) -> N if doc["fields"]["family_document_ref"] not in seen_family_ids: _create_family(db, family) _create_family_event(db, family) - _create_family_metadata(db, family) + if not deterministic_metadata: + _create_family_metadata(db, family) + else: + _create_family_metadata_deterministic(db, family) seen_family_ids.append(doc["fields"]["family_document_ref"]) _create_document(db, doc, family) if count == max_docs: @@ -191,6 +197,24 @@ def _create_family_metadata(db: Session, family: VespaFixture): db.commit() +def _create_family_metadata_deterministic(db: Session, family: VespaFixture): + metadata_values = defaultdict(list) + family_metadata = family["fields"]["metadata"] + if not family_metadata: + return + for metadata in family_metadata: + name = metadata["name"] # type: ignore + value = metadata["value"] # type: ignore + metadata_values[name].append(value) + + family_metadata = FamilyMetadata( + family_import_id=family["fields"]["family_import_id"], + value=metadata_values, + ) + db.add(family_metadata) + db.commit() + + def _create_document( db: Session, doc: VespaFixture, diff --git a/tests/search/vespa/test_range_and_keyword_filters_search.py b/tests/search/vespa/test_range_and_keyword_filters_search.py index 6bf0520d..9e63bb7a 100644 --- a/tests/search/vespa/test_range_and_keyword_filters_search.py +++ b/tests/search/vespa/test_range_and_keyword_filters_search.py @@ -13,7 +13,7 @@ @pytest.mark.search @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) -def test_keyword_country_filters( +def test_keyword_country_filters__geography( label, query, test_vespa, data_client, data_db, monkeypatch ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) @@ -28,6 +28,7 @@ def test_keyword_country_filters( assert len(families) == VESPA_FIXTURE_COUNT for family in families: + assert family["family_geography"] in family["family_geographies"] country_code = family["family_geography"] country_slug = get_country_slug_from_country_code(data_db, country_code) @@ -41,6 +42,39 @@ def test_keyword_country_filters( assert family["family_slug"] in filtered_family_slugs +@pytest.mark.search +@pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) +def test_keyword_country_filters__geographies( + label, query, test_vespa, data_client, data_db, monkeypatch +): + monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) + _populate_db_families(data_db) + base_params = {"query_string": query} + + # Get all documents and iterate over their country codes to confirm that each are + # the specific one that is returned in the query (as they each have a unique + # country code) + all_body = _make_search_request(data_client, params=base_params) + families = [f for f in all_body["families"]] + assert len(families) == VESPA_FIXTURE_COUNT + + for family in families: + assert family["family_geography"] in family["family_geographies"] + for country_code in family["family_geographies"]: + country_slug = get_country_slug_from_country_code(data_db, country_code) + + params = { + **base_params, + **{"keyword_filters": {"countries": [country_slug]}}, + } + body_with_filters = _make_search_request(data_client, params=params) + filtered_family_slugs = [ + f["family_slug"] for f in body_with_filters["families"] + ] + assert len(filtered_family_slugs) == 1 + assert family["family_slug"] in filtered_family_slugs + + @pytest.mark.search @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_keyword_region_filters( diff --git a/tests/search/vespa/test_vespa_search.py b/tests/search/vespa/test_vespa_search.py index e34fbbd1..5c03fdc4 100644 --- a/tests/search/vespa/test_vespa_search.py +++ b/tests/search/vespa/test_vespa_search.py @@ -84,6 +84,7 @@ def test_no_doc_if_in_postgres_but_not_vespa( "document_slug": "aslug", "family_description": "", "family_geography": "CAN", + "family_geographies": ["CAN"], "family_publication_ts": "2011-08-01T00:00:00+00:00", "family_import_id": "CCLW.family.111.0", }, @@ -205,3 +206,44 @@ def test_search_with_deleted_docs(test_vespa, monkeypatch, data_client, data_db) all_deleted_count = len(all_deleted_body["families"]) assert start_family_count > one_deleted_count > all_deleted_count assert len(all_deleted_body["families"]) == 0 + + +@pytest.mark.search +@pytest.mark.parametrize( + "label,query,metadata_filters", + [ + ("search", "the", [{"name": "sector", "value": "Price"}]), + ( + "browse", + "", + [ + {"name": "topic", "value": "Mitigation"}, + {"name": "instrument", "value": "Capacity building"}, + ], + ), + ], +) +def test_metadata_filter( + label, query, metadata_filters, test_vespa, data_db, monkeypatch, data_client +): + monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) + + _populate_db_families(data_db, deterministic_metadata=True) + + response = data_client.post( + SEARCH_ENDPOINT, + json={ + "query_string": query, + "metadata": metadata_filters, + }, + ) + assert response.status_code == 200 + assert len(response.json()["families"]) > 0 + + for metadata_filter in metadata_filters: + for f in response.json()["families"]: + assert metadata_filter["name"] in f["family_metadata"] + assert ( + metadata_filter["value"] + in f["family_metadata"][metadata_filter["name"]] + ) diff --git a/tests/unit/app/schemas/test_schemas.py b/tests/unit/app/schemas/test_schemas.py index 6dfb21d9..677f587f 100644 --- a/tests/unit/app/schemas/test_schemas.py +++ b/tests/unit/app/schemas/test_schemas.py @@ -132,6 +132,7 @@ def test_search_response() -> None: ), # You can replace this with an actual date string family_source="Example Source", family_geography="Example Geography", + family_geographies=["Example Geography"], family_metadata={"key1": "value1", "key2": "value2"}, family_title_match=True, family_description_match=False,