Skip to content

Commit

Permalink
Vespa processing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Joel Wright committed Oct 31, 2023
1 parent 09770fb commit c618ee4
Show file tree
Hide file tree
Showing 5 changed files with 426 additions and 54 deletions.
2 changes: 2 additions & 0 deletions app/api/api_v1/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from app.core.config import VESPA_SECRETS_LOCATION, VESPA_URL
from app.core.lookups import get_countries_for_region, get_country_by_slug
from app.core.search import (
ENCODER,
FilterField,
OpenSearchConnection,
OpenSearchConfig,
Expand All @@ -46,6 +47,7 @@
_VESPA_CONNECTION = VespaSearchAdapter(
instance_url=VESPA_URL,
cert_directory=VESPA_SECRETS_LOCATION,
embedder=ENCODER,
)

search_router = APIRouter()
Expand Down
27 changes: 17 additions & 10 deletions app/core/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import string

from cpr_data_access.embedding import Embedder
from cpr_data_access.models.search import Document as DataAccessResponseDocument
from cpr_data_access.models.search import Family as DataAccessResponseFamily
from cpr_data_access.models.search import Passage as DataAccessResponsePassage
from cpr_data_access.models.search import SearchParameters as DataAccessSearchParams
from cpr_data_access.models.search import SearchResponse as DataAccessSearchResponse
from cpr_data_access.models.search import (
Document as DataAccessResponseDocument,
Family as DataAccessResponseFamily,
Passage as DataAccessResponsePassage,
SearchParameters as DataAccessSearchParams,
SearchResponse as DataAccessSearchResponse,
)
from opensearchpy import OpenSearch
from opensearchpy import JSONSerializer as jss
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -81,7 +83,7 @@

_LOGGER = logging.getLogger(__name__)

_ENCODER = Embedder(cache_folder=INDEX_ENCODER_CACHE_FOLDER)
ENCODER = Embedder(cache_folder=INDEX_ENCODER_CACHE_FOLDER)

# Map a sort field type to the document key used by OpenSearch
_SORT_FIELD_MAP: Mapping[SortField, str] = {
Expand Down Expand Up @@ -413,7 +415,7 @@ def with_semantic_query(self, query_string: str, knn: bool):

_LOGGER.info(f"Starting embeddings generation for '{query_string}'")
start_generation = time.time_ns()
embedding = _ENCODER.embed(
embedding = ENCODER.embed(
query_string,
normalize=False,
show_progress_bar=False,
Expand Down Expand Up @@ -1126,6 +1128,7 @@ def _process_vespa_search_response_families(
if db_family_tuple is None:
_LOGGER.error(f"Could not locate family with import id '{vespa_family.id}'")
continue
# TODO: filter UNPUBLISHED docs?
if db_family_tuple[0].family_status != FamilyStatus.PUBLISHED:
_LOGGER.debug(
f"Skipping unpublished family with id '{vespa_family.id}' "
Expand Down Expand Up @@ -1166,21 +1169,22 @@ def _process_vespa_search_response_families(
family_name=hit.family_name,
family_description=hit.family_description or "",
family_category=hit.family_category,
family_date=db_family.publication_date,
family_last_updated_date=db_family.last_updated_date,
family_date=db_family.published_date.isoformat(),
family_last_updated_date=db_family.last_updated_date.isoformat(),
family_source=hit.family_source,
family_description_match=False,
family_title_match=False,
family_documents=[],
family_geography=hit.family_geography,
family_metadata=cast(dict, db_family_metadata.value),
)
response_family_lookup[family_import_id] = response_family

if isinstance(hit, DataAccessResponseDocument):
response_family.family_description_match = True
response_family.family_title_match = True

if isinstance(hit, DataAccessResponsePassage):
elif isinstance(hit, DataAccessResponsePassage):
document_import_id = hit.document_import_id
if document_import_id is None:
_LOGGER.error("Skipping hit with empty document import id")
Expand Down Expand Up @@ -1219,6 +1223,9 @@ def _process_vespa_search_response_families(
)
)

else:
_LOGGER.error(f"Unknown hit type: {type(hit)}")

response_families.append(response_family)
response_family = None

Expand Down
32 changes: 16 additions & 16 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ alembic = "^1.7.6"
Authlib = "^0.15.5"
bcrypt = "^3.2.0"
boto3 = "^1.26"
cpr-data-access = {git = "https://github.com/climatepolicyradar/data-access.git", tag = "v0.2.2"}
cpr-data-access = {git = "https://github.com/climatepolicyradar/data-access.git", tag = "v0.2.3"}
fastapi = "^0.89.0"
fastapi-health = "^0.4.0"
fastapi-pagination = { extras = ["sqlalchemy"], version = "^0.9.1" }
Expand Down
Loading

0 comments on commit c618ee4

Please sign in to comment.