Skip to content

Commit

Permalink
Enable filtering on metadata and corpus (#88)
Browse files Browse the repository at this point in the history
* Bugfix for the torch version.

* Adding a failing test.

* Updating the version.

* testing different query.

* Adding corpus_type_name to the response.

* Add in the filtering option for corpus_type_name

* Removing as filter.

* Reverting change to yql builder.

* Successfully filtering on corpus type name.

* test: Modify fixtures to give wider scenarios cover

* Resolving merge conflict.

* Updating tests to query for multiple corpora.

* Remove typo.

* Adding a test for metadata querying.

* Updating the test data to prefix metadata name with family.

* Reformatting.

* Bugfix for appending empty metadata.

* Bugfix for the corpus_type_name test.

* Refactoring.

* Reverting poetry.lock change.

* Updating SearchParameters to accept a pydantic object for metadata filters to validate the structure.

* Adding a test for the MetadataFilter object.

* Adding a filter for corpus_import_id.

* Integrate geographies filter  (#95)

* Adding failing tests.

* Trunk fix.

* Adding in tests for the geographies filters.

* Refactoring the family_geographies test.

* Refactoring.

* converting geograpghies from and to or operator and adding a test.

* Updating test to contain two geographies.

---------

Co-authored-by: Mark <[email protected]>

---------

Co-authored-by: Mark <[email protected]>
Co-authored-by: Jesse Claven <[email protected]>
  • Loading branch information
3 people authored Sep 11, 2024
1 parent a4768cb commit 5379ff5
Show file tree
Hide file tree
Showing 9 changed files with 1,167 additions and 2,505 deletions.
50 changes: 43 additions & 7 deletions src/cpr_sdk/models/search.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from datetime import datetime
import re
from datetime import datetime
from typing import List, Optional, Sequence

from pydantic import (
AliasChoices,
BaseModel,
computed_field,
ConfigDict,
Field,
computed_field,
field_validator,
model_validator,
)


# Value Lookup Tables
sort_orders = {
"asc": "+",
Expand All @@ -38,10 +37,18 @@
ID_PATTERN = re.compile(rf"{_ID_ELEMENT}\.{_ID_ELEMENT}\.{_ID_ELEMENT}\.{_ID_ELEMENT}")


class MetadataFilter(BaseModel):
"""A filter for metadata fields"""

name: str
value: str


class Filters(BaseModel):
"""Filterable fields in a search request"""

family_geography: Sequence[str] = []
family_geographies: Sequence[str] = []
family_category: Sequence[str] = []
document_languages: Sequence[str] = []
family_source: Sequence[str] = []
Expand All @@ -51,7 +58,11 @@ class Filters(BaseModel):
}

@field_validator(
"family_geography", "family_category", "document_languages", "family_source"
"family_geographies",
"family_geography",
"family_category",
"document_languages",
"family_source",
)
def sanitise_filter_inputs(cls, field):
"""Remove problematic characters from filter values"""
Expand All @@ -75,14 +86,14 @@ class SearchParameters(BaseModel):

exact_match: bool = False
"""
Indicate if the `query_string` should be treated as an exact match when
Indicate if the `query_string` should be treated as an exact match when
the search is performed.
"""

all_results: bool = False
"""
Return all results rather than searching or ranking
Filters can still be applied
"""

Expand Down Expand Up @@ -138,10 +149,27 @@ class SearchParameters(BaseModel):
"""
Use to return the next page of results from a specific search, the next token
can be found on the response object. It's also possible to get the next page
of passages by including the family level continuation token first in the
of passages by including the family level continuation token first in the
array followed by the passage level one.
"""

corpus_type_names: Optional[Sequence[str]] = None
"""
The name of the corpus that a document belongs to.
"""

corpus_import_ids: Optional[Sequence[str]] = None
"""
The import id of the corpus that a document belongs to.
"""

metadata: Optional[Sequence[MetadataFilter]] = None
"""
A field and item mapping to search in the metadata field of the documents.
E.g. [{"name": "family.sector", "value": "Price"}]
"""

@model_validator(mode="after")
def validate(self):
"""Validate against mutually exclusive fields"""
Expand Down Expand Up @@ -252,12 +280,16 @@ class Hit(BaseModel):
family_category: Optional[str] = None
family_publication_ts: Optional[datetime] = None
family_geography: Optional[str] = None
family_geographies: Optional[List[str]] = None
document_import_id: Optional[str] = None
document_slug: Optional[str] = None
document_languages: Optional[List[str]] = None
document_content_type: Optional[str] = None
document_cdn_object: Optional[str] = None
document_source_url: Optional[str] = None
corpus_type_name: Optional[str] = None
corpus_import_id: Optional[str] = None
metadata: Optional[Sequence[dict[str, str]]] = None

@classmethod
def from_vespa_response(cls, response_hit: dict) -> "Hit":
Expand Down Expand Up @@ -311,12 +343,16 @@ def from_vespa_response(cls, response_hit: dict) -> "Document":
family_category=fields.get("family_category"),
family_publication_ts=family_publication_ts,
family_geography=fields.get("family_geography"),
family_geographies=fields.get("family_geographies", []),
document_import_id=fields.get("document_import_id"),
document_slug=fields.get("document_slug"),
document_languages=fields.get("document_languages", []),
document_content_type=fields.get("document_content_type"),
document_cdn_object=fields.get("document_cdn_object"),
document_source_url=fields.get("document_source_url"),
corpus_type_name=fields.get("corpus_type_name"),
corpus_import_id=fields.get("corpus_import_id"),
metadata=fields.get("metadata"),
)


Expand Down
4 changes: 2 additions & 2 deletions src/cpr_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_MAJOR = "1"
_MINOR = "3"
_PATCH = "13"
_MINOR = "4"
_PATCH = "0"
_SUFFIX = ""

VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
Expand Down
11 changes: 2 additions & 9 deletions src/cpr_sdk/vespa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,15 @@
from typing import Any, List

import yaml
from vespa.io import VespaResponse
from vespa.exceptions import VespaError
from vespa.io import VespaResponse

from cpr_sdk.models.search import (
Family,
Hit,
SearchParameters,
SearchResponse,
)
from cpr_sdk.embedding import Embedder
from cpr_sdk.exceptions import FetchError
from cpr_sdk.models.search import Family, Hit, SearchParameters, SearchResponse
from cpr_sdk.utils import dig, is_sensitive_query, load_sensitive_query_terms
from cpr_sdk.yql_builder import YQLBuilder


SENSITIVE_QUERY_TERMS = load_sensitive_query_terms()


Expand Down Expand Up @@ -115,7 +109,6 @@ def parse_vespa_response(vespa_response: VespaResponse) -> SearchResponse:
root = vespa_response.json["root"]

response_families = dig(root, "children", 0, "children", 0, "children", default=[])

for family in response_families:
total_passage_hits = dig(family, "fields", "count()")
family_hits: List[Hit] = []
Expand Down
41 changes: 38 additions & 3 deletions src/cpr_sdk/yql_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class YQLBuilder:
"""
select * from sources $SOURCES
where $WHERE_CLAUSE
limit 0
limit 0
|
$CONTINUATION
all(
Expand Down Expand Up @@ -84,6 +84,38 @@ def build_search_term(self) -> str:
)
"""

def build_metadata_filter(self) -> Optional[str]:
"""Create the part of the query that limits to specific metadata"""
metadata_filters = []
if self.params.metadata:
[
metadata_filters.append(
f"""
(
metadata contains sameElement(
name contains '{metadata.name}',
value contains '{metadata.value}'
)
)
"""
)
for metadata in self.params.metadata
]
return f"({' and '.join(metadata_filters)})"
return None

def build_corpus_type_name_filter(self) -> Optional[str]:
"""Create the part of the query that limits to specific corpora"""
if self.params.corpus_type_names:
corpora = ", ".join([f"'{c}'" for c in self.params.corpus_type_names])
return f"(corpus_type_name in({corpora}))"

def build_corpus_import_ids_filter(self) -> Optional[str]:
"""Create the part of the query that limits to specific corpora import id"""
if self.params.corpus_import_ids:
corpora = ", ".join([f"'{c}'" for c in self.params.corpus_import_ids])
return f"(corpus_import_id in({corpora}))"

def build_family_filter(self) -> Optional[str]:
"""Create the part of the query that limits to specific families"""
if self.params.family_ids:
Expand All @@ -98,7 +130,7 @@ def build_document_filter(self) -> Optional[str]:
return f"(document_import_id in({documents}))"
return None

def _inclusive_filters(self, filters: Filters, field_name: str):
def _inclusive_filters(self, filters: Filters, field_name: str) -> Optional[str]:
values = getattr(filters, field_name)
query_filters = []
for value in values:
Expand Down Expand Up @@ -128,14 +160,17 @@ def build_where_clause(self) -> str:
filters.append(self.build_search_term())
filters.append(self.build_family_filter())
filters.append(self.build_document_filter())
filters.append(self.build_corpus_type_name_filter())
filters.append(self.build_corpus_import_ids_filter())
filters.append(self.build_metadata_filter())
if f := self.params.filters:
filters.append(self._inclusive_filters(f, "family_geographies"))
filters.append(self._inclusive_filters(f, "family_geography"))
filters.append(self._inclusive_filters(f, "family_category"))
filters.append(self._inclusive_filters(f, "document_languages"))
filters.append(self._inclusive_filters(f, "family_source"))
filters.append(self.build_year_start_filter())
filters.append(self.build_year_end_filter())

return " and ".join([f for f in filters if f]) # Remove empty

def build_continuation(self) -> str:
Expand Down
Loading

0 comments on commit 5379ff5

Please sign in to comment.