Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating dal to cpr_sdk. #260

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@

# Ignore LICENSE re-formatting
4e8229e076fec4c5013655a9950187bea9b354df

# Ignore all trunk auto-fixable errors
4eeb41cc40914cbc7ea254687239849fce5ac6b8
4 changes: 2 additions & 2 deletions app/api/api_v1/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import logging
from io import BytesIO

from cpr_data_access.exceptions import QueryError
from cpr_data_access.search_adaptors import VespaSearchAdapter
from cpr_sdk.exceptions import QueryError
from cpr_sdk.search_adaptors import VespaSearchAdapter
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
Expand Down
3 changes: 2 additions & 1 deletion app/api/api_v1/schemas/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Literal, Mapping, Optional, Sequence

from cpr_data_access.models.search import SearchParameters as DataAccessSearchParameters
from cpr_sdk.models.search import SearchParameters as DataAccessSearchParameters
from db_client.models.dfce import FamilyCategory
from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator
from typing_extensions import Annotated
Expand Down Expand Up @@ -46,6 +46,7 @@ class SearchRequestBody(DataAccessSearchParameters):
"""The request body expected by the search API endpoint."""

# Query string should be required in backend (its not in dal)
# trunk-ignore(pyright/reportIncompatibleVariableOverride)
query_string: str

# We need to add `keyword_filters` here because the items recieved from the frontend
Expand Down
18 changes: 9 additions & 9 deletions app/core/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from io import StringIO
from typing import Any, Mapping, Optional, Sequence, cast

from cpr_data_access.embedding import Embedder
from cpr_data_access.models.search import Document as DataAccessResponseDocument
from cpr_data_access.models.search import Family as DataAccessResponseFamily
from cpr_data_access.models.search import Filters as DataAccessKeywordFilters
from cpr_data_access.models.search import Passage as DataAccessResponsePassage
from cpr_data_access.models.search import SearchResponse as DataAccessSearchResponse
from cpr_data_access.models.search import filter_fields
from cpr_sdk.embedding import Embedder
from cpr_sdk.models.search import Document as DataAccessResponseDocument
from cpr_sdk.models.search import Family as DataAccessResponseFamily
from cpr_sdk.models.search import Filters as DataAccessKeywordFilters
from cpr_sdk.models.search import Passage as DataAccessResponsePassage
from cpr_sdk.models.search import SearchResponse as DataAccessSearchResponse
from cpr_sdk.models.search import filter_fields
from db_client.models.dfce import (
Collection,
CollectionFamily,
Expand Down Expand Up @@ -327,9 +327,9 @@ def _process_vespa_search_response_families(
offset: int,
) -> Sequence[SearchResponseFamily]:
"""
Process a list of data access results into a list of SearchResponse Families
Process a list of cpr sdk results into a list of SearchResponse Families

Note: this function requires that results from the data access library are grouped
Note: this function requires that results from the cpr sdk library are grouped
by family_import_id.
"""
vespa_families_to_process = vespa_families[offset : limit + offset]
Expand Down
97 changes: 41 additions & 56 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "navigator_backend"
version = "1.8.2"
version = "1.8.3"
description = ""
authors = ["CPR-dev-team <[email protected]>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand All @@ -10,9 +10,7 @@ python = "^3.9"
Authlib = "^0.15.5"
bcrypt = "^3.2.0"
boto3 = "^1.26"
cpr-data-access = { git = "https://github.com/climatepolicyradar/data-access.git", tag = "v0.5.8", extras = [
"vespa",
] }
cpr_sdk = { version = "1.0.2", extras = ["vespa"] }
fastapi = "^0.104.1"
fastapi-health = "^0.4.0"
fastapi-pagination = { extras = ["sqlalchemy"], version = "^0.12.19" }
Expand Down
21 changes: 13 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import typing as t
import uuid
from typing import Dict, Optional

import pytest
from cpr_data_access.embedding import Embedder
from cpr_data_access.search_adaptors import Vespa, VespaSearchAdapter
from cpr_sdk.embedding import Embedder
from cpr_sdk.search_adaptors import Vespa, VespaSearchAdapter
from db_client import run_migrations
from db_client.models import Base
from db_client.models.organisation import AppUser
Expand Down Expand Up @@ -55,13 +55,18 @@
def test_vespa():
"""Connect to local vespa instance"""

def __mocked_init__(self, embedder: t.Optional[Embedder] = None):
self.client = Vespa(url="http://vespatest", port=8080)
def __mocked_init__(
self,
instance_url: str,
cert_directory: Optional[str] = None,
embedder: Optional[Embedder] = None,
):
self.client = Vespa(url=instance_url, port=8080)
self.embedder = embedder or Embedder()

VespaSearchAdapter.__init__ = __mocked_init__

yield VespaSearchAdapter()
yield VespaSearchAdapter(instance_url="http://vespatest")


def get_test_db_url() -> str:
Expand Down Expand Up @@ -135,7 +140,7 @@

@pytest.fixture
def data_db(scope="function"):
"""

Check failure on line 143 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(F821)

[new] Undefined name `t`

Check failure on line 143 in tests/conftest.py

View workflow job for this annotation

GitHub Actions / Trunk Check

pyright(reportUndefinedVariable)

[new] "t" is not defined
Create a fresh test database for each test.

This will populate the db using the alembic migrations.
Expand Down Expand Up @@ -275,7 +280,7 @@
@pytest.fixture
def superuser_token_headers(
test_client: TestClient, test_superuser, test_password, monkeypatch
) -> t.Dict[str, str]:
) -> Dict[str, str]:
monkeypatch.setattr(security, "verify_password", verify_password_mock)

login_data = {
Expand All @@ -292,7 +297,7 @@
@pytest.fixture
def data_superuser_token_headers(
data_client: TestClient, data_superuser, test_password, monkeypatch
) -> t.Dict[str, str]:
) -> Dict[str, str]:
monkeypatch.setattr(security, "verify_password", verify_password_mock)

login_data = {
Expand Down
Loading
Loading