diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java new file mode 100644 index 0000000000..a894d8ed63 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.faiss; + +import org.opensearch.common.ValidationException; +import org.opensearch.knn.index.engine.Encoder; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.mapper.CompressionLevel; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; + +/** + * Abstract class for Faiss PQ encoders. This class provides the common logic for product quantization based encoders + */ +public abstract class AbstractFaissPQEncoder implements Encoder { + + @Override + public CompressionLevel calculateCompressionLevel( + MethodComponentContext methodComponentContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + // Roughly speaking, PQ can be configured to produce a lot of different compression levels. The "m" parameter + // specifies how many sub-vectors to break the vector up into, and then the "code_size" represents the number + // of bits to encode each subvector. Thus, a d-dimensional vector of float32s goes from + // d*32 -> (m)*code_size bits. So if we want (d*32)/(m*code_size) will be the compression level. + // + // Example: + // d=768, m=384, code_size=8 + // (768*32)/(384*8) = 8x (i.e. 24,576 vs. 3,072). + // + // Because of this variability, we will need to properly round to one of the supported values. + if (methodComponentContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_M) == false + || methodComponentContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE) == false) { + return CompressionLevel.NOT_CONFIGURED; + } + + // Map the number of bits passed in, back to the compression level + Object value = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_M); + ValidationException validationException = getMethodComponent().getParameters() + .get(ENCODER_PARAMETER_PQ_M) + .validate(value, knnMethodConfigContext); + if (validationException != null) { + throw validationException; + } + Integer m = (Integer) value; + value = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE); + validationException = getMethodComponent().getParameters() + .get(ENCODER_PARAMETER_PQ_CODE_SIZE) + .validate(value, knnMethodConfigContext); + if (validationException != null) { + throw validationException; + } + Integer codeSize = (Integer) value; + int dimension = knnMethodConfigContext.getDimension(); + + float actualCompression = ((float) dimension * 32) / (m * codeSize); + + if (actualCompression < 2.0f) { + return CompressionLevel.x1; + } + + if (actualCompression < 4.0f) { + return CompressionLevel.x2; + } + + if (actualCompression < 8.0f) { + return CompressionLevel.x4; + } + + if (actualCompression < 16.0f) { + return CompressionLevel.x8; + } + + if (actualCompression < 32.0f) { + return CompressionLevel.x16; + } + + if (actualCompression < 64.0f) { + return CompressionLevel.x32; + } + + // TODO: The problem is that the theoretical compression level of PQ can be in the thousands. Thus, Im not sure + // it makes sense to have an enum all the way up to that value. So, for now, we will just return the max + // compression + return CompressionLevel.MAX_COMPRESSION_LEVEL; + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java index 6750d84ed9..c22a9dec76 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java @@ -8,12 +8,8 @@ import com.google.common.collect.ImmutableSet; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.Encoder; -import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.MethodComponent; -import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; -import org.opensearch.knn.index.mapper.CompressionLevel; import java.util.Objects; import java.util.Set; @@ -30,7 +26,7 @@ * Faiss HNSW PQ encoder. Right now, the implementations are slightly different during validation between this an * {@link FaissIVFPQEncoder}. Hence, they are separate classes. */ -public class FaissHNSWPQEncoder implements Encoder { +public class FaissHNSWPQEncoder extends AbstractFaissPQEncoder { private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT); @@ -72,13 +68,4 @@ public class FaissHNSWPQEncoder implements Encoder { public MethodComponent getMethodComponent() { return METHOD_COMPONENT; } - - @Override - public CompressionLevel calculateCompressionLevel( - MethodComponentContext methodComponentContext, - KNNMethodConfigContext knnMethodConfigContext - ) { - // TODO: For now, not supported out of the box - return CompressionLevel.NOT_CONFIGURED; - } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java index 8d54548bd4..d6cfd9a8c2 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java @@ -8,12 +8,8 @@ import com.google.common.collect.ImmutableSet; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.Encoder; -import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.MethodComponent; -import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; -import org.opensearch.knn.index.mapper.CompressionLevel; import java.util.Set; @@ -30,7 +26,7 @@ * Faiss IVF PQ encoder. Right now, the implementations are slightly different during validation between this an * {@link FaissHNSWPQEncoder}. Hence, they are separate classes. */ -public class FaissIVFPQEncoder implements Encoder { +public class FaissIVFPQEncoder extends AbstractFaissPQEncoder { private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT); @@ -93,13 +89,4 @@ public class FaissIVFPQEncoder implements Encoder { public MethodComponent getMethodComponent() { return METHOD_COMPONENT; } - - @Override - public CompressionLevel calculateCompressionLevel( - MethodComponentContext methodComponentContext, - KNNMethodConfigContext knnMethodConfigContext - ) { - // TODO: For now, not supported out of the box - return CompressionLevel.NOT_CONFIGURED; - } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java index ab583a2e08..47db31f6d0 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java +++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java @@ -29,6 +29,8 @@ public enum CompressionLevel { x16(16, "16x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)), x32(32, "32x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)); + public static final CompressionLevel MAX_COMPRESSION_LEVEL = CompressionLevel.x32; + // Internally, an empty string is easier to deal with them null. However, from the mapping, // we do not want users to pass in the empty string and instead want null. So we make the conversion here public static final String[] NAMES_ARRAY = new String[] { diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoderTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoderTests.java new file mode 100644 index 0000000000..2492b6ad79 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoderTests.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.faiss; + +import lombok.SneakyThrows; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.engine.Encoder; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.mapper.CompressionLevel; + +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; + +public class AbstractFaissPQEncoderTests extends KNNTestCase { + + @SneakyThrows + public void testCalculateCompressionLevel() { + AbstractFaissPQEncoder encoder = randomBoolean() ? new FaissHNSWPQEncoder() : new FaissIVFPQEncoder(); + + // Compression formula is: + // actual_compression = (d*32)/(m*code_size) and then round down to nearest: 1x, 2x, 4x, 8x, 16x, 32x + + // d=768 + // m=2 + // code_size=8 + // actual_compression = (768*32)/(2*8) = 1,536x + // expected_compression = Max compression level + assertCompressionLevel(2, 8, 768, CompressionLevel.MAX_COMPRESSION_LEVEL, encoder); + + // d=32 + // m=4 + // code_size=16 + // actual_compression = (32*32)/(4*16) = 16x + // expected_compression = Max compression level + assertCompressionLevel(4, 16, 32, CompressionLevel.x16, encoder); + + // d=1536 + // m=768 + // code_size=8 + // actual_compression = (1536*32)/(768*8) = 8x + // expected_compression = Max compression level + assertCompressionLevel(768, 8, 1536, CompressionLevel.x8, encoder); + + // d=128 + // m=128 + // code_size=8 + // actual_compression = (128*32)/(128*8) = 4x + // expected_compression = Max compression level + assertCompressionLevel(128, 8, 128, CompressionLevel.x4, encoder); + } + + private void assertCompressionLevel(int m, int codeSize, int d, CompressionLevel expectedCompression, Encoder encoder) { + assertEquals( + expectedCompression, + encoder.calculateCompressionLevel(generateMethodComponentContext(m, codeSize), generateKNNMethodConfigContext(d)) + ); + } + + private MethodComponentContext generateMethodComponentContext(int m, int codeSize) { + return new MethodComponentContext(ENCODER_PQ, Map.of(ENCODER_PARAMETER_PQ_M, m, ENCODER_PARAMETER_PQ_CODE_SIZE, codeSize)); + } + + private KNNMethodConfigContext generateKNNMethodConfigContext(int dimension) { + return KNNMethodConfigContext.builder().dimension(dimension).build(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoderTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoderTests.java deleted file mode 100644 index 3f7dd9dcd2..0000000000 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoderTests.java +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.engine.faiss; - -import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.mapper.CompressionLevel; - -public class FaissHNSWPQEncoderTests extends KNNTestCase { - public void testCalculateCompressionLevel() { - FaissHNSWPQEncoder encoder = new FaissHNSWPQEncoder(); - assertEquals(CompressionLevel.NOT_CONFIGURED, encoder.calculateCompressionLevel(null, null)); - } -} diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoderTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoderTests.java deleted file mode 100644 index 35b7a64abb..0000000000 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoderTests.java +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.engine.faiss; - -import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.mapper.CompressionLevel; - -public class FaissIVFPQEncoderTests extends KNNTestCase { - public void testCalculateCompressionLevel() { - FaissIVFPQEncoder encoder = new FaissIVFPQEncoder(); - assertEquals(CompressionLevel.NOT_CONFIGURED, encoder.calculateCompressionLevel(null, null)); - } -}