Skip to content

Commit

Permalink
Remove cosineSimilOptimized
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan committed May 16, 2024
1 parent 540782c commit f872d83
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.Map;
import java.util.function.BiFunction;

import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isLongFieldType;
Expand Down Expand Up @@ -102,8 +101,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) {
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
SpaceType.COSINESIMIL.validateVector(processedQuery);
float qVectorSquaredMagnitude = getVectorMagnitudeSquared(this.processedQuery);
this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude);
this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimil(q, v);
}

public ScoreScript getScoreScript(
Expand Down
48 changes: 0 additions & 48 deletions src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,54 +86,6 @@ public static float l2Squared(List<Number> queryVector, KNNVectorScriptDocValues
return l2Squared(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue());
}

/**
* This method can be used script to avoid repeated calculation of normalization
* for query vector for each filtered documents
*
* @param queryVector query vector
* @param inputVector input vector
* @param normQueryVector normalized query vector value.
* @return cosine score
*/
public static float cosinesimilOptimized(float[] queryVector, float[] inputVector, float normQueryVector) {
requireEqualDimension(queryVector, inputVector);
float dotProduct = 0.0f;
float normInputVector = 0.0f;
for (int i = 0; i < queryVector.length; i++) {
dotProduct += queryVector[i] * inputVector[i];
normInputVector += inputVector[i] * inputVector[i];
}
float normalizedProduct = normQueryVector * normInputVector;
if (normalizedProduct == 0) {
logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
return 0.0f;
}
return (float) (dotProduct / (Math.sqrt(normalizedProduct)));
}

/**
* Allowlisted cosineSimilarity method that can be used in a script to avoid repeated
* calculation of normalization for the query vector.
* Example:
* "script": {
* "source": "cosineSimilarity(params.query_vector, docs[field], 1.0) ",
* "params": {
* "query_vector": [1, 2, 3.4],
* "field": "my_dense_vector"
* }
* }
*
* @param queryVector query vector
* @param docValues script doc values
* @param queryVectorMagnitude the magnitude of the query vector.
* @return cosine score
*/
public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) {
float[] inputVector = toFloat(queryVector, docValues.getVectorDataType());
SpaceType.COSINESIMIL.validateVector(inputVector);
return cosinesimilOptimized(inputVector, docValues.getValue(), queryVectorMagnitude.floatValue());
}

/**
* This method calculates cosine similarity
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ static_import {
float l1Norm(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil
float innerProduct(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil
float cosineSimilarity(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil
float cosineSimilarity(List, org.opensearch.knn.index.KNNVectorScriptDocValues, Number) from_class org.opensearch.knn.plugin.script.KNNScoringUtil
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,6 @@ public void testCosineSimilScoringFunction() {
assertEquals(expectedScore, actualScore, 0.0001);
}

public void testCosineSimilOptimizedScoringFunction() {
float[] queryVector = { 1.0f, 1.0f, 1.0f };
float[] inputVector = { 4.0f, 4.0f, 4.0f };
float queryVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(queryVector);
float inputVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(inputVector);
float dotProduct = 12.0f;
float expectedScore = (float) (dotProduct / (Math.sqrt(queryVectorMagnitude * inputVectorMagnitude)));

Float actualScore = KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, queryVectorMagnitude);
assertEquals(expectedScore, actualScore, 0.0001);
}

public void testGetInvalidVectorMagnitudeSquared() {
float[] queryVector = null;
// vector cannot be null
Expand All @@ -92,24 +80,12 @@ public void testCosineSimilQueryVectorZeroMagnitude() {
assertEquals(0, KNNScoringUtil.cosinesimil(queryVector, inputVector), 0.00001);
}

public void testCosineSimilOptimizedQueryVectorZeroMagnitude() {
float[] inputVector = { 4.0f, 4.0f };
float[] queryVector = { 0, 0 };
assertTrue(0 == KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, 0.0f));
}

public void testWrongDimensionCosineSimilScoringFunction() {
float[] queryVector = { 1.0f, 1.0f };
float[] inputVector = { 4.0f, 4.0f, 4.0f };
expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.cosinesimil(queryVector, inputVector));
}

public void testWrongDimensionCosineSimilOPtimizedScoringFunction() {
float[] queryVector = { 1.0f, 1.0f };
float[] inputVector = { 4.0f, 4.0f, 4.0f };
expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, 1.0f));
}

public void testBitHammingDistance_BitSet() {
BigInteger bigInteger1 = new BigInteger("4", 16);
BigInteger bigInteger2 = new BigInteger("32278", 16);
Expand Down Expand Up @@ -230,17 +206,6 @@ public void testZeroVectorFailsCosineSimilarity() throws IOException {
dataset.close();
}

public void testCosineSimilarityOptimizedScoringFunction() throws IOException {
List<Number> queryVector = getTestQueryVector();
TestKNNScriptDocValues dataset = new TestKNNScriptDocValues();
dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name");
KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name");
scriptDocValues.setNextDocId(0);
Float actualScore = KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f);
assertEquals(1.0f, actualScore, 0.0001);
dataset.close();
}

public void testScriptDocValuesFailsCosineSimilarityOptimized() throws IOException {
List<Number> queryVector = getTestQueryVector();
TestKNNScriptDocValues dataset = new TestKNNScriptDocValues();
Expand All @@ -249,24 +214,6 @@ public void testScriptDocValuesFailsCosineSimilarityOptimized() throws IOExcepti
dataset.close();
}

public void testZeroVectorFailsCosineSimilarityOptimized() throws IOException {
List<Number> queryVector = getTestZeroVector();
TestKNNScriptDocValues dataset = new TestKNNScriptDocValues();
dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name");
KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name");
scriptDocValues.setNextDocId(0);

IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f)
);
assertEquals(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", SpaceType.COSINESIMIL.getValue()),
exception.getMessage()
);
dataset.close();
}

class TestKNNScriptDocValues {
private KNNVectorScriptDocValues scriptDocValues;
private Directory directory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,15 @@ public void testCosineSimilarityScriptScoreWithNumericField() throws Exception {

// test fails without size check before executing method
public void testCosineSimilarityNormalizedScriptScoreFails() throws Exception {
String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'], 3.0f)", FIELD_NAME);
String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME);
Request request = buildPainlessScoreScriptRequest(source, 3, getCosineTestData());
addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000);
expectThrows(ResponseException.class, () -> client().performRequest(request));
deleteKNNIndex(INDEX_NAME);
}

public void testCosineSimilarityNormalizedScriptScore() throws Exception {
String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'], 3.0f)", FIELD_NAME);
String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME);
Request request = buildPainlessScoreScriptRequest(source, 3, getCosineTestData());
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
Expand All @@ -366,11 +366,7 @@ public void testCosineSimilarityNormalizedScriptScore() throws Exception {
}

public void testCosineSimilarityNormalizedScriptScoreWithNumericField() throws Exception {
String source = String.format(
"doc['%s'].size() == 0 ? 0 : 1 + cosineSimilarity([2.0f, -2.0f], doc['%s'], 3.0f)",
FIELD_NAME,
FIELD_NAME
);
String source = String.format("doc['%s'].size() == 0 ? 0 : 1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME, FIELD_NAME);
Request request = buildPainlessScoreScriptRequest(source, 3, getCosineTestData());
addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000);
Response response = client().performRequest(request);
Expand Down

0 comments on commit f872d83

Please sign in to comment.