From de6084ff2040ecc475405d203c35cb7903968f36 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Wed, 17 Jul 2024 16:55:02 -0700 Subject: [PATCH] Add painless script support for hamming with binary vector data type (#1839) Signed-off-by: Heemin Kim --- CHANGELOG.md | 1 + .../knn/plugin/script/KNNScoringUtil.java | 315 ++++++++++++------ .../knn/plugin/script/knn_allowlist.txt | 1 + .../knn/common/KNNValidationUtilTests.java | 2 +- .../plugin/script/KNNScoringUtilTests.java | 49 +++ 5 files changed, 263 insertions(+), 105 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29c0f1841..9091f17d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Adds dynamic query parameter nprobes [#1792](https://github.com/opensearch-project/k-NN/pull/1792) * Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781) * Add script scoring support for knn field with binary data type [#1826](https://github.com/opensearch-project/k-NN/pull/1826) +* Add painless script support for hamming with binary vector data type [#1839](https://github.com/opensearch-project/k-NN/pull/1839) ### Enhancements * Switch from byte stream to byte ref for serde [#1825](https://github.com/opensearch-project/k-NN/pull/1825) ### Bug Fixes diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 8493ea5bd..f61ae4349 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -7,6 +7,7 @@ import java.math.BigInteger; import java.util.List; +import java.util.Locale; import java.util.Objects; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -40,6 +41,52 @@ private static void requireEqualDimension(final float[] queryVector, final float } } + /** + * checks both query vector and input vector has equal dimension + * + * @param queryVector query vector + * @param inputVector input vector + * @throws IllegalArgumentException if query vector and input vector has different dimensions + */ + private static void requireEqualDimension(final byte[] queryVector, final byte[] inputVector) { + Objects.requireNonNull(queryVector); + Objects.requireNonNull(inputVector); + if (queryVector.length != inputVector.length) { + String errorMessage = String.format( + "query vector dimension mismatch. Expected: %d, Given: %d", + inputVector.length, + queryVector.length + ); + throw new IllegalArgumentException(errorMessage); + } + } + + private static void requireNonBinaryType(final String spaceName, final VectorDataType vectorDataType) { + if (VectorDataType.BINARY == vectorDataType) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Incompatible field_type for %s space. The data type should be either float or byte but got %s", + spaceName, + vectorDataType.getValue() + ) + ); + } + } + + private static void requireBinaryType(final String spaceName, final VectorDataType vectorDataType) { + if (VectorDataType.BINARY != vectorDataType) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Incompatible field_type for %s space. The data type should be binary but got %s", + spaceName, + vectorDataType.getValue() + ) + ); + } + } + /** * This method calculates L2 squared distance between query vector * and input vector @@ -52,13 +99,13 @@ public static float l2Squared(float[] queryVector, float[] inputVector) { return VectorUtil.squareDistance(queryVector, inputVector); } - private static float[] toFloat(List inputVector, VectorDataType vectorDataType) { + private static float[] toFloat(final List inputVector, final VectorDataType vectorDataType) { Objects.requireNonNull(inputVector); float[] value = new float[inputVector.size()]; int index = 0; for (final Number val : inputVector) { float floatValue = val.floatValue(); - if (VectorDataType.BYTE == vectorDataType) { + if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) { validateByteVectorValue(floatValue, vectorDataType); } value[index++] = floatValue; @@ -66,24 +113,35 @@ private static float[] toFloat(List inputVector, VectorDataType vectorDa return value; } + private static byte[] toByte(final List inputVector, final VectorDataType vectorDataType) { + Objects.requireNonNull(inputVector); + byte[] value = new byte[inputVector.size()]; + int index = 0; + for (final Number val : inputVector) { + float floatValue = val.floatValue(); + if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) { + validateByteVectorValue(floatValue, vectorDataType); + } + value[index++] = val.byteValue(); + } + return value; + } + /** - * Allowlisted l2Squared method for users to calculate L2 squared distance between query vector - * and document vectors - * Example - * "script": { - * "source": "1/(1 + l2Squared(params.query_vector, doc[params.field]))", - * "params": { - * "query_vector": [1, 2, 3.4], - * "field": "my_dense_vector" - * } - * } + * This method calculates cosine similarity * * @param queryVector query vector - * @param docValues script doc values - * @return L2 score + * @param inputVector input vector + * @return cosine score */ - public static float l2Squared(List queryVector, KNNVectorScriptDocValues docValues) { - return l2Squared(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + public static float cosinesimil(float[] queryVector, float[] inputVector) { + requireEqualDimension(queryVector, inputVector); + try { + return VectorUtil.cosine(queryVector, inputVector); + } catch (IllegalArgumentException | AssertionError e) { + logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end"); + return 0.0f; + } } /** @@ -111,68 +169,6 @@ public static float cosinesimilOptimized(float[] queryVector, float[] inputVecto return (float) (dotProduct / (Math.sqrt(normalizedProduct))); } - /** - * Allowlisted cosineSimilarity method that can be used in a script to avoid repeated - * calculation of normalization for the query vector. - * Example: - * "script": { - * "source": "cosineSimilarity(params.query_vector, docs[field], 1.0) ", - * "params": { - * "query_vector": [1, 2, 3.4], - * "field": "my_dense_vector" - * } - * } - * - * @param queryVector query vector - * @param docValues script doc values - * @param queryVectorMagnitude the magnitude of the query vector. - * @return cosine score - */ - public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { - float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); - SpaceType.COSINESIMIL.validateVector(inputVector); - return cosinesimilOptimized(inputVector, docValues.getValue(), queryVectorMagnitude.floatValue()); - } - - /** - * This method calculates cosine similarity - * - * @param queryVector query vector - * @param inputVector input vector - * @return cosine score - */ - public static float cosinesimil(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); - try { - return VectorUtil.cosine(queryVector, inputVector); - } catch (IllegalArgumentException | AssertionError e) { - logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end"); - return 0.0f; - } - } - - /** - * Allowlisted cosineSimilarity method for users to calculate cosine similarity between query vectors and - * document vectors - * Example: - * "script": { - * "source": "cosineSimilarity(params.query_vector, docs[field]) ", - * "params": { - * "query_vector": [1, 2, 3.4], - * "field": "my_dense_vector" - * } - * } - * - * @param queryVector query vector - * @param docValues script doc values - * @return cosine score - */ - public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues) { - float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); - SpaceType.COSINESIMIL.validateVector(inputVector); - return cosinesimil(inputVector, docValues.getValue()); - } - /** * This method calculates hamming distance on 2 BigIntegers * @@ -204,6 +200,7 @@ public static float calculateHammingBit(Long queryLong, Long inputLong) { * @return hamming distance */ public static float calculateHammingBit(byte[] queryVector, byte[] inputVector) { + requireEqualDimension(queryVector, inputVector); return VectorUtil.xorBitCount(queryVector, inputVector); } @@ -216,6 +213,7 @@ public static float calculateHammingBit(byte[] queryVector, byte[] inputVector) * @return L1 score */ public static float l1Norm(float[] queryVector, float[] inputVector) { + requireEqualDimension(queryVector, inputVector); float distance = 0; for (int i = 0; i < inputVector.length; i++) { float diff = queryVector[i] - inputVector[i]; @@ -224,26 +222,6 @@ public static float l1Norm(float[] queryVector, float[] inputVector) { return distance; } - /** - * Allowlisted l1distance method for users to calculate L1 distance between query vector - * and document vectors - * Example - * "script": { - * "source": "1/(1 + l1Norm(params.query_vector, doc[params.field]))", - * "params": { - * "query_vector": [1, 2, 3.4], - * "field": "my_dense_vector" - * } - * } - * - * @param queryVector query vector - * @param docValues script doc values - * @return L1 score - */ - public static float l1Norm(List queryVector, KNNVectorScriptDocValues docValues) { - return l1Norm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); - } - /** * This method calculates L-inf distance between query vector * and input vector @@ -253,6 +231,7 @@ public static float l1Norm(List queryVector, KNNVectorScriptDocValues do * @return L-inf score */ public static float lInfNorm(float[] queryVector, float[] inputVector) { + requireEqualDimension(queryVector, inputVector); float distance = 0; for (int i = 0; i < inputVector.length; i++) { float diff = queryVector[i] - inputVector[i]; @@ -261,6 +240,46 @@ public static float lInfNorm(float[] queryVector, float[] inputVector) { return distance; } + /** + * This method calculates dot product distance between query vector + * and input vector + * + * @param queryVector query vector + * @param inputVector input vector + * @return dot product score + */ + public static float innerProduct(float[] queryVector, float[] inputVector) { + requireEqualDimension(queryVector, inputVector); + return VectorUtil.dotProduct(queryVector, inputVector); + } + + /** + ********************************************************************************************* + * Functions to be used in painless script which is defined in knn_allowlist.txt + ********************************************************************************************* + */ + + /** + * Allowlisted l2Squared method for users to calculate L2 squared distance between query vector + * and document vectors + * Example + * "script": { + * "source": "1/(1 + l2Squared(params.query_vector, doc[params.field]))", + * "params": { + * "query_vector": [1, 2, 3.4], + * "field": "my_dense_vector" + * } + * } + * + * @param queryVector query vector + * @param docValues script doc values + * @return L2 score + */ + public static float l2Squared(List queryVector, KNNVectorScriptDocValues docValues) { + requireNonBinaryType("l2Squared", docValues.getVectorDataType()); + return l2Squared(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + } + /** * Allowlisted lInfNorm method for users to calculate L-inf distance between query vector * and document vectors @@ -278,19 +297,29 @@ public static float lInfNorm(float[] queryVector, float[] inputVector) { * @return L-inf score */ public static float lInfNorm(List queryVector, KNNVectorScriptDocValues docValues) { + requireNonBinaryType("lInfNorm", docValues.getVectorDataType()); return lInfNorm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } /** - * This method calculates dot product distance between query vector - * and input vector + * Allowlisted l1distance method for users to calculate L1 distance between query vector + * and document vectors + * Example + * "script": { + * "source": "1/(1 + l1Norm(params.query_vector, doc[params.field]))", + * "params": { + * "query_vector": [1, 2, 3.4], + * "field": "my_dense_vector" + * } + * } * * @param queryVector query vector - * @param inputVector input vector - * @return dot product score + * @param docValues script doc values + * @return L1 score */ - public static float innerProduct(float[] queryVector, float[] inputVector) { - return VectorUtil.dotProduct(queryVector, inputVector); + public static float l1Norm(List queryVector, KNNVectorScriptDocValues docValues) { + requireNonBinaryType("l1Norm", docValues.getVectorDataType()); + return l1Norm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } /** @@ -310,6 +339,84 @@ public static float innerProduct(float[] queryVector, float[] inputVector) { * @return inner product score */ public static float innerProduct(List queryVector, KNNVectorScriptDocValues docValues) { + requireNonBinaryType("innerProduct", docValues.getVectorDataType()); return innerProduct(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); } + + /** + * Allowlisted cosineSimilarity method for users to calculate cosine similarity between query vectors and + * document vectors + * Example: + * "script": { + * "source": "cosineSimilarity(params.query_vector, docs[field]) ", + * "params": { + * "query_vector": [1, 2, 3.4], + * "field": "my_dense_vector" + * } + * } + * + * @param queryVector query vector + * @param docValues script doc values + * @return cosine score + */ + public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues) { + requireNonBinaryType("cosineSimilarity", docValues.getVectorDataType()); + float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); + SpaceType.COSINESIMIL.validateVector(inputVector); + return cosinesimil(inputVector, docValues.getValue()); + } + + /** + * Allowlisted cosineSimilarity method that can be used in a script to avoid repeated + * calculation of normalization for the query vector. + * Example: + * "script": { + * "source": "cosineSimilarity(params.query_vector, docs[field], 1.0) ", + * "params": { + * "query_vector": [1, 2, 3.4], + * "field": "my_dense_vector" + * } + * } + * + * @param queryVector query vector + * @param docValues script doc values + * @param queryVectorMagnitude the magnitude of the query vector. + * @return cosine score + */ + public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { + requireNonBinaryType("cosineSimilarity", docValues.getVectorDataType()); + float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); + SpaceType.COSINESIMIL.validateVector(inputVector); + return cosinesimilOptimized(inputVector, docValues.getValue(), queryVectorMagnitude.floatValue()); + } + + /** + * Allowlisted hamming method that can be used in a script to avoid repeated + * calculation of normalization for the query vector. + * Example: + * "script": { + * "source": "hamming(params.query_vector, docs[field]) ", + * "params": { + * "query_vector": [1, 2], + * "field": "my_dense_vector" + * } + * } + * + * @param queryVector query vector + * @param docValues script doc values + * @return hamming score + */ + public static float hamming(List queryVector, KNNVectorScriptDocValues docValues) { + requireBinaryType("hamming", docValues.getVectorDataType()); + byte[] queryVectorInByte = toByte(queryVector, docValues.getVectorDataType()); + + // TODO Optimization need be done for doc value to return byte[] instead of float[] + float[] docVectorInFloat = docValues.getValue(); + byte[] docVectorInByte = new byte[docVectorInFloat.length]; + for (int i = 0; i < docVectorInByte.length; i++) { + docVectorInByte[i] = (byte) docVectorInFloat[i]; + } + + return calculateHammingBit(queryVectorInByte, docVectorInByte); + } } diff --git a/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt b/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt index 6b6e6434e..388cdda8a 100644 --- a/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt +++ b/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt @@ -13,4 +13,5 @@ static_import { float innerProduct(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil float cosineSimilarity(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil float cosineSimilarity(List, org.opensearch.knn.index.KNNVectorScriptDocValues, Number) from_class org.opensearch.knn.plugin.script.KNNScoringUtil + float hamming(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil } diff --git a/src/test/java/org/opensearch/knn/common/KNNValidationUtilTests.java b/src/test/java/org/opensearch/knn/common/KNNValidationUtilTests.java index 56e462fc1..4b4337880 100644 --- a/src/test/java/org/opensearch/knn/common/KNNValidationUtilTests.java +++ b/src/test/java/org/opensearch/knn/common/KNNValidationUtilTests.java @@ -12,7 +12,7 @@ public class KNNValidationUtilTests extends KNNTestCase { public void testValidateVectorDimension_whenBinary_thenVectorSizeShouldBeEightTimesLarger() { - int vectorLength = randomInt(100); + int vectorLength = randomInt(100) + 1; Exception ex = expectThrows( IllegalArgumentException.class, () -> KNNValidationUtil.validateVectorDimension(vectorLength, vectorLength, VectorDataType.BINARY) diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index a8d37b6c5..2cc20c8f9 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.script; +import java.util.Arrays; import java.util.Locale; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNVectorScriptDocValues; @@ -25,6 +26,10 @@ import java.io.IOException; import java.math.BigInteger; import java.util.List; +import java.util.function.BiFunction; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class KNNScoringUtilTests extends KNNTestCase { @@ -273,6 +278,50 @@ public void testCalculateHammingBit_whenByte_thenSuccess() { assertEquals(10, KNNScoringUtil.calculateHammingBit(v1, v2), 0.001f); } + private void validateThrowExceptionOnGivenDataType( + final BiFunction, KNNVectorScriptDocValues, Float> func, + final VectorDataType dataType, + final String errorMsg + ) { + List queryVector = Arrays.asList(1, 2); + KNNVectorScriptDocValues docValues = mock(KNNVectorScriptDocValues.class); + when(docValues.getVectorDataType()).thenReturn(dataType); + Exception e = expectThrows(IllegalArgumentException.class, () -> func.apply(queryVector, docValues)); + assertTrue(e.getMessage().contains(errorMsg)); + } + + public void testLInfNorm_whenKNNVectorScriptDocValuesOfBinary_thenThrowException() { + validateThrowExceptionOnGivenDataType(KNNScoringUtil::lInfNorm, VectorDataType.BINARY, "should be either float or byte"); + } + + public void testL1Norm_whenKNNVectorScriptDocValuesOfBinary_thenThrowException() { + validateThrowExceptionOnGivenDataType(KNNScoringUtil::l1Norm, VectorDataType.BINARY, "should be either float or byte"); + } + + public void testInnerProduct_whenKNNVectorScriptDocValuesOfBinary_thenThrowException() { + validateThrowExceptionOnGivenDataType(KNNScoringUtil::innerProduct, VectorDataType.BINARY, "should be either float or byte"); + } + + public void testCosineSimilarity_whenKNNVectorScriptDocValuesOfBinary_thenThrowException() { + validateThrowExceptionOnGivenDataType(KNNScoringUtil::cosineSimilarity, VectorDataType.BINARY, "should be either float or byte"); + } + + public void testHamming_whenKNNVectorScriptDocValuesOfNonBinary_thenThrowException() { + validateThrowExceptionOnGivenDataType(KNNScoringUtil::hamming, VectorDataType.FLOAT, "should be binary"); + } + + public void testHamming_whenKNNVectorScriptDocValuesOfBinary_thenSuccess() { + byte[] b1 = { 1, 16, -128 }; // 0000 0001, 0001 0000, 1000 0000 + byte[] b2 = { 2, 17, -1 }; // 0000 0010, 0001 0001, 1111 1111 + float[] f1 = { 1, 16, -128 }; // 0000 0001, 0001 0000, 1000 0000 + float[] f2 = { 2, 17, -1 }; // 0000 0010, 0001 0001, 1111 1111 + List queryVector = Arrays.asList(f1[0], f1[1], f1[2]); + KNNVectorScriptDocValues docValues = mock(KNNVectorScriptDocValues.class); + when(docValues.getVectorDataType()).thenReturn(VectorDataType.BINARY); + when(docValues.getValue()).thenReturn(f2); + assertEquals(KNNScoringUtil.calculateHammingBit(b1, b2), KNNScoringUtil.hamming(queryVector, docValues), 0.01f); + } + class TestKNNScriptDocValues { private KNNVectorScriptDocValues scriptDocValues; private Directory directory;