From f872d8389683186c9ff64f6a65fd77f170f4a47d Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 16 May 2024 12:47:50 -0700 Subject: [PATCH] Remove cosineSimilOptimized Signed-off-by: Ryan Bogan --- .../knn/plugin/script/KNNScoringSpace.java | 4 +- .../knn/plugin/script/KNNScoringUtil.java | 48 ----------------- .../knn/plugin/script/knn_allowlist.txt | 1 - .../plugin/script/KNNScoringUtilTests.java | 53 ------------------- .../knn/plugin/script/PainlessScriptIT.java | 10 ++-- 5 files changed, 4 insertions(+), 112 deletions(-) diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 8105539ba..813a5f5d7 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -20,7 +20,6 @@ import java.util.Map; import java.util.function.BiFunction; -import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isLongFieldType; @@ -102,8 +101,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); SpaceType.COSINESIMIL.validateVector(processedQuery); - float qVectorSquaredMagnitude = getVectorMagnitudeSquared(this.processedQuery); - this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); + this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimil(q, v); } public ScoreScript getScoreScript( diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 84e986faa..b92b4a8bf 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -86,54 +86,6 @@ public static float l2Squared(List queryVector, KNNVectorScriptDocValues return l2Squared(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } - /** - * This method can be used script to avoid repeated calculation of normalization - * for query vector for each filtered documents - * - * @param queryVector query vector - * @param inputVector input vector - * @param normQueryVector normalized query vector value. - * @return cosine score - */ - public static float cosinesimilOptimized(float[] queryVector, float[] inputVector, float normQueryVector) { - requireEqualDimension(queryVector, inputVector); - float dotProduct = 0.0f; - float normInputVector = 0.0f; - for (int i = 0; i < queryVector.length; i++) { - dotProduct += queryVector[i] * inputVector[i]; - normInputVector += inputVector[i] * inputVector[i]; - } - float normalizedProduct = normQueryVector * normInputVector; - if (normalizedProduct == 0) { - logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end"); - return 0.0f; - } - return (float) (dotProduct / (Math.sqrt(normalizedProduct))); - } - - /** - * Allowlisted cosineSimilarity method that can be used in a script to avoid repeated - * calculation of normalization for the query vector. - * Example: - * "script": { - * "source": "cosineSimilarity(params.query_vector, docs[field], 1.0) ", - * "params": { - * "query_vector": [1, 2, 3.4], - * "field": "my_dense_vector" - * } - * } - * - * @param queryVector query vector - * @param docValues script doc values - * @param queryVectorMagnitude the magnitude of the query vector. - * @return cosine score - */ - public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { - float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); - SpaceType.COSINESIMIL.validateVector(inputVector); - return cosinesimilOptimized(inputVector, docValues.getValue(), queryVectorMagnitude.floatValue()); - } - /** * This method calculates cosine similarity * diff --git a/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt b/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt index 6b6e6434e..d8fae4739 100644 --- a/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt +++ b/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt @@ -12,5 +12,4 @@ static_import { float l1Norm(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil float innerProduct(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil float cosineSimilarity(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil - float cosineSimilarity(List, org.opensearch.knn.index.KNNVectorScriptDocValues, Number) from_class org.opensearch.knn.plugin.script.KNNScoringUtil } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 8c43a4acf..1a8112d49 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -63,18 +63,6 @@ public void testCosineSimilScoringFunction() { assertEquals(expectedScore, actualScore, 0.0001); } - public void testCosineSimilOptimizedScoringFunction() { - float[] queryVector = { 1.0f, 1.0f, 1.0f }; - float[] inputVector = { 4.0f, 4.0f, 4.0f }; - float queryVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(queryVector); - float inputVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(inputVector); - float dotProduct = 12.0f; - float expectedScore = (float) (dotProduct / (Math.sqrt(queryVectorMagnitude * inputVectorMagnitude))); - - Float actualScore = KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, queryVectorMagnitude); - assertEquals(expectedScore, actualScore, 0.0001); - } - public void testGetInvalidVectorMagnitudeSquared() { float[] queryVector = null; // vector cannot be null @@ -92,24 +80,12 @@ public void testCosineSimilQueryVectorZeroMagnitude() { assertEquals(0, KNNScoringUtil.cosinesimil(queryVector, inputVector), 0.00001); } - public void testCosineSimilOptimizedQueryVectorZeroMagnitude() { - float[] inputVector = { 4.0f, 4.0f }; - float[] queryVector = { 0, 0 }; - assertTrue(0 == KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, 0.0f)); - } - public void testWrongDimensionCosineSimilScoringFunction() { float[] queryVector = { 1.0f, 1.0f }; float[] inputVector = { 4.0f, 4.0f, 4.0f }; expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.cosinesimil(queryVector, inputVector)); } - public void testWrongDimensionCosineSimilOPtimizedScoringFunction() { - float[] queryVector = { 1.0f, 1.0f }; - float[] inputVector = { 4.0f, 4.0f, 4.0f }; - expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, 1.0f)); - } - public void testBitHammingDistance_BitSet() { BigInteger bigInteger1 = new BigInteger("4", 16); BigInteger bigInteger2 = new BigInteger("32278", 16); @@ -230,17 +206,6 @@ public void testZeroVectorFailsCosineSimilarity() throws IOException { dataset.close(); } - public void testCosineSimilarityOptimizedScoringFunction() throws IOException { - List queryVector = getTestQueryVector(); - TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); - KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); - scriptDocValues.setNextDocId(0); - Float actualScore = KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f); - assertEquals(1.0f, actualScore, 0.0001); - dataset.close(); - } - public void testScriptDocValuesFailsCosineSimilarityOptimized() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); @@ -249,24 +214,6 @@ public void testScriptDocValuesFailsCosineSimilarityOptimized() throws IOExcepti dataset.close(); } - public void testZeroVectorFailsCosineSimilarityOptimized() throws IOException { - List queryVector = getTestZeroVector(); - TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); - KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); - scriptDocValues.setNextDocId(0); - - IllegalArgumentException exception = expectThrows( - IllegalArgumentException.class, - () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f) - ); - assertEquals( - String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", SpaceType.COSINESIMIL.getValue()), - exception.getMessage() - ); - dataset.close(); - } - class TestKNNScriptDocValues { private KNNVectorScriptDocValues scriptDocValues; private Directory directory; diff --git a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java index 0315c47c5..58f73c4a3 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java @@ -342,7 +342,7 @@ public void testCosineSimilarityScriptScoreWithNumericField() throws Exception { // test fails without size check before executing method public void testCosineSimilarityNormalizedScriptScoreFails() throws Exception { - String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'], 3.0f)", FIELD_NAME); + String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME); Request request = buildPainlessScoreScriptRequest(source, 3, getCosineTestData()); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); expectThrows(ResponseException.class, () -> client().performRequest(request)); @@ -350,7 +350,7 @@ public void testCosineSimilarityNormalizedScriptScoreFails() throws Exception { } public void testCosineSimilarityNormalizedScriptScore() throws Exception { - String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'], 3.0f)", FIELD_NAME); + String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME); Request request = buildPainlessScoreScriptRequest(source, 3, getCosineTestData()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -366,11 +366,7 @@ public void testCosineSimilarityNormalizedScriptScore() throws Exception { } public void testCosineSimilarityNormalizedScriptScoreWithNumericField() throws Exception { - String source = String.format( - "doc['%s'].size() == 0 ? 0 : 1 + cosineSimilarity([2.0f, -2.0f], doc['%s'], 3.0f)", - FIELD_NAME, - FIELD_NAME - ); + String source = String.format("doc['%s'].size() == 0 ? 0 : 1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME, FIELD_NAME); Request request = buildPainlessScoreScriptRequest(source, 3, getCosineTestData()); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); Response response = client().performRequest(request);