Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pdct 859 500 docs appear in a search #44

Merged
merged 5 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions src/cpr_sdk/models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
)


from cpr_sdk.exceptions import QueryError

# Value Lookup Tables
sort_orders = {
"asc": "+",
Expand Down Expand Up @@ -70,36 +68,87 @@ class SearchParameters(BaseModel):
"""Parameters for a search request"""

query_string: Optional[str] = ""
"""
A string representation of the search to be performed.
For example: 'Adaptation strategy'"
"""

exact_match: bool = False
"""
Indicate if the `query_string` should be treated as an exact match when
the search is performed.
"""

all_results: bool = False
"""
Return all results rather than searching or ranking

Filters can still be applied
"""

documents_only: bool = False
limit: int = Field(ge=0, default=100)
"""Ignores passages in search when true."""

limit: int = Field(ge=0, default=100, le=500)
"""
Refers to the maximum number of results to return from the "
query result.
"""

max_hits_per_family: int = Field(
validation_alias=AliasChoices("max_passages_per_doc", "max_hits_per_family"),
default=10,
ge=0,
le=500,
)
"""
The maximum number of matched passages to be returned for a "
single document.
"""

family_ids: Optional[Sequence[str]] = None
"""Optionally limit a search to a specific set of family ids."""

document_ids: Optional[Sequence[str]] = None
"""Optionally limit a search to a specific set of document ids."""

filters: Optional[Filters] = None
"""Filter results to matching filter items."""

year_range: Optional[tuple[Optional[int], Optional[int]]] = None
"""
The years to search between. Containing exactly two values,
which can be null or an integer representing the years to
search between. These are inclusive and can be null. Example:
[null, 2010] will return all documents return in or before 2010.
"""

sort_by: Optional[str] = Field(
validation_alias=AliasChoices("sort_field", "sort_by"), default=None
)
"""The field to sort by can be chosen from `date` or `title`."""

sort_order: str = "descending"
"""
The order of the results according to the `sort_field`, can be chosen from
ascending (use “asc”) or descending (use “desc”).
"""

continuation_tokens: Optional[Sequence[str]] = None
"""
Use to return the next page of results from a specific search, the next token
can be found on the response object. It's also possible to get the next page
of passages by including the family level continuation token first in the
array followed by the passage level one.
"""

@model_validator(mode="after")
def validate(self):
"""Validate against mutually exclusive fields"""
if self.exact_match and self.all_results:
raise QueryError("`exact_match` and `all_results` are mutually exclusive")
raise ValueError("`exact_match` and `all_results` are mutually exclusive")
if self.documents_only and not self.all_results:
raise QueryError(
raise ValueError(
"`documents_only` requires `all_results`, other queries are not supported"
)
return self
Expand All @@ -114,9 +163,9 @@ def continuation_tokens_must_be_upper_strings(cls, continuation_tokens):
if token == "":
continue
if not token.isalpha():
raise QueryError(f"Expected continuation tokens to be letters: {token}")
raise ValueError(f"Expected continuation tokens to be letters: {token}")
if not token.isupper():
raise QueryError(
raise ValueError(
f"Expected continuation tokens to be uppercase: {token}"
)
return continuation_tokens
Expand All @@ -142,7 +191,7 @@ def ids_must_fit_pattern(cls, ids):
if ids:
for _id in ids:
if not re.fullmatch(ID_PATTERN, _id):
raise QueryError(f"id seems invalid: {_id}")
raise ValueError(f"id seems invalid: {_id}")
return ids

@field_validator("year_range")
Expand All @@ -151,7 +200,7 @@ def year_range_must_be_valid(cls, year_range):
if year_range is not None:
if year_range[0] is not None and year_range[1] is not None:
if year_range[0] > year_range[1]:
raise QueryError(
raise ValueError(
"The first supplied year must be less than or equal to the "
f"second supplied year. Received: {year_range}"
)
Expand All @@ -162,7 +211,7 @@ def sort_by_must_be_valid(cls, sort_by):
"""Validate that the sort field is valid."""
if sort_by is not None:
if sort_by not in sort_fields:
raise QueryError(
raise ValueError(
f"Invalid sort field: {sort_by}. sort_by must be one of: "
f"{list(sort_fields.keys())}"
)
Expand All @@ -172,7 +221,7 @@ def sort_by_must_be_valid(cls, sort_by):
def sort_order_must_be_valid(cls, sort_order):
"""Validate that the sort order is valid."""
if sort_order not in sort_orders:
raise QueryError(
raise ValueError(
f"Invalid sort order: {sort_order}. sort_order must be one of: "
f"{sort_orders}"
)
Expand Down
1 change: 1 addition & 0 deletions src/cpr_sdk/search_adaptors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Adaptors for searching CPR data"""

import time
from abc import ABC
from pathlib import Path
Expand Down
2 changes: 1 addition & 1 deletion src/cpr_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_MAJOR = "1"
_MINOR = "1"
_PATCH = "5"
_PATCH = "6"
_SUFFIX = ""

VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
Expand Down
17 changes: 12 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@
import boto3
from moto import mock_aws

from cpr_sdk.search_adaptors import VespaSearchAdapter


VESPA_TEST_SEARCH_URL = "http://localhost:8080"


@pytest.fixture()
def fake_vespa_credentials():
with tempfile.TemporaryDirectory() as tmpdir:
with open(Path(tmpdir) / "cert.pem", "w"):
def test_vespa():
"""Vespa adapter pointing to test url and using empty cert files"""
with tempfile.TemporaryDirectory() as tmpdir_cert_dir:
with open(Path(tmpdir_cert_dir) / "cert.pem", "w"):
pass
with open(Path(tmpdir) / "key.pem", "w"):
with open(Path(tmpdir_cert_dir) / "key.pem", "w"):
pass
yield tmpdir
adaptor = VespaSearchAdapter(
instance_url=VESPA_TEST_SEARCH_URL, cert_directory=tmpdir_cert_dir
)

yield adaptor


@pytest.fixture()
Expand Down
Loading
Loading