Skip to content

Commit

Permalink
Add bwc test for model id
Browse files Browse the repository at this point in the history
This is applicable for index created before 2.17

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Dec 10, 2024
1 parent 5ede283 commit c8406c8
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 11 deletions.
19 changes: 19 additions & 0 deletions qa/restart-upgrade/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ testClusters {
excludeTestsMatching "org.opensearch.knn.bwc.IndexingIT.testKNNIndexLuceneForceMerge"
}
}
if (!(knn_bwc_version.startsWith("2.13.") ||
knn_bwc_version.startsWith("2.14.") ||
knn_bwc_version.startsWith("2.15.") ||
knn_bwc_version.startsWith("2.16."))) {
filter {
excludeTestsMatching "org.opensearch.knn.bwc.ModelIT.testNonKNNIndex_withModelId"
excludeTestsMatching "org.opensearch.knn.bwc.PainlessScriptScoringIT.testNonKNNIndex_withMethodParams_withFaissEngine"
}
}

nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}")
nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}")
Expand Down Expand Up @@ -170,6 +179,16 @@ testClusters {
}
}

if (!(knn_bwc_version.startsWith("2.13.") ||
knn_bwc_version.startsWith("2.14.") ||
knn_bwc_version.startsWith("2.15.") ||
knn_bwc_version.startsWith("2.16."))) {
filter {
excludeTestsMatching "org.opensearch.knn.bwc.ModelIT.testNonKNNIndex_withModelId"
excludeTestsMatching "org.opensearch.knn.bwc.PainlessScriptScoringIT.testNonKNNIndex_withMethodParams_withFaissEngine"
}
}

nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}")
nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}")
systemProperty 'tests.security.manager', 'false'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,27 @@ public class ModelIT extends AbstractRestartUpgradeTestCase {
private static final String TEST_MODEL_INDEX_DEFAULT = KNN_BWC_PREFIX + "test-model-index-default";
private static final String TRAINING_INDEX = KNN_BWC_PREFIX + "train-index";
private static final String TRAINING_INDEX_DEFAULT = KNN_BWC_PREFIX + "train-index-default";
private static final String TRAINING_INDEX_FOR_NON_KNN_INDEX = KNN_BWC_PREFIX + "train-index-for-non-knn-index";
private static final String TRAINING_FIELD = "train-field";
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 5;
private static int DOC_ID = 0;
private static int DOC_ID_TEST_MODEL_INDEX = 0;
private static int DOC_ID_TEST_MODEL_INDEX_DEFAULT = 0;
private static final int DELAY_MILLI_SEC = 1000;
private static final int EXP_NUM_OF_MODELS = 2;
private static final int MIN_NUM_OF_MODELS = 2;
private static final int K = 5;
private static final int NUM_DOCS = 10;
private static final int NUM_DOCS_TEST_MODEL_INDEX = 100;
private static final int NUM_DOCS_TEST_MODEL_INDEX_DEFAULT = 100;
private static final int NUM_DOCS_TEST_MODEL_INDEX_FOR_NON_KNN_INDEX = 100;
private static final int NUM_OF_ATTEMPTS = 30;
private static int QUERY_COUNT = 0;
private static int QUERY_COUNT_TEST_MODEL_INDEX = 0;
private static int QUERY_COUNT_TEST_MODEL_INDEX_DEFAULT = 0;
private static final String TEST_MODEL_ID = "test-model-id";
private static final String TEST_MODEL_ID_DEFAULT = "test-model-id-default";
private static final String TEST_MODEL_ID_FOR_NON_KNN_INDEX = "test-model-id-for-non-knn-index";
private static final String MODEL_DESCRIPTION = "Description for train model test";

// KNN model test
Expand Down Expand Up @@ -135,6 +138,32 @@ public void testKNNModelDefault() throws Exception {
}
}

public void testNonKNNIndex_withModelId() throws Exception {
if (isRunningAgainstOldCluster()) {

// Create a training index and randomly ingest data into it
createBasicKnnIndex(TRAINING_INDEX_FOR_NON_KNN_INDEX, TRAINING_FIELD, DIMENSIONS);
bulkIngestRandomVectors(TRAINING_INDEX_FOR_NON_KNN_INDEX, TRAINING_FIELD, NUM_DOCS, DIMENSIONS);

trainKNNModel(TEST_MODEL_ID_FOR_NON_KNN_INDEX, TRAINING_INDEX_FOR_NON_KNN_INDEX, TRAINING_FIELD, DIMENSIONS, MODEL_DESCRIPTION);
validateModelCreated(TEST_MODEL_ID_FOR_NON_KNN_INDEX);

createKnnIndex(
testIndex,
createKNNDefaultScriptScoreSettings(),
modelIndexMapping(TEST_FIELD, TEST_MODEL_ID_FOR_NON_KNN_INDEX)
);
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, NUM_DOCS);
} else {
Thread.sleep(1000);
DOC_ID = NUM_DOCS;
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, NUM_DOCS);
deleteKNNIndex(testIndex);
deleteKNNIndex(TRAINING_INDEX_FOR_NON_KNN_INDEX);
deleteKNNModel(TEST_MODEL_ID_FOR_NON_KNN_INDEX);
}
}

// Delete Models and ".opensearch-knn-models" index to clear cluster metadata
@AfterClass
public static void wipeAllModels() throws IOException {
Expand Down Expand Up @@ -168,7 +197,7 @@ public void searchKNNModel(String testModelID) throws Exception {
XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);
assertEquals(EXP_NUM_OF_MODELS, searchResponse.getHits().getHits().length);
assertTrue(MIN_NUM_OF_MODELS <= searchResponse.getHits().getHits().length);

for (SearchHit hit : searchResponse.getHits().getHits()) {
assertTrue(hit.getId().startsWith(testModelID));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.knn.bwc;

import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.engine.KNNEngine;

public class PainlessScriptScoringIT extends AbstractRestartUpgradeTestCase {
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 5;
Expand Down Expand Up @@ -53,4 +56,26 @@ public void testKNNL1PainlessScriptScore() throws Exception {
}
}

public void testNonKNNIndex_withMethodParams_withFAISSEngine() throws Exception {
if (isRunningAgainstOldCluster()) {
createKnnIndex(
testIndex,
createKNNDefaultScriptScoreSettings(),
createKnnIndexMapping(TEST_FIELD, DIMENSIONS, "hnsw", KNNEngine.FAISS.getName(), SpaceType.DEFAULT.getValue(), false)
);
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, NUM_DOCS);
} else {
DOC_ID = NUM_DOCS;
QUERY_COUNT = NUM_DOCS;
String source = createL1PainlessScriptSource(TEST_FIELD, DIMENSIONS, QUERY_COUNT);
validateKNNPainlessScriptScoreSearch(testIndex, TEST_FIELD, source, QUERY_COUNT, K);
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, NUM_DOCS);
QUERY_COUNT = QUERY_COUNT + NUM_DOCS;
source = createL1PainlessScriptSource(TEST_FIELD, DIMENSIONS, QUERY_COUNT);
validateKNNPainlessScriptScoreSearch(testIndex, TEST_FIELD, source, QUERY_COUNT, K);
forceMergeKnnIndex(testIndex, 1);
validateKNNPainlessScriptScoreSearch(testIndex, TEST_FIELD, source, QUERY_COUNT, K);
deleteKNNIndex(testIndex);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.knn.index.engine.KNNEngine;

import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -96,6 +97,25 @@ public void testKNNInnerProductScriptScore() throws Exception {
}
}

public void testNonKNNIndex_withMethodParams_withFaissEngine() throws Exception {
if (isRunningAgainstOldCluster()) {
createKnnIndex(
testIndex,
createKNNDefaultScriptScoreSettings(),
createKnnIndexMapping(TEST_FIELD, DIMENSIONS, "hnsw", KNNEngine.FAISS.getName(), SpaceType.DEFAULT.getValue(), false)
);
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, NUM_DOCS);
} else {
QUERY_COUNT = NUM_DOCS;
DOC_ID = NUM_DOCS;
validateKNNScriptScoreSearch(testIndex, TEST_FIELD, DIMENSIONS, QUERY_COUNT, K, SpaceType.L2);
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, NUM_DOCS);
QUERY_COUNT = QUERY_COUNT + NUM_DOCS;
validateKNNScriptScoreSearch(testIndex, TEST_FIELD, DIMENSIONS, QUERY_COUNT, K, SpaceType.L2);
deleteKNNIndex(testIndex);
}
}

// Validate Script score search for space_type : "inner_product"
private void validateKNNInnerProductScriptScoreSearch(String testIndex, String testField, int dimension, int numDocs, int k)
throws Exception {
Expand All @@ -121,5 +141,4 @@ private void validateKNNInnerProductScriptScoreSearch(String testIndex, String t
assertEquals(expDocID, actualDocID);
}
}

}
17 changes: 9 additions & 8 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER;
Expand Down Expand Up @@ -381,31 +382,31 @@ protected void putMappingRequest(String index, String mapping) throws IOExceptio
}

/**
* Utility to create a Knn Index Mapping
* Utility to create a Knn Index Mapping for given model id
*/
protected String createKnnIndexMapping(String fieldName, Integer dimensions) throws IOException {
public String createKnnIndexMapping(String fieldName, String modelId) throws IOException {
return XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(PROPERTIES)
.startObject(fieldName)
.field("type", "knn_vector")
.field("dimension", dimensions.toString())
.field(VECTOR_TYPE, KNN_VECTOR)
.field(MODEL_ID, modelId)
.endObject()
.endObject()
.endObject()
.toString();
}

/**
* Utility to create a Knn Index Mapping with model id
* Utility to create a Knn Index Mapping
*/
protected String createKnnIndexMapping(String fieldName, String modelId) throws IOException {
protected String createKnnIndexMapping(String fieldName, Integer dimensions) throws IOException {
return XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field("model_id", modelId)
.field("dimension", dimensions.toString())
.endObject()
.endObject()
.endObject()
Expand Down

0 comments on commit c8406c8

Please sign in to comment.