Skip to content

Commit

Permalink
Use ValidationError to raise validation errors
Browse files Browse the repository at this point in the history
These where not getting caught on backend api requests, leading to
undocumented 500 errors. Note that pydantic takes value errors and raises
`ValidationError`
  • Loading branch information
olaughter committed Jun 25, 2024
1 parent a438bb9 commit c4a2a3c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 38 deletions.
30 changes: 14 additions & 16 deletions src/cpr_sdk/models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
)


from cpr_sdk.exceptions import QueryError

# Value Lookup Tables
sort_orders = {
"asc": "+",
Expand Down Expand Up @@ -74,13 +72,13 @@ class SearchParameters(BaseModel):
A string representation of the search to be performed.
For example: 'Adaptation strategy'"
"""

exact_match: bool = False
"""
Indicate if the `query_string` should be treated as an exact match when
the search is performed.
"""

all_results: bool = False
"""
Return all results rather than searching or ranking
Expand All @@ -96,12 +94,12 @@ class SearchParameters(BaseModel):
Refers to the maximum number of results to return from the "
query result.
"""

max_hits_per_family: int = Field(
validation_alias=AliasChoices("max_passages_per_doc", "max_hits_per_family"),
default=10,
ge=0,
le=500
le=500,
)
"""
The maximum number of matched passages to be returned for a "
Expand All @@ -110,7 +108,7 @@ class SearchParameters(BaseModel):

family_ids: Optional[Sequence[str]] = None
"""Optionally limit a search to a specific set of family ids."""

document_ids: Optional[Sequence[str]] = None
"""Optionally limit a search to a specific set of document ids."""

Expand All @@ -124,7 +122,7 @@ class SearchParameters(BaseModel):
search between. These are inclusive and can be null. Example:
[null, 2010] will return all documents return in or before 2010.
"""

sort_by: Optional[str] = Field(
validation_alias=AliasChoices("sort_field", "sort_by"), default=None
)
Expand All @@ -148,9 +146,9 @@ class SearchParameters(BaseModel):
def validate(self):
"""Validate against mutually exclusive fields"""
if self.exact_match and self.all_results:
raise QueryError("`exact_match` and `all_results` are mutually exclusive")
raise ValueError("`exact_match` and `all_results` are mutually exclusive")
if self.documents_only and not self.all_results:
raise QueryError(
raise ValueError(
"`documents_only` requires `all_results`, other queries are not supported"
)
return self
Expand All @@ -165,9 +163,9 @@ def continuation_tokens_must_be_upper_strings(cls, continuation_tokens):
if token == "":
continue
if not token.isalpha():
raise QueryError(f"Expected continuation tokens to be letters: {token}")
raise ValueError(f"Expected continuation tokens to be letters: {token}")
if not token.isupper():
raise QueryError(
raise ValueError(
f"Expected continuation tokens to be uppercase: {token}"
)
return continuation_tokens
Expand All @@ -193,7 +191,7 @@ def ids_must_fit_pattern(cls, ids):
if ids:
for _id in ids:
if not re.fullmatch(ID_PATTERN, _id):
raise QueryError(f"id seems invalid: {_id}")
raise ValueError(f"id seems invalid: {_id}")
return ids

@field_validator("year_range")
Expand All @@ -202,7 +200,7 @@ def year_range_must_be_valid(cls, year_range):
if year_range is not None:
if year_range[0] is not None and year_range[1] is not None:
if year_range[0] > year_range[1]:
raise QueryError(
raise ValueError(
"The first supplied year must be less than or equal to the "
f"second supplied year. Received: {year_range}"
)
Expand All @@ -213,7 +211,7 @@ def sort_by_must_be_valid(cls, sort_by):
"""Validate that the sort field is valid."""
if sort_by is not None:
if sort_by not in sort_fields:
raise QueryError(
raise ValueError(
f"Invalid sort field: {sort_by}. sort_by must be one of: "
f"{list(sort_fields.keys())}"
)
Expand All @@ -223,7 +221,7 @@ def sort_by_must_be_valid(cls, sort_by):
def sort_order_must_be_valid(cls, sort_order):
"""Validate that the sort order is valid."""
if sort_order not in sort_orders:
raise QueryError(
raise ValueError(
f"Invalid sort order: {sort_order}. sort_order must be one of: "
f"{sort_orders}"
)
Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ def test_vespa():
adaptor = VespaSearchAdapter(
instance_url=VESPA_TEST_SEARCH_URL, cert_directory=tmpdir_cert_dir
)

yield adaptor

yield adaptor


@pytest.fixture()
Expand Down
20 changes: 19 additions & 1 deletion tests/test_search_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_vespa_search_adaptor__sensitive(test_vespa):
(1, 100),
(2, 1),
(2, 5),
(3, 1000),
(3, 500),
],
)
@pytest.mark.vespa
Expand All @@ -203,6 +203,24 @@ def test_vespa_search_adaptor__limits(test_vespa, family_limit, max_hits_per_fam
assert len(fam.hits) <= max_hits_per_family


@pytest.mark.parametrize(
"family_limit, max_hits_per_family",
[
(501, 5),
(3, 501),
],
)
@pytest.mark.vespa
def test_vespa_search_adaptor__limits__errors(family_limit, max_hits_per_family):
with pytest.raises(ValueError):
SearchParameters(
query_string="the",
family_ids=[],
limit=family_limit,
max_hits_per_family=max_hits_per_family,
)


@pytest.mark.vespa
def test_vespa_search_adaptor__continuation_tokens__families(test_vespa):
query_string = "the"
Expand Down
35 changes: 16 additions & 19 deletions tests/test_search_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import pytest
from cpr_sdk.embedding import Embedder
from cpr_sdk.exceptions import QueryError
from cpr_sdk.models.search import (
Filters,
SearchParameters,
Expand Down Expand Up @@ -57,9 +56,8 @@ def test_whether_an_empty_query_string_does_all_result_search():

def test_wether_documents_only_without_all_results_raises_error():
q = "Search"
with pytest.raises(QueryError) as excinfo:
with pytest.raises(ValidationError) as excinfo:
SearchParameters(query_string=q, documents_only=True)
assert "Failed to build query" in str(excinfo.value)
assert "`documents_only` requires `all_results`" in str(excinfo.value)

# They should be fine otherwise:
Expand All @@ -71,9 +69,8 @@ def test_wether_documents_only_without_all_results_raises_error():

def test_wether_combining_all_results_and_exact_match_raises_error():
q = "Search"
with pytest.raises(QueryError) as excinfo:
with pytest.raises(ValidationError) as excinfo:
SearchParameters(query_string=q, exact_match=True, all_results=True)
assert "Failed to build query" in str(excinfo.value)
assert "`exact_match` and `all_results`" in str(excinfo.value)

# They should be fine independently:
Expand Down Expand Up @@ -102,8 +99,8 @@ def test_whether_valid_year_ranges_are_accepted(year_range):
assert isinstance(params, SearchParameters)


def test_whether_an_invalid_year_range_ranges_raises_a_queryerror():
with pytest.raises(QueryError) as excinfo:
def test_whether_an_invalid_year_range_ranges_raises_a_validation_error():
with pytest.raises(ValidationError) as excinfo:
SearchParameters(query_string="test", year_range=(2023, 2000))
assert (
"The first supplied year must be less than or equal to the second supplied year"
Expand All @@ -129,8 +126,8 @@ def test_whether_valid_family_ids_are_accepted():
"UNFCCC.family.i00000003.n000.11",
],
)
def test_whether_an_invalid_family_id_raises_a_queryerror(bad_id):
with pytest.raises(QueryError) as excinfo:
def test_whether_an_invalid_family_id_raises_a_validation_error(bad_id):
with pytest.raises(ValidationError) as excinfo:
SearchParameters(
query_string="test",
family_ids=("CCLW.family.i00000003.n0000", bad_id),
Expand All @@ -157,8 +154,8 @@ def test_whether_valid_document_ids_are_accepted():
"UNFCCC.doc.i00000003",
],
)
def test_whether_an_invalid_document_id_raises_a_queryerror(bad_id):
with pytest.raises(QueryError) as excinfo:
def test_whether_an_invalid_document_id_raises_a_validation_error(bad_id):
with pytest.raises(ValidationError) as excinfo:
SearchParameters(
query_string="test",
document_ids=(bad_id, "CCLW.document.i00000004.n0000"),
Expand All @@ -174,8 +171,8 @@ def test_whether_valid_sort_fields_are_accepted(field):
assert isinstance(params, SearchParameters)


def test_whether_an_invalid_sort_field_raises_a_queryerror():
with pytest.raises(QueryError) as excinfo:
def test_whether_an_invalid_sort_field_raises_a_validation_error():
with pytest.raises(ValidationError) as excinfo:
SearchParameters(query_string="test", sort_by="invalid_field")
assert "sort_by must be one of" in str(excinfo.value)

Expand All @@ -186,8 +183,8 @@ def test_whether_valid_sort_orders_are_accepted(order):
assert isinstance(params, SearchParameters)


def test_whether_an_invalid_sort_order_raises_a_queryerror():
with pytest.raises(QueryError) as excinfo:
def test_whether_an_invalid_sort_order_raises_a_validation_error():
with pytest.raises(ValidationError) as excinfo:
SearchParameters(query_string="test", sort_order="invalid_order")
assert "sort_order must be one of" in str(excinfo.value)

Expand Down Expand Up @@ -249,10 +246,10 @@ def test_whether_an_invalid_filter_fields_value_fixes_it_silently(
(
(["", None], ValidationError),
([123], ValidationError),
(["123"], QueryError),
(["!@$"], QueryError),
(["lower"], QueryError),
(["", "lower"], QueryError),
(["123"], ValidationError),
(["!@$"], ValidationError),
(["lower"], ValidationError),
(["", "lower"], ValidationError),
),
)
def test_continuation_tokens__bad(tokens, error):
Expand Down

0 comments on commit c4a2a3c

Please sign in to comment.