From bbf1f909578e32b70e900911d34b0db18f30e208 Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 8 Apr 2024 16:02:33 +0100 Subject: [PATCH 01/15] Updating dal to cpr_sdk. --- poetry.lock | 89 +++++++++++++++++++++----------------------------- pyproject.toml | 2 +- 2 files changed, 38 insertions(+), 53 deletions(-) diff --git a/poetry.lock b/poetry.lock index d19ba516..8b530bd6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -347,17 +347,17 @@ uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "boto3" -version = "1.34.76" +version = "1.34.79" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.76-py3-none-any.whl", hash = "sha256:530a4cea3d40a6bd2f15a368ea395beef1ea6dff4491823bc48bd20c7d4da655"}, - {file = "boto3-1.34.76.tar.gz", hash = "sha256:8c598382e8fb61cfa8f75056197e9b509eb52039ebc291af3b1096241ba2542c"}, + {file = "boto3-1.34.79-py3-none-any.whl", hash = "sha256:265b0b4865e8c07e27abb32a31d2bd9129bb009b1d89ca0783776ec084886123"}, + {file = "boto3-1.34.79.tar.gz", hash = "sha256:139dd2d94eaa0e3213ff37ba7cf4cb2e3823269178fe8f3e33c965f680a9ddde"}, ] [package.dependencies] -botocore = ">=1.34.76,<1.35.0" +botocore = ">=1.34.79,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -366,13 +366,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.76" +version = "1.34.79" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.76-py3-none-any.whl", hash = "sha256:62e45e7374844ee39e86a96fe7f5e973eb5bf3469da028b4e3a8caba0909fb1f"}, - {file = "botocore-1.34.76.tar.gz", hash = "sha256:68be44487a95132fccbc0b836fded4190dae30324f6bf822e1b6efd385ffdc83"}, + {file = "botocore-1.34.79-py3-none-any.whl", hash = "sha256:a42a014d3dbaa9ef123810592af69f9e55b456c5be3ac9efc037325685519e83"}, + {file = "botocore-1.34.79.tar.gz", hash = "sha256:6b59b0f7de219d383a2a633f6718c2600642ebcb707749dc6c67a6a436474b7a"}, ] [package.dependencies] @@ -597,39 +597,35 @@ files = [ ] [[package]] -name = "cpr-data-access" -version = "0.5.8" +name = "cpr-sdk" +version = "1.0.2" description = "" optional = false -python-versions = "^3.9" -files = [] -develop = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "cpr_sdk-1.0.2-py3-none-any.whl", hash = "sha256:fa2693410d21fea75498a5862b9333bfab4d67f1f823ee4811f484e084962379"}, + {file = "cpr_sdk-1.0.2.tar.gz", hash = "sha256:314c3c7ca90ded62754e99287bb914fdcca2e5a7a837077a94ad215dafbae86f"}, +] [package.dependencies] -aws-error-utils = "^2.7.0" -boto3 = "^1.26.16" -datasets = "^2.14.0" -deprecation = "^2.1.0" -langdetect = "^1.0.9" +aws-error-utils = ">=2.7.0,<3.0.0" +boto3 = ">=1.26.16,<2.0.0" +datasets = ">=2.14.0,<3.0.0" +deprecation = ">=2.1.0,<3.0.0" +langdetect = ">=1.0.9,<2.0.0" numpy = ">=1.23.5" -pandas = "^1.5.3" -pydantic = "^2.4.0" -pyvespa = {version = "^0.37.1", optional = true} -pyyaml = {version = "^6.0.1", optional = true} -sentence-transformers = {version = "^2.2.2", optional = true} -torch = {version = "^2.0.0", optional = true} -tqdm = "^4.64.1" +pandas = ">=1.5.3,<2.0.0" +pydantic = ">=2.4.0,<3.0.0" +pyvespa = {version = ">=0.37.1,<0.38.0", optional = true, markers = "extra == \"vespa\""} +pyyaml = {version = ">=6.0.1,<7.0.0", optional = true, markers = "extra == \"vespa\""} +sentence-transformers = {version = ">=2.2.2,<3.0.0", optional = true, markers = "extra == \"vespa\""} +torch = {version = ">=2.0.0,<3.0.0", optional = true, markers = "extra == \"vespa\""} +tqdm = ">=4.64.1,<5.0.0" [package.extras] spacy = ["spacy (>=3.5.1,<4.0.0)"] vespa = ["pyvespa (>=0.37.1,<0.38.0)", "pyyaml (>=6.0.1,<7.0.0)", "sentence-transformers (>=2.2.2,<3.0.0)", "torch (>=2.0.0,<3.0.0)"] -[package.source] -type = "git" -url = "https://github.com/climatepolicyradar/data-access.git" -reference = "v0.5.8" -resolved_reference = "88713808fc33a30b119c39c643faa00ca677fc68" - [[package]] name = "cryptography" version = "42.0.5" @@ -864,13 +860,13 @@ fastapi = ">=0.63.0" [[package]] name = "fastapi-pagination" -version = "0.12.21" +version = "0.12.22" description = "FastAPI pagination" optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ - {file = "fastapi_pagination-0.12.21-py3-none-any.whl", hash = "sha256:5715b3dec31f9f9a0df6e08a53d7efe8c185d1fc8b392438d60e15349d7478d1"}, - {file = "fastapi_pagination-0.12.21.tar.gz", hash = "sha256:ba0bd1023ae37cb32946e91b1356f2454809e15393911d68e318f5c7aa6887c4"}, + {file = "fastapi_pagination-0.12.22-py3-none-any.whl", hash = "sha256:ff26ec95b6c5005f99b56267712d80753278f8dd3907b45c7159155f050ce48e"}, + {file = "fastapi_pagination-0.12.22.tar.gz", hash = "sha256:f7d7cffd82a3546e1cd67e28b355b44e9ec62fbfa4e99bf5682cec0daec7907a"}, ] [package.dependencies] @@ -1873,14 +1869,13 @@ files = [ [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.4.99" +version = "12.4.127" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_aarch64.whl", hash = "sha256:75d6498c96d9adb9435f2bbdbddb479805ddfb97b5c1b32395c694185c20ca57"}, - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"}, - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"}, + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, ] [[package]] @@ -2604,7 +2599,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2612,15 +2606,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2637,7 +2624,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2645,7 +2631,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3664,13 +3649,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.10.0" +version = "4.11.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, - {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, + {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"}, + {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"}, ] [[package]] @@ -4222,4 +4207,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "95df2c1426c351826a146c418bad19330f03fa937a03e12fdbc2651489637d93" +content-hash = "3a19f50e43389305d03161b75b2e141b71508d7a901d853f0c80c5fb2811cb36" diff --git a/pyproject.toml b/pyproject.toml index 00a6e9b1..13e6a011 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ python = "^3.9" Authlib = "^0.15.5" bcrypt = "^3.2.0" boto3 = "^1.26" -cpr-data-access = {git = "https://github.com/climatepolicyradar/data-access.git", tag="v0.5.8", extras = ["vespa"]} +cpr_sdk = { version = "1.0.2", extras = ["vespa"]} fastapi = "^0.104.1" fastapi-health = "^0.4.0" fastapi-pagination = { extras = ["sqlalchemy"], version = "^0.12.19" } From c284e30fad915240f7e18e66e15e304bc07ade11 Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 8 Apr 2024 16:20:54 +0100 Subject: [PATCH 02/15] Updating the package name. --- app/api/api_v1/routers/search.py | 4 +-- app/api/api_v1/schemas/search.py | 11 ++++--- app/core/search.py | 50 ++++++++++++++++---------------- tests/conftest.py | 11 ++++--- tests/core/test_search.py | 39 +++++++++++++++---------- 5 files changed, 61 insertions(+), 54 deletions(-) diff --git a/app/api/api_v1/routers/search.py b/app/api/api_v1/routers/search.py index 8d4eb64e..76f07582 100644 --- a/app/api/api_v1/routers/search.py +++ b/app/api/api_v1/routers/search.py @@ -8,8 +8,8 @@ import logging from io import BytesIO -from cpr_data_access.exceptions import QueryError -from cpr_data_access.search_adaptors import VespaSearchAdapter +from cpr_sdk.exceptions import QueryError +from cpr_sdk.search_adaptors import VespaSearchAdapter from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session diff --git a/app/api/api_v1/schemas/search.py b/app/api/api_v1/schemas/search.py index 98292c83..837720b2 100644 --- a/app/api/api_v1/schemas/search.py +++ b/app/api/api_v1/schemas/search.py @@ -1,19 +1,18 @@ from enum import Enum -from typing import List, Mapping, Optional, Sequence - -from pydantic import field_validator, Field, BaseModel, PrivateAttr, model_validator -from typing import Literal +from typing import List, Literal, Mapping, Optional, Sequence +from cpr_sdk.models.search import SearchParameters as DataAccessSearchParameters from db_client.models.dfce import FamilyCategory -from . import CLIMATE_LAWS_MATCH +from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator from typing_extensions import Annotated -from cpr_data_access.models.search import SearchParameters as DataAccessSearchParameters from app.core.config import ( VESPA_SEARCH_LIMIT, VESPA_SEARCH_MATCHES_PER_DOC, ) +from . import CLIMATE_LAWS_MATCH + Coord = tuple[float, float] diff --git a/app/core/search.py b/app/core/search.py index c93cc1c6..8d869eb5 100644 --- a/app/core/search.py +++ b/app/core/search.py @@ -5,18 +5,33 @@ from io import StringIO from typing import Any, Mapping, Optional, Sequence, cast -from cpr_data_access.embedding import Embedder -from cpr_data_access.models.search import Document as DataAccessResponseDocument -from cpr_data_access.models.search import Family as DataAccessResponseFamily -from cpr_data_access.models.search import Passage as DataAccessResponsePassage -from cpr_data_access.models.search import SearchResponse as DataAccessSearchResponse -from cpr_data_access.models.search import Filters as DataAccessKeywordFilters -from cpr_data_access.models.search import filter_fields +from cpr_sdk.embedding import Embedder +from cpr_sdk.models.search import Document as DataAccessResponseDocument +from cpr_sdk.models.search import Family as DataAccessResponseFamily +from cpr_sdk.models.search import Filters as DataAccessKeywordFilters +from cpr_sdk.models.search import Passage as DataAccessResponsePassage +from cpr_sdk.models.search import SearchResponse as DataAccessSearchResponse +from cpr_sdk.models.search import filter_fields +from db_client.models.dfce import ( + Collection, + CollectionFamily, + Family, + FamilyDocument, + FamilyMetadata, + Slug, +) +from db_client.models.dfce.family import ( + Corpus, + DocumentStatus, + FamilyCorpus, + FamilyStatus, +) +from db_client.models.organisation import Organisation from sqlalchemy.orm import Session from app.api.api_v1.schemas.search import ( - FilterField, BackendFilterValues, + FilterField, SearchRequestBody, SearchResponse, SearchResponseDocumentPassage, @@ -29,21 +44,6 @@ ) from app.core.lookups import get_countries_for_region, get_countries_for_slugs from app.core.util import to_cdn_url -from db_client.models.organisation import Organisation -from db_client.models.dfce import ( - Collection, - CollectionFamily, - Family, - FamilyDocument, - FamilyMetadata, - Slug, -) -from db_client.models.dfce.family import ( - DocumentStatus, - FamilyStatus, - FamilyCorpus, - Corpus, -) _LOGGER = logging.getLogger(__name__) @@ -330,9 +330,9 @@ def _process_vespa_search_response_families( offset: int, ) -> Sequence[SearchResponseFamily]: """ - Process a list of data access results into a list of SearchResponse Families + Process a list of cpr sdk results into a list of SearchResponse Families - Note: this function requires that results from the data access library are grouped + Note: this function requires that results from the cpr sdk library are grouped by family_import_id. """ vespa_families_to_process = vespa_families[offset : limit + offset] diff --git a/tests/conftest.py b/tests/conftest.py index 8fb19864..bc0e4c2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,18 +3,17 @@ import uuid import pytest -from cpr_data_access.embedding import Embedder -from cpr_data_access.search_adaptors import Vespa, VespaSearchAdapter +from cpr_sdk.embedding import Embedder +from cpr_sdk.search_adaptors import Vespa, VespaSearchAdapter +from db_client import run_migrations +from db_client.models import Base +from db_client.models.organisation import AppUser from fastapi.testclient import TestClient from moto import mock_s3 from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy_utils import create_database, database_exists, drop_database -from db_client.models.organisation import AppUser -from db_client.models import Base -from db_client import run_migrations - from app.core import security from app.core.aws import S3Client, get_s3_client from app.db.session import get_db diff --git a/tests/core/test_search.py b/tests/core/test_search.py index 5654d81f..01c2d584 100644 --- a/tests/core/test_search.py +++ b/tests/core/test_search.py @@ -1,31 +1,30 @@ import random from dataclasses import dataclass from datetime import datetime -from slugify import slugify from typing import Sequence import pytest -from cpr_data_access.models.search import Filters as DataAccessFilters -from cpr_data_access.models.search import ( +from cpr_sdk.models.search import ( Document as DataAccessDocument, +) +from cpr_sdk.models.search import ( Family as DataAccessFamily, +) +from cpr_sdk.models.search import Filters as DataAccessFilters +from cpr_sdk.models.search import ( Hit as DataAccessHit, +) +from cpr_sdk.models.search import ( Passage as DataAccessPassage, +) +from cpr_sdk.models.search import ( SearchResponse as DataAccessSearchResponse, - filter_fields, ) -from sqlalchemy.orm import Session - -from app.core.config import VESPA_SEARCH_MATCHES_PER_DOC, VESPA_SEARCH_LIMIT -from app.core.search import ( - SearchRequestBody, - create_vespa_search_params, - process_vespa_search_response, - _convert_filters, +from cpr_sdk.models.search import ( + filter_fields, ) - -from db_client.models.document import PhysicalDocument from db_client.models.dfce import ( + DocumentStatus, EventStatus, Family, FamilyCategory, @@ -33,7 +32,17 @@ FamilyEvent, FamilyMetadata, Geography, - DocumentStatus, +) +from db_client.models.document import PhysicalDocument +from slugify import slugify +from sqlalchemy.orm import Session + +from app.core.config import VESPA_SEARCH_LIMIT, VESPA_SEARCH_MATCHES_PER_DOC +from app.core.search import ( + SearchRequestBody, + _convert_filters, + create_vespa_search_params, + process_vespa_search_response, ) From ac85cb30d8797aeddbd4fb54ff9fa36ab43be16d Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 8 Apr 2024 16:30:22 +0100 Subject: [PATCH 03/15] Updating poetry files. --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 8b530bd6..f9a284f3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1349,13 +1349,13 @@ files = [ [[package]] name = "joblib" -version = "1.3.2" +version = "1.4.0" description = "Lightweight pipelining with Python functions" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"}, - {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"}, + {file = "joblib-1.4.0-py3-none-any.whl", hash = "sha256:42942470d4062537be4d54c83511186da1fc14ba354961a2114da91efa9a4ed7"}, + {file = "joblib-1.4.0.tar.gz", hash = "sha256:1eb0dc091919cd384490de890cb5dfd538410a6d4b3b54eef09fb8c50b409b1c"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index 13e6a011..f250f0f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "navigator_backend" -version = "0.1.0" +version = "1.8.3" description = "" authors = ["CPR-dev-team "] From d9a55de7f959c579821936d3978e60844b799d66 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 09:18:01 +0100 Subject: [PATCH 04/15] Trunk updates. --- app/api/api_v1/routers/search.py | 1 + app/api/api_v1/schemas/search.py | 5 +---- app/core/search.py | 5 +---- pyproject.toml | 16 +++++----------- tests/core/test_search.py | 24 ++++++------------------ 5 files changed, 14 insertions(+), 37 deletions(-) diff --git a/app/api/api_v1/routers/search.py b/app/api/api_v1/routers/search.py index 76f07582..700b10c0 100644 --- a/app/api/api_v1/routers/search.py +++ b/app/api/api_v1/routers/search.py @@ -5,6 +5,7 @@ its input. The individual endpoints will return different responses tailored for the type of document search being performed. """ + import logging from io import BytesIO diff --git a/app/api/api_v1/schemas/search.py b/app/api/api_v1/schemas/search.py index 837720b2..44cf5069 100644 --- a/app/api/api_v1/schemas/search.py +++ b/app/api/api_v1/schemas/search.py @@ -6,10 +6,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator from typing_extensions import Annotated -from app.core.config import ( - VESPA_SEARCH_LIMIT, - VESPA_SEARCH_MATCHES_PER_DOC, -) +from app.core.config import VESPA_SEARCH_LIMIT, VESPA_SEARCH_MATCHES_PER_DOC from . import CLIMATE_LAWS_MATCH diff --git a/app/core/search.py b/app/core/search.py index 8d869eb5..4460415d 100644 --- a/app/core/search.py +++ b/app/core/search.py @@ -38,10 +38,7 @@ SearchResponseFamily, SearchResponseFamilyDocument, ) -from app.core.config import ( - INDEX_ENCODER_CACHE_FOLDER, - PUBLIC_APP_URL, -) +from app.core.config import INDEX_ENCODER_CACHE_FOLDER, PUBLIC_APP_URL from app.core.lookups import get_countries_for_region, get_countries_for_slugs from app.core.util import to_cdn_url diff --git a/pyproject.toml b/pyproject.toml index f250f0f7..74945aae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ python = "^3.9" Authlib = "^0.15.5" bcrypt = "^3.2.0" boto3 = "^1.26" -cpr_sdk = { version = "1.0.2", extras = ["vespa"]} +cpr_sdk = { version = "1.0.2", extras = ["vespa"] } fastapi = "^0.104.1" fastapi-health = "^0.4.0" fastapi-pagination = { extras = ["sqlalchemy"], version = "^0.12.19" } @@ -30,7 +30,7 @@ starlette = "^0.27.0" tenacity = "^8.0.1" uvicorn = { extras = ["standard"], version = "^0.20.0" } botocore = "^1.34.19" -db-client = {git = "https://github.com/climatepolicyradar/navigator-db-client.git", tag = "v3.4.0"} +db-client = { git = "https://github.com/climatepolicyradar/navigator-db-client.git", tag = "v3.4.0" } urllib3 = "<2" apscheduler = "^3.10.4" @@ -57,11 +57,7 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] addopts = "-p no:cacheprovider" -markers = [ - "cors", - "search", - "unit", -] +markers = ["cors", "search", "unit"] asyncio_mode = "strict" [tool.pydocstyle] @@ -70,9 +66,7 @@ ignore = """ [tool.pyright] include = ["app", "scripts", "tests"] -exclude = [ - "**/__pycache__", -] +exclude = ["**/__pycache__"] ignore = ["scripts/**/*"] defineConstant = { DEBUG = true } @@ -81,4 +75,4 @@ reportMissingTypeStubs = false pythonVersion = "3.9" pythonPlatform = "Linux" -venv = "backend" \ No newline at end of file +venv = "backend" diff --git a/tests/core/test_search.py b/tests/core/test_search.py index 01c2d584..d32229a7 100644 --- a/tests/core/test_search.py +++ b/tests/core/test_search.py @@ -4,25 +4,13 @@ from typing import Sequence import pytest -from cpr_sdk.models.search import ( - Document as DataAccessDocument, -) -from cpr_sdk.models.search import ( - Family as DataAccessFamily, -) +from cpr_sdk.models.search import Document as DataAccessDocument +from cpr_sdk.models.search import Family as DataAccessFamily from cpr_sdk.models.search import Filters as DataAccessFilters -from cpr_sdk.models.search import ( - Hit as DataAccessHit, -) -from cpr_sdk.models.search import ( - Passage as DataAccessPassage, -) -from cpr_sdk.models.search import ( - SearchResponse as DataAccessSearchResponse, -) -from cpr_sdk.models.search import ( - filter_fields, -) +from cpr_sdk.models.search import Hit as DataAccessHit +from cpr_sdk.models.search import Passage as DataAccessPassage +from cpr_sdk.models.search import SearchResponse as DataAccessSearchResponse +from cpr_sdk.models.search import filter_fields from db_client.models.dfce import ( DocumentStatus, EventStatus, From da734f0f3a1467028a067f799f32447f819791ad Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 10:40:44 +0100 Subject: [PATCH 05/15] Updates post trunk check. --- .trunk/trunk.yaml | 2 +- app/api/api_v1/schemas/search.py | 4 ++-- pyproject.toml | 8 ++------ tests/core/test_search.py | 8 ++++++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index 08a16ea7..99a7b48a 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -49,7 +49,7 @@ lint: - markdownlint@0.39.0 - osv-scanner@1.7.0 - oxipng@9.0.0 - - pre-commit-hooks@4.5.0: + - pre-commit-hooks@4.6.0: commands: - check-ast - check-case-conflict diff --git a/app/api/api_v1/schemas/search.py b/app/api/api_v1/schemas/search.py index 44cf5069..d671d254 100644 --- a/app/api/api_v1/schemas/search.py +++ b/app/api/api_v1/schemas/search.py @@ -6,10 +6,9 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator from typing_extensions import Annotated +from app.api.api_v1.schemas import CLIMATE_LAWS_MATCH from app.core.config import VESPA_SEARCH_LIMIT, VESPA_SEARCH_MATCHES_PER_DOC -from . import CLIMATE_LAWS_MATCH - Coord = tuple[float, float] @@ -47,6 +46,7 @@ class SearchRequestBody(DataAccessSearchParameters): """The request body expected by the search API endpoint.""" # Query string should be required in backend (its not in dal) + # trunk-ignore(pyright/reportIncompatibleVariableOverride) query_string: str # We need to add `keyword_filters` here because the items recieved from the frontend diff --git a/pyproject.toml b/pyproject.toml index 74945aae..5f86e97a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ name = "navigator_backend" version = "1.8.3" description = "" authors = ["CPR-dev-team "] +packages = [{ include = "app" }, { include = "tests" }] [tool.poetry.dependencies] python = "^3.9" @@ -65,14 +66,9 @@ ignore = """ """ [tool.pyright] -include = ["app", "scripts", "tests"] +include = ["app", "tests"] exclude = ["**/__pycache__"] ignore = ["scripts/**/*"] -defineConstant = { DEBUG = true } - -reportMissingImports = true -reportMissingTypeStubs = false - pythonVersion = "3.9" pythonPlatform = "Linux" venv = "backend" diff --git a/tests/core/test_search.py b/tests/core/test_search.py index d32229a7..123aa59e 100644 --- a/tests/core/test_search.py +++ b/tests/core/test_search.py @@ -205,12 +205,16 @@ def test_create_vespa_search_params( search_request_body = SearchRequestBody( query_string=query_string, exact_match=exact_match, - max_passages_per_doc=max_passages, + # The SearchParameters model provides allows this field as an alias for + # max_hits_per_family. + max_passages_per_doc=max_passages, # type: ignore family_ids=family_ids, document_ids=document_ids, keyword_filters=keyword_filters, year_range=year_range, - sort_field=sort_field, + # The SearchParameters model provides allows this field as an alias for + # max_hits_per_family. + sort_field=sort_field, # type: ignore sort_order=sort_order, limit=limit, offset=offset, From 17a9b2340bf91c791d5056e55e66f8d01a3e34f2 Mon Sep 17 00:00:00 2001 From: Katy Baulch <46493669+katybaulch@users.noreply.github.com> Date: Tue, 9 Apr 2024 11:41:43 +0100 Subject: [PATCH 06/15] Ignore all changes from fixing trunk errors PR SHA in git blame --- .git-blame-ignore-revs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 12f83238..34518e92 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -14,3 +14,6 @@ # Ignore LICENSE re-formatting 4e8229e076fec4c5013655a9950187bea9b354df + +# Ignore all trunk auto-fixable errors +4eeb41cc40914cbc7ea254687239849fce5ac6b8 From 865fa654b16b1ac043c846a16b2fcf9d38e717dc Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 12:10:02 +0100 Subject: [PATCH 07/15] Correcting test. --- tests/conftest.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 40c1f572..7c0e5f62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import os -import typing as t import uuid +from typing import Optional import pytest from cpr_sdk.embedding import Embedder @@ -55,13 +55,18 @@ def test_s3_client(s3_document_bucket_names): def test_vespa(): """Connect to local vespa instance""" - def __mocked_init__(self, embedder: t.Optional[Embedder] = None): - self.client = Vespa(url="http://vespatest", port=8080) + def __mocked_init__( + self, + instance_url: str, + cert_directory: Optional[str] = None, + embedder: Optional[Embedder] = None, + ): + self.client = Vespa(url=instance_url, port=8080) self.embedder = embedder or Embedder() VespaSearchAdapter.__init__ = __mocked_init__ - yield VespaSearchAdapter() + yield VespaSearchAdapter(instance_url="http://vespatest") def get_test_db_url() -> str: From 5b12a334364188b25715602c8e6758daa40251a9 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 12:20:28 +0100 Subject: [PATCH 08/15] Removing typing as t. --- tests/conftest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7c0e5f62..fa1015ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import os import uuid -from typing import Optional +from typing import Dict, Optional import pytest from cpr_sdk.embedding import Embedder @@ -280,7 +280,7 @@ def verify_password_mock(first: str, second: str) -> bool: @pytest.fixture def superuser_token_headers( test_client: TestClient, test_superuser, test_password, monkeypatch -) -> t.Dict[str, str]: +) -> Dict[str, str]: monkeypatch.setattr(security, "verify_password", verify_password_mock) login_data = { @@ -297,7 +297,7 @@ def superuser_token_headers( @pytest.fixture def data_superuser_token_headers( data_client: TestClient, data_superuser, test_password, monkeypatch -) -> t.Dict[str, str]: +) -> Dict[str, str]: monkeypatch.setattr(security, "verify_password", verify_password_mock) login_data = { From d5c88585a15e3a651af50c90e504b70dfc21ddf7 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 12:37:39 +0100 Subject: [PATCH 09/15] Attempting trunk fix. --- tests/core/test_search.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/core/test_search.py b/tests/core/test_search.py index 123aa59e..3591b909 100644 --- a/tests/core/test_search.py +++ b/tests/core/test_search.py @@ -1,7 +1,7 @@ import random from dataclasses import dataclass from datetime import datetime -from typing import Sequence +from typing import Mapping, Sequence, Union import pytest from cpr_sdk.models.search import Document as DataAccessDocument @@ -240,8 +240,12 @@ def test_create_vespa_search_params( # Test converted data if keyword_filters: + converted_keyword_filters: Union[Mapping[str, Sequence[str]], None] = ( + _convert_filters(data_db, keyword_filters) + ) + assert converted_keyword_filters assert produced_search_parameters.filters == DataAccessFilters( - **_convert_filters(data_db, keyword_filters) + **converted_keyword_filters ) else: assert not produced_search_parameters.keyword_filters From 9aec0ee4a65757526908cb1d11ef2c5d4a83372d Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 12:38:13 +0100 Subject: [PATCH 10/15] Attempting trunk fix via ignore. --- tests/core/test_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_search.py b/tests/core/test_search.py index 3591b909..a772214c 100644 --- a/tests/core/test_search.py +++ b/tests/core/test_search.py @@ -245,7 +245,7 @@ def test_create_vespa_search_params( ) assert converted_keyword_filters assert produced_search_parameters.filters == DataAccessFilters( - **converted_keyword_filters + **converted_keyword_filters # type: ignore ) else: assert not produced_search_parameters.keyword_filters From 54234f9c4665a30d9f56ba308ae37e64581f64d7 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 12:45:11 +0100 Subject: [PATCH 11/15] Reverting t import change. --- tests/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index fa1015ed..d918e60f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +import typing as t import uuid from typing import Dict, Optional @@ -280,7 +281,7 @@ def verify_password_mock(first: str, second: str) -> bool: @pytest.fixture def superuser_token_headers( test_client: TestClient, test_superuser, test_password, monkeypatch -) -> Dict[str, str]: +) -> t.Dict[str, str]: monkeypatch.setattr(security, "verify_password", verify_password_mock) login_data = { @@ -297,7 +298,7 @@ def superuser_token_headers( @pytest.fixture def data_superuser_token_headers( data_client: TestClient, data_superuser, test_password, monkeypatch -) -> Dict[str, str]: +) -> t.Dict[str, str]: monkeypatch.setattr(security, "verify_password", verify_password_mock) login_data = { From 1f1cdd8535518cea4babfa0a4d668e9ccefff205 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 13:01:44 +0100 Subject: [PATCH 12/15] Removing Dict import. --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index d918e60f..80004ef6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import os import typing as t import uuid -from typing import Dict, Optional +from typing import Optional import pytest from cpr_sdk.embedding import Embedder From a4d9e4f9c3349064c435e34fa62730349f0dfb40 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 13:15:12 +0100 Subject: [PATCH 13/15] Bumping the version. --- poetry.lock | 14 +++++++------- pyproject.toml | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/poetry.lock b/poetry.lock index f9a284f3..f8a2fecd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -347,17 +347,17 @@ uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "boto3" -version = "1.34.79" +version = "1.34.80" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.79-py3-none-any.whl", hash = "sha256:265b0b4865e8c07e27abb32a31d2bd9129bb009b1d89ca0783776ec084886123"}, - {file = "boto3-1.34.79.tar.gz", hash = "sha256:139dd2d94eaa0e3213ff37ba7cf4cb2e3823269178fe8f3e33c965f680a9ddde"}, + {file = "boto3-1.34.80-py3-none-any.whl", hash = "sha256:bb8f433c04dcdffbd4a802df56c1c30f2be23b1161fd8fb45e4b76c1487ec122"}, + {file = "boto3-1.34.80.tar.gz", hash = "sha256:5627f6ecadb46fc7c9f8c368baf948f1b00a3fd2f8eb1275c254469853ad8fdb"}, ] [package.dependencies] -botocore = ">=1.34.79,<1.35.0" +botocore = ">=1.34.80,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -366,13 +366,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.79" +version = "1.34.80" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.79-py3-none-any.whl", hash = "sha256:a42a014d3dbaa9ef123810592af69f9e55b456c5be3ac9efc037325685519e83"}, - {file = "botocore-1.34.79.tar.gz", hash = "sha256:6b59b0f7de219d383a2a633f6718c2600642ebcb707749dc6c67a6a436474b7a"}, + {file = "botocore-1.34.80-py3-none-any.whl", hash = "sha256:354a00f03faba52acc6f1a84fa4f035d48541633be98ccc24b59dc544f679f8b"}, + {file = "botocore-1.34.80.tar.gz", hash = "sha256:8402262e819f3d46df504bbd781e770858c0130b90f660699f75ef3a63abca5a"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 5f86e97a..498874d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "navigator_backend" -version = "1.8.3" +version = "1.9.0" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] From c1a3a27c7979c46f6381d9cd0828fe2eb938cb79 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 15:01:45 +0100 Subject: [PATCH 14/15] Updating to CprSdk from DataAccess and updating how we pass params into CprSdkFilters model. --- app/api/api_v1/schemas/search.py | 4 +-- app/core/search.py | 20 ++++++------- tests/core/test_search.py | 51 ++++++++++++++++++++++---------- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/app/api/api_v1/schemas/search.py b/app/api/api_v1/schemas/search.py index d671d254..01981d91 100644 --- a/app/api/api_v1/schemas/search.py +++ b/app/api/api_v1/schemas/search.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List, Literal, Mapping, Optional, Sequence -from cpr_sdk.models.search import SearchParameters as DataAccessSearchParameters +from cpr_sdk.models.search import SearchParameters as CprSdkSearchParameters from db_client.models.dfce import FamilyCategory from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator from typing_extensions import Annotated @@ -42,7 +42,7 @@ class FilterField(str, Enum): BackendKeywordFilter = Optional[Mapping[BackendFilterValues, Sequence[str]]] -class SearchRequestBody(DataAccessSearchParameters): +class SearchRequestBody(CprSdkSearchParameters): """The request body expected by the search API endpoint.""" # Query string should be required in backend (its not in dal) diff --git a/app/core/search.py b/app/core/search.py index 4460415d..2af2b52c 100644 --- a/app/core/search.py +++ b/app/core/search.py @@ -6,11 +6,11 @@ from typing import Any, Mapping, Optional, Sequence, cast from cpr_sdk.embedding import Embedder -from cpr_sdk.models.search import Document as DataAccessResponseDocument -from cpr_sdk.models.search import Family as DataAccessResponseFamily -from cpr_sdk.models.search import Filters as DataAccessKeywordFilters -from cpr_sdk.models.search import Passage as DataAccessResponsePassage -from cpr_sdk.models.search import SearchResponse as DataAccessSearchResponse +from cpr_sdk.models.search import Document as CprSdkResponseDocument +from cpr_sdk.models.search import Family as CprSdkResponseFamily +from cpr_sdk.models.search import Filters as CprSdkKeywordFilters +from cpr_sdk.models.search import Passage as CprSdkResponsePassage +from cpr_sdk.models.search import SearchResponse as CprSdkSearchResponse from cpr_sdk.models.search import filter_fields from db_client.models.dfce import ( Collection, @@ -322,7 +322,7 @@ def _convert_filters( def _process_vespa_search_response_families( db: Session, - vespa_families: Sequence[DataAccessResponseFamily], + vespa_families: Sequence[CprSdkResponseFamily], limit: int, offset: int, ) -> Sequence[SearchResponseFamily]: @@ -422,11 +422,11 @@ def _process_vespa_search_response_families( ) response_family_lookup[family_import_id] = response_family - if isinstance(hit, DataAccessResponseDocument): + if isinstance(hit, CprSdkResponseDocument): response_family.family_description_match = True response_family.family_title_match = True - elif isinstance(hit, DataAccessResponsePassage): + elif isinstance(hit, CprSdkResponsePassage): document_import_id = hit.document_import_id if document_import_id is None: _LOGGER.error("Skipping hit with empty document import id") @@ -476,7 +476,7 @@ def _process_vespa_search_response_families( def process_vespa_search_response( db: Session, - vespa_search_response: DataAccessSearchResponse, + vespa_search_response: CprSdkSearchResponse, limit: int, offset: int, ) -> SearchResponse: @@ -505,7 +505,7 @@ def create_vespa_search_params( """Create Vespa search parameters from a F/E search request body""" converted_filters = _convert_filters(db, search_body.keyword_filters) if converted_filters: - search_body.filters = DataAccessKeywordFilters.model_validate(converted_filters) + search_body.filters = CprSdkKeywordFilters.model_validate(converted_filters) else: search_body.filters = None return search_body diff --git a/tests/core/test_search.py b/tests/core/test_search.py index a772214c..e5f5420b 100644 --- a/tests/core/test_search.py +++ b/tests/core/test_search.py @@ -4,12 +4,12 @@ from typing import Mapping, Sequence, Union import pytest -from cpr_sdk.models.search import Document as DataAccessDocument -from cpr_sdk.models.search import Family as DataAccessFamily -from cpr_sdk.models.search import Filters as DataAccessFilters -from cpr_sdk.models.search import Hit as DataAccessHit -from cpr_sdk.models.search import Passage as DataAccessPassage -from cpr_sdk.models.search import SearchResponse as DataAccessSearchResponse +from cpr_sdk.models.search import Document as CprSdkDocument +from cpr_sdk.models.search import Family as CprSdkFamily +from cpr_sdk.models.search import Filters as CprSdkFilters +from cpr_sdk.models.search import Hit as CprSdkHit +from cpr_sdk.models.search import Passage as CprSdkPassage +from cpr_sdk.models.search import SearchResponse as CprSdkSearchResponse from cpr_sdk.models.search import filter_fields from db_client.models.dfce import ( DocumentStatus, @@ -238,14 +238,35 @@ def test_create_vespa_search_params( assert produced_search_parameters.query_string == query_string assert produced_search_parameters.exact_match == exact_match + # Test converted data # Test converted data if keyword_filters: converted_keyword_filters: Union[Mapping[str, Sequence[str]], None] = ( _convert_filters(data_db, keyword_filters) ) - assert converted_keyword_filters - assert produced_search_parameters.filters == DataAccessFilters( - **converted_keyword_filters # type: ignore + assert converted_keyword_filters is not None + assert produced_search_parameters.filters is not None + assert produced_search_parameters.filters == CprSdkFilters( + family_geography=( + converted_keyword_filters["family_geography"] + if "family_geography" in converted_keyword_filters.keys() + else [] + ), + family_category=( + converted_keyword_filters["family_category"] + if "family_category" in converted_keyword_filters.keys() + else [] + ), + document_languages=( + converted_keyword_filters["document_languages"] + if "document_languages" in converted_keyword_filters.keys() + else [] + ), + family_source=( + converted_keyword_filters["family_source"] + if "family_source" in converted_keyword_filters.keys() + else [] + ), ) else: assert not produced_search_parameters.keyword_filters @@ -451,7 +472,7 @@ def _generate_coords(): ] -def _generate_search_response_hits(spec: FamSpec) -> Sequence[DataAccessHit]: +def _generate_search_response_hits(spec: FamSpec) -> Sequence[CprSdkHit]: random.seed(spec.random_seed) doc_data = {} hits = [] @@ -462,7 +483,7 @@ def _generate_search_response_hits(spec: FamSpec) -> Sequence[DataAccessHit]: "languages": random.sample(_LANGUAGES, random.randint(1, 3)), } hits.append( - DataAccessDocument( + CprSdkDocument( family_import_id=spec.family_import_id, family_name=spec.family_name, family_description=spec.family_description, @@ -493,7 +514,7 @@ def _generate_search_response_hits(spec: FamSpec) -> Sequence[DataAccessHit]: "languages": random.sample(_LANGUAGES, random.randint(1, 3)), } hits.append( - DataAccessPassage( + CprSdkPassage( family_import_id=spec.family_import_id, family_name=spec.family_name, family_description=spec.family_description, @@ -535,18 +556,18 @@ def _generate_search_response_hits(spec: FamSpec) -> Sequence[DataAccessHit]: return hits -def _generate_search_response(specs: Sequence[FamSpec]) -> DataAccessSearchResponse: +def _generate_search_response(specs: Sequence[FamSpec]) -> CprSdkSearchResponse: families = [] for fam_spec in specs: passage_hits = _generate_search_response_hits(fam_spec) - f = DataAccessFamily( + f = CprSdkFamily( id=fam_spec.family_import_id, hits=passage_hits, total_passage_hits=(len(passage_hits) * 10), ) families.append(f) - return DataAccessSearchResponse( + return CprSdkSearchResponse( total_hits=len(specs), total_family_hits=(len(families) * 4), query_time_ms=87 * len(specs), From 17ec741c96079a3742010364d40eff5a0154bd80 Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 9 Apr 2024 15:02:20 +0100 Subject: [PATCH 15/15] Removing duplicate comment. --- tests/core/test_search.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/core/test_search.py b/tests/core/test_search.py index e5f5420b..19d8cdfb 100644 --- a/tests/core/test_search.py +++ b/tests/core/test_search.py @@ -238,7 +238,6 @@ def test_create_vespa_search_params( assert produced_search_parameters.query_string == query_string assert produced_search_parameters.exact_match == exact_match - # Test converted data # Test converted data if keyword_filters: converted_keyword_filters: Union[Mapping[str, Sequence[str]], None] = (