Skip to content

Commit

Permalink
do query encoding on vespa (#104)
Browse files Browse the repository at this point in the history
* add embedding model to test_vespa services

* add sentence-transformers specific model tweaks to services

* change hybrid search to use vespa embedder

* add experimental flag for new feature and test that it works

* bump minor version
  • Loading branch information
kdutia authored Sep 17, 2024
1 parent 2d67102 commit 23b83a7
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src/cpr_sdk/models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ class SearchParameters(BaseModel):
the search is performed.
"""

experimental_encode_on_vespa: bool = False
"""
Backend change to encode the query string on the Vespa side rather than through
the SDK.
"""

all_results: bool = False
"""
Return all results rather than searching or ranking
Expand Down
2 changes: 1 addition & 1 deletion src/cpr_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_MAJOR = "1"
_MINOR = "4"
_MINOR = "5"
_PATCH = "4"
_SUFFIX = ""

Expand Down
12 changes: 9 additions & 3 deletions src/cpr_sdk/vespa.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,15 @@ def build_vespa_request_body(
vespa_request_body["ranking.profile"] = "hybrid_no_closeness"
else:
vespa_request_body["ranking.profile"] = "hybrid"
vespa_request_body["input.query(query_embedding)"] = embedder.embed(
parameters.query_string, normalize=False, show_progress_bar=False
)
if parameters.experimental_encode_on_vespa:
vespa_request_body[
"input.query(query_embedding)"
] = "embed(msmarco-distilbert-dot-v5, @query_string)"
else:
vespa_request_body["input.query(query_embedding)"] = embedder.embed(
parameters.query_string, normalize=False, show_progress_bar=False
)

return vespa_request_body


Expand Down
6 changes: 6 additions & 0 deletions tests/local_vespa/test_app/services.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
<services version="1.0" xmlns:deploy="vespa" xmlns:preprocess="properties">

<container id="default" version="1.0">
<component id="msmarco-distilbert-dot-v5" type="hugging-face-embedder">
<transformer-model url="https://huggingface.co/onnx-models/msmarco-distilbert-dot-v5-onnx/resolve/main/model.onnx"/>
<tokenizer-model url="https://huggingface.co/onnx-models/msmarco-distilbert-dot-v5-onnx/resolve/main/tokenizer.json"/>
<transformer-token-type-ids/>
<transformer-output>token_embeddings</transformer-output>
</component>
<document-api/>
<search/>
<nodes deploy:environment="dev" count="1">
Expand Down
28 changes: 28 additions & 0 deletions tests/test_search_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,34 @@ def test_vespa_search_adaptor__hybrid(test_vespa):
assert family_name in got_family_names


@pytest.mark.vespa
def test_vespa_search_adaptor__hybrid_encoding_on_vespa(test_vespa):
family_name = "Climate Change Adaptation and Low Emissions Growth Strategy by 2035"
requests = {
"encode_in_sdk": SearchParameters(query_string=family_name),
"encode_on_vespa": SearchParameters(
query_string=family_name, experimental_encode_on_vespa=True
),
}
responses = dict()

for request_name, request in requests.items():
responses[request_name] = vespa_search(test_vespa, request)

# Was the family searched for in the results.
# Note that this is a fairly loose test
got_family_names = []
for fam in responses[request_name].families:
for doc in fam.hits:
got_family_names.append(doc.family_name)
assert family_name in got_family_names

assert (
responses["encode_in_sdk"].total_hits == responses["encode_on_vespa"].total_hits
)
assert responses["encode_in_sdk"].families == responses["encode_on_vespa"].families


@pytest.mark.vespa
def test_vespa_search_adaptor__all(test_vespa):
request = SearchParameters(query_string="", all_results=True)
Expand Down

0 comments on commit 23b83a7

Please sign in to comment.