Skip to content

Commit

Permalink
Use the Lucene Distance Calculation Function in Script Scoring for do…
Browse files Browse the repository at this point in the history
…ing exact search (#1699)

* Use the Lucene Distance Calculation Function in Script Scoring for doing exact search

Signed-off-by: Ryan Bogan <[email protected]>

* Add Changelog entry

Signed-off-by: Ryan Bogan <[email protected]>

* Fix failing test

Signed-off-by: Ryan Bogan <[email protected]>

* fix test

Signed-off-by: Ryan Bogan <[email protected]>

* Fix test bug and remove unnecessary validation

Signed-off-by: Ryan Bogan <[email protected]>

* Remove cosineSimilOptimized

Signed-off-by: Ryan Bogan <[email protected]>

* Revert "Remove cosineSimilOptimized"

This reverts commit f872d83.

Signed-off-by: Ryan Bogan <[email protected]>

---------

Signed-off-by: Ryan Bogan <[email protected]>
(cherry picked from commit 7a88f40)
  • Loading branch information
ryanbogan authored and github-actions[bot] committed May 24, 2024
1 parent 40bf388 commit 7933bb5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.14...2.x)
### Features
* Use the Lucene Distance Calculation Function in Script Scoring for doing exact search [#1699](https://github.com/opensearch-project/k-NN/pull/1699)
### Enhancements
* Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688)
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)
Expand Down
32 changes: 6 additions & 26 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.VectorUtil;
import org.opensearch.knn.index.KNNVectorScriptDocValues;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
Expand Down Expand Up @@ -48,13 +49,7 @@ private static void requireEqualDimension(final float[] queryVector, final float
* @return L2 score
*/
public static float l2Squared(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float squaredDistance = 0;
for (int i = 0; i < inputVector.length; i++) {
float diff = queryVector[i] - inputVector[i];
squaredDistance += diff * diff;
}
return squaredDistance;
return VectorUtil.squareDistance(queryVector, inputVector);
}

private static float[] toFloat(List<Number> inputVector, VectorDataType vectorDataType) {
Expand Down Expand Up @@ -148,20 +143,12 @@ public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDo
*/
public static float cosinesimil(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float dotProduct = 0.0f;
float normQueryVector = 0.0f;
float normInputVector = 0.0f;
for (int i = 0; i < queryVector.length; i++) {
dotProduct += queryVector[i] * inputVector[i];
normQueryVector += queryVector[i] * queryVector[i];
normInputVector += inputVector[i] * inputVector[i];
}
float normalizedProduct = normQueryVector * normInputVector;
if (normalizedProduct == 0) {
try {
return VectorUtil.cosine(queryVector, inputVector);
} catch (IllegalArgumentException | AssertionError e) {
logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
return 0.0f;
}
return (float) (dotProduct / (Math.sqrt(normalizedProduct)));
}

/**
Expand Down Expand Up @@ -217,7 +204,6 @@ public static float calculateHammingBit(Long queryLong, Long inputLong) {
* @return L1 score
*/
public static float l1Norm(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float distance = 0;
for (int i = 0; i < inputVector.length; i++) {
float diff = queryVector[i] - inputVector[i];
Expand Down Expand Up @@ -255,7 +241,6 @@ public static float l1Norm(List<Number> queryVector, KNNVectorScriptDocValues do
* @return L-inf score
*/
public static float lInfNorm(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float distance = 0;
for (int i = 0; i < inputVector.length; i++) {
float diff = queryVector[i] - inputVector[i];
Expand Down Expand Up @@ -293,12 +278,7 @@ public static float lInfNorm(List<Number> queryVector, KNNVectorScriptDocValues
* @return dot product score
*/
public static float innerProduct(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float distance = 0;
for (int i = 0; i < inputVector.length; i++) {
distance += queryVector[i] * inputVector[i];
}
return distance;
return VectorUtil.dotProduct(queryVector, inputVector);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void testL2() {

public void testCosineSimilarity() {
float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f };
List<Double> arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0));
List<Double> arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0));
float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f };
KNNMethodContext knnMethodContext = KNNMethodContext.getDefault();

Expand All @@ -59,7 +59,7 @@ public void testCosineSimilarity() {
);
KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType);

assertEquals(3F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F);
assertEquals(2F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F);

// invalid zero vector
final List<Float> queryZeroVector = List.of(0.0f, 0.0f, 0.0f);
Expand Down

0 comments on commit 7933bb5

Please sign in to comment.