diff --git a/CHANGELOG.md b/CHANGELOG.md index 72adb6fbf..dc304aebf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.14...2.x) ### Features +* Use the Lucene Distance Calculation Function in Script Scoring for doing exact search [#1699](https://github.com/opensearch-project/k-NN/pull/1699) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index b6c20455e..84e986faa 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -143,20 +143,12 @@ public static float cosineSimilarity(List queryVector, KNNVectorScriptDo */ public static float cosinesimil(float[] queryVector, float[] inputVector) { requireEqualDimension(queryVector, inputVector); - float dotProduct = 0.0f; - float normQueryVector = 0.0f; - float normInputVector = 0.0f; - for (int i = 0; i < queryVector.length; i++) { - dotProduct += queryVector[i] * inputVector[i]; - normQueryVector += queryVector[i] * queryVector[i]; - normInputVector += inputVector[i] * inputVector[i]; - } - float normalizedProduct = normQueryVector * normInputVector; - if (normalizedProduct == 0) { + try { + return VectorUtil.cosine(queryVector, inputVector); + } catch (IllegalArgumentException | AssertionError e) { logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end"); return 0.0f; } - return (float) (dotProduct / (Math.sqrt(normalizedProduct))); } /** @@ -212,7 +204,6 @@ public static float calculateHammingBit(Long queryLong, Long inputLong) { * @return L1 score */ public static float l1Norm(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); float distance = 0; for (int i = 0; i < inputVector.length; i++) { float diff = queryVector[i] - inputVector[i]; @@ -250,7 +241,6 @@ public static float l1Norm(List queryVector, KNNVectorScriptDocValues do * @return L-inf score */ public static float lInfNorm(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); float distance = 0; for (int i = 0; i < inputVector.length; i++) { float diff = queryVector[i] - inputVector[i]; diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 6b40f375c..3cfbe56f1 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -47,7 +47,7 @@ public void testL2() { public void testCosineSimilarity() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; - List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); + List arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0)); float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f }; KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); @@ -59,7 +59,7 @@ public void testCosineSimilarity() { ); KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); - assertEquals(3F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F); + assertEquals(2F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F); // invalid zero vector final List queryZeroVector = List.of(0.0f, 0.0f, 0.0f);