diff --git a/app/api/api_v1/routers/search.py b/app/api/api_v1/routers/search.py index 8dc06069..213a6da3 100644 --- a/app/api/api_v1/routers/search.py +++ b/app/api/api_v1/routers/search.py @@ -8,17 +8,9 @@ import json import logging from io import BytesIO -from typing import Mapping, Optional, Sequence, cast +from typing import Mapping, Sequence from cpr_data_access.search_adaptors import VespaSearchAdapter -from cpr_data_access.models.search import Document as DataAccessResponseDocument -from cpr_data_access.models.search import Family as DataAccessResponseFamily -from cpr_data_access.models.search import FilterField as DataAccessFilterField -from cpr_data_access.models.search import Passage as DataAccessResponsePassage -from cpr_data_access.models.search import SearchRequestBody as DataAccessSearchRequest -from cpr_data_access.models.search import SearchResponse as DataAccessSearchResponse -from cpr_data_access.models.search import SortField as DataAccessSortField -from cpr_data_access.models.search import SortOrder as DataAccessSortOrder from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session @@ -27,25 +19,18 @@ SearchRequestBody, SearchResponse, SortField, - SortOrder, ) from app.core.browse import BrowseArgs, browse_rds_families -from app.core.config import ( - VESPA_SECRETS_LOCATION, - VESPA_URL, - VESPA_SEARCH_LIMIT, - VESPA_SEARCH_MATCHES_PER_DOC, -) +from app.core.config import VESPA_SECRETS_LOCATION, VESPA_URL from app.core.lookups import get_countries_for_region, get_country_by_slug from app.core.search import ( FilterField, OpenSearchConnection, OpenSearchConfig, OpenSearchQueryConfig, - SearchResponseFamily, - SearchResponseFamilyDocument, - SearchResponseDocumentPassage, + create_vespa_search_params, process_result_into_csv, + process_vespa_search_response, ) from app.db.crud.document import DocumentExtraCache from app.db.session import get_db @@ -66,179 +51,6 @@ search_router = APIRouter() -def _convert_sort_field( - sort_field: Optional[SortField], -) -> Optional[DataAccessSortField]: - if sort_field is None: - return None - - if sort_field == SortField.DATE: - return "date" - if sort_field == SortField.TITLE: - return "name" - - -def _convert_sort_order(sort_order: SortOrder) -> DataAccessSortOrder: - if sort_order == SortOrder.ASCENDING: - return "ascending" - if sort_order == SortOrder.DESCENDING: - return "descending" - - -def _convert_filter_field(filter_field: FilterField) -> DataAccessFilterField: - if filter_field == FilterField.CATEGORY: - return "category" - if filter_field == FilterField.COUNTRY: - return "geography" - if filter_field == FilterField.REGION: - return "geography" - if filter_field == FilterField.LANGUAGE: - return "language" - if filter_field == FilterField.SOURCE: - return "source" - - -def _convert_filters( - db: Session, - keyword_filters: Optional[Mapping[FilterField, Sequence[str]]], -) -> Optional[Mapping[DataAccessFilterField, Sequence[str]]]: - if keyword_filters is None: - return None - - new_keyword_filters = {} - for field, values in keyword_filters.items(): - if field == FilterField.REGION: - new_values = [] - for region in values: - new_values.extend(get_countries_for_region(db, region)) - else: - new_values = values - new_keyword_filters[_convert_filter_field(field)] = new_values - return new_keyword_filters - - -from app.db.models.law_policy import Family, FamilyMetadata, FamilyDocument, FamilyStatus -from app.core.util import to_cdn_url - - -def _process_vespa_search_response_families( - db: Session, vespa_families: Sequence[DataAccessResponseFamily] -) -> Sequence[SearchResponseFamily]: - all_family_ids = [vf.id for vf in vespa_families] - - family_and_family_metadata: Sequence[tuple[Family, FamilyMetadata]] = ( - db.query(Family, FamilyMetadata) - .filter(Family.import_id.in_(all_family_ids)) - .join(FamilyMetadata, FamilyMetadata.family_import_id == Family.import_id) - .all() - ) # type: ignore - db_family_lookup: Mapping[str, tuple[Family, FamilyMetadata]] = { - str(family.import_id): (family, family_metadata) - for (family, family_metadata) in family_and_family_metadata - } - db_family_document_lookup: Mapping[str, FamilyDocument] = { - str(fd.import_id): fd for (fam, _) in family_and_family_metadata - for fd in fam.family_documents - } - - response_families = [] - response_family = None - - for vespa_family in vespa_families: - db_family_tuple = db_family_lookup.get(vespa_family.id) - if db_family_tuple is None: - _LOGGER.error(f"Could not locate family with import id '{vespa_family.id}'") - continue - if db_family_tuple[0].family_status != FamilyStatus.PUBLISHED: - _LOGGER.debug( - f"Skipping unpublished family with id '{vespa_family.id}' " - "in search results" - ) - db_family = db_family_tuple[0] - db_family_metadata = db_family_tuple[1] - - response_family_lookup = {} - response_document_lookup = {} - - for hit in vespa_family.hits: - family_import_id = hit.family_import_id - if family_import_id is None: - _LOGGER.error("Skipping hit with empty family import id") - continue - - response_family = response_family_lookup.get(family_import_id) - # All hits contain required family info to create response - if response_family is None: - response_family = SearchResponseFamily( - family_slug=hit.family_slug, - family_name=hit.family_name, - family_description=hit.family_description, - family_category=hit.family_category, - family_date=db_family.publication_date, - family_last_updated_date=db_family.last_updated_date, - family_source=hit.family_source, - family_description_match=False, - family_title_match=False, - family_documents=[], - family_geography=hit.family_geography, - family_metadata=cast(dict, db_family_metadata.value), - ) - - if isinstance(hit, DataAccessResponseDocument): - response_family.family_description_match = True - response_family.family_title_match = True - - if isinstance(hit, DataAccessResponsePassage): - document_import_id = hit.document_import_id - if document_import_id is None: - _LOGGER.error("Skipping hit with empty document import id") - continue - - response_document = response_document_lookup.get(document_import_id) - if response_document is None: - db_family_document = db_family_document_lookup.get(document_import_id) - if db_family_document is None: - _LOGGER.error(f"Skipping unknown family document with id '{document_import_id}'") - continue - response_document = SearchResponseFamilyDocument( - document_title=db_family_document.physical_document.title, - document_slug=hit.document_slug, - document_type=db_family_document.document_type, - document_source_url=hit.document_source_url, - document_url=to_cdn_url(hit.document_cdn_object), - document_content_type=hit.document_content_type, - document_passage_matches=[], - ) - response_document_lookup[document_import_id] = response_document - response_family.family_documents.append(response_document) - - response_document.document_passage_matches.append( - SearchResponseDocumentPassage( - text=hit.text_block, - text_block_id=hit.text_block_id, - text_block_page=hit.text_block_page, - text_block_coords=hit.text_block_coords, - ) - ) - - response_families.append(response_family) - response_family = None - - return response_families - - -def _process_vespa_search_response( - db: Session, vespa_search_response: DataAccessSearchResponse -) -> SearchResponse: - # TODO: implement conversion - return SearchResponse( - hits=vespa_search_response.total_hits, - query_time_ms=vespa_search_response.query_time_ms or 0, - total_time_ms=vespa_search_response.total_time_ms or 0, - families=[], - ) - - def _search_request( db: Session, search_body: SearchRequestBody, use_vespa: bool = False ) -> SearchResponse: @@ -256,21 +68,15 @@ def _search_request( ) else: if use_vespa: - data_access_search_body = DataAccessSearchRequest( - query_string=search_body.query_string, - exact_match=search_body.exact_match, - limit=VESPA_SEARCH_LIMIT, - max_hits_per_family=VESPA_SEARCH_MATCHES_PER_DOC, - keyword_filters=_convert_filters(db, search_body.keyword_filters), - year_range=search_body.year_range, - sort_by=_convert_sort_field(search_body.sort_field), - sort_order=_convert_sort_order(search_body.sort_order), - continuation_token=None, # TODO: implement pagination? - ) + data_access_search_params = create_vespa_search_params(db, search_body) data_access_search_response = _VESPA_CONNECTION.search( - request=data_access_search_body + parameters=data_access_search_params + ) + return process_vespa_search_response( + db, + data_access_search_response, + search_body, ) - return _process_vespa_search_response(db, data_access_search_response) else: return _OPENSEARCH_CONNECTION.query_families( search_request_body=search_body, diff --git a/app/api/api_v1/schemas/search.py b/app/api/api_v1/schemas/search.py index 79c30df4..d6dc83ec 100644 --- a/app/api/api_v1/schemas/search.py +++ b/app/api/api_v1/schemas/search.py @@ -71,6 +71,8 @@ class SearchRequestBody(BaseModel): limit: int = 10 # TODO: decide on default offset: int = 0 + continuation_token: Optional[str] = None + class SearchResponseDocumentPassage(BaseModel): """A Document passage match returned by the search API endpoint.""" diff --git a/app/core/search.py b/app/core/search.py index 54ebc8a7..a0aab93f 100644 --- a/app/core/search.py +++ b/app/core/search.py @@ -12,6 +12,11 @@ import string from cpr_data_access.embedding import Embedder +from cpr_data_access.models.search import Document as DataAccessResponseDocument +from cpr_data_access.models.search import Family as DataAccessResponseFamily +from cpr_data_access.models.search import Passage as DataAccessResponsePassage +from cpr_data_access.models.search import SearchParameters as DataAccessSearchParams +from cpr_data_access.models.search import SearchResponse as DataAccessSearchResponse from opensearchpy import OpenSearch from opensearchpy import JSONSerializer as jss from sqlalchemy.orm import Session @@ -55,8 +60,11 @@ OPENSEARCH_SSL_WARNINGS, OPENSEARCH_JIT_MAX_DOC_COUNT, PUBLIC_APP_URL, + VESPA_SEARCH_LIMIT, + VESPA_SEARCH_MATCHES_PER_DOC, ) from app.core.util import to_cdn_url +from app.core.lookups import get_countries_for_region from app.db.models.app.users import Organisation from app.db.models.law_policy import ( Family, @@ -1009,3 +1017,225 @@ def process_result_into_csv( csv_result_io.seek(0) return csv_result_io.read() + + +# Vespa search processing functions + +def _convert_sort_field( + sort_field: Optional[SortField], +) -> Optional[str]: + if sort_field is None: + return None + + if sort_field == SortField.DATE: + return "date" + if sort_field == SortField.TITLE: + return "name" + + +def _convert_sort_order(sort_order: SortOrder) -> str: + if sort_order == SortOrder.ASCENDING: + return "ascending" + if sort_order == SortOrder.DESCENDING: + return "descending" + + +def _convert_filter_field(filter_field: FilterField) -> str: + if filter_field == FilterField.CATEGORY: + return "category" + if filter_field == FilterField.COUNTRY: + return "geography" + if filter_field == FilterField.REGION: + return "geography" + if filter_field == FilterField.LANGUAGE: + return "language" + if filter_field == FilterField.SOURCE: + return "source" + + +def _convert_filters( + db: Session, + keyword_filters: Optional[Mapping[FilterField, Sequence[str]]], +) -> Optional[Mapping[str, Sequence[str]]]: + if keyword_filters is None: + return None + + new_keyword_filters = {} + for field, values in keyword_filters.items(): + if field == FilterField.REGION: + new_values = [] + for region in values: + new_values.extend(get_countries_for_region(db, region)) + else: + new_values = values + new_keyword_filters[_convert_filter_field(field)] = new_values + return new_keyword_filters + + +from app.db.models.law_policy import ( + Family, + FamilyMetadata, + FamilyDocument, + FamilyStatus, +) +from app.core.util import to_cdn_url + + +def _process_vespa_search_response_families( + db: Session, vespa_families: Sequence[DataAccessResponseFamily] +) -> Sequence[SearchResponseFamily]: + all_family_ids = [vf.id for vf in vespa_families] + + family_and_family_metadata: Sequence[tuple[Family, FamilyMetadata]] = ( + db.query(Family, FamilyMetadata) + .filter(Family.import_id.in_(all_family_ids)) + .join(FamilyMetadata, FamilyMetadata.family_import_id == Family.import_id) + .all() + ) # type: ignore + db_family_lookup: Mapping[str, tuple[Family, FamilyMetadata]] = { + str(family.import_id): (family, family_metadata) + for (family, family_metadata) in family_and_family_metadata + } + db_family_document_lookup: Mapping[str, FamilyDocument] = { + str(fd.import_id): fd + for (fam, _) in family_and_family_metadata + for fd in fam.family_documents + } + + response_families = [] + response_family = None + + for vespa_family in vespa_families: + db_family_tuple = db_family_lookup.get(vespa_family.id) + if db_family_tuple is None: + _LOGGER.error(f"Could not locate family with import id '{vespa_family.id}'") + continue + if db_family_tuple[0].family_status != FamilyStatus.PUBLISHED: + _LOGGER.debug( + f"Skipping unpublished family with id '{vespa_family.id}' " + "in search results" + ) + db_family = db_family_tuple[0] + db_family_metadata = db_family_tuple[1] + + response_family_lookup = {} + response_document_lookup = {} + + for hit in vespa_family.hits: + family_import_id = hit.family_import_id + if family_import_id is None: + _LOGGER.error("Skipping hit with empty family import id") + continue + + # Check for all required family/document fields in the hit + if ( + hit.family_slug is None + or hit.document_slug is None + or hit.family_name is None + or hit.family_category is None + or hit.family_source is None + or hit.family_geography is None + ): + _LOGGER.error( + "Skipping hit with empty required family info for import " + f"id: {family_import_id}" + ) + continue + + response_family = response_family_lookup.get(family_import_id) + # All hits contain required family info to create response + if response_family is None: + response_family = SearchResponseFamily( + family_slug=hit.family_slug, + family_name=hit.family_name, + family_description=hit.family_description or "", + family_category=hit.family_category, + family_date=db_family.publication_date, + family_last_updated_date=db_family.last_updated_date, + family_source=hit.family_source, + family_description_match=False, + family_title_match=False, + family_documents=[], + family_geography=hit.family_geography, + family_metadata=cast(dict, db_family_metadata.value), + ) + + if isinstance(hit, DataAccessResponseDocument): + response_family.family_description_match = True + response_family.family_title_match = True + + if isinstance(hit, DataAccessResponsePassage): + document_import_id = hit.document_import_id + if document_import_id is None: + _LOGGER.error("Skipping hit with empty document import id") + continue + + response_document = response_document_lookup.get(document_import_id) + if response_document is None: + db_family_document = db_family_document_lookup.get( + document_import_id + ) + if db_family_document is None: + _LOGGER.error( + "Skipping unknown family document with id " + f"'{document_import_id}'" + ) + continue + + response_document = SearchResponseFamilyDocument( + document_title=str(db_family_document.physical_document.title), + document_slug=hit.document_slug, + document_type=str(db_family_document.document_type), + document_source_url=hit.document_source_url, + document_url=to_cdn_url(hit.document_cdn_object), + document_content_type=hit.document_content_type, + document_passage_matches=[], + ) + response_document_lookup[document_import_id] = response_document + response_family.family_documents.append(response_document) + + response_document.document_passage_matches.append( + SearchResponseDocumentPassage( + text=hit.text_block, + text_block_id=hit.text_block_id, + text_block_page=hit.text_block_page, + text_block_coords=hit.text_block_coords, + ) + ) + + response_families.append(response_family) + response_family = None + + return response_families + + +def process_vespa_search_response( + db: Session, vespa_search_response: DataAccessSearchResponse +) -> SearchResponse: + # TODO: implement conversion + return SearchResponse( + hits=vespa_search_response.total_hits, + query_time_ms=vespa_search_response.query_time_ms or 0, + total_time_ms=vespa_search_response.total_time_ms or 0, + families=_process_vespa_search_response_families( + db, + vespa_search_response.families, + ), + ) + + +def create_vespa_search_params( + db: Session, + search_body: SearchRequestBody + ): + return DataAccessSearchParams( + query_string=search_body.query_string, + exact_match=search_body.exact_match, + limit=VESPA_SEARCH_LIMIT, + max_hits_per_family=VESPA_SEARCH_MATCHES_PER_DOC, + keyword_filters=_convert_filters(db, search_body.keyword_filters), + year_range=search_body.year_range, + sort_by=_convert_sort_field(search_body.sort_field), + sort_order=_convert_sort_order(search_body.sort_order), + continuation_token=None, # TODO: implement pagination? + ) \ No newline at end of file