forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
QuantizationFramework Changes
Showing
32 changed files
with
1,165 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 45 additions & 0 deletions
45
src/main/java/org/opensearch/knn/quantization/QuantizationManager.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization; | ||
|
||
import org.opensearch.knn.quantization.factory.QuantizerFactory; | ||
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; | ||
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
import org.opensearch.knn.quantization.models.requests.SamplingTrainingRequest; | ||
import org.opensearch.knn.quantization.models.requests.TrainingRequest; | ||
import org.opensearch.knn.quantization.quantizer.Quantizer; | ||
import org.opensearch.knn.quantization.sampler.Sampler; | ||
import org.opensearch.knn.quantization.sampler.SamplingFactory; | ||
|
||
public class QuantizationManager { | ||
private static QuantizationManager instance; | ||
|
||
private QuantizationManager() {} | ||
|
||
public static QuantizationManager getInstance() { | ||
if (instance == null) { | ||
instance = new QuantizationManager(); | ||
} | ||
return instance; | ||
} | ||
public <T, R> QuantizationState train(TrainingRequest<T> trainingRequest) { | ||
Quantizer<T, R> quantizer = (Quantizer<T, R>) getQuantizer(trainingRequest.getParams()); | ||
int sampleSize = quantizer.getSamplingSize(); | ||
Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); | ||
TrainingRequest<T> sampledRequest = new SamplingTrainingRequest<>(trainingRequest, sampler, sampleSize); | ||
return quantizer.train(sampledRequest); | ||
} | ||
public Quantizer<?, ?> getQuantizer(QuantizationParams params) { | ||
return QuantizerFactory.getQuantizer(params); | ||
} | ||
} | ||
|
17 changes: 17 additions & 0 deletions
17
src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.enums; | ||
|
||
public enum QuantizationType { | ||
SPACE_QUANTIZATION, | ||
VALUE_QUANTIZATION, | ||
} |
21 changes: 21 additions & 0 deletions
21
src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.enums; | ||
|
||
public enum SQTypes { | ||
FP16, | ||
INT8, | ||
INT6, | ||
INT4, | ||
ONE_BIT, | ||
TWO_BIT | ||
} |
17 changes: 17 additions & 0 deletions
17
src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.enums; | ||
|
||
public enum ValueQuantizationType { | ||
SQ | ||
} | ||
|
34 changes: 34 additions & 0 deletions
34
src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.factory; | ||
|
||
import org.opensearch.knn.quantization.enums.SQTypes; | ||
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; | ||
import org.opensearch.knn.quantization.models.quantizationParams.SQParams; | ||
import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; | ||
import org.opensearch.knn.quantization.quantizer.Quantizer; | ||
|
||
public class QuantizerFactory { | ||
static { | ||
// Register all quantizers here | ||
QuantizerRegistry.register(SQParams.class, SQTypes.ONE_BIT.name(), OneBitScalarQuantizer::new); | ||
} | ||
|
||
public static Quantizer<?, ?> getQuantizer(QuantizationParams params) { | ||
if (params instanceof SQParams) { | ||
SQParams sqParams = (SQParams) params; | ||
return QuantizerRegistry.getQuantizer(params, sqParams.getSqType().name()); | ||
} | ||
// Add more cases for other quantization parameters here | ||
throw new IllegalArgumentException("Unsupported quantization parameters: " + params.getClass().getName()); | ||
} | ||
} |
39 changes: 39 additions & 0 deletions
39
src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.factory; | ||
|
||
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; | ||
import org.opensearch.knn.quantization.quantizer.Quantizer; | ||
|
||
import java.util.HashMap; | ||
import java.util.Map; | ||
import java.util.function.Supplier; | ||
|
||
public class QuantizerRegistry { | ||
private static final Map<Class<? extends QuantizationParams>, Map<String, Supplier<? extends Quantizer<?, ?>>>> registry = new HashMap<>(); | ||
|
||
public static <T extends QuantizationParams> void register(Class<T> paramClass, String typeIdentifier, Supplier<? extends Quantizer<?, ?>> quantizerSupplier) { | ||
registry.computeIfAbsent(paramClass, k -> new HashMap<>()).put(typeIdentifier, quantizerSupplier); | ||
} | ||
|
||
public static Quantizer<?, ?> getQuantizer(QuantizationParams params, String typeIdentifier) { | ||
Map<String, Supplier<? extends Quantizer<?, ?>>> typeMap = registry.get(params.getClass()); | ||
if (typeMap == null) { | ||
throw new IllegalArgumentException("No quantizer registered for parameters: " + params.getClass().getName()); | ||
} | ||
Supplier<? extends Quantizer<?, ?>> supplier = typeMap.get(typeIdentifier); | ||
if (supplier == null) { | ||
throw new IllegalArgumentException("No quantizer registered for type identifier: " + typeIdentifier); | ||
} | ||
return supplier.get(); | ||
} | ||
} |
19 changes: 19 additions & 0 deletions
19
...opensearch/knn/quantization/models/quantizationOutput/OneBitScalarQuantizationOutput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.models.quantizationOutput; | ||
|
||
public class OneBitScalarQuantizationOutput extends QuantizationOutput<byte[]> { | ||
|
||
public OneBitScalarQuantizationOutput(byte[] quantizedVector) { | ||
super(quantizedVector); | ||
} | ||
} |
24 changes: 24 additions & 0 deletions
24
...in/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.models.quantizationOutput; | ||
|
||
public abstract class QuantizationOutput<T> { | ||
private final T quantizedVector; | ||
|
||
public QuantizationOutput(T quantizedVector) { | ||
this.quantizedVector = quantizedVector; | ||
} | ||
|
||
public T getQuantizedVector() { | ||
return quantizedVector; | ||
} | ||
} |
27 changes: 27 additions & 0 deletions
27
...in/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.models.quantizationParams; | ||
|
||
import java.io.Serializable; | ||
import org.opensearch.knn.quantization.enums.QuantizationType; | ||
|
||
public abstract class QuantizationParams implements Serializable { | ||
private QuantizationType quantizationType; | ||
|
||
public QuantizationParams(QuantizationType quantizationType) { | ||
this.quantizationType = quantizationType; | ||
} | ||
|
||
public QuantizationType getQuantizationType() { | ||
return quantizationType; | ||
} | ||
} |
28 changes: 28 additions & 0 deletions
28
src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.models.quantizationParams; | ||
|
||
import org.opensearch.knn.quantization.enums.QuantizationType; | ||
import org.opensearch.knn.quantization.enums.SQTypes; | ||
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; | ||
|
||
public class SQParams extends QuantizationParams { | ||
private SQTypes sqType; | ||
|
||
public SQParams(SQTypes sqType) { | ||
super(QuantizationType.VALUE_QUANTIZATION); | ||
this.sqType = sqType; | ||
} | ||
public SQTypes getSqType() { | ||
return sqType; | ||
} | ||
} |
54 changes: 54 additions & 0 deletions
54
...g/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.models.quantizationState; | ||
|
||
|
||
import org.opensearch.knn.quantization.models.quantizationParams.SQParams; | ||
|
||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.ObjectOutputStream; | ||
import java.io.ByteArrayInputStream; | ||
import java.io.ObjectInputStream; | ||
|
||
public class OneBitScalarQuantizationState extends QuantizationState { | ||
private float[] mean; | ||
|
||
public OneBitScalarQuantizationState(SQParams quantizationParams, float[] floatArray) { | ||
super(quantizationParams); | ||
this.mean = floatArray; | ||
} | ||
|
||
public float[] getMean() { | ||
return mean; | ||
} | ||
|
||
@Override | ||
public byte[] toByteArray() throws IOException { | ||
byte[] parentBytes = super.toByteArray(); | ||
ByteArrayOutputStream bos = new ByteArrayOutputStream(); | ||
ObjectOutputStream out = new ObjectOutputStream(bos); | ||
out.write(parentBytes); | ||
out.writeObject(mean); | ||
out.flush(); | ||
return bos.toByteArray(); | ||
} | ||
|
||
public static OneBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException { | ||
ByteArrayInputStream bis = new ByteArrayInputStream(bytes); | ||
ObjectInputStream in = new ObjectInputStream(bis); | ||
QuantizationState parentState = (QuantizationState) in.readObject(); | ||
float[] floatArray = (float[]) in.readObject(); | ||
return new OneBitScalarQuantizationState((SQParams) parentState.getQuantizationParams(), floatArray); | ||
} | ||
} | ||
|
47 changes: 47 additions & 0 deletions
47
...main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.models.quantizationState; | ||
|
||
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; | ||
|
||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.ObjectOutputStream; | ||
import java.io.Serializable; | ||
import java.io.ByteArrayInputStream; | ||
import java.io.ObjectInputStream; | ||
|
||
public abstract class QuantizationState implements Serializable { | ||
private QuantizationParams quantizationParams; | ||
|
||
public QuantizationState(QuantizationParams quantizationParams) { | ||
this.quantizationParams = quantizationParams; | ||
} | ||
|
||
public QuantizationParams getQuantizationParams() { | ||
return quantizationParams; | ||
} | ||
|
||
public byte[] toByteArray() throws IOException { | ||
ByteArrayOutputStream bos = new ByteArrayOutputStream(); | ||
ObjectOutputStream out = new ObjectOutputStream(bos); | ||
out.writeObject(this); | ||
out.flush(); | ||
return bos.toByteArray(); | ||
} | ||
|
||
public static QuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException { | ||
ByteArrayInputStream bis = new ByteArrayInputStream(bytes); | ||
ObjectInputStream in = new ObjectInputStream(bis); | ||
return (QuantizationState) in.readObject(); | ||
} | ||
} |
41 changes: 41 additions & 0 deletions
41
src/main/java/org/opensearch/knn/quantization/models/requests/SamplingTrainingRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.models.requests; | ||
|
||
import org.opensearch.knn.quantization.sampler.Sampler; | ||
|
||
import java.util.List; | ||
|
||
public class SamplingTrainingRequest<T> extends TrainingRequest<T> { | ||
private TrainingRequest<T> originalRequest; | ||
private int[] sampledIndices; | ||
|
||
public SamplingTrainingRequest(TrainingRequest<T> originalRequest, Sampler sampler, int sampleSize) { | ||
super(originalRequest.getParams(), originalRequest.getTotalNumberOfVectors()); | ||
this.originalRequest = originalRequest; | ||
this.sampledIndices = sampler.sample(originalRequest.getTotalNumberOfVectors(), sampleSize); | ||
} | ||
|
||
@Override | ||
public T getVector() { | ||
return originalRequest.getVector(); | ||
} | ||
|
||
@Override | ||
public T getVectorByDocId(int docId) { | ||
return originalRequest.getVectorByDocId(docId); | ||
} | ||
|
||
public int[] getSampledIndices() { | ||
return sampledIndices; | ||
} | ||
} |
37 changes: 37 additions & 0 deletions
37
src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.models.requests; | ||
|
||
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; | ||
|
||
public abstract class TrainingRequest<T> { | ||
private QuantizationParams params; | ||
private int totalNumberOfVectors; | ||
|
||
public TrainingRequest(QuantizationParams params, int totalNumberOfVectors) { | ||
this.params = params; | ||
this.totalNumberOfVectors = totalNumberOfVectors; | ||
} | ||
|
||
public QuantizationParams getParams() { | ||
return params; | ||
} | ||
|
||
public int getTotalNumberOfVectors() { | ||
return totalNumberOfVectors; | ||
} | ||
|
||
public abstract T getVector(); | ||
|
||
public abstract T getVectorByDocId(int docId); | ||
} | ||
|
125 changes: 125 additions & 0 deletions
125
src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.quantizer; | ||
|
||
import org.opensearch.knn.quantization.models.quantizationOutput.OneBitScalarQuantizationOutput; | ||
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; | ||
import org.opensearch.knn.quantization.models.quantizationParams.SQParams; | ||
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; | ||
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
import org.opensearch.knn.quantization.models.requests.SamplingTrainingRequest; | ||
import org.opensearch.knn.quantization.models.requests.TrainingRequest; | ||
|
||
public class OneBitScalarQuantizer implements Quantizer<float[], byte[]> { | ||
private static final int SAMPLING_SIZE = 25000; | ||
|
||
@Override | ||
public int getSamplingSize() { | ||
return SAMPLING_SIZE; | ||
} | ||
|
||
@Override | ||
public QuantizationState train(TrainingRequest<float[]> trainingRequest) { | ||
if (!(trainingRequest instanceof SamplingTrainingRequest)) { | ||
throw new IllegalArgumentException("Training request must be of type SamplingTrainingRequest."); | ||
} | ||
|
||
SamplingTrainingRequest<float[]> samplingRequest = (SamplingTrainingRequest<float[]>) trainingRequest; | ||
int[] sampledIndices = samplingRequest.getSampledIndices(); | ||
|
||
if (sampledIndices == null || sampledIndices.length == 0) { | ||
throw new IllegalArgumentException("Sampled indices must not be null or empty."); | ||
} | ||
|
||
int totalSamples = sampledIndices.length; | ||
float[] sum = null; | ||
|
||
// Calculate the sum for each dimension based on sampled indices | ||
for (int i = 0; i < totalSamples; i++) { | ||
float[] vector = samplingRequest.getVectorByDocId(sampledIndices[i]); | ||
if (vector == null) { | ||
throw new IllegalArgumentException("Vector at sampled index " + sampledIndices[i] + " is null."); | ||
} | ||
if (sum == null) { | ||
sum = new float[vector.length]; | ||
} else if (sum.length != vector.length) { | ||
throw new IllegalArgumentException("All vectors must have the same dimension."); | ||
} | ||
for (int j = 0; j < vector.length; j++) { | ||
sum[j] += vector[j]; | ||
} | ||
} | ||
if (sum == null) { | ||
throw new IllegalStateException("Sum array should not be null after processing vectors."); | ||
} | ||
// Calculate the mean for each dimension | ||
float[] mean = new float[sum.length]; | ||
for (int j = 0; j < sum.length; j++) { | ||
mean[j] = sum[j] / totalSamples; | ||
} | ||
SQParams params = (SQParams) trainingRequest.getParams(); | ||
if (params == null) { | ||
throw new IllegalArgumentException("Quantization parameters must not be null."); | ||
} | ||
return new OneBitScalarQuantizationState(params, mean); | ||
} | ||
|
||
@Override | ||
public QuantizationOutput<byte[]> quantize(float[] vector, QuantizationState state) { | ||
if (vector == null) { | ||
throw new IllegalArgumentException("Vector to quantize must not be null."); | ||
} | ||
if (!(state instanceof OneBitScalarQuantizationState)) { | ||
throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState."); | ||
} | ||
OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState) state; | ||
float[] thresholds = binaryState.getMean(); | ||
if (thresholds == null || thresholds.length != vector.length) { | ||
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); | ||
} | ||
byte[] quantizedVector = new byte[vector.length]; | ||
for (int i = 0; i < vector.length; i++) { | ||
quantizedVector[i] = (byte) (vector[i] > thresholds[i] ? 1 : 0); | ||
} | ||
return new OneBitScalarQuantizationOutput(packBitsFromBitArray(quantizedVector)); | ||
} | ||
|
||
@Override | ||
public QuantizationOutput<byte[]> quantize(float[] vector) { | ||
if (vector == null) { | ||
throw new IllegalArgumentException("Vector to quantize must not be null."); | ||
} | ||
byte[] quantizedVector = new byte[vector.length]; | ||
for (int i = 0; i < vector.length; i++) { | ||
quantizedVector[i] = (byte) (vector[i] > 0 ? 1 : 0); | ||
} | ||
return new OneBitScalarQuantizationOutput(packBitsFromBitArray(quantizedVector)); | ||
} | ||
|
||
private byte[] packBitsFromBitArray(byte[] bitArray) { | ||
int bitLength = bitArray.length; | ||
int byteLength = (bitLength + 7) / 8; | ||
byte[] packedArray = new byte[byteLength]; | ||
|
||
for (int i = 0; i < bitLength; i++) { | ||
if (bitArray[i] != 0 && bitArray[i] != 1) { | ||
throw new IllegalArgumentException("Array elements must be 0 or 1"); | ||
} | ||
int byteIndex = i / 8; | ||
int bitIndex = 7 - (i % 8); | ||
packedArray[byteIndex] |= (bitArray[i] << bitIndex); | ||
} | ||
|
||
return packedArray; | ||
} | ||
} | ||
|
31 changes: 31 additions & 0 deletions
31
src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.quantizer; | ||
|
||
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; | ||
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
import org.opensearch.knn.quantization.models.requests.TrainingRequest; | ||
public interface Quantizer<T, R> { | ||
int getSamplingSize(); | ||
|
||
default QuantizationState train(TrainingRequest<T> trainingRequest) { | ||
throw new UnsupportedOperationException("Train method is not supported by this quantizer."); | ||
} | ||
|
||
default QuantizationOutput<R> quantize(T vector, QuantizationState state) { | ||
throw new UnsupportedOperationException("Quantize method with state is not supported by this quantizer."); | ||
} | ||
|
||
default QuantizationOutput<R> quantize(T vector) { | ||
throw new UnsupportedOperationException("Quantize method without state is not supported by this quantizer."); | ||
} | ||
} |
41 changes: 41 additions & 0 deletions
41
src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.sampler; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Random; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.IntStream; | ||
|
||
public class ReservoirSampler implements Sampler { | ||
private final Random random = new Random(); | ||
|
||
@Override | ||
public int[] sample(int totalNumberOfVectors, int sampleSize) { | ||
if (totalNumberOfVectors <= sampleSize) { | ||
return IntStream.range(0, totalNumberOfVectors).toArray(); | ||
} | ||
return reservoirSampleIndices(totalNumberOfVectors, sampleSize); | ||
} | ||
private int[] reservoirSampleIndices(int numVectors, int sampleSize) { | ||
int[] indices = IntStream.range(0, sampleSize).toArray(); | ||
for (int i = sampleSize; i < numVectors; i++) { | ||
int j = random.nextInt(i + 1); | ||
if (j < sampleSize) { | ||
indices[j] = i; | ||
} | ||
} | ||
Arrays.sort(indices); | ||
return indices; | ||
} | ||
} |
18 changes: 18 additions & 0 deletions
18
src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.sampler; | ||
|
||
import java.util.List; | ||
|
||
public interface Sampler { | ||
int[] sample(int totalNumberOfVectors, int sampleSize); | ||
} |
28 changes: 28 additions & 0 deletions
28
src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.sampler; | ||
|
||
public class SamplingFactory { | ||
public enum SamplerType { | ||
RESERVOIR, | ||
} | ||
|
||
public static Sampler getSampler(SamplerType samplerType) { | ||
switch (samplerType) { | ||
case RESERVOIR: | ||
return new ReservoirSampler(); | ||
// Add more cases for different samplers | ||
default: | ||
throw new IllegalArgumentException("Unsupported sampler type: " + samplerType); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
src/test/java/org/opensearch/knn/quantization/QuantizationManagerTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization; | ||
|
||
import org.opensearch.knn.KNNTestCase; | ||
import org.opensearch.knn.quantization.enums.SQTypes; | ||
import org.opensearch.knn.quantization.models.quantizationParams.SQParams; | ||
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; | ||
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
import org.opensearch.knn.quantization.models.requests.TrainingRequest; | ||
import org.opensearch.knn.quantization.quantizer.Quantizer; | ||
|
||
|
||
public class QuantizationManagerTests extends KNNTestCase { | ||
public void testSingletonInstance() { | ||
QuantizationManager instance1 = QuantizationManager.getInstance(); | ||
QuantizationManager instance2 = QuantizationManager.getInstance(); | ||
assertSame(instance1, instance2); | ||
} | ||
|
||
public void testTrain() { | ||
QuantizationManager quantizationManager = QuantizationManager.getInstance(); | ||
float[][] vectors = { | ||
{1.0f, 2.0f, 3.0f}, | ||
{4.0f, 5.0f, 6.0f}, | ||
{7.0f, 8.0f, 9.0f} | ||
}; | ||
|
||
SQParams params = new SQParams(SQTypes.ONE_BIT); | ||
TrainingRequest<float[]> originalRequest = new TrainingRequest<float[]>(params, vectors.length) { | ||
@Override | ||
public float[] getVector() { | ||
return null; // Not used in this test | ||
} | ||
@Override | ||
public float[] getVectorByDocId(int docId) { | ||
return vectors[docId]; | ||
} | ||
}; | ||
QuantizationState state = quantizationManager.train(originalRequest); | ||
|
||
assertTrue(state instanceof OneBitScalarQuantizationState); | ||
float[] mean = ((OneBitScalarQuantizationState) state).getMean(); | ||
assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f); | ||
} | ||
|
||
public void testTrainWithFewVectors() { | ||
QuantizationManager quantizationManager = QuantizationManager.getInstance(); | ||
float[][] vectors = { | ||
{1.0f, 2.0f, 3.0f}, | ||
{4.0f, 5.0f, 6.0f} | ||
}; | ||
|
||
SQParams params = new SQParams(SQTypes.ONE_BIT); | ||
TrainingRequest<float[]> originalRequest = new TrainingRequest<float[]>(params, vectors.length) { | ||
@Override | ||
public float[] getVector() { | ||
return null; // Not used in this test | ||
} | ||
|
||
@Override | ||
public float[] getVectorByDocId(int docId) { | ||
return vectors[docId]; | ||
} | ||
}; | ||
|
||
QuantizationState state = quantizationManager.train(originalRequest); | ||
|
||
assertTrue(state instanceof OneBitScalarQuantizationState); | ||
float[] mean = ((OneBitScalarQuantizationState) state).getMean(); | ||
assertArrayEquals(new float[]{2.5f, 3.5f, 4.5f}, mean, 0.001f); | ||
} | ||
|
||
|
||
public void testGetQuantizer() { | ||
QuantizationManager quantizationManager = QuantizationManager.getInstance(); | ||
SQParams params = new SQParams(SQTypes.ONE_BIT); | ||
|
||
Quantizer<?, ?> quantizer = quantizationManager.getQuantizer(params); | ||
|
||
assertTrue(quantizer instanceof Quantizer); | ||
} | ||
} |
30 changes: 30 additions & 0 deletions
30
src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.enums; | ||
|
||
import org.opensearch.knn.KNNTestCase; | ||
|
||
public class QuantizationTypeTests extends KNNTestCase { | ||
|
||
public void testQuantizationTypeValues() { | ||
QuantizationType[] expectedValues = { | ||
QuantizationType.SPACE_QUANTIZATION, | ||
QuantizationType.VALUE_QUANTIZATION | ||
}; | ||
assertArrayEquals(expectedValues, QuantizationType.values()); | ||
} | ||
|
||
public void testQuantizationTypeValueOf() { | ||
assertEquals(QuantizationType.SPACE_QUANTIZATION, QuantizationType.valueOf("SPACE_QUANTIZATION")); | ||
assertEquals(QuantizationType.VALUE_QUANTIZATION, QuantizationType.valueOf("VALUE_QUANTIZATION")); | ||
} | ||
} |
37 changes: 37 additions & 0 deletions
37
src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.enums; | ||
|
||
import org.opensearch.knn.KNNTestCase; | ||
|
||
public class SQTypesTests extends KNNTestCase { | ||
public void testSQTypesValues() { | ||
SQTypes[] expectedValues = { | ||
SQTypes.FP16, | ||
SQTypes.INT8, | ||
SQTypes.INT6, | ||
SQTypes.INT4, | ||
SQTypes.ONE_BIT, | ||
SQTypes.TWO_BIT | ||
}; | ||
assertArrayEquals(expectedValues, SQTypes.values()); | ||
} | ||
|
||
public void testSQTypesValueOf() { | ||
assertEquals(SQTypes.FP16, SQTypes.valueOf("FP16")); | ||
assertEquals(SQTypes.INT8, SQTypes.valueOf("INT8")); | ||
assertEquals(SQTypes.INT6, SQTypes.valueOf("INT6")); | ||
assertEquals(SQTypes.INT4, SQTypes.valueOf("INT4")); | ||
assertEquals(SQTypes.ONE_BIT, SQTypes.valueOf("ONE_BIT")); | ||
assertEquals(SQTypes.TWO_BIT, SQTypes.valueOf("TWO_BIT")); | ||
} | ||
} |
27 changes: 27 additions & 0 deletions
27
src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.enums; | ||
|
||
import org.opensearch.knn.KNNTestCase; | ||
|
||
public class ValueQuantizationTypeTests extends KNNTestCase { | ||
public void testValueQuantizationTypeValues() { | ||
ValueQuantizationType[] expectedValues = { | ||
ValueQuantizationType.SQ | ||
}; | ||
assertArrayEquals(expectedValues, ValueQuantizationType.values()); | ||
} | ||
|
||
public void testValueQuantizationTypeValueOf() { | ||
assertEquals(ValueQuantizationType.SQ, ValueQuantizationType.valueOf("SQ")); | ||
} | ||
} |
26 changes: 26 additions & 0 deletions
26
src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.factory; | ||
|
||
import org.opensearch.knn.KNNTestCase; | ||
import org.opensearch.knn.quantization.enums.SQTypes; | ||
import org.opensearch.knn.quantization.models.quantizationParams.SQParams; | ||
import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; | ||
import org.opensearch.knn.quantization.quantizer.Quantizer; | ||
|
||
public class QuantizerFactoryTests extends KNNTestCase { | ||
public void testGetQuantizer_withSQParams() { | ||
SQParams params = new SQParams(SQTypes.ONE_BIT); | ||
Quantizer<?, ?> quantizer = QuantizerFactory.getQuantizer(params); | ||
assertTrue(quantizer instanceof OneBitScalarQuantizer); | ||
} | ||
} |
38 changes: 38 additions & 0 deletions
38
src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.factory; | ||
|
||
import org.opensearch.knn.KNNTestCase; | ||
import org.opensearch.knn.quantization.enums.SQTypes; | ||
import org.opensearch.knn.quantization.models.quantizationParams.SQParams; | ||
import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; | ||
import org.opensearch.knn.quantization.quantizer.Quantizer; | ||
import org.junit.BeforeClass; | ||
|
||
public class QuantizerRegistryTests extends KNNTestCase { | ||
@BeforeClass | ||
public static void setup() { | ||
// Register the quantizer for testing | ||
QuantizerRegistry.register(SQParams.class, SQTypes.ONE_BIT.name(), OneBitScalarQuantizer::new); | ||
} | ||
|
||
public void testRegisterAndGetQuantizer() { | ||
SQParams params = new SQParams(SQTypes.ONE_BIT); | ||
Quantizer<?, ?> quantizer = QuantizerRegistry.getQuantizer(params, SQTypes.ONE_BIT.name()); | ||
assertTrue(quantizer instanceof OneBitScalarQuantizer); | ||
} | ||
|
||
public void testGetQuantizer_withUnsupportedTypeIdentifier() { | ||
SQParams params = new SQParams(SQTypes.ONE_BIT); | ||
expectThrows( IllegalArgumentException.class, ()-> QuantizerRegistry.getQuantizer(params, "UNSUPPORTED_TYPE")); | ||
} | ||
} |
81 changes: 81 additions & 0 deletions
81
src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.quantizer; | ||
|
||
import org.opensearch.knn.KNNTestCase; | ||
import org.opensearch.knn.quantization.enums.SQTypes; | ||
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; | ||
import org.opensearch.knn.quantization.models.quantizationParams.SQParams; | ||
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; | ||
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
import org.opensearch.knn.quantization.models.requests.SamplingTrainingRequest; | ||
import org.opensearch.knn.quantization.models.requests.TrainingRequest; | ||
import org.opensearch.knn.quantization.sampler.ReservoirSampler; | ||
|
||
public class OneBitScalarQuantizerTests extends KNNTestCase { | ||
|
||
public void testTrain() { | ||
float[][] vectors = { | ||
{1.0f, 2.0f, 3.0f}, | ||
{4.0f, 5.0f, 6.0f}, | ||
{7.0f, 8.0f, 9.0f} | ||
}; | ||
|
||
SQParams params = new SQParams(SQTypes.ONE_BIT); | ||
TrainingRequest<float[]> originalRequest = new TrainingRequest<float[]>(params, vectors.length) { | ||
@Override | ||
public float[] getVector() { | ||
return null; // Not used in this test | ||
} | ||
@Override | ||
public float[] getVectorByDocId(int docId) { | ||
return vectors[docId]; | ||
} | ||
}; | ||
TrainingRequest<float[]> trainingRequest = new SamplingTrainingRequest<>( | ||
originalRequest, | ||
new ReservoirSampler(), | ||
vectors.length | ||
); | ||
OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); | ||
QuantizationState state = quantizer.train(trainingRequest); | ||
|
||
assertTrue(state instanceof OneBitScalarQuantizationState); | ||
float[] mean = ((OneBitScalarQuantizationState) state).getMean(); | ||
assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f); | ||
} | ||
|
||
public void testQuantize_withState() { | ||
float[] vector = {3.0f, 6.0f, 9.0f}; | ||
float[] thresholds = {4.0f, 5.0f, 6.0f}; | ||
OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), thresholds); | ||
|
||
OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); | ||
QuantizationOutput<byte[]> output = quantizer.quantize(vector, state); | ||
|
||
assertArrayEquals(new byte[]{96}, output.getQuantizedVector()); | ||
} | ||
|
||
public void testQuantize_withoutState() { | ||
float[] vector = {-1.0f, 0.5f, 1.5f}; | ||
|
||
OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); | ||
QuantizationOutput<byte[]> output = quantizer.quantize(vector); | ||
|
||
assertArrayEquals(new byte[]{96}, output.getQuantizedVector()); | ||
} | ||
|
||
public void testQuantize_withNullVector() { | ||
OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); | ||
expectThrows( IllegalArgumentException.class, ()-> quantizer.quantize(null)); | ||
} | ||
} |
74 changes: 74 additions & 0 deletions
74
src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.sampler; | ||
|
||
import org.opensearch.knn.KNNTestCase; | ||
|
||
|
||
public class ReservoirSamplerTests extends KNNTestCase { | ||
|
||
public void testSample() { | ||
Sampler sampler = new ReservoirSampler(); | ||
int totalNumberOfVectors = 100; | ||
int sampleSize = 10; | ||
|
||
int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); | ||
assertEquals(sampleSize, samples.length); | ||
for (int index : samples) { | ||
assertTrue(index >= 0 && index < totalNumberOfVectors); | ||
} | ||
} | ||
|
||
public void testSample_withFullSampling() { | ||
Sampler sampler = new ReservoirSampler(); | ||
int totalNumberOfVectors = 10; | ||
int sampleSize = 10; | ||
|
||
int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); | ||
assertEquals(sampleSize, samples.length); | ||
for (int index : samples) { | ||
assertTrue(index >= 0 && index < totalNumberOfVectors); | ||
} | ||
} | ||
|
||
public void testSample_withLessVectors() { | ||
Sampler sampler = new ReservoirSampler(); | ||
int totalNumberOfVectors = 5; | ||
int sampleSize = 10; | ||
|
||
int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); | ||
assertEquals(totalNumberOfVectors, samples.length); | ||
for (int index : samples) { | ||
assertTrue(index >= 0 && index < totalNumberOfVectors); | ||
} | ||
} | ||
|
||
public void testSample_withZeroVectors() { | ||
Sampler sampler = new ReservoirSampler(); | ||
int totalNumberOfVectors = 0; | ||
int sampleSize = 10; | ||
|
||
int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); | ||
assertEquals(0, samples.length); | ||
} | ||
|
||
public void testSample_withOneVector() { | ||
Sampler sampler = new ReservoirSampler(); | ||
int totalNumberOfVectors = 1; | ||
int sampleSize = 10; | ||
|
||
int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); | ||
assertEquals(1, samples.length); | ||
assertTrue(samples[0] == 0); | ||
} | ||
} | ||
|
25 changes: 25 additions & 0 deletions
25
src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.knn.quantization.sampler; | ||
|
||
import org.opensearch.knn.KNNTestCase; | ||
|
||
public class SamplingFactoryTests extends KNNTestCase { | ||
public void testGetSampler_withReservoir() { | ||
Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); | ||
assertTrue(sampler instanceof ReservoirSampler); | ||
} | ||
|
||
public void testGetSampler_withUnsupportedType() { | ||
expectThrows( NullPointerException.class, ()-> SamplingFactory.getSampler(null)); // This should throw an exception | ||
} | ||
} |