Skip to content

Commit

Permalink
Remove compress as an input parameter and set default as true
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 22, 2024
1 parent 281f46f commit 31eefb2
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import static org.opensearch.knn.common.KNNConstants.BEAM_WIDTH;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.MAX_CONNECTIONS;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
Expand Down Expand Up @@ -61,11 +60,14 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE
&& params != null
&& params.containsKey(METHOD_ENCODER_PARAMETER)) {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams();
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth
);
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
knnScalarQuantizedVectorsFormatParams.initialize(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\",[{}] = \"{}\"",
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
Expand All @@ -74,17 +76,14 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
LUCENE_SQ_CONFIDENCE_INTERVAL,
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
LUCENE_SQ_BITS,
knnScalarQuantizedVectorsFormatParams.getBits(),
LUCENE_SQ_COMPRESS,
knnScalarQuantizedVectorsFormatParams.isCompressFlag()
knnScalarQuantizedVectorsFormatParams.getBits()
);
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
}

}

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams();
knnVectorsFormatParams.initialize(params, defaultMaxConnections, defaultBeamWidth);
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
Expand All @@ -32,6 +31,15 @@ public class KNNScalarQuantizedVectorsFormatParams extends KNNVectorsFormatParam
private int bits;
private boolean compressFlag;

public KNNScalarQuantizedVectorsFormatParams(Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth) {
super(params, defaultMaxConnections, defaultBeamWidth);
MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER);
Map<String, Object> sqEncoderParams = encoderMethodComponentContext.getParameters();
this.initConfidenceInterval(sqEncoderParams);
this.initBits(sqEncoderParams);
this.initCompressFlag();
}

@Override
public boolean validate(Map<String, Object> params) {
if (params.get(METHOD_ENCODER_PARAMETER) == null) {
Expand All @@ -50,19 +58,6 @@ public boolean validate(Map<String, Object> params) {
return true;
}

@Override
public void initialize(Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth) {
super.initialize(params, defaultMaxConnections, defaultBeamWidth);
MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER);
Map<String, Object> sqEncoderParams = encoderMethodComponentContext.getParameters();
this.initConfidenceInterval(sqEncoderParams);
this.initBits(sqEncoderParams);
this.initCompressFlag(sqEncoderParams);
// this.confidenceInterval = getConfidenceInterval(sqEncoderParams);
// this.bits = getBits(sqEncoderParams);
// this.compressFlag = getCompressFlag(sqEncoderParams);
}

private void initConfidenceInterval(final Map<String, Object> params) {

if (params != null && params.containsKey(LUCENE_SQ_CONFIDENCE_INTERVAL)) {
Expand All @@ -87,11 +82,7 @@ private void initBits(final Map<String, Object> params) {
this.bits = LUCENE_SQ_DEFAULT_BITS;
}

private void initCompressFlag(final Map<String, Object> params) {
if (params != null && params.containsKey(LUCENE_SQ_COMPRESS)) {
this.compressFlag = (boolean) params.get(LUCENE_SQ_COMPRESS);
return;
}
this.compressFlag = false;
private void initCompressFlag() {
this.compressFlag = true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ public class KNNVectorsFormatParams {
private int maxConnections;
private int beamWidth;

public boolean validate(final Map<String, Object> params) {
return false;
}

public void initialize(final Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth) {
public KNNVectorsFormatParams(final Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth) {
initMaxConnections(params, defaultMaxConnections);
initBeamWidth(params, defaultBeamWidth);
}

public boolean validate(final Map<String, Object> params) {
return true;
}

private void initMaxConnections(final Map<String, Object> params, int defaultMaxConnections) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
this.maxConnections = (int) params.get(KNNConstants.METHOD_PARAMETER_M);
Expand Down
3 changes: 0 additions & 3 deletions src/main/java/org/opensearch/knn/index/util/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.DYNAMIC_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS_SUPPORTED;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.MAXIMUM_CONFIDENCE_INTERVAL;
Expand Down Expand Up @@ -54,7 +52,6 @@ public class Lucene extends JVMLibrary {
LUCENE_SQ_BITS,
new Parameter.IntegerParameter(LUCENE_SQ_BITS, LUCENE_SQ_DEFAULT_BITS, LUCENE_SQ_BITS_SUPPORTED::contains)
)
.addParameter(LUCENE_SQ_COMPRESS, new Parameter.BooleanParameter(LUCENE_SQ_COMPRESS, false, Objects::nonNull))
.build()
);

Expand Down
29 changes: 9 additions & 20 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.MAXIMUM_CONFIDENCE_INTERVAL;
Expand Down Expand Up @@ -486,8 +485,7 @@ public void testSQ_whenInvalidBits_thenThrowException() {
SpaceType.L2,
VectorDataType.FLOAT,
bits,
MINIMUM_CONFIDENCE_INTERVAL,
false
MINIMUM_CONFIDENCE_INTERVAL
)
);
}
Expand All @@ -502,8 +500,7 @@ public void testSQ_whenInvalidConfidenceInterval_thenThrowException() {
SpaceType.L2,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
confidenceInterval,
false
confidenceInterval
)
);
}
Expand All @@ -517,8 +514,7 @@ public void testSQ_withByteVectorDataType_thenThrowException() {
SpaceType.L2,
VectorDataType.BYTE,
LUCENE_SQ_DEFAULT_BITS,
MINIMUM_CONFIDENCE_INTERVAL,
false
MINIMUM_CONFIDENCE_INTERVAL
)
);
assertTrue(ex.getMessage(), ex.getMessage().contains("data type does not support"));
Expand All @@ -531,8 +527,7 @@ public void testAddDocWithSQEncoder() {
SpaceType.L2,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
MAXIMUM_CONFIDENCE_INTERVAL,
false
MAXIMUM_CONFIDENCE_INTERVAL
);
Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);
Expand All @@ -548,8 +543,7 @@ public void testUpdateDocWithSQEncoder() {
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
MAXIMUM_CONFIDENCE_INTERVAL,
false
MAXIMUM_CONFIDENCE_INTERVAL
);
Float[] vector = { 6.0f, 6.0f, 7.0f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);
Expand All @@ -568,8 +562,7 @@ public void testDeleteDocWithSQEncoder() {
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
MAXIMUM_CONFIDENCE_INTERVAL,
false
MAXIMUM_CONFIDENCE_INTERVAL
);
Float[] vector = { 6.0f, 6.0f, 7.0f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);
Expand All @@ -587,8 +580,7 @@ public void testIndexingAndQueryingWithSQEncoder() {
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
MAXIMUM_CONFIDENCE_INTERVAL,
false
MAXIMUM_CONFIDENCE_INTERVAL
);

int numDocs = 10;
Expand Down Expand Up @@ -620,8 +612,7 @@ public void testQueryWithFilterUsingSQEncoder() throws Exception {
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
MAXIMUM_CONFIDENCE_INTERVAL,
false
MAXIMUM_CONFIDENCE_INTERVAL
);

addKnnDocWithAttributes(
Expand All @@ -645,8 +636,7 @@ private void createKnnIndexMappingWithLuceneEngineAndSQEncoder(
SpaceType spaceType,
VectorDataType vectorDataType,
int bits,
double confidenceInterval,
boolean compress
double confidenceInterval
) throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
Expand All @@ -667,7 +657,6 @@ private void createKnnIndexMappingWithLuceneEngineAndSQEncoder(
.startObject(PARAMETERS)
.field(LUCENE_SQ_BITS, bits)
.field(LUCENE_SQ_CONFIDENCE_INTERVAL, confidenceInterval)
.field(LUCENE_SQ_COMPRESS, compress)
.endObject()
.endObject()
.endObject()
Expand Down

0 comments on commit 31eefb2

Please sign in to comment.