diff --git a/lib/searchkick/query.rb b/lib/searchkick/query.rb index 22d782c6..ef9caccc 100644 --- a/lib/searchkick/query.rb +++ b/lib/searchkick/query.rb @@ -532,23 +532,20 @@ def prepare raise ArgumentError, "Hybrid search not supported yet" end - if options[:where] - raise ArgumentError, "KNN search with where not supported yet" - end - if options[:knn].size != 1 raise ArgumentError, "Invalid knn option" end k = per_page + offset + filter = payload.delete(:query) if Searchkick.opensearch? - payload[:query].delete(:match_all) - payload[:query][:knn] = {} + payload[:query] = {knn: {}} options[:knn].each do |field, vector| payload[:query][:knn][field.to_sym] = { vector: vector, - k: k + k: k, + filter: filter } end else @@ -556,7 +553,8 @@ def prepare payload[:knn] = { field: field, k: k, - query_vector: vector + query_vector: vector, + filter: filter } end end diff --git a/test/knn_test.rb b/test/knn_test.rb index 55afc9e5..6956d4bb 100644 --- a/test/knn_test.rb +++ b/test/knn_test.rb @@ -6,13 +6,17 @@ def setup super end - def test_works + def test_basic store [{name: "A", embedding: [1, 2, 3]}, {name: "B", embedding: [-1, -2, -3]}] assert_order "*", ["A", "B"], knn: {embedding: [1, 2, 3]} - expected = Searchkick.opensearch? ? [1, 0] : [2, 1] scores = Product.search(knn: {embedding: [1, 2, 3]}).hits.map { |v| v["_score"] } - assert_in_delta expected[0], scores[0] - assert_in_delta expected[1], scores[1] + assert_in_delta 1, scores[0] + assert_in_delta 0, scores[1] + end + + def test_where + store [{name: "A", store_id: 1, embedding: [1, 2, 3]}, {name: "B", store_id: 2, embedding: [-1, -2, -3]}] + assert_order "*", ["A"], knn: {embedding: [1, 2, 3]}, where: {store_id: 1} end end