-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CompressionLevel Calculation for PQ (#2200)
Currently, for product quantization, we set the calculated compression level to NOT_CONFIGURED. The main issue with this is that if a user sets up a disk-based index with PQ, no re-scoring will happen by default. This change adds the calculation so that the proper re-scoring will happen. The formula is fairly straightforward => actual compression = (d * 32) / (m * code_size). Then, we round to the neareste compression level (because we only support discrete compression levels). One small issue with this is that if PQ is configured to have compression > 64x, the value will be 64x. Functionally, the only issue will be that we may not be as aggressive on oversampling for on disk mode. Signed-off-by: John Mazanec <[email protected]> (cherry picked from commit 228aead)
- Loading branch information
1 parent
75f9a18
commit ac066ad
Showing
11 changed files
with
197 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.engine.faiss; | ||
|
||
import org.opensearch.common.ValidationException; | ||
import org.opensearch.knn.index.engine.Encoder; | ||
import org.opensearch.knn.index.engine.KNNMethodConfigContext; | ||
import org.opensearch.knn.index.engine.MethodComponentContext; | ||
import org.opensearch.knn.index.mapper.CompressionLevel; | ||
|
||
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; | ||
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; | ||
|
||
/** | ||
* Abstract class for Faiss PQ encoders. This class provides the common logic for product quantization based encoders | ||
*/ | ||
public abstract class AbstractFaissPQEncoder implements Encoder { | ||
|
||
@Override | ||
public CompressionLevel calculateCompressionLevel( | ||
MethodComponentContext methodComponentContext, | ||
KNNMethodConfigContext knnMethodConfigContext | ||
) { | ||
// Roughly speaking, PQ can be configured to produce a lot of different compression levels. The "m" parameter | ||
// specifies how many sub-vectors to break the vector up into, and then the "code_size" represents the number | ||
// of bits to encode each subvector. Thus, a d-dimensional vector of float32s goes from | ||
// d*32 -> (m)*code_size bits. So if we want (d*32)/(m*code_size) will be the compression level. | ||
// | ||
// Example: | ||
// d=768, m=384, code_size=8 | ||
// (768*32)/(384*8) = 8x (i.e. 24,576 vs. 3,072). | ||
// | ||
// Because of this variability, we will need to properly round to one of the supported values. | ||
if (methodComponentContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_M) == false | ||
|| methodComponentContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE) == false) { | ||
return CompressionLevel.NOT_CONFIGURED; | ||
} | ||
|
||
// Map the number of bits passed in, back to the compression level | ||
Object value = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_M); | ||
ValidationException validationException = getMethodComponent().getParameters() | ||
.get(ENCODER_PARAMETER_PQ_M) | ||
.validate(value, knnMethodConfigContext); | ||
if (validationException != null) { | ||
throw validationException; | ||
} | ||
Integer m = (Integer) value; | ||
value = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE); | ||
validationException = getMethodComponent().getParameters() | ||
.get(ENCODER_PARAMETER_PQ_CODE_SIZE) | ||
.validate(value, knnMethodConfigContext); | ||
if (validationException != null) { | ||
throw validationException; | ||
} | ||
Integer codeSize = (Integer) value; | ||
int dimension = knnMethodConfigContext.getDimension(); | ||
|
||
float actualCompression = ((float) dimension * 32) / (m * codeSize); | ||
|
||
if (actualCompression < 2.0f) { | ||
return CompressionLevel.x1; | ||
} | ||
|
||
if (actualCompression < 4.0f) { | ||
return CompressionLevel.x2; | ||
} | ||
|
||
if (actualCompression < 8.0f) { | ||
return CompressionLevel.x4; | ||
} | ||
|
||
if (actualCompression < 16.0f) { | ||
return CompressionLevel.x8; | ||
} | ||
|
||
if (actualCompression < 32.0f) { | ||
return CompressionLevel.x16; | ||
} | ||
|
||
if (actualCompression < 64.0f) { | ||
return CompressionLevel.x32; | ||
} | ||
|
||
// TODO: The problem is that the theoretical compression level of PQ can be in the thousands. Thus, Im not sure | ||
// it makes sense to have an enum all the way up to that value. So, for now, we will just return the max | ||
// compression | ||
return CompressionLevel.MAX_COMPRESSION_LEVEL; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
79 changes: 79 additions & 0 deletions
79
src/test/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoderTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.engine.faiss; | ||
|
||
import lombok.SneakyThrows; | ||
import org.opensearch.knn.KNNTestCase; | ||
import org.opensearch.knn.index.engine.Encoder; | ||
import org.opensearch.knn.index.engine.KNNMethodConfigContext; | ||
import org.opensearch.knn.index.engine.MethodComponent; | ||
import org.opensearch.knn.index.engine.MethodComponentContext; | ||
import org.opensearch.knn.index.mapper.CompressionLevel; | ||
|
||
import java.util.Map; | ||
|
||
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; | ||
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; | ||
import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; | ||
|
||
public class AbstractFaissPQEncoderTests extends KNNTestCase { | ||
|
||
@SneakyThrows | ||
public void testCalculateCompressionLevel() { | ||
AbstractFaissPQEncoder encoder = new AbstractFaissPQEncoder() { | ||
@Override | ||
public MethodComponent getMethodComponent() { | ||
return FaissIVFPQEncoder.METHOD_COMPONENT; | ||
} | ||
}; | ||
|
||
// Compression formula is: | ||
// actual_compression = (d*32)/(m*code_size) and then round down to nearest: 1x, 2x, 4x, 8x, 16x, 32x | ||
|
||
// d=768 | ||
// m=2 | ||
// code_size=8 | ||
// actual_compression = (768*32)/(2*8) = 1,536x | ||
// expected_compression = Max compression level | ||
assertCompressionLevel(2, 8, 768, CompressionLevel.MAX_COMPRESSION_LEVEL, encoder); | ||
|
||
// d=32 | ||
// m=4 | ||
// code_size=16 | ||
// actual_compression = (32*32)/(4*16) = 16x | ||
// expected_compression = Max compression level | ||
assertCompressionLevel(4, 16, 32, CompressionLevel.x16, encoder); | ||
|
||
// d=1536 | ||
// m=768 | ||
// code_size=8 | ||
// actual_compression = (1536*32)/(768*8) = 8x | ||
// expected_compression = Max compression level | ||
assertCompressionLevel(768, 8, 1536, CompressionLevel.x8, encoder); | ||
|
||
// d=128 | ||
// m=128 | ||
// code_size=8 | ||
// actual_compression = (128*32)/(128*8) = 4x | ||
// expected_compression = Max compression level | ||
assertCompressionLevel(128, 8, 128, CompressionLevel.x4, encoder); | ||
} | ||
|
||
private void assertCompressionLevel(int m, int codeSize, int d, CompressionLevel expectedCompression, Encoder encoder) { | ||
assertEquals( | ||
expectedCompression, | ||
encoder.calculateCompressionLevel(generateMethodComponentContext(m, codeSize), generateKNNMethodConfigContext(d)) | ||
); | ||
} | ||
|
||
private MethodComponentContext generateMethodComponentContext(int m, int codeSize) { | ||
return new MethodComponentContext(ENCODER_PQ, Map.of(ENCODER_PARAMETER_PQ_M, m, ENCODER_PARAMETER_PQ_CODE_SIZE, codeSize)); | ||
} | ||
|
||
private KNNMethodConfigContext generateKNNMethodConfigContext(int dimension) { | ||
return KNNMethodConfigContext.builder().dimension(dimension).build(); | ||
} | ||
} |
16 changes: 0 additions & 16 deletions
16
src/test/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoderTests.java
This file was deleted.
Oops, something went wrong.
16 changes: 0 additions & 16 deletions
16
src/test/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoderTests.java
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters