Skip to content

Commit

Permalink
Add Tests
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 17, 2024
1 parent cfdfa8c commit 7b192cf
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Adds dynamic query parameter nprobes [#1792](https://github.com/opensearch-project/k-NN/pull/1792)
* Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781)
* Add script scoring support for knn field with binary data type [#1826](https://github.com/opensearch-project/k-NN/pull/1826)
* Add support for Lucene inbuilt Scalar Quantizer [#1848](https://github.com/opensearch-project/k-NN/pull/1848)
### Enhancements
* Switch from byte stream to byte ref for serde [#1825](https://github.com/opensearch-project/k-NN/pull/1825)
### Bug Fixes
Expand Down
11 changes: 7 additions & 4 deletions src/main/java/org/opensearch/knn/index/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.common.ValidationException;
import org.opensearch.knn.training.VectorSpaceInfo;

import java.util.Locale;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Predicate;
Expand Down Expand Up @@ -229,15 +230,15 @@ public ValidationException validate(Object value) {
if (!(value instanceof Double)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("Value not of type Double for Double " + "parameter \"%s\".", getName())
String.format(Locale.ROOT, "Value not of type Double for Double " + "parameter \"%s\".", getName())
);
return validationException;
}

if (!validator.test((Double) value)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("Parameter validation failed for Double " + "parameter \"%s\".", getName())
String.format(Locale.ROOT, "Parameter validation failed for Double " + "parameter \"%s\".", getName())
);
}
return validationException;
Expand All @@ -249,7 +250,7 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector
if (!(value instanceof Double)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("value is not an instance of Double for Double parameter [%s].", getName())
String.format(Locale.ROOT, "value is not an instance of Double for Double parameter [%s].", getName())
);
return validationException;
}
Expand All @@ -260,7 +261,9 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector

if (!validatorWithData.apply((Double) value, vectorSpaceInfo)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for Double parameter [%s].", getName()));
validationException.addValidationError(
String.format(Locale.ROOT, "parameter validation failed for Double parameter [%s].", getName())
);
}

return validationException;
Expand Down
214 changes: 214 additions & 0 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,21 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.MAXIMUM_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.MINIMUM_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

public class LuceneEngineIT extends KNNRestTestCase {
Expand Down Expand Up @@ -466,6 +476,210 @@ public void testRadiusSearch_usingScoreThreshold_withFilter_usingCosineMetrics_u
validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, COLOR_FIELD_NAME, "red", null);
}

@SneakyThrows
public void testSQ_whenInvalidBits_thenThrowException() {
int bits = -1;
expectThrows(
ResponseException.class,
() -> createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
SpaceType.L2,
VectorDataType.FLOAT,
bits,
MINIMUM_CONFIDENCE_INTERVAL,
false
)
);
}

@SneakyThrows
public void testSQ_whenInvalidConfidenceInterval_thenThrowException() {
double confidenceInterval = 2.5;
expectThrows(
ResponseException.class,
() -> createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
SpaceType.L2,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
confidenceInterval,
false
)
);
}

@SneakyThrows
public void testSQ_withByteVectorDataType_thenThrowException() {
Exception ex = expectThrows(
ResponseException.class,
() -> createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
SpaceType.L2,
VectorDataType.BYTE,
LUCENE_SQ_DEFAULT_BITS,
MINIMUM_CONFIDENCE_INTERVAL,
false
)
);
assertTrue(ex.getMessage(), ex.getMessage().contains("data type does not support"));
}

@SneakyThrows
public void testAddDocWithSQEncoder() {
createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
SpaceType.L2,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
MAXIMUM_CONFIDENCE_INTERVAL,
false
);
Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

refreshIndex(INDEX_NAME);
assertEquals(1, getDocCount(INDEX_NAME));
}

@SneakyThrows
public void testUpdateDocWithSQEncoder() {
createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
MAXIMUM_CONFIDENCE_INTERVAL,
false
);
Float[] vector = { 6.0f, 6.0f, 7.0f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

Float[] updatedVector = { 8.0f, 8.0f, 8.0f };
updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector);

refreshIndex(INDEX_NAME);
assertEquals(1, getDocCount(INDEX_NAME));
}

@SneakyThrows
public void testDeleteDocWithSQEncoder() {
createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
MAXIMUM_CONFIDENCE_INTERVAL,
false
);
Float[] vector = { 6.0f, 6.0f, 7.0f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

deleteKnnDoc(INDEX_NAME, DOC_ID);

refreshIndex(INDEX_NAME);
assertEquals(0, getDocCount(INDEX_NAME));
}

@SneakyThrows
public void testIndexingAndQueryingWithSQEncoder() {
createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
MAXIMUM_CONFIDENCE_INTERVAL,
false
);

int numDocs = 10;
for (int i = 0; i < numDocs; i++) {
float[] indexVector = new float[DIMENSION];
Arrays.fill(indexVector, (float) i);
addKnnDocWithAttributes(INDEX_NAME, Integer.toString(i), FIELD_NAME, indexVector, ImmutableMap.of("rating", String.valueOf(i)));
}

// Assert that all docs are ingested
refreshAllNonSystemIndices();
assertEquals(numDocs, getDocCount(INDEX_NAME));

float[] queryVector = new float[DIMENSION];
Arrays.fill(queryVector, (float) numDocs);
int k = 10;

Response searchResponse = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, queryVector, k), k);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), FIELD_NAME);
assertEquals(k, results.size());
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId()));
}
}

public void testQueryWithFilterUsingSQEncoder() throws Exception {
createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
MAXIMUM_CONFIDENCE_INTERVAL,
false
);

addKnnDocWithAttributes(
DOC_ID,
new float[] { 6.0f, 7.9f, 3.1f },
ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet")
);
addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

refreshIndex(INDEX_NAME);

final float[] searchVector = { 6.0f, 6.0f, 4.1f };
List<String> expectedDocIdsKGreaterThanFilterResult = Arrays.asList(DOC_ID, DOC_ID_3);
List<String> expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID);
validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult);
}

private void createKnnIndexMappingWithLuceneEngineAndSQEncoder(
int dimension,
SpaceType spaceType,
VectorDataType vectorDataType,
int bits,
double confidenceInterval,
boolean compress
) throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD_NAME)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, dimension)
.field(VECTOR_DATA_TYPE_FIELD, vectorDataType)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, KNNEngine.LUCENE.getMethod(METHOD_HNSW).getMethodComponent().getName())
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.LUCENE.getName())
.startObject(KNNConstants.PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, M)
.field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_SQ)
.startObject(PARAMETERS)
.field(LUCENE_SQ_BITS, bits)
.field(LUCENE_SQ_CONFIDENCE_INTERVAL, confidenceInterval)
.field(LUCENE_SQ_COMPRESS, compress)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

String mapping = builder.toString();
createKnnIndex(INDEX_NAME, mapping);
}

private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, VectorDataType vectorDataType) throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
Expand Down
34 changes: 34 additions & 0 deletions src/test/java/org/opensearch/knn/index/ParameterTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,40 @@ public void testStringParameter_validateWithData() {
assertNotNull(parameter.validateWithData("test", testVectorSpaceInfo));
}

public void testDoubleParameter_validate() {
final Parameter.DoubleParameter parameter = new Parameter.DoubleParameter("test_parameter", 1.0, v -> v >= 0);

// valid value
assertNull(parameter.validate(0.9));

// Invalid type
assertNotNull(parameter.validate(true));

// Invalid type
assertNotNull(parameter.validate(-1));

}

public void testDoubleParameter_validateWithData() {
final Parameter.DoubleParameter parameter = new Parameter.DoubleParameter(
"test",
1.0,
v -> v > 0,
(v, vectorSpaceInfo) -> v > vectorSpaceInfo.getDimension()
);

VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(0);

// Invalid type
assertNotNull(parameter.validateWithData("String", testVectorSpaceInfo));

// Invalid value
assertNotNull(parameter.validateWithData(-1, testVectorSpaceInfo));

// valid value
assertNull(parameter.validateWithData(1.2, testVectorSpaceInfo));
}

public void testMethodComponentContextParameter_validate() {
String methodComponentName1 = "method-1";
String parameterKey1 = "parameter_key_1";
Expand Down

0 comments on commit 7b192cf

Please sign in to comment.