Skip to content

Commit

Permalink
Fixing the dimension for the vector when using Lucene field in ModelF…
Browse files Browse the repository at this point in the history
…ieldMapper

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Aug 17, 2024
1 parent f42e86e commit 964ac3e
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ protected void parseCreateField(ParseContext context) throws IOException {
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
if (useLuceneBasedVectorField) {
int adjustedDimension = modelMetadata.getVectorDataType() == VectorDataType.BINARY
? modelMetadata.getDimension()
: modelMetadata.getDimension() / 8;
? modelMetadata.getDimension() / Byte.SIZE
: modelMetadata.getDimension();
final VectorEncoding encoding = modelMetadata.getVectorDataType() == VectorDataType.FLOAT
? VectorEncoding.FLOAT32
: VectorEncoding.BYTE;
Expand Down
22 changes: 22 additions & 0 deletions src/test/java/org/opensearch/knn/KNNTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNLibrarySearchContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
Expand Down Expand Up @@ -146,4 +147,25 @@ public int getDimension() {
}
};
}

/**
* Adjust the provided dimension based on {@link VectorDataType} during ingestion.
* @param dimension int
* @param vectorDataType {@link VectorDataType}
* @return int
*/
protected int adjustDimensionForIndexing(final int dimension, final VectorDataType vectorDataType) {
return VectorDataType.BINARY == vectorDataType ? dimension * Byte.SIZE : dimension;
}

/**
* Adjust the provided dimension based on {@link VectorDataType} for search.
*
* @param dimension int
* @param vectorDataType {@link VectorDataType}
* @return int
*/
protected int adjustDimensionForSearch(final int dimension, final VectorDataType vectorDataType) {
return VectorDataType.BINARY == vectorDataType ? dimension / Byte.SIZE : dimension;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.indices.ModelUtil;

import java.io.IOException;
import java.time.ZoneOffset;
Expand All @@ -66,6 +67,7 @@
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
Expand Down Expand Up @@ -768,108 +770,220 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException {

@SneakyThrows
public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldTypes() {
MockedStatic<KNNVectorFieldMapperUtil> utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class);
for (VectorDataType dataType : VectorDataType.values()) {
log.info("Vector Data Type is : {}", dataType);
int dimension = dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION;
final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap());
SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT;
KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder()
.vectorDataType(dataType)
.versionCreated(CURRENT)
.dimension(dimension)
.build();
final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, spaceType, methodComponentContext);

ParseContext.Document document = new ParseContext.Document();
ContentPath contentPath = new ContentPath();
ParseContext parseContext = mock(ParseContext.class);
when(parseContext.doc()).thenReturn(document);
when(parseContext.path()).thenReturn(contentPath);

utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true);
MethodFieldMapper methodFieldMapper = Mockito.spy(
MethodFieldMapper.createFieldMapper(
TEST_FIELD_NAME,
TEST_FIELD_NAME,
Collections.emptyMap(),
knnMethodContext,
knnMethodConfigContext,
knnMethodContext,
FieldMapper.MultiFields.empty(),
FieldMapper.CopyTo.empty(),
new Explicit<>(true, true),
false,
false
)
);

if (dataType == VectorDataType.BINARY) {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper)
.getBytesFromContext(parseContext, TEST_DIMENSION * 8, dataType);
} else if (dataType == VectorDataType.BYTE) {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION, dataType);
} else {
doReturn(Optional.of(TEST_VECTOR)).when(methodFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION);
}

methodFieldMapper.parseCreateField(parseContext, dimension, dataType);

List<IndexableField> fields = document.getFields();
assertEquals(1, fields.size());
IndexableField field1 = fields.get(0);
if (dataType == VectorDataType.FLOAT) {
assertTrue(field1 instanceof KnnFloatVectorField);
assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.FLOAT32);
} else {
assertTrue(field1 instanceof KnnByteVectorField);
assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.BYTE);
// MockedStatic<KNNVectorFieldMapperUtil> utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class);
try (MockedStatic<KNNVectorFieldMapperUtil> utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class)) {
for (VectorDataType dataType : VectorDataType.values()) {
log.info("Vector Data Type is : {}", dataType);
int dimension = adjustDimensionForIndexing(TEST_DIMENSION, dataType);
final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap());
SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT;
KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder()
.vectorDataType(dataType)
.versionCreated(CURRENT)
.dimension(dimension)
.build();
final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, spaceType, methodComponentContext);

ParseContext.Document document = new ParseContext.Document();
ContentPath contentPath = new ContentPath();
ParseContext parseContext = mock(ParseContext.class);
when(parseContext.doc()).thenReturn(document);
when(parseContext.path()).thenReturn(contentPath);

utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true);
MethodFieldMapper methodFieldMapper = Mockito.spy(
MethodFieldMapper.createFieldMapper(
TEST_FIELD_NAME,
TEST_FIELD_NAME,
Collections.emptyMap(),
knnMethodContext,
knnMethodConfigContext,
knnMethodContext,
FieldMapper.MultiFields.empty(),
FieldMapper.CopyTo.empty(),
new Explicit<>(true, true),
false,
false
)
);

if (dataType == VectorDataType.FLOAT) {
doReturn(Optional.of(TEST_VECTOR)).when(methodFieldMapper).getFloatsFromContext(parseContext, dimension);
} else {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper).getBytesFromContext(parseContext, dimension, dataType);
}
methodFieldMapper.parseCreateField(parseContext, dimension, dataType);

List<IndexableField> fields = document.getFields();
assertEquals(1, fields.size());
IndexableField field1 = fields.get(0);
if (dataType == VectorDataType.FLOAT) {
assertTrue(field1 instanceof KnnFloatVectorField);
assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.FLOAT32);
} else {
assertTrue(field1 instanceof KnnByteVectorField);
assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.BYTE);
}

assertEquals(field1.fieldType().vectorDimension(), adjustDimensionForSearch(dimension, dataType));
assertEquals(Integer.parseInt(field1.fieldType().getAttributes().get(DIMENSION_FIELD_NAME)), dimension);
assertEquals(
field1.fieldType().vectorSimilarityFunction(),
SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction()
);

utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false);

document = new ParseContext.Document();
contentPath = new ContentPath();
when(parseContext.doc()).thenReturn(document);
when(parseContext.path()).thenReturn(contentPath);
methodFieldMapper = Mockito.spy(
MethodFieldMapper.createFieldMapper(
TEST_FIELD_NAME,
TEST_FIELD_NAME,
Collections.emptyMap(),
knnMethodContext,
knnMethodConfigContext,
knnMethodContext,
FieldMapper.MultiFields.empty(),
FieldMapper.CopyTo.empty(),
new Explicit<>(true, true),
false,
false
)
);

if (dataType == VectorDataType.FLOAT) {
doReturn(Optional.of(TEST_VECTOR)).when(methodFieldMapper).getFloatsFromContext(parseContext, dimension);
} else {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper).getBytesFromContext(parseContext, dimension, dataType);
}

methodFieldMapper.parseCreateField(parseContext, dimension, dataType);
fields = document.getFields();
assertEquals(1, fields.size());
field1 = fields.get(0);
assertTrue(field1 instanceof VectorField);
assertEquals(Integer.parseInt(field1.fieldType().getAttributes().get(DIMENSION_FIELD_NAME)), dimension);
}
}
// making sure to close the static mock to ensure that for tests running on this thread are not impacted by
// this mocking
// utilMockedStatic.close();
}

assertEquals(field1.fieldType().vectorDimension(), TEST_DIMENSION);
assertEquals(
field1.fieldType().vectorSimilarityFunction(),
SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction()
);

utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false);

document = new ParseContext.Document();
contentPath = new ContentPath();
when(parseContext.doc()).thenReturn(document);
when(parseContext.path()).thenReturn(contentPath);
methodFieldMapper = Mockito.spy(
MethodFieldMapper.createFieldMapper(
TEST_FIELD_NAME,
TEST_FIELD_NAME,
Collections.emptyMap(),
knnMethodContext,
knnMethodConfigContext,
knnMethodContext,
FieldMapper.MultiFields.empty(),
FieldMapper.CopyTo.empty(),
new Explicit<>(true, true),
false,
false
)
);

if (dataType == VectorDataType.FLOAT) {
doReturn(Optional.of(TEST_VECTOR)).when(methodFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION);
} else {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper)
.getBytesFromContext(parseContext, dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION, dataType);
@SneakyThrows
public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTypes() {
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_IVF, Collections.emptyMap());
try (
MockedStatic<KNNVectorFieldMapperUtil> utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class);
MockedStatic<ModelUtil> modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class)
) {
for (VectorDataType dataType : VectorDataType.values()) {
log.info("Vector Data Type is : {}", dataType);
SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT;
int dimension = adjustDimensionForIndexing(TEST_DIMENSION, dataType);
when(modelDao.getMetadata(MODEL_ID)).thenReturn(modelMetadata);
modelUtilMockedStatic.when(() -> ModelUtil.isModelCreated(modelMetadata)).thenReturn(true);
when(modelMetadata.getDimension()).thenReturn(dimension);
when(modelMetadata.getVectorDataType()).thenReturn(dataType);
when(modelMetadata.getSpaceType()).thenReturn(spaceType);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getMethodComponentContext()).thenReturn(methodComponentContext);

ParseContext.Document document = new ParseContext.Document();
ContentPath contentPath = new ContentPath();
ParseContext parseContext = mock(ParseContext.class);
when(parseContext.doc()).thenReturn(document);
when(parseContext.path()).thenReturn(contentPath);

utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true);
ModelFieldMapper modelFieldMapper = Mockito.spy(
ModelFieldMapper.createFieldMapper(
TEST_FIELD_NAME,
TEST_FIELD_NAME,
Collections.emptyMap(),
dataType,
MODEL_ID,
FieldMapper.MultiFields.empty(),
FieldMapper.CopyTo.empty(),
new Explicit<>(true, true),
false,
false,
modelDao,
CURRENT
)
);

if (dataType == VectorDataType.FLOAT) {
doReturn(Optional.of(TEST_VECTOR)).when(modelFieldMapper).getFloatsFromContext(parseContext, dimension);

} else {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper).getBytesFromContext(parseContext, dimension, dataType);
}

modelFieldMapper.parseCreateField(parseContext);

List<IndexableField> fields = document.getFields();
assertEquals(1, fields.size());
IndexableField field1 = fields.get(0);
if (dataType == VectorDataType.FLOAT) {
assertTrue(field1 instanceof KnnFloatVectorField);
assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.FLOAT32);
} else {
assertTrue(field1 instanceof KnnByteVectorField);
assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.BYTE);
}

assertEquals(field1.fieldType().vectorDimension(), adjustDimensionForSearch(dimension, dataType));
assertEquals(
field1.fieldType().vectorSimilarityFunction(),
SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction()
);

utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false);

document = new ParseContext.Document();
contentPath = new ContentPath();
when(parseContext.doc()).thenReturn(document);
when(parseContext.path()).thenReturn(contentPath);
modelFieldMapper = Mockito.spy(
ModelFieldMapper.createFieldMapper(
TEST_FIELD_NAME,
TEST_FIELD_NAME,
Collections.emptyMap(),
dataType,
MODEL_ID,
FieldMapper.MultiFields.empty(),
FieldMapper.CopyTo.empty(),
new Explicit<>(true, true),
false,
false,
modelDao,
CURRENT
)
);

if (dataType == VectorDataType.FLOAT) {
doReturn(Optional.of(TEST_VECTOR)).when(modelFieldMapper).getFloatsFromContext(parseContext, dimension);
} else {
doReturn(Optional.of(TEST_BYTE_VECTOR)).when(modelFieldMapper).getBytesFromContext(parseContext, dimension, dataType);
}

modelFieldMapper.parseCreateField(parseContext);
fields = document.getFields();
assertEquals(1, fields.size());
field1 = fields.get(0);
assertTrue(field1 instanceof VectorField);
}

methodFieldMapper.parseCreateField(parseContext, dimension, dataType);
fields = document.getFields();
assertEquals(1, fields.size());
field1 = fields.get(0);
assertTrue(field1 instanceof VectorField);
}
// making sure to close the static mock to ensure that for tests running on this thread are not impacted by
// this mocking
utilMockedStatic.close();
// utilMockedStatic.close();
// modelUtilMockedStatic.close();
}

@SneakyThrows
Expand Down

0 comments on commit 964ac3e

Please sign in to comment.