From 23b83a75b339e670b3e6a2446e4c14f2b2f8cad4 Mon Sep 17 00:00:00 2001 From: Kalyan Dutia Date: Tue, 17 Sep 2024 14:47:23 +0100 Subject: [PATCH] do query encoding on vespa (#104) * add embedding model to test_vespa services * add sentence-transformers specific model tweaks to services * change hybrid search to use vespa embedder * add experimental flag for new feature and test that it works * bump minor version --- src/cpr_sdk/models/search.py | 6 ++++++ src/cpr_sdk/version.py | 2 +- src/cpr_sdk/vespa.py | 12 ++++++++--- tests/local_vespa/test_app/services.xml | 6 ++++++ tests/test_search_adaptors.py | 28 +++++++++++++++++++++++++ 5 files changed, 50 insertions(+), 4 deletions(-) diff --git a/src/cpr_sdk/models/search.py b/src/cpr_sdk/models/search.py index 14b721a..3e0c89f 100644 --- a/src/cpr_sdk/models/search.py +++ b/src/cpr_sdk/models/search.py @@ -91,6 +91,12 @@ class SearchParameters(BaseModel): the search is performed. """ + experimental_encode_on_vespa: bool = False + """ + Backend change to encode the query string on the Vespa side rather than through + the SDK. + """ + all_results: bool = False """ Return all results rather than searching or ranking diff --git a/src/cpr_sdk/version.py b/src/cpr_sdk/version.py index 9393761..fd459d6 100644 --- a/src/cpr_sdk/version.py +++ b/src/cpr_sdk/version.py @@ -1,5 +1,5 @@ _MAJOR = "1" -_MINOR = "4" +_MINOR = "5" _PATCH = "4" _SUFFIX = "" diff --git a/src/cpr_sdk/vespa.py b/src/cpr_sdk/vespa.py index 7e3ba9e..87ab0d2 100644 --- a/src/cpr_sdk/vespa.py +++ b/src/cpr_sdk/vespa.py @@ -103,9 +103,15 @@ def build_vespa_request_body( vespa_request_body["ranking.profile"] = "hybrid_no_closeness" else: vespa_request_body["ranking.profile"] = "hybrid" - vespa_request_body["input.query(query_embedding)"] = embedder.embed( - parameters.query_string, normalize=False, show_progress_bar=False - ) + if parameters.experimental_encode_on_vespa: + vespa_request_body[ + "input.query(query_embedding)" + ] = "embed(msmarco-distilbert-dot-v5, @query_string)" + else: + vespa_request_body["input.query(query_embedding)"] = embedder.embed( + parameters.query_string, normalize=False, show_progress_bar=False + ) + return vespa_request_body diff --git a/tests/local_vespa/test_app/services.xml b/tests/local_vespa/test_app/services.xml index c2a5dff..38aac5c 100644 --- a/tests/local_vespa/test_app/services.xml +++ b/tests/local_vespa/test_app/services.xml @@ -3,6 +3,12 @@ + + + + + token_embeddings + diff --git a/tests/test_search_adaptors.py b/tests/test_search_adaptors.py index 5524881..f2411d6 100644 --- a/tests/test_search_adaptors.py +++ b/tests/test_search_adaptors.py @@ -145,6 +145,34 @@ def test_vespa_search_adaptor__hybrid(test_vespa): assert family_name in got_family_names +@pytest.mark.vespa +def test_vespa_search_adaptor__hybrid_encoding_on_vespa(test_vespa): + family_name = "Climate Change Adaptation and Low Emissions Growth Strategy by 2035" + requests = { + "encode_in_sdk": SearchParameters(query_string=family_name), + "encode_on_vespa": SearchParameters( + query_string=family_name, experimental_encode_on_vespa=True + ), + } + responses = dict() + + for request_name, request in requests.items(): + responses[request_name] = vespa_search(test_vespa, request) + + # Was the family searched for in the results. + # Note that this is a fairly loose test + got_family_names = [] + for fam in responses[request_name].families: + for doc in fam.hits: + got_family_names.append(doc.family_name) + assert family_name in got_family_names + + assert ( + responses["encode_in_sdk"].total_hits == responses["encode_on_vespa"].total_hits + ) + assert responses["encode_in_sdk"].families == responses["encode_on_vespa"].families + + @pytest.mark.vespa def test_vespa_search_adaptor__all(test_vespa): request = SearchParameters(query_string="", all_results=True)