Skip to content

Commit

Permalink
BackPort Java Doc Fix with Code Improvements
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <[email protected]>
  • Loading branch information
Vikasht34 committed Aug 14, 2024
1 parent 8eb3e19 commit 88b1cc4
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,35 @@
public final class QuantizerFactory {
private static final AtomicBoolean isRegistered = new AtomicBoolean(false);

/**
* Ensures that default quantizers are registered.
*/
private static void ensureRegistered() {
if (!isRegistered.get()) {
synchronized (QuantizerFactory.class) {
if (!isRegistered.get()) {
QuantizerRegistrar.registerDefaultQuantizers();
isRegistered.set(true);
}
}
}
}

/**
* Retrieves a quantizer instance based on the provided quantization parameters.
*
* @param params the quantization parameters used to determine the appropriate quantizer
* @param <P> the type of quantization parameters, extending {@link QuantizationParams}
* @param <Q> the type of the quantized output
* @param <T> the type of the input vector to be quantized
* @param <R> the type of the output after quantization
* @return an instance of {@link Quantizer} corresponding to the provided parameters
*/
public static <P extends QuantizationParams, Q> Quantizer<P, Q> getQuantizer(final P params) {
public static <P extends QuantizationParams, T, R> Quantizer<T, R> getQuantizer(final P params) {
if (params == null) {
throw new IllegalArgumentException("Quantization parameters must not be null.");
}
// Lazy Registration instead of static block as class level;
ensureRegistered();
return QuantizerRegistry.getQuantizer(params);
return (Quantizer<T, R>) QuantizerRegistry.getQuantizer(params);
}

/**
* Ensures that default quantizers are registered.
*/
private static void ensureRegistered() {
if (!isRegistered.get()) {
synchronized (QuantizerFactory.class) {
if (!isRegistered.get()) {
QuantizerRegistrar.registerDefaultQuantizers();
isRegistered.set(true);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,17 @@ static void register(final String paramIdentifier, final Quantizer<?, ?> quantiz
*
* @param params the quantization parameters used to determine the appropriate quantizer
* @param <P> the type of quantization parameters
* @param <Q> the type of the quantized output
* @param <T> the type of the input vector to be quantized
* @param <R> the type of the output after quantization
* @return an instance of {@link Quantizer} corresponding to the provided parameters
* @throws IllegalArgumentException if no quantizer is registered for the given parameters
*/
static <P extends QuantizationParams, Q> Quantizer<P, Q> getQuantizer(final P params) {
static <P extends QuantizationParams, T, R> Quantizer<T, R> getQuantizer(final P params) {
String identifier = params.getTypeIdentifier();
Quantizer<?, ?> quantizer = registry.get(identifier);
if (quantizer == null) {
throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier);
}
@SuppressWarnings("unchecked")
Quantizer<P, Q> typedQuantizer = (Quantizer<P, Q>) quantizer;
return typedQuantizer;
return (Quantizer<T, R>) quantizer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ public String getTypeIdentifier() {
return generateIdentifier(sqType.getId());
}

private static String generateIdentifier(int id) {
return "ScalarQuantizationParams_" + id;
}

/**
* Writes the object to the output stream.
* This method is part of the Writeable interface and is used to serialize the object.
Expand All @@ -74,4 +70,21 @@ public ScalarQuantizationParams(StreamInput in, int version) throws IOException
int typeId = in.readVInt();
this.sqType = ScalarQuantizationType.fromId(typeId);
}

/**
* Generates a unique identifier for Scalar Quantization Parameters.
*
* <p>
* This method constructs an identifier string by prefixing the given integer ID
* with "ScalarQuantizationParams_". The resulting string can be used to uniquely
* identify specific quantization parameter instances, especially when registering
* or retrieving them in a registry or similar structure.
* </p>
*
* @param id the integer ID to be used in generating the unique identifier.
* @return a string representing the unique identifier for the quantization parameters.
*/
private static String generateIdentifier(int id) {
return "ScalarQuantizationParams_" + id;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ public final class MultiBitScalarQuantizationState implements QuantizationState
*
* For example:
* - For 2-bit quantization:
* thresholds[0] -> {0.5f, 1.5f, 2.5f} // Thresholds for the first bit level
* thresholds[1] -> {1.0f, 2.0f, 3.0f} // Thresholds for the second bit level
* thresholds[0] {0.5f, 1.5f, 2.5f} // Thresholds for the first bit level
* thresholds[1] {1.0f, 2.0f, 3.0f} // Thresholds for the second bit level
* - For 4-bit quantization:
* thresholds[0] -> {0.1f, 0.2f, 0.3f}
* thresholds[1] -> {0.4f, 0.5f, 0.6f}
* thresholds[2] -> {0.7f, 0.8f, 0.9f}
* thresholds[3] -> {1.0f, 1.1f, 1.2f}
* thresholds[0] {0.1f, 0.2f, 0.3f}
* thresholds[1] {0.4f, 0.5f, 0.6f}
* thresholds[2] {0.7f, 0.8f, 0.9f}
* thresholds[3] {1.0f, 1.1f, 1.2f}
*
* Each column represents the threshold for a specific dimension in the vector space.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ public void test_Lazy_Registration() {
ScalarQuantizationParams paramsTwoBit = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
ScalarQuantizationParams paramsFourBit = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
assertFalse(isRegisteredFieldAccessible());
Quantizer<?, ?> quantizer = QuantizerFactory.getQuantizer(params);
Quantizer<?, ?> quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit);
Quantizer<?, ?> quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit);
assertTrue(quantizerFourBit instanceof MultiBitScalarQuantizer);
assertTrue(quantizerTwoBit instanceof MultiBitScalarQuantizer);
assertTrue(quantizer instanceof OneBitScalarQuantizer);
Quantizer<Float[], Byte[]> oneBitQuantizer = QuantizerFactory.getQuantizer(params);
Quantizer<Float[], Byte[]> quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit);
Quantizer<Float[], Byte[]> quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit);
assertEquals(quantizerFourBit.getClass(), MultiBitScalarQuantizer.class);
assertEquals(quantizerTwoBit.getClass(), MultiBitScalarQuantizer.class);
assertEquals(oneBitQuantizer.getClass(), OneBitScalarQuantizer.class);
assertTrue(isRegisteredFieldAccessible());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,39 +34,37 @@ public static void setup() {
public void testRegisterAndGetQuantizer() {
// Test for OneBitScalarQuantizer
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
Quantizer<?, ?> oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
assertTrue(oneBitQuantizer instanceof OneBitScalarQuantizer);
Quantizer<Float[], Byte[]> oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
assertEquals(oneBitQuantizer.getClass(), OneBitScalarQuantizer.class);

// Test for MultiBitScalarQuantizer (2-bit)
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
Quantizer<?, ?> twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
assertTrue(twoBitQuantizer instanceof MultiBitScalarQuantizer);
assertEquals(2, ((MultiBitScalarQuantizer) twoBitQuantizer).getBitsPerCoordinate());
Quantizer<Float[], Byte[]> twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
assertEquals(twoBitQuantizer.getClass(), MultiBitScalarQuantizer.class);

// Test for MultiBitScalarQuantizer (4-bit)
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
Quantizer<?, ?> fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer);
assertEquals(4, ((MultiBitScalarQuantizer) fourBitQuantizer).getBitsPerCoordinate());
Quantizer<Float[], Byte[]> fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
assertEquals(fourBitQuantizer.getClass(), MultiBitScalarQuantizer.class);
}

public void testQuantizerRegistryIsSingleton() {
// Ensure the same instance is returned for the same type identifier
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
Quantizer<?, ?> firstOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
Quantizer<?, ?> secondOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
Quantizer<Float[], Byte[]> firstOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
Quantizer<Float[], Byte[]> secondOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
assertSame(firstOneBitQuantizer, secondOneBitQuantizer);

// Ensure the same instance is returned for the same type identifier (2-bit)
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
Quantizer<?, ?> firstTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
Quantizer<?, ?> secondTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
Quantizer<Float[], Byte[]> firstTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
Quantizer<Float[], Byte[]> secondTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
assertSame(firstTwoBitQuantizer, secondTwoBitQuantizer);

// Ensure the same instance is returned for the same type identifier (4-bit)
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
Quantizer<?, ?> firstFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
Quantizer<?, ?> secondFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
Quantizer<Float[], Byte[]> firstFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
Quantizer<Float[], Byte[]> secondFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
assertSame(firstFourBitQuantizer, secondFourBitQuantizer);
}

Expand Down

0 comments on commit 88b1cc4

Please sign in to comment.