diff --git a/app/api/api_v1/routers/search.py b/app/api/api_v1/routers/search.py index 4cdf2dc3..1d546dc7 100644 --- a/app/api/api_v1/routers/search.py +++ b/app/api/api_v1/routers/search.py @@ -76,7 +76,7 @@ def _search_request( data_access_search_response, limit=search_body.limit, offset=search_body.offset, - ) + ).increment_pages() else: return _OPENSEARCH_CONNECTION.query_families( search_request_body=search_body, @@ -85,7 +85,7 @@ def _search_request( db ), preference="default_search_preference", - ) + ).increment_pages() @search_router.post("/searches") diff --git a/app/api/api_v1/schemas/search.py b/app/api/api_v1/schemas/search.py index 3ed8181d..2dfd440a 100644 --- a/app/api/api_v1/schemas/search.py +++ b/app/api/api_v1/schemas/search.py @@ -172,23 +172,25 @@ class SearchResponse(BaseModel): families: Sequence[SearchResponseFamily] - @validator("families", always=True) - @classmethod - def increment_pages(cls, value): + def increment_pages(self): """PDF page numbers must be incremented from our 0-indexed values.""" - for family in value: - for family_document in family.family_documents: - for index, passage_match in enumerate( + for family_index, family in enumerate(self.families): + for family_document_index, family_document in enumerate( + family.family_documents + ): + for passage_match_index, passage_match in enumerate( family_document.document_passage_matches ): if ( passage_match.text_block_page or passage_match.text_block_page == 0 ): - family_document.document_passage_matches[ - index - ].text_block_page += 1 - return value + self.families[family_index].family_documents[ + family_document_index + ].document_passage_matches[ + passage_match_index + ].text_block_page += 1 # type: ignore + return self Top5FamilyList = conlist(SearchResponseFamily, max_items=5) diff --git a/tests/unit/app/schemas/test_schemas.py b/tests/unit/app/schemas/test_schemas.py index 6b36ab63..412cc9a4 100644 --- a/tests/unit/app/schemas/test_schemas.py +++ b/tests/unit/app/schemas/test_schemas.py @@ -1,7 +1,13 @@ import pytest +from datetime import datetime from app.api.api_v1.schemas.document import FamilyDocumentResponse -from app.api.api_v1.schemas.search import SearchResponseFamilyDocument +from app.api.api_v1.schemas.search import ( + SearchResponseDocumentPassage, + SearchResponseFamilyDocument, + SearchResponseFamily, + SearchResponse, +) CLIMATE_LAWS_DOMAIN_PATHS = [ "climate-laws.org", @@ -98,3 +104,94 @@ def test_non_climate_laws_source_url_left_in_document(source_domain_path, scheme document_role=None, ) assert document_response.source_url == given_url + + +def test_search_response() -> None: + """ + Test that instantiating Search Response objects is done correctly. + + Particularly testing of the validators. + """ + search_response = SearchResponse( + hits=1, + query_time_ms=1, + total_time_ms=1, + families=[ + SearchResponseFamily( + family_slug="example_slug", + family_name="Example Family", + family_description="This is an example family", + family_category="Example Category", + family_date=str( + datetime.now() + ), # You can replace this with an actual date string + family_last_updated_date=str( + datetime.now() + ), # You can replace this with an actual date string + family_source="Example Source", + family_geography="Example Geography", + family_metadata={"key1": "value1", "key2": "value2"}, + family_title_match=True, + family_description_match=False, + family_documents=[ + SearchResponseFamilyDocument( + document_title="Document Title", + document_slug="Document Slug", + document_type="Executive", + document_source_url="https://cdn.example.com/file.pdf", + document_url=None, + document_content_type="application/pdf", + document_passage_matches=[ + SearchResponseDocumentPassage( + text="Example", + text_block_id="p_0_b_0", + text_block_page=0, + text_block_coords=None, + ), + SearchResponseDocumentPassage( + text="Example", + text_block_id="p_1_b_0", + text_block_page=1, + text_block_coords=None, + ), + SearchResponseDocumentPassage( + text="Example", + text_block_id="p_1_b_2", + text_block_page=1, + text_block_coords=None, + ), + ], + ) + ], + ) + ], + ) + + first_document_initial_pages = [ + page.text_block_page + for page in search_response.families[0] + .family_documents[0] + .document_passage_matches + ] + + search_response_incremented = search_response.increment_pages() + + first_document_incremented_pages = [ + page.text_block_page + for page in search_response_incremented.families[0] + .family_documents[0] + .document_passage_matches + ] + + assert len(first_document_initial_pages) == len(first_document_incremented_pages) + + assert first_document_initial_pages != first_document_incremented_pages + + expected_pages = [] + for page in first_document_initial_pages: + if page is None: + expected_pages.append(page) + else: + expected_pages.append(page + 1) + + assert expected_pages == first_document_incremented_pages