From a4697f4e4e940d7faedb4fd582cf411115fc792c Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Wed, 14 Aug 2024 09:37:50 -0700 Subject: [PATCH 1/2] Integrate KNNVectorValues with vector ANN Search flow (#1952) Signed-off-by: Navneet Verma --- CHANGELOG.md | 7 +- .../knn/common/FieldInfoExtractor.java | 37 ++++++++ .../opensearch/knn/index/query/KNNWeight.java | 22 +++-- .../filtered/FilteredIdsKNNByteIterator.java | 16 ++-- .../filtered/FilteredIdsKNNIterator.java | 17 ++-- .../NestedFilteredIdsKNNByteIterator.java | 8 +- .../NestedFilteredIdsKNNIterator.java | 8 +- .../vectorvalues/KNNVectorValuesFactory.java | 34 ++++++- .../org/opensearch/knn/indices/ModelUtil.java | 21 +++++ .../knn/common/FieldInfoExtractorTests.java | 44 +++++++++ .../knn/index/query/KNNWeightTests.java | 34 ++++--- .../FilteredIdsKNNByteIteratorTests.java | 8 +- .../filtered/FilteredIdsKNNIteratorTests.java | 11 +-- ...NestedFilteredIdsKNNByteIteratorTests.java | 8 +- .../NestedFilteredIdsKNNIteratorTests.java | 15 ++- .../KNNVectorValuesFactoryTests.java | 91 +++++++++++++++++++ .../knn/indices/ModelUtilTests.java | 36 ++++++++ 17 files changed, 338 insertions(+), 79 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java create mode 100644 src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java create mode 100644 src/test/java/org/opensearch/knn/indices/ModelUtilTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index dfca7bece..a5c641b8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,19 +15,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.16...2.x) ### Features +* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) ### Enhancements ### Bug Fixes * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) -* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) -* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) -* Disallow a vector field to have an invalid character for a physical file name. [#1936] (https://github.com/opensearch-project/k-NN/pull/1936) +* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) +* Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936) ### Infrastructure ### Documentation ### Maintenance * Fix a flaky unit test:testMultiFieldsKnnIndex, which was failing due to inconsistent merge behaviors [#1924](https://github.com/opensearch-project/k-NN/pull/1924) ### Refactoring * Introduce KNNVectorValues interface to iterate on different types of Vector values during indexing and search [#1897](https://github.com/opensearch-project/k-NN/pull/1897) +* Integrate KNNVectorValues with vector ANN Search flow [#1952](https://github.com/opensearch-project/k-NN/pull/1952) * Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824) * Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913) * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) diff --git a/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java b/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java new file mode 100644 index 000000000..591f16735 --- /dev/null +++ b/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common; + +import lombok.experimental.UtilityClass; +import org.apache.commons.lang.StringUtils; +import org.apache.lucene.index.FieldInfo; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; + +/** + * A utility class to extract information from FieldInfo. + */ +@UtilityClass +public class FieldInfoExtractor { + + /** + * Extract vector data type from fieldInfo + * @param fieldInfo {@link FieldInfo} + * @return {@link VectorDataType} + */ + public static VectorDataType extractVectorDataType(final FieldInfo fieldInfo) { + String vectorDataTypeString = fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD); + if (StringUtils.isEmpty(vectorDataTypeString)) { + final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID)); + if (modelMetadata != null) { + VectorDataType vectorDataType = modelMetadata.getVectorDataType(); + vectorDataTypeString = vectorDataType == null ? null : vectorDataType.getValue(); + } + } + return StringUtils.isNotEmpty(vectorDataTypeString) ? VectorDataType.get(vectorDataTypeString) : VectorDataType.DEFAULT; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index f88652525..df40f6850 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -8,8 +8,6 @@ import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.StringUtils; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.DocValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; @@ -43,6 +41,10 @@ import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNByteIterator; import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; @@ -412,25 +414,31 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont private KNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) throws IOException { final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); - final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); final SpaceType spaceType = getSpaceType(fieldInfo); if (VectorDataType.BINARY == knnQuery.getVectorDataType()) { + final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, leafReaderContext.reader()); return knnQuery.getParentsFilter() == null - ? new FilteredIdsKNNByteIterator(filterIdsBitSet, knnQuery.getByteQueryVector(), values, spaceType) + ? new FilteredIdsKNNByteIterator( + filterIdsBitSet, + knnQuery.getByteQueryVector(), + (KNNBinaryVectorValues) vectorValues, + spaceType + ) : new NestedFilteredIdsKNNByteIterator( filterIdsBitSet, knnQuery.getByteQueryVector(), - values, + (KNNBinaryVectorValues) vectorValues, spaceType, knnQuery.getParentsFilter().getBitSet(leafReaderContext) ); } else { + final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, leafReaderContext.reader()); return knnQuery.getParentsFilter() == null - ? new FilteredIdsKNNIterator(filterIdsBitSet, knnQuery.getQueryVector(), values, spaceType) + ? new FilteredIdsKNNIterator(filterIdsBitSet, knnQuery.getQueryVector(), (KNNFloatVectorValues) vectorValues, spaceType) : new NestedFilteredIdsKNNIterator( filterIdsBitSet, knnQuery.getQueryVector(), - values, + (KNNFloatVectorValues) vectorValues, spaceType, knnQuery.getParentsFilter().getBitSet(leafReaderContext) ); diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java index 815e621f6..ccfe626a0 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java @@ -5,14 +5,12 @@ package org.opensearch.knn.index.query.filtered; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; -import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; -import java.io.ByteArrayInputStream; import java.io.IOException; /** @@ -26,7 +24,7 @@ public class FilteredIdsKNNByteIterator implements KNNIterator { protected final BitSet filterIdsBitSet; protected final BitSetIterator bitSetIterator; protected final byte[] queryVector; - protected final BinaryDocValues binaryDocValues; + protected final KNNBinaryVectorValues binaryVectorValues; protected final SpaceType spaceType; protected float currentScore = Float.NEGATIVE_INFINITY; protected int docId; @@ -34,13 +32,13 @@ public class FilteredIdsKNNByteIterator implements KNNIterator { public FilteredIdsKNNByteIterator( final BitSet filterIdsBitSet, final byte[] queryVector, - final BinaryDocValues binaryDocValues, + final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType ) { this.filterIdsBitSet = filterIdsBitSet; this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); this.queryVector = queryVector; - this.binaryDocValues = binaryDocValues; + this.binaryVectorValues = binaryVectorValues; this.spaceType = spaceType; this.docId = bitSetIterator.nextDoc(); } @@ -57,7 +55,7 @@ public int nextDoc() throws IOException { if (docId == DocIdSetIterator.NO_MORE_DOCS) { return DocIdSetIterator.NO_MORE_DOCS; } - int doc = binaryDocValues.advance(docId); + int doc = binaryVectorValues.advance(docId); currentScore = computeScore(); docId = bitSetIterator.nextDoc(); return doc; @@ -69,9 +67,7 @@ public float score() { } protected float computeScore() throws IOException { - final BytesRef value = binaryDocValues.binaryValue(); - final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); - final byte[] vector = byteStream.readAllBytes(); + final byte[] vector = binaryVectorValues.getVector(); // Calculates a similarity score between the two vectors with a specified function. Higher similarity // scores correspond to closer vectors. return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java index 7e554fb7d..a0d7694c9 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java @@ -5,14 +5,11 @@ package org.opensearch.knn.index.query.filtered; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; -import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import java.io.IOException; @@ -27,7 +24,7 @@ public class FilteredIdsKNNIterator implements KNNIterator { protected final BitSet filterIdsBitSet; protected final BitSetIterator bitSetIterator; protected final float[] queryVector; - protected final BinaryDocValues binaryDocValues; + protected final KNNFloatVectorValues knnFloatVectorValues; protected final SpaceType spaceType; protected float currentScore = Float.NEGATIVE_INFINITY; protected int docId; @@ -35,13 +32,13 @@ public class FilteredIdsKNNIterator implements KNNIterator { public FilteredIdsKNNIterator( final BitSet filterIdsBitSet, final float[] queryVector, - final BinaryDocValues binaryDocValues, + final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType ) { this.filterIdsBitSet = filterIdsBitSet; this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); this.queryVector = queryVector; - this.binaryDocValues = binaryDocValues; + this.knnFloatVectorValues = knnFloatVectorValues; this.spaceType = spaceType; this.docId = bitSetIterator.nextDoc(); } @@ -58,7 +55,7 @@ public int nextDoc() throws IOException { if (docId == DocIdSetIterator.NO_MORE_DOCS) { return DocIdSetIterator.NO_MORE_DOCS; } - int doc = binaryDocValues.advance(docId); + int doc = knnFloatVectorValues.advance(docId); currentScore = computeScore(); docId = bitSetIterator.nextDoc(); return doc; @@ -70,9 +67,7 @@ public float score() { } protected float computeScore() throws IOException { - final BytesRef value = binaryDocValues.binaryValue(); - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(value); - final float[] vector = vectorSerializer.byteToFloatArray(value); + final float[] vector = knnFloatVectorValues.getVector(); // Calculates a similarity score between the two vectors with a specified function. Higher similarity // scores correspond to closer vectors. return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java index 80fba1e41..b69a90518 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java @@ -5,10 +5,10 @@ package org.opensearch.knn.index.query.filtered; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; import java.io.IOException; @@ -22,11 +22,11 @@ public class NestedFilteredIdsKNNByteIterator extends FilteredIdsKNNByteIterator public NestedFilteredIdsKNNByteIterator( final BitSet filterIdsArray, final byte[] queryVector, - final BinaryDocValues values, + final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) { - super(filterIdsArray, queryVector, values, spaceType); + super(filterIdsArray, queryVector, binaryVectorValues, spaceType); this.parentBitSet = parentBitSet; } @@ -47,7 +47,7 @@ public int nextDoc() throws IOException { int bestChild = -1; while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { - binaryDocValues.advance(docId); + binaryVectorValues.advance(docId); float score = computeScore(); if (score > currentScore) { bestChild = docId; diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java index 9776ebbe9..259b004f8 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java @@ -5,10 +5,10 @@ package org.opensearch.knn.index.query.filtered; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import java.io.IOException; @@ -22,11 +22,11 @@ public class NestedFilteredIdsKNNIterator extends FilteredIdsKNNIterator { public NestedFilteredIdsKNNIterator( final BitSet filterIdsArray, final float[] queryVector, - final BinaryDocValues values, + final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) { - super(filterIdsArray, queryVector, values, spaceType); + super(filterIdsArray, queryVector, knnFloatVectorValues, spaceType); this.parentBitSet = parentBitSet; } @@ -47,7 +47,7 @@ public int nextDoc() throws IOException { int bestChild = -1; while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { - binaryDocValues.advance(docId); + knnFloatVectorValues.advance(docId); float score = computeScore(); if (score > currentScore) { bestChild = docId; diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java index 5b6558f32..41408e217 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -5,10 +5,16 @@ package org.opensearch.knn.index.vectorvalues; +import org.apache.lucene.index.DocValues; import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.DocIdSetIterator; +import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.VectorDataType; +import java.io.IOException; import java.util.Map; /** @@ -21,7 +27,7 @@ public final class KNNVectorValuesFactory { * * @param vectorDataType {@link VectorDataType} * @param docIdSetIterator {@link DocIdSetIterator} - * @return {@link KNNVectorValues} of type float[] + * @return {@link KNNVectorValues} */ public static KNNVectorValues getVectorValues(final VectorDataType vectorDataType, final DocIdSetIterator docIdSetIterator) { return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator)); @@ -32,7 +38,7 @@ public static KNNVectorValues getVectorValues(final VectorDataType vector * * @param vectorDataType {@link VectorDataType} * @param docIdWithFieldSet {@link DocsWithFieldSet} - * @return {@link KNNVectorValues} of type float[] + * @return {@link KNNVectorValues} */ public static KNNVectorValues getVectorValues( final VectorDataType vectorDataType, @@ -42,6 +48,30 @@ public static KNNVectorValues getVectorValues( return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues(docIdWithFieldSet, vectors)); } + /** + * Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader} + * + * @param fieldInfo {@link FieldInfo} + * @param leafReader {@link LeafReader} + * @return {@link KNNVectorValues} + */ + public static KNNVectorValues getVectorValues(final FieldInfo fieldInfo, final LeafReader leafReader) throws IOException { + final DocIdSetIterator docIdSetIterator; + if (fieldInfo.hasVectorValues()) { + if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) { + docIdSetIterator = leafReader.getByteVectorValues(fieldInfo.getName()); + } else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) { + docIdSetIterator = leafReader.getFloatVectorValues(fieldInfo.getName()); + } else { + throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues"); + } + } else { + docIdSetIterator = DocValues.getBinary(leafReader, fieldInfo.getName()); + } + final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator); + return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); + } + @SuppressWarnings("unchecked") private static KNNVectorValues getVectorValues( final VectorDataType vectorDataType, diff --git a/src/main/java/org/opensearch/knn/indices/ModelUtil.java b/src/main/java/org/opensearch/knn/indices/ModelUtil.java index 4c6230a46..0f5a049fc 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelUtil.java +++ b/src/main/java/org/opensearch/knn/indices/ModelUtil.java @@ -11,6 +11,10 @@ package org.opensearch.knn.indices; +import org.apache.commons.lang.StringUtils; + +import java.util.Locale; + /** * A utility class for models. */ @@ -33,4 +37,21 @@ public static boolean isModelCreated(ModelMetadata modelMetadata) { return modelMetadata.getState().equals(ModelState.CREATED); } + /** + * Gets Model Metadata from a given model id. + * @param modelId {@link String} + * @return {@link ModelMetadata} + */ + public static ModelMetadata getModelMetadata(final String modelId) { + if (StringUtils.isEmpty(modelId)) { + return null; + } + final Model model = ModelCache.getInstance().get(modelId); + final ModelMetadata modelMetadata = model.getModelMetadata(); + if (ModelUtil.isModelCreated(modelMetadata) == false) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId)); + } + return modelMetadata; + } + } diff --git a/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java b/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java new file mode 100644 index 000000000..e86a153d3 --- /dev/null +++ b/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common; + +import org.apache.lucene.index.FieldInfo; +import org.junit.Assert; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; + +public class FieldInfoExtractorTests extends KNNTestCase { + + private static final String MODEL_ID = "model_id"; + + public void testExtractVectorDataType_whenDifferentConditions_thenSuccess() { + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + MockedStatic modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class); + + // default case + Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(null); + Mockito.when(fieldInfo.getAttribute(KNNConstants.MODEL_ID)).thenReturn(MODEL_ID); + modelUtilMockedStatic.when(() -> ModelUtil.getModelMetadata(MODEL_ID)).thenReturn(null); + Assert.assertEquals(VectorDataType.DEFAULT, FieldInfoExtractor.extractVectorDataType(fieldInfo)); + + // VectorDataType present in fieldInfo + Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(VectorDataType.BINARY.getValue()); + Assert.assertEquals(VectorDataType.BINARY, FieldInfoExtractor.extractVectorDataType(fieldInfo)); + + // VectorDataType present in ModelMetadata + ModelMetadata modelMetadata = Mockito.mock(ModelMetadata.class); + Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(null); + modelUtilMockedStatic.when(() -> ModelUtil.getModelMetadata(MODEL_ID)).thenReturn(modelMetadata); + Mockito.when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.BYTE); + Assert.assertEquals(VectorDataType.BYTE, FieldInfoExtractor.extractVectorDataType(fieldInfo)); + + modelUtilMockedStatic.close(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index c5abc964d..15402a148 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -33,6 +33,7 @@ import org.junit.Before; import org.junit.BeforeClass; import org.mockito.MockedStatic; +import org.mockito.Mockito; import org.opensearch.common.io.PathUtils; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.unit.ByteSizeValue; @@ -46,6 +47,9 @@ import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -702,7 +706,8 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean ); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); - final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); + final KNNFloatVectorValues floatVectorValues = mock(KNNFloatVectorValues.class); + final KNNBinaryVectorValues binaryVectorValues = mock(KNNBinaryVectorValues.class); when(reader.getFieldInfos()).thenReturn(fieldInfos); when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); when(fieldInfo.attributes()).thenReturn(attributesMap); @@ -712,14 +717,15 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.getValue()); } when(fieldInfo.getName()).thenReturn(FIELD_NAME); - when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); - when(binaryDocValues.advance(filterDocId)).thenReturn(filterDocId); - BytesRef vectorByteRef = new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector)); - + MockedStatic valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class); if (isBinary) { - when(binaryDocValues.binaryValue()).thenReturn(new BytesRef(byteVector)); + valuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)).thenReturn(binaryVectorValues); + when(binaryVectorValues.advance(filterDocId)).thenReturn(filterDocId); + Mockito.when(binaryVectorValues.getVector()).thenReturn(byteVector); } else { - when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef); + valuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)).thenReturn(floatVectorValues); + when(floatVectorValues.advance(filterDocId)).thenReturn(filterDocId); + Mockito.when(floatVectorValues.getVector()).thenReturn(vector); } final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); @@ -739,6 +745,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean } assertEquals(docIdSetIterator.cost(), actualDocIds.size()); assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + valuesFactoryMockedStatic.close(); } @SneakyThrows @@ -909,16 +916,18 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSettingOnBinaryInd ); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); - final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); when(reader.getFieldInfos()).thenReturn(fieldInfos); when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); when(fieldInfo.attributes()).thenReturn(attributesMap); when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.HAMMING.getValue()); when(fieldInfo.getName()).thenReturn(FIELD_NAME); - when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); - when(binaryDocValues.advance(0)).thenReturn(0); - BytesRef vectorByteRef = new BytesRef(vector); - when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef); + + KNNBinaryVectorValues knnBinaryVectorValues = mock(KNNBinaryVectorValues.class); + MockedStatic vectorValuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class); + vectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)) + .thenReturn(knnBinaryVectorValues); + when(knnBinaryVectorValues.advance(0)).thenReturn(0); + when(knnBinaryVectorValues.getVector()).thenReturn(vector); final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); assertNotNull(knnScorer); @@ -933,6 +942,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSettingOnBinaryInd } assertEquals(docIdSetIterator.cost(), actualDocIds.size()); assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + vectorValuesFactoryMockedStatic.close(); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java index 7583f50bc..c52798c05 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java @@ -7,11 +7,10 @@ import junit.framework.TestCase; import lombok.SneakyThrows; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; import java.util.Arrays; import java.util.List; @@ -31,9 +30,8 @@ public void testNextDoc_whenCalled_IterateAllDocs() { .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) .collect(Collectors.toList()); - BinaryDocValues values = mock(BinaryDocValues.class); - final List byteRefs = dataVectors.stream().map(vector -> new BytesRef(vector)).collect(Collectors.toList()); - when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); FixedBitSet filterBitSet = new FixedBitSet(4); for (int id : filterIds) { diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java index cf8582a05..731eed2cc 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java @@ -6,13 +6,11 @@ package org.opensearch.knn.index.query.filtered; import lombok.SneakyThrows; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import java.util.Arrays; import java.util.List; @@ -36,11 +34,8 @@ public void testNextDoc_whenCalled_IterateAllDocs() { .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) .collect(Collectors.toList()); - BinaryDocValues values = mock(BinaryDocValues.class); - final List byteRefs = dataVectors.stream() - .map(vector -> new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector))) - .collect(Collectors.toList()); - when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + KNNFloatVectorValues values = mock(KNNFloatVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); FixedBitSet filterBitSet = new FixedBitSet(4); for (int id : filterIds) { diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java index c4b7859d0..1940ffe12 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java @@ -7,12 +7,11 @@ import junit.framework.TestCase; import lombok.SneakyThrows; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; import java.util.Arrays; import java.util.List; @@ -36,9 +35,8 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) .collect(Collectors.toList()); - BinaryDocValues values = mock(BinaryDocValues.class); - final List byteRefs = dataVectors.stream().map(vector -> new BytesRef(vector)).collect(Collectors.toList()); - when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); FixedBitSet filterBitSet = new FixedBitSet(4); for (int id : filterIds) { diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java index 508b0d3d6..cca789a4d 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java @@ -7,13 +7,11 @@ import junit.framework.TestCase; import lombok.SneakyThrows; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import java.util.Arrays; import java.util.List; @@ -41,11 +39,12 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) .collect(Collectors.toList()); - BinaryDocValues values = mock(BinaryDocValues.class); - final List byteRefs = dataVectors.stream() - .map(vector -> new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector))) - .collect(Collectors.toList()); - when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + KNNFloatVectorValues values = mock(KNNFloatVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + // final List byteRefs = dataVectors.stream() + // .map(vector -> new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector))) + // .collect(Collectors.toList()); + // when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); FixedBitSet filterBitSet = new FixedBitSet(4); for (int id : filterIds) { diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactoryTests.java b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactoryTests.java index 9827cb03b..a717aa2c2 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactoryTests.java @@ -5,12 +5,19 @@ package org.opensearch.knn.index.vectorvalues; +import lombok.SneakyThrows; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.index.VectorEncoding; import org.junit.Assert; +import org.mockito.Mockito; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; +import java.util.List; import java.util.Map; public class KNNVectorValuesFactoryTests extends KNNTestCase { @@ -58,4 +65,88 @@ public void testGetVectorValuesUsingDocWithFieldSet_whenValidInput_thenSuccess() Assert.assertNotNull(binaryVectorValues); } + @SneakyThrows + public void testGetVectorValuesFromFieldInfo_whenVectorDimIsNotZero_thenSuccess() { + final List byteArrayList = List.of(new byte[] { 1, 2, 3 }); + final List floatArrayList = List.of(new float[] { 1.3f, 2.2f, 3.2f }); + final List binaryArrayList = List.of(new byte[] { 3, 2, 3 }); + final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + final SegmentReader reader = Mockito.mock(SegmentReader.class); + Mockito.when(fieldInfo.hasVectorValues()).thenReturn(true); + Mockito.when(fieldInfo.getName()).thenReturn("test_field"); + + // Checking for ByteVectorValues + Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); + Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(VectorDataType.BYTE.getValue()); + Mockito.when(reader.getByteVectorValues("test_field")).thenReturn(new TestVectorValues.PreDefinedByteVectorValues(byteArrayList)); + final KNNVectorValues byteVectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + byteVectorValues.nextDoc(); + Assert.assertArrayEquals(byteArrayList.get(0), byteVectorValues.getVector()); + Assert.assertNotNull(byteVectorValues); + + // Checking for FloatVectorValues + Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); + Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(VectorDataType.FLOAT.getValue()); + Mockito.when(reader.getFloatVectorValues("test_field")) + .thenReturn(new TestVectorValues.PreDefinedFloatVectorValues(floatArrayList)); + final KNNVectorValues floatVectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + floatVectorValues.nextDoc(); + Assert.assertArrayEquals(floatArrayList.get(0), floatVectorValues.getVector(), 0.0f); + Assert.assertNotNull(floatVectorValues); + + // Checking for BinaryVectorValues + Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); + Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(VectorDataType.BINARY.getValue()); + Mockito.when(reader.getByteVectorValues("test_field")) + .thenReturn(new TestVectorValues.PreDefinedBinaryVectorValues(binaryArrayList)); + final KNNVectorValues binaryVectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + binaryVectorValues.nextDoc(); + Assert.assertArrayEquals(binaryArrayList.get(0), binaryVectorValues.getVector()); + Assert.assertNotNull(binaryVectorValues); + + } + + @SneakyThrows + public void testGetVectorValuesFromFieldInfo_whenVectorDimIsZero_thenSuccess() { + final List byteArrayList = List.of(new byte[] { 1, 2, 3 }); + final List floatArrayList = List.of(new float[] { 1.3f, 2.2f, 3.2f }); + final List binaryArrayList = List.of(new byte[] { 3, 2, 3 }); + final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + final SegmentReader reader = Mockito.mock(SegmentReader.class); + Mockito.when(fieldInfo.hasVectorValues()).thenReturn(false); + Mockito.when(fieldInfo.getName()).thenReturn("test_field"); + + // Checking for ByteVectorValues + Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(VectorDataType.BYTE.getValue()); + Mockito.when(reader.getBinaryDocValues("test_field")) + .thenReturn(new TestVectorValues.PredefinedByteVectorBinaryDocValues(byteArrayList)); + + final KNNVectorValues byteVectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + byteVectorValues.nextDoc(); + Assert.assertArrayEquals(byteArrayList.get(0), byteVectorValues.getVector()); + Assert.assertNotNull(byteVectorValues); + + // Checking for Floats with BinaryDocValues + Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(VectorDataType.FLOAT.getValue()); + Mockito.when(reader.getBinaryDocValues("test_field")) + .thenReturn(new TestVectorValues.PredefinedFloatVectorBinaryDocValues(floatArrayList)); + + final KNNVectorValues floatVectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + floatVectorValues.nextDoc(); + Assert.assertArrayEquals(floatArrayList.get(0), floatVectorValues.getVector(), 0.0f); + Assert.assertNotNull(floatVectorValues); + + // Checking for BinaryVectorValues + Mockito.when(fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD)).thenReturn(VectorDataType.BINARY.getValue()); + Mockito.when(reader.getBinaryDocValues("test_field")) + .thenReturn(new TestVectorValues.PredefinedByteVectorBinaryDocValues(binaryArrayList)); + + final KNNVectorValues binaryVectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + binaryVectorValues.nextDoc(); + Assert.assertArrayEquals(binaryArrayList.get(0), binaryVectorValues.getVector()); + Assert.assertNotNull(binaryVectorValues); + + Mockito.verify(fieldInfo, Mockito.times(0)).getVectorEncoding(); + } + } diff --git a/src/test/java/org/opensearch/knn/indices/ModelUtilTests.java b/src/test/java/org/opensearch/knn/indices/ModelUtilTests.java new file mode 100644 index 000000000..edefd10ee --- /dev/null +++ b/src/test/java/org/opensearch/knn/indices/ModelUtilTests.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.indices; + +import org.junit.Assert; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; + +public class ModelUtilTests extends KNNTestCase { + private static final String MODEL_ID = "test-model"; + + public void testGetModelMetadata_whenVariousInputs_thenSuccess() { + Assert.assertNull(ModelUtil.getModelMetadata(null)); + Assert.assertNull(ModelUtil.getModelMetadata("")); + + ModelCache modelCache = Mockito.mock(ModelCache.class); + Model model = Mockito.mock(Model.class); + ModelMetadata modelMetadata = Mockito.mock(ModelMetadata.class); + MockedStatic modelCacheMockedStatic = Mockito.mockStatic(ModelCache.class); + + modelCacheMockedStatic.when(ModelCache::getInstance).thenReturn(modelCache); + + Mockito.when(modelCache.get(MODEL_ID)).thenReturn(model); + Mockito.when(model.getModelMetadata()).thenReturn(null); + Assert.assertThrows(IllegalArgumentException.class, () -> ModelUtil.getModelMetadata(MODEL_ID)); + + Mockito.when(model.getModelMetadata()).thenReturn(modelMetadata); + Mockito.when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + Assert.assertNotNull(ModelUtil.getModelMetadata(MODEL_ID)); + modelCacheMockedStatic.close(); + } +} From 3557d79ab940cfd2d8439a86d3189e0817d63e33 Mon Sep 17 00:00:00 2001 From: Vikasht34 Date: Wed, 14 Aug 2024 11:19:55 -0700 Subject: [PATCH 2/2] BackPort Java Doc Fix with Code Improvements (#1959) --- .../factory/QuantizerFactory.java | 35 ++++++++++--------- .../factory/QuantizerRegistry.java | 9 +++-- .../ScalarQuantizationParams.java | 21 ++++++++--- .../MultiBitScalarQuantizationState.java | 12 +++---- .../factory/QuantizerFactoryTests.java | 12 +++---- .../factory/QuantizerRegistryTests.java | 26 +++++++------- 6 files changed, 63 insertions(+), 52 deletions(-) diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java index b99f6ebdc..6705c7688 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java @@ -21,34 +21,35 @@ public final class QuantizerFactory { private static final AtomicBoolean isRegistered = new AtomicBoolean(false); - /** - * Ensures that default quantizers are registered. - */ - private static void ensureRegistered() { - if (!isRegistered.get()) { - synchronized (QuantizerFactory.class) { - if (!isRegistered.get()) { - QuantizerRegistrar.registerDefaultQuantizers(); - isRegistered.set(true); - } - } - } - } - /** * Retrieves a quantizer instance based on the provided quantization parameters. * * @param params the quantization parameters used to determine the appropriate quantizer * @param

the type of quantization parameters, extending {@link QuantizationParams} - * @param the type of the quantized output + * @param the type of the input vector to be quantized + * @param the type of the output after quantization * @return an instance of {@link Quantizer} corresponding to the provided parameters */ - public static

Quantizer getQuantizer(final P params) { + public static

Quantizer getQuantizer(final P params) { if (params == null) { throw new IllegalArgumentException("Quantization parameters must not be null."); } // Lazy Registration instead of static block as class level; ensureRegistered(); - return QuantizerRegistry.getQuantizer(params); + return (Quantizer) QuantizerRegistry.getQuantizer(params); + } + + /** + * Ensures that default quantizers are registered. + */ + private static void ensureRegistered() { + if (!isRegistered.get()) { + synchronized (QuantizerFactory.class) { + if (!isRegistered.get()) { + QuantizerRegistrar.registerDefaultQuantizers(); + isRegistered.set(true); + } + } + } } } diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java index ac266f547..2da830f3b 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java @@ -42,18 +42,17 @@ static void register(final String paramIdentifier, final Quantizer quantiz * * @param params the quantization parameters used to determine the appropriate quantizer * @param

the type of quantization parameters - * @param the type of the quantized output + * @param the type of the input vector to be quantized + * @param the type of the output after quantization * @return an instance of {@link Quantizer} corresponding to the provided parameters * @throws IllegalArgumentException if no quantizer is registered for the given parameters */ - static

Quantizer getQuantizer(final P params) { + static

Quantizer getQuantizer(final P params) { String identifier = params.getTypeIdentifier(); Quantizer quantizer = registry.get(identifier); if (quantizer == null) { throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier); } - @SuppressWarnings("unchecked") - Quantizer typedQuantizer = (Quantizer) quantizer; - return typedQuantizer; + return (Quantizer) quantizer; } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java index 4e7a53892..881c2132d 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java @@ -47,10 +47,6 @@ public String getTypeIdentifier() { return generateIdentifier(sqType.getId()); } - private static String generateIdentifier(int id) { - return "ScalarQuantizationParams_" + id; - } - /** * Writes the object to the output stream. * This method is part of the Writeable interface and is used to serialize the object. @@ -74,4 +70,21 @@ public ScalarQuantizationParams(StreamInput in, int version) throws IOException int typeId = in.readVInt(); this.sqType = ScalarQuantizationType.fromId(typeId); } + + /** + * Generates a unique identifier for Scalar Quantization Parameters. + * + *

+ * This method constructs an identifier string by prefixing the given integer ID + * with "ScalarQuantizationParams_". The resulting string can be used to uniquely + * identify specific quantization parameter instances, especially when registering + * or retrieving them in a registry or similar structure. + *

+ * + * @param id the integer ID to be used in generating the unique identifier. + * @return a string representing the unique identifier for the quantization parameters. + */ + private static String generateIdentifier(int id) { + return "ScalarQuantizationParams_" + id; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java index 09092fde8..2778a6cf4 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -30,13 +30,13 @@ public final class MultiBitScalarQuantizationState implements QuantizationState * * For example: * - For 2-bit quantization: - * thresholds[0] -> {0.5f, 1.5f, 2.5f} // Thresholds for the first bit level - * thresholds[1] -> {1.0f, 2.0f, 3.0f} // Thresholds for the second bit level + * thresholds[0] {0.5f, 1.5f, 2.5f} // Thresholds for the first bit level + * thresholds[1] {1.0f, 2.0f, 3.0f} // Thresholds for the second bit level * - For 4-bit quantization: - * thresholds[0] -> {0.1f, 0.2f, 0.3f} - * thresholds[1] -> {0.4f, 0.5f, 0.6f} - * thresholds[2] -> {0.7f, 0.8f, 0.9f} - * thresholds[3] -> {1.0f, 1.1f, 1.2f} + * thresholds[0] {0.1f, 0.2f, 0.3f} + * thresholds[1] {0.4f, 0.5f, 0.6f} + * thresholds[2] {0.7f, 0.8f, 0.9f} + * thresholds[3] {1.0f, 1.1f, 1.2f} * * Each column represents the threshold for a specific dimension in the vector space. */ diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java index 3474b7ec9..b95123e21 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java @@ -31,12 +31,12 @@ public void test_Lazy_Registration() { ScalarQuantizationParams paramsTwoBit = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); ScalarQuantizationParams paramsFourBit = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); assertFalse(isRegisteredFieldAccessible()); - Quantizer quantizer = QuantizerFactory.getQuantizer(params); - Quantizer quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit); - Quantizer quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit); - assertTrue(quantizerFourBit instanceof MultiBitScalarQuantizer); - assertTrue(quantizerTwoBit instanceof MultiBitScalarQuantizer); - assertTrue(quantizer instanceof OneBitScalarQuantizer); + Quantizer oneBitQuantizer = QuantizerFactory.getQuantizer(params); + Quantizer quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit); + Quantizer quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit); + assertEquals(quantizerFourBit.getClass(), MultiBitScalarQuantizer.class); + assertEquals(quantizerTwoBit.getClass(), MultiBitScalarQuantizer.class); + assertEquals(oneBitQuantizer.getClass(), OneBitScalarQuantizer.class); assertTrue(isRegisteredFieldAccessible()); } diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java index dec34e632..62d31ab61 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -34,39 +34,37 @@ public static void setup() { public void testRegisterAndGetQuantizer() { // Test for OneBitScalarQuantizer ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - Quantizer oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); - assertTrue(oneBitQuantizer instanceof OneBitScalarQuantizer); + Quantizer oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); + assertEquals(oneBitQuantizer.getClass(), OneBitScalarQuantizer.class); // Test for MultiBitScalarQuantizer (2-bit) ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - Quantizer twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); - assertTrue(twoBitQuantizer instanceof MultiBitScalarQuantizer); - assertEquals(2, ((MultiBitScalarQuantizer) twoBitQuantizer).getBitsPerCoordinate()); + Quantizer twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); + assertEquals(twoBitQuantizer.getClass(), MultiBitScalarQuantizer.class); // Test for MultiBitScalarQuantizer (4-bit) ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - Quantizer fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); - assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer); - assertEquals(4, ((MultiBitScalarQuantizer) fourBitQuantizer).getBitsPerCoordinate()); + Quantizer fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); + assertEquals(fourBitQuantizer.getClass(), MultiBitScalarQuantizer.class); } public void testQuantizerRegistryIsSingleton() { // Ensure the same instance is returned for the same type identifier ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - Quantizer firstOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); - Quantizer secondOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); + Quantizer firstOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); + Quantizer secondOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); assertSame(firstOneBitQuantizer, secondOneBitQuantizer); // Ensure the same instance is returned for the same type identifier (2-bit) ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - Quantizer firstTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); - Quantizer secondTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); + Quantizer firstTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); + Quantizer secondTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); assertSame(firstTwoBitQuantizer, secondTwoBitQuantizer); // Ensure the same instance is returned for the same type identifier (4-bit) ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - Quantizer firstFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); - Quantizer secondFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); + Quantizer firstFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); + Quantizer secondFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); assertSame(firstFourBitQuantizer, secondFourBitQuantizer); }