Skip to content

Commit

Permalink
WIP: initial vespa search support via cpr_data_access
Browse files Browse the repository at this point in the history
  • Loading branch information
Joel Wright committed Oct 25, 2023
1 parent beb9b23 commit a7bbb5d
Show file tree
Hide file tree
Showing 4 changed files with 923 additions and 546 deletions.
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ RUN mkdir /cpr-backend
WORKDIR /cpr-backend

RUN apt update && \
apt install -y postgresql-client curl \
apt install -y postgresql-client curl git \
&& rm -rf /var/lib/apt/lists/*

# Install pip and poetry
Expand All @@ -18,6 +18,7 @@ COPY poetry.lock pyproject.toml ./
RUN poetry export --with dev \
| grep -v '\--hash' \
| grep -v '^torch' \
| grep -v '^triton' \
| grep -v '^nvidia' \
| sed -e 's/ \\$//' \
| sed -e 's/^[[:alpha:]]\+\[\([[:alpha:]]\+\[[[:alpha:]]\+\]\)\]/\1/' \
Expand All @@ -26,7 +27,7 @@ RUN poetry export --with dev \
# e.g. we need to replace pydocstyle[pydocstyle[toml]] with pydocstyle[toml]

# Install torch-cpu with pip
RUN pip3 install --no-cache "torch==1.13.0+cpu" "torchvision==0.14.0+cpu" -f https://download.pytorch.org/whl/torch_stable.html
RUN pip3 install --no-cache "torch==2.0.0+cpu" "torchvision==0.15.1+cpu" -f https://download.pytorch.org/whl/torch_stable.html

# Install application requirements
RUN pip3 install --no-cache -r requirements.txt
Expand Down
12 changes: 8 additions & 4 deletions app/core/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from typing import Any, Mapping, Optional, Sequence, cast
import string

from cpr_data_access.embedding import Embedder
from opensearchpy import OpenSearch
from opensearchpy import JSONSerializer as jss
from sentence_transformers import SentenceTransformer
from sqlalchemy.orm import Session

from app.api.api_v1.schemas.search import (
Expand Down Expand Up @@ -73,10 +73,10 @@

_LOGGER = logging.getLogger(__name__)

_ENCODER = SentenceTransformer(
model_name_or_path=OPENSEARCH_INDEX_ENCODER,
_ENCODER = Embedder(
cache_folder=os.environ.get("INDEX_ENCODER_CACHE_FOLDER", "/models"),
)

# Map a sort field type to the document key used by OpenSearch
_SORT_FIELD_MAP: Mapping[SortField, str] = {
SortField.DATE: "document_date",
Expand Down Expand Up @@ -410,7 +410,11 @@ def with_semantic_query(self, query_string: str, knn: bool):

_LOGGER.info(f"Starting embeddings generation for '{query_string}'")
start_generation = time.time_ns()
embedding = _ENCODER.encode(query_string, show_progress_bar=False)
embedding = _ENCODER.embed(
query_string,
normalize=False,
show_progress_bar=False,
)
end_generation = time.time_ns()
embeddings_generation_time = round((end_generation - start_generation) / 1e6)
_LOGGER.info(
Expand Down
Loading

0 comments on commit a7bbb5d

Please sign in to comment.