Skip to content

Commit

Permalink
Enable script score to work with model based indices (#1649)
Browse files Browse the repository at this point in the history
* Enable script score to work with model based indices

Signed-off-by: Ryan Bogan <[email protected]>

* Add changelog entry

Signed-off-by: Ryan Bogan <[email protected]>

* Refactor into KNNVectorFieldMapperUtil and split test into two tests

Signed-off-by: Ryan Bogan <[email protected]>

* Make parameters final for public methods

Signed-off-by: Ryan Bogan <[email protected]>

* Add integration test

Signed-off-by: Ryan Bogan <[email protected]>

* Transfer variable to constant

Signed-off-by: Ryan Bogan <[email protected]>

* Remove star import

Signed-off-by: Ryan Bogan <[email protected]>

---------

Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan authored Apr 29, 2024
1 parent 0cb4f36 commit e608d2d
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499)
### Bug Fixes
* Add stored fields for knn_vector type [#1630](https://github.com/opensearch-project/k-NN/pull/1630)
* Enable script score to work with model based indices [#1649](https://github.com/opensearch-project/k-NN/pull/1649)
### Infrastructure
* Add micro-benchmark module in k-NN plugin for benchmark streaming vectors to JNI layer functionality. [#1583](https://github.com/opensearch-project/k-NN/pull/1583)
* Add arm64 check when SIMD is disabled [#1618](https://github.com/opensearch-project/k-NN/pull/1618)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

import java.util.Arrays;
import java.util.Locale;
Expand All @@ -34,9 +37,22 @@
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;

/**
* Utility class for KNNVectorFieldMapper
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNVectorFieldMapperUtil {

private static ModelDao modelDao;

/**
* Initializes static instance variables
* @param modelDao ModelDao object
*/
public static void initialize(final ModelDao modelDao) {
KNNVectorFieldMapperUtil.modelDao = modelDao;
}

/**
* Validate the float vector value and throw exception if it is not a number or not in the finite range
* or is not within the FP16 range of [-65504 to 65504].
Expand Down Expand Up @@ -171,4 +187,46 @@ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataTy

return vectorDataType.getVectorFromBytesRef(storedVector);
}

/**
* Get the expected dimensions from a specified knn vector field type.
*
* If the field is model-based, get dimensions from model metadata.
* @param knnVectorFieldType knn vector field type
* @return expected dimensions
*/
public static int getExpectedDimensions(final KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) {
int expectedDimensions = knnVectorFieldType.getDimension();
if (isModelBasedIndex(expectedDimensions)) {
ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType);
expectedDimensions = modelMetadata.getDimension();
}
return expectedDimensions;
}

private static boolean isModelBasedIndex(int expectedDimensions) {
return expectedDimensions == -1;
}

/**
* Returns the model metadata for a specified knn vector field
*
* @param knnVectorField knn vector field
* @return the model metadata from knnVectorField
*/
private static ModelMetadata getModelMetadataForField(final KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
String modelId = knnVectorField.getModelId();

if (modelId == null) {
throw new IllegalArgumentException(
String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName())
);
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
}
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
Expand Down Expand Up @@ -204,6 +205,7 @@ public Collection<Object> createComponents(
TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNVectorFieldMapperUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.query.KNNWeight;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -28,6 +29,7 @@
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong;

public interface KNNScoringSpace {

/**
* Return the correct scoring script for a given query. The scoring script
*
Expand Down Expand Up @@ -60,7 +62,7 @@ public L2(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v));
Expand Down Expand Up @@ -96,7 +98,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
SpaceType.COSINESIMIL.validateVector(processedQuery);
Expand Down Expand Up @@ -191,7 +193,7 @@ public L1(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v));
Expand Down Expand Up @@ -226,7 +228,7 @@ public LInf(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v));
Expand Down Expand Up @@ -263,7 +265,7 @@ public InnerProd(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;

/**
* Utility class for KNNScoringSpace
*/
public class KNNScoringSpaceUtil {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@

import org.apache.lucene.document.StoredField;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;

import java.io.ByteArrayInputStream;
import java.util.Arrays;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class KNNVectorFieldMapperUtilTests extends KNNTestCase {

private static final String TEST_FIELD_NAME = "test_field_name";
Expand Down Expand Up @@ -51,4 +59,59 @@ public void testStoredFields_whenVectorIsFloatType_thenSucceed() {
assertTrue(vector instanceof float[]);
assertArrayEquals(TEST_FLOAT_VECTOR, (float[]) vector, 0.001f);
}

public void testGetExpectedDimensionsSuccess() {
KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldType.getDimension()).thenReturn(3);

KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1);
String modelId = "test-model";
when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelMetadata.getDimension()).thenReturn(4);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);

KNNVectorFieldMapperUtil.initialize(modelDao);

assertEquals(3, KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldType));
assertEquals(4, KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased));
}

public void testGetExpectedDimensionsFailure() {
KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1);
String modelId = "test-model";
when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.TRAINING);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);

KNNVectorFieldMapperUtil.initialize(modelDao);

IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)
);
assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage());

when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
MethodComponentContext methodComponentContext = mock(MethodComponentContext.class);
String fieldName = "test-field";
when(methodComponentContext.getName()).thenReturn(fieldName);
when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext);
when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext);

e = expectThrows(
IllegalArgumentException.class,
() -> KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)
);
assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,23 @@
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsString;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.TYPE;
import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR;
import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER;

public class KNNScriptScoringIT extends KNNRestTestCase {

private static final String TEST_MODEL = "test-model";

public void testKNNL2ScriptScore() throws Exception {
testKNNScriptScore(SpaceType.L2);
}
Expand Down Expand Up @@ -550,6 +564,46 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception {
assertEquals(1, secondQueryCacheMap.get("hit_count"));
}

public void testKNNScriptScoreOnModelBasedIndex() throws Exception {
int dimensions = randomIntBetween(2, 10);
String trainMapping = createKnnIndexMapping(TRAIN_FIELD_PARAMETER, dimensions);
createKnnIndex(TRAIN_INDEX_PARAMETER, trainMapping);
bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions * 3, dimensions);

XContentBuilder methodBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, METHOD_IVF)
.field(KNN_ENGINE, FAISS_NAME)
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NLIST, 4)
.field(METHOD_PARAMETER_NPROBES, 2)
.endObject()
.endObject();
Map<String, Object> method = xContentBuilderToMap(methodBuilder);

trainModel(TEST_MODEL, TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions, method, "test model for script score");
assertTrainingSucceeds(TEST_MODEL, 30, 1000);

String testMapping = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME)
.field(TYPE, TYPE_KNN_VECTOR)
.field(MODEL_ID, TEST_MODEL)
.endObject()
.endObject()
.endObject()
.toString();

for (SpaceType spaceType : SpaceType.values()) {
if (spaceType != SpaceType.HAMMING_BIT) {
final float[] queryVector = randomVector(dimensions);
final BiFunction<float[], float[], Float> scoreFunction = getScoreFunction(spaceType, queryVector);
createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector);
}
}
}

private List<String> createMappers(int dimensions) throws Exception {
return List.of(
createKnnIndexMapping(FIELD_NAME, dimensions),
Expand Down

0 comments on commit e608d2d

Please sign in to comment.