diff --git a/src/cpr_sdk/models/search.py b/src/cpr_sdk/models/search.py index 57518fd..c9ccc5f 100644 --- a/src/cpr_sdk/models/search.py +++ b/src/cpr_sdk/models/search.py @@ -13,8 +13,6 @@ ) -from cpr_sdk.exceptions import QueryError - # Value Lookup Tables sort_orders = { "asc": "+", @@ -70,36 +68,87 @@ class SearchParameters(BaseModel): """Parameters for a search request""" query_string: Optional[str] = "" + """ + A string representation of the search to be performed. + For example: 'Adaptation strategy'" + """ + exact_match: bool = False + """ + Indicate if the `query_string` should be treated as an exact match when + the search is performed. + """ + all_results: bool = False + """ + Return all results rather than searching or ranking + + Filters can still be applied + """ + documents_only: bool = False - limit: int = Field(ge=0, default=100) + """Ignores passages in search when true.""" + + limit: int = Field(ge=0, default=100, le=500) + """ + Refers to the maximum number of results to return from the " + query result. + """ + max_hits_per_family: int = Field( validation_alias=AliasChoices("max_passages_per_doc", "max_hits_per_family"), default=10, ge=0, + le=500, ) + """ + The maximum number of matched passages to be returned for a " + single document. + """ family_ids: Optional[Sequence[str]] = None + """Optionally limit a search to a specific set of family ids.""" + document_ids: Optional[Sequence[str]] = None + """Optionally limit a search to a specific set of document ids.""" filters: Optional[Filters] = None + """Filter results to matching filter items.""" + year_range: Optional[tuple[Optional[int], Optional[int]]] = None + """ + The years to search between. Containing exactly two values, + which can be null or an integer representing the years to + search between. These are inclusive and can be null. Example: + [null, 2010] will return all documents return in or before 2010. + """ sort_by: Optional[str] = Field( validation_alias=AliasChoices("sort_field", "sort_by"), default=None ) + """The field to sort by can be chosen from `date` or `title`.""" + sort_order: str = "descending" + """ + The order of the results according to the `sort_field`, can be chosen from + ascending (use “asc”) or descending (use “desc”). + """ continuation_tokens: Optional[Sequence[str]] = None + """ + Use to return the next page of results from a specific search, the next token + can be found on the response object. It's also possible to get the next page + of passages by including the family level continuation token first in the + array followed by the passage level one. + """ @model_validator(mode="after") def validate(self): """Validate against mutually exclusive fields""" if self.exact_match and self.all_results: - raise QueryError("`exact_match` and `all_results` are mutually exclusive") + raise ValueError("`exact_match` and `all_results` are mutually exclusive") if self.documents_only and not self.all_results: - raise QueryError( + raise ValueError( "`documents_only` requires `all_results`, other queries are not supported" ) return self @@ -114,9 +163,9 @@ def continuation_tokens_must_be_upper_strings(cls, continuation_tokens): if token == "": continue if not token.isalpha(): - raise QueryError(f"Expected continuation tokens to be letters: {token}") + raise ValueError(f"Expected continuation tokens to be letters: {token}") if not token.isupper(): - raise QueryError( + raise ValueError( f"Expected continuation tokens to be uppercase: {token}" ) return continuation_tokens @@ -142,7 +191,7 @@ def ids_must_fit_pattern(cls, ids): if ids: for _id in ids: if not re.fullmatch(ID_PATTERN, _id): - raise QueryError(f"id seems invalid: {_id}") + raise ValueError(f"id seems invalid: {_id}") return ids @field_validator("year_range") @@ -151,7 +200,7 @@ def year_range_must_be_valid(cls, year_range): if year_range is not None: if year_range[0] is not None and year_range[1] is not None: if year_range[0] > year_range[1]: - raise QueryError( + raise ValueError( "The first supplied year must be less than or equal to the " f"second supplied year. Received: {year_range}" ) @@ -162,7 +211,7 @@ def sort_by_must_be_valid(cls, sort_by): """Validate that the sort field is valid.""" if sort_by is not None: if sort_by not in sort_fields: - raise QueryError( + raise ValueError( f"Invalid sort field: {sort_by}. sort_by must be one of: " f"{list(sort_fields.keys())}" ) @@ -172,7 +221,7 @@ def sort_by_must_be_valid(cls, sort_by): def sort_order_must_be_valid(cls, sort_order): """Validate that the sort order is valid.""" if sort_order not in sort_orders: - raise QueryError( + raise ValueError( f"Invalid sort order: {sort_order}. sort_order must be one of: " f"{sort_orders}" ) diff --git a/src/cpr_sdk/search_adaptors.py b/src/cpr_sdk/search_adaptors.py index 0216cc6..187f7c5 100644 --- a/src/cpr_sdk/search_adaptors.py +++ b/src/cpr_sdk/search_adaptors.py @@ -1,4 +1,5 @@ """Adaptors for searching CPR data""" + import time from abc import ABC from pathlib import Path diff --git a/src/cpr_sdk/version.py b/src/cpr_sdk/version.py index bc6a02f..c4b1f49 100644 --- a/src/cpr_sdk/version.py +++ b/src/cpr_sdk/version.py @@ -1,6 +1,6 @@ _MAJOR = "1" _MINOR = "1" -_PATCH = "5" +_PATCH = "6" _SUFFIX = "" VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) diff --git a/tests/conftest.py b/tests/conftest.py index 8ec45ba..ef55d29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,18 +6,25 @@ import boto3 from moto import mock_aws +from cpr_sdk.search_adaptors import VespaSearchAdapter + VESPA_TEST_SEARCH_URL = "http://localhost:8080" @pytest.fixture() -def fake_vespa_credentials(): - with tempfile.TemporaryDirectory() as tmpdir: - with open(Path(tmpdir) / "cert.pem", "w"): +def test_vespa(): + """Vespa adapter pointing to test url and using empty cert files""" + with tempfile.TemporaryDirectory() as tmpdir_cert_dir: + with open(Path(tmpdir_cert_dir) / "cert.pem", "w"): pass - with open(Path(tmpdir) / "key.pem", "w"): + with open(Path(tmpdir_cert_dir) / "key.pem", "w"): pass - yield tmpdir + adaptor = VespaSearchAdapter( + instance_url=VESPA_TEST_SEARCH_URL, cert_directory=tmpdir_cert_dir + ) + + yield adaptor @pytest.fixture() diff --git a/tests/test_search_adaptors.py b/tests/test_search_adaptors.py index e6954fa..68a5253 100644 --- a/tests/test_search_adaptors.py +++ b/tests/test_search_adaptors.py @@ -1,24 +1,23 @@ -from unittest.mock import patch from timeit import timeit from typing import Mapping +from unittest.mock import patch import pytest -from conftest import VESPA_TEST_SEARCH_URL + from cpr_sdk.models.search import ( + Document, + Passage, SearchParameters, SearchResponse, sort_fields, - Document, - Passage, ) from cpr_sdk.search_adaptors import VespaSearchAdapter -def vespa_search(cert_directory: str, request: SearchParameters) -> SearchResponse: +def vespa_search( + adaptor: VespaSearchAdapter, request: SearchParameters +) -> SearchResponse: try: - adaptor = VespaSearchAdapter( - instance_url=VESPA_TEST_SEARCH_URL, cert_directory=cert_directory - ) response = adaptor.search(request) except Exception as e: pytest.fail(f"Vespa query failed. {e.__class__.__name__}: {e}") @@ -26,10 +25,10 @@ def vespa_search(cert_directory: str, request: SearchParameters) -> SearchRespon def profile_search( - fake_vespa_credentials, params: Mapping[str, str], n: int = 25 + test_vespa: VespaSearchAdapter, params: Mapping[str, str], n: int = 25 ) -> float: t = timeit( - lambda: vespa_search(fake_vespa_credentials, SearchParameters(**params)), + lambda: vespa_search(test_vespa, SearchParameters(**params)), number=n, ) avg_ms = (t / n) * 1000 @@ -37,9 +36,9 @@ def profile_search( @pytest.mark.vespa -def test_vespa_search_adaptor__works(fake_vespa_credentials): +def test_vespa_search_adaptor__works(test_vespa): request = SearchParameters(query_string="the") - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) assert len(response.families) == response.total_family_hits == 3 assert response.query_time_ms < response.total_time_ms @@ -59,10 +58,10 @@ def test_vespa_search_adaptor__works(fake_vespa_credentials): ), ) @pytest.mark.vespa -def test_vespa_search_adaptor__is_fast_enough(fake_vespa_credentials, params): +def test_vespa_search_adaptor__is_fast_enough(test_vespa, params): MAX_SPEED_MS = 850 - avg_ms = profile_search(fake_vespa_credentials, params=params) + avg_ms = profile_search(test_vespa, params=params) assert avg_ms <= MAX_SPEED_MS @@ -76,9 +75,9 @@ def test_vespa_search_adaptor__is_fast_enough(fake_vespa_credentials, params): ["CCLW.family.4934.0"], ], ) -def test_vespa_search_adaptor__family_ids(fake_vespa_credentials, family_ids): +def test_vespa_search_adaptor__family_ids(test_vespa, family_ids): request = SearchParameters(query_string="the", family_ids=family_ids) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) got_family_ids = [f.id for f in response.families] assert sorted(got_family_ids) == sorted(family_ids) @@ -93,9 +92,9 @@ def test_vespa_search_adaptor__family_ids(fake_vespa_credentials, family_ids): ["CCLW.executive.4934.1571"], ], ) -def test_vespa_search_adaptor__document_ids(fake_vespa_credentials, document_ids): +def test_vespa_search_adaptor__document_ids(test_vespa, document_ids): request = SearchParameters(query_string="the", document_ids=document_ids) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) # As passages are returned we need to collect and deduplicate them to get id list got_document_ids = [] @@ -108,20 +107,20 @@ def test_vespa_search_adaptor__document_ids(fake_vespa_credentials, document_ids @pytest.mark.vespa -def test_vespa_search_adaptor__bad_query_string_still_works(fake_vespa_credentials): +def test_vespa_search_adaptor__bad_query_string_still_works(test_vespa): family_name = ' Bad " query/ ' request = SearchParameters(query_string=family_name) try: - vespa_search(fake_vespa_credentials, request) + vespa_search(test_vespa, request) except Exception as e: assert False, f"failed with: {e}" @pytest.mark.vespa -def test_vespa_search_adaptor__hybrid(fake_vespa_credentials): +def test_vespa_search_adaptor__hybrid(test_vespa): family_name = "Climate Change Adaptation and Low Emissions Growth Strategy by 2035" request = SearchParameters(query_string=family_name) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) # Was the family searched for in the results. # Note that this is a fairly loose test @@ -133,9 +132,9 @@ def test_vespa_search_adaptor__hybrid(fake_vespa_credentials): @pytest.mark.vespa -def test_vespa_search_adaptor__all(fake_vespa_credentials): +def test_vespa_search_adaptor__all(test_vespa): request = SearchParameters(query_string="", all_results=True) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) assert len(response.families) == response.total_family_hits # Filtering should still work @@ -143,16 +142,16 @@ def test_vespa_search_adaptor__all(fake_vespa_credentials): request = SearchParameters( query_string="", all_results=True, family_ids=[family_id] ) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) assert len(response.families) == 1 assert response.families[0].id == family_id @pytest.mark.vespa -def test_vespa_search_adaptor__exact(fake_vespa_credentials): +def test_vespa_search_adaptor__exact(test_vespa): query_string = "Environmental Strategy for 2014-2023" request = SearchParameters(query_string=query_string, exact_match=True) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) got_family_names = [] for fam in response.families: for doc in fam.hits: @@ -165,15 +164,15 @@ def test_vespa_search_adaptor__exact(fake_vespa_credentials): # Conversely we'd expect nothing if the query string isnt present query_string = "no such string as this can be found in the test documents" request = SearchParameters(query_string=query_string, exact_match=True) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) assert len(response.families) == 0 @pytest.mark.vespa @patch("cpr_sdk.vespa.SENSITIVE_QUERY_TERMS", {"Government"}) -def test_vespa_search_adaptor__sensitive(fake_vespa_credentials): +def test_vespa_search_adaptor__sensitive(test_vespa): request = SearchParameters(query_string="Government") - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) # Without being too prescriptive, we'd expect something back for this assert len(response.families) > 0 @@ -186,28 +185,44 @@ def test_vespa_search_adaptor__sensitive(fake_vespa_credentials): (1, 100), (2, 1), (2, 5), - (3, 1000), + (3, 500), ], ) @pytest.mark.vespa -def test_vespa_search_adaptor__limits( - fake_vespa_credentials, family_limit, max_hits_per_family -): +def test_vespa_search_adaptor__limits(test_vespa, family_limit, max_hits_per_family): request = SearchParameters( query_string="the", family_ids=[], limit=family_limit, max_hits_per_family=max_hits_per_family, ) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) assert len(response.families) == family_limit for fam in response.families: assert len(fam.hits) <= max_hits_per_family +@pytest.mark.parametrize( + "family_limit, max_hits_per_family", + [ + (501, 5), + (3, 501), + ], +) +@pytest.mark.vespa +def test_vespa_search_adaptor__limits__errors(family_limit, max_hits_per_family): + with pytest.raises(ValueError): + SearchParameters( + query_string="the", + family_ids=[], + limit=family_limit, + max_hits_per_family=max_hits_per_family, + ) + + @pytest.mark.vespa -def test_vespa_search_adaptor__continuation_tokens__families(fake_vespa_credentials): +def test_vespa_search_adaptor__continuation_tokens__families(test_vespa): query_string = "the" limit = 2 max_hits_per_family = 3 @@ -218,7 +233,7 @@ def test_vespa_search_adaptor__continuation_tokens__families(fake_vespa_credenti limit=limit, max_hits_per_family=max_hits_per_family, ) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) first_family_ids = [f.id for f in response.families] family_continuation = response.continuation_token assert len(response.families) == 2 @@ -231,7 +246,7 @@ def test_vespa_search_adaptor__continuation_tokens__families(fake_vespa_credenti max_hits_per_family=max_hits_per_family, continuation_tokens=[family_continuation], ) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) prev_family_continuation = response.prev_continuation_token assert len(response.families) == 1 assert response.total_family_hits == 3 @@ -249,13 +264,13 @@ def test_vespa_search_adaptor__continuation_tokens__families(fake_vespa_credenti max_hits_per_family=max_hits_per_family, continuation_tokens=[prev_family_continuation], ) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) prev_family_ids = [f.id for f in response.families] assert prev_family_ids == first_family_ids @pytest.mark.vespa -def test_vespa_search_adaptor__continuation_tokens__passages(fake_vespa_credentials): +def test_vespa_search_adaptor__continuation_tokens__passages(test_vespa): query_string = "the" limit = 1 max_hits_per_family = 10 @@ -266,7 +281,7 @@ def test_vespa_search_adaptor__continuation_tokens__passages(fake_vespa_credenti limit=limit, max_hits_per_family=max_hits_per_family, ) - initial_response = vespa_search(fake_vespa_credentials, request) + initial_response = vespa_search(test_vespa, request) # Collect family & hits for comparison later initial_family_id = initial_response.families[0].id @@ -282,7 +297,7 @@ def test_vespa_search_adaptor__continuation_tokens__passages(fake_vespa_credenti max_hits_per_family=max_hits_per_family, continuation_tokens=[this_continuation, passage_continuation], ) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) prev_passage_continuation = response.families[0].prev_continuation_token # Family should not have changed @@ -299,7 +314,7 @@ def test_vespa_search_adaptor__continuation_tokens__passages(fake_vespa_credenti max_hits_per_family=max_hits_per_family, continuation_tokens=[this_continuation, prev_passage_continuation], ) - response = vespa_search(fake_vespa_credentials, request) + response = vespa_search(test_vespa, request) assert response.families[0].id == initial_family_id prev_passages = sorted([h.text_block_id for h in response.families[0].hits]) assert sorted(prev_passages) != sorted(new_passages) @@ -308,7 +323,7 @@ def test_vespa_search_adaptor__continuation_tokens__passages(fake_vespa_credenti @pytest.mark.vespa def test_vespa_search_adaptor__continuation_tokens__families_and_passages( - fake_vespa_credentials, + test_vespa, ): query_string = "the" limit = 1 @@ -320,7 +335,7 @@ def test_vespa_search_adaptor__continuation_tokens__families_and_passages( limit=limit, max_hits_per_family=max_hits_per_family, ) - response_one = vespa_search(fake_vespa_credentials, request_one) + response_one = vespa_search(test_vespa, request_one) # Increment Families request_two = SearchParameters( @@ -329,7 +344,7 @@ def test_vespa_search_adaptor__continuation_tokens__families_and_passages( max_hits_per_family=max_hits_per_family, continuation_tokens=[response_one.continuation_token], ) - response_two = vespa_search(fake_vespa_credentials, request_two) + response_two = vespa_search(test_vespa, request_two) # Then Increment Passages Twice @@ -342,7 +357,7 @@ def test_vespa_search_adaptor__continuation_tokens__families_and_passages( response_two.families[0].continuation_token, ], ) - response_three = vespa_search(fake_vespa_credentials, request_three) + response_three = vespa_search(test_vespa, request_three) request_four = SearchParameters( query_string=query_string, @@ -353,7 +368,7 @@ def test_vespa_search_adaptor__continuation_tokens__families_and_passages( response_three.families[0].continuation_token, ], ) - response_four = vespa_search(fake_vespa_credentials, request_four) + response_four = vespa_search(test_vespa, request_four) # All of these should have different passages from each other assert ( @@ -366,13 +381,13 @@ def test_vespa_search_adaptor__continuation_tokens__families_and_passages( @pytest.mark.parametrize("sort_by", sort_fields.keys()) @pytest.mark.vespa -def test_vespa_search_adapter_sorting(fake_vespa_credentials, sort_by): +def test_vespa_search_adapter_sorting(test_vespa, sort_by): ascend = vespa_search( - fake_vespa_credentials, + test_vespa, SearchParameters(query_string="the", sort_by=sort_by, sort_order="ascending"), ) descend = vespa_search( - fake_vespa_credentials, + test_vespa, SearchParameters(query_string="the", sort_by=sort_by, sort_order="descending"), ) @@ -380,9 +395,9 @@ def test_vespa_search_adapter_sorting(fake_vespa_credentials, sort_by): @pytest.mark.vespa -def test_vespa_search_no_passages_search(fake_vespa_credentials): +def test_vespa_search_no_passages_search(test_vespa): no_passages = vespa_search( - fake_vespa_credentials, + test_vespa, SearchParameters(all_results=True, documents_only=True), ) for family in no_passages.families: @@ -390,7 +405,7 @@ def test_vespa_search_no_passages_search(fake_vespa_credentials): assert isinstance(hit, Document) with_passages = vespa_search( - fake_vespa_credentials, + test_vespa, SearchParameters(all_results=True), ) found_a_passage = False diff --git a/tests/test_search_requests.py b/tests/test_search_requests.py index bc86cd6..7d3fc56 100644 --- a/tests/test_search_requests.py +++ b/tests/test_search_requests.py @@ -3,7 +3,6 @@ import pytest from cpr_sdk.embedding import Embedder -from cpr_sdk.exceptions import QueryError from cpr_sdk.models.search import ( Filters, SearchParameters, @@ -57,9 +56,8 @@ def test_whether_an_empty_query_string_does_all_result_search(): def test_wether_documents_only_without_all_results_raises_error(): q = "Search" - with pytest.raises(QueryError) as excinfo: + with pytest.raises(ValidationError) as excinfo: SearchParameters(query_string=q, documents_only=True) - assert "Failed to build query" in str(excinfo.value) assert "`documents_only` requires `all_results`" in str(excinfo.value) # They should be fine otherwise: @@ -71,9 +69,8 @@ def test_wether_documents_only_without_all_results_raises_error(): def test_wether_combining_all_results_and_exact_match_raises_error(): q = "Search" - with pytest.raises(QueryError) as excinfo: + with pytest.raises(ValidationError) as excinfo: SearchParameters(query_string=q, exact_match=True, all_results=True) - assert "Failed to build query" in str(excinfo.value) assert "`exact_match` and `all_results`" in str(excinfo.value) # They should be fine independently: @@ -102,8 +99,8 @@ def test_whether_valid_year_ranges_are_accepted(year_range): assert isinstance(params, SearchParameters) -def test_whether_an_invalid_year_range_ranges_raises_a_queryerror(): - with pytest.raises(QueryError) as excinfo: +def test_whether_an_invalid_year_range_ranges_raises_a_validation_error(): + with pytest.raises(ValidationError) as excinfo: SearchParameters(query_string="test", year_range=(2023, 2000)) assert ( "The first supplied year must be less than or equal to the second supplied year" @@ -129,8 +126,8 @@ def test_whether_valid_family_ids_are_accepted(): "UNFCCC.family.i00000003.n000.11", ], ) -def test_whether_an_invalid_family_id_raises_a_queryerror(bad_id): - with pytest.raises(QueryError) as excinfo: +def test_whether_an_invalid_family_id_raises_a_validation_error(bad_id): + with pytest.raises(ValidationError) as excinfo: SearchParameters( query_string="test", family_ids=("CCLW.family.i00000003.n0000", bad_id), @@ -157,8 +154,8 @@ def test_whether_valid_document_ids_are_accepted(): "UNFCCC.doc.i00000003", ], ) -def test_whether_an_invalid_document_id_raises_a_queryerror(bad_id): - with pytest.raises(QueryError) as excinfo: +def test_whether_an_invalid_document_id_raises_a_validation_error(bad_id): + with pytest.raises(ValidationError) as excinfo: SearchParameters( query_string="test", document_ids=(bad_id, "CCLW.document.i00000004.n0000"), @@ -174,8 +171,8 @@ def test_whether_valid_sort_fields_are_accepted(field): assert isinstance(params, SearchParameters) -def test_whether_an_invalid_sort_field_raises_a_queryerror(): - with pytest.raises(QueryError) as excinfo: +def test_whether_an_invalid_sort_field_raises_a_validation_error(): + with pytest.raises(ValidationError) as excinfo: SearchParameters(query_string="test", sort_by="invalid_field") assert "sort_by must be one of" in str(excinfo.value) @@ -186,8 +183,8 @@ def test_whether_valid_sort_orders_are_accepted(order): assert isinstance(params, SearchParameters) -def test_whether_an_invalid_sort_order_raises_a_queryerror(): - with pytest.raises(QueryError) as excinfo: +def test_whether_an_invalid_sort_order_raises_a_validation_error(): + with pytest.raises(ValidationError) as excinfo: SearchParameters(query_string="test", sort_order="invalid_order") assert "sort_order must be one of" in str(excinfo.value) @@ -249,10 +246,10 @@ def test_whether_an_invalid_filter_fields_value_fixes_it_silently( ( (["", None], ValidationError), ([123], ValidationError), - (["123"], QueryError), - (["!@$"], QueryError), - (["lower"], QueryError), - (["", "lower"], QueryError), + (["123"], ValidationError), + (["!@$"], ValidationError), + (["lower"], ValidationError), + (["", "lower"], ValidationError), ), ) def test_continuation_tokens__bad(tokens, error):