Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate KNNVectorValues with vector ANN Search flow #1952

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.16...2.x)
### Features
* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945)
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
### Enhancements
### Bug Fixes
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945)
* Disallow a vector field to have an invalid character for a physical file name. [#1936] (https://github.com/opensearch-project/k-NN/pull/1936)
* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
* Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936)
### Infrastructure
### Documentation
### Maintenance
* Fix a flaky unit test:testMultiFieldsKnnIndex, which was failing due to inconsistent merge behaviors [#1924](https://github.com/opensearch-project/k-NN/pull/1924)
### Refactoring
* Introduce KNNVectorValues interface to iterate on different types of Vector values during indexing and search [#1897](https://github.com/opensearch-project/k-NN/pull/1897)
* Integrate KNNVectorValues with vector ANN Search flow [#1952](https://github.com/opensearch-project/k-NN/pull/1952)
* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824)
* Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913)
* Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920)
Expand Down
37 changes: 37 additions & 0 deletions src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.common;

import lombok.experimental.UtilityClass;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.index.FieldInfo;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

/**
* A utility class to extract information from FieldInfo.
*/
@UtilityClass
public class FieldInfoExtractor {

/**
* Extract vector data type from fieldInfo
* @param fieldInfo {@link FieldInfo}
* @return {@link VectorDataType}
*/
public static VectorDataType extractVectorDataType(final FieldInfo fieldInfo) {
String vectorDataTypeString = fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD);
if (StringUtils.isEmpty(vectorDataTypeString)) {
final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
if (modelMetadata != null) {
VectorDataType vectorDataType = modelMetadata.getVectorDataType();
vectorDataTypeString = vectorDataType == null ? null : vectorDataType.getValue();
}
}
return StringUtils.isNotEmpty(vectorDataTypeString) ? VectorDataType.get(vectorDataTypeString) : VectorDataType.DEFAULT;
}
}
22 changes: 15 additions & 7 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
Expand Down Expand Up @@ -43,6 +41,10 @@
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNByteIterator;
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
Expand Down Expand Up @@ -412,25 +414,31 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont
private KNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) throws IOException {
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = getSpaceType(fieldInfo);
if (VectorDataType.BINARY == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, leafReaderContext.reader());
return knnQuery.getParentsFilter() == null
? new FilteredIdsKNNByteIterator(filterIdsBitSet, knnQuery.getByteQueryVector(), values, spaceType)
? new FilteredIdsKNNByteIterator(
filterIdsBitSet,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType
)
: new NestedFilteredIdsKNNByteIterator(
filterIdsBitSet,
knnQuery.getByteQueryVector(),
values,
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
} else {
final KNNVectorValues<float[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, leafReaderContext.reader());
return knnQuery.getParentsFilter() == null
? new FilteredIdsKNNIterator(filterIdsBitSet, knnQuery.getQueryVector(), values, spaceType)
? new FilteredIdsKNNIterator(filterIdsBitSet, knnQuery.getQueryVector(), (KNNFloatVectorValues) vectorValues, spaceType)
: new NestedFilteredIdsKNNIterator(
filterIdsBitSet,
knnQuery.getQueryVector(),
values,
(KNNFloatVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.ByteArrayInputStream;
import java.io.IOException;

/**
Expand All @@ -26,21 +24,21 @@ public class FilteredIdsKNNByteIterator implements KNNIterator {
protected final BitSet filterIdsBitSet;
protected final BitSetIterator bitSetIterator;
protected final byte[] queryVector;
protected final BinaryDocValues binaryDocValues;
protected final KNNBinaryVectorValues binaryVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public FilteredIdsKNNByteIterator(
final BitSet filterIdsBitSet,
final byte[] queryVector,
final BinaryDocValues binaryDocValues,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType
) {
this.filterIdsBitSet = filterIdsBitSet;
this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryDocValues = binaryDocValues;
this.binaryVectorValues = binaryVectorValues;
this.spaceType = spaceType;
this.docId = bitSetIterator.nextDoc();
}
Expand All @@ -57,7 +55,7 @@ public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}
int doc = binaryDocValues.advance(docId);
int doc = binaryVectorValues.advance(docId);
currentScore = computeScore();
docId = bitSetIterator.nextDoc();
return doc;
Expand All @@ -69,9 +67,7 @@ public float score() {
}

protected float computeScore() throws IOException {
final BytesRef value = binaryDocValues.binaryValue();
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final byte[] vector = byteStream.readAllBytes();
final byte[] vector = binaryVectorValues.getVector();
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;

import java.io.IOException;

Expand All @@ -27,21 +24,21 @@ public class FilteredIdsKNNIterator implements KNNIterator {
protected final BitSet filterIdsBitSet;
protected final BitSetIterator bitSetIterator;
protected final float[] queryVector;
protected final BinaryDocValues binaryDocValues;
protected final KNNFloatVectorValues knnFloatVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public FilteredIdsKNNIterator(
final BitSet filterIdsBitSet,
final float[] queryVector,
final BinaryDocValues binaryDocValues,
final KNNFloatVectorValues knnFloatVectorValues,
final SpaceType spaceType
) {
this.filterIdsBitSet = filterIdsBitSet;
this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryDocValues = binaryDocValues;
this.knnFloatVectorValues = knnFloatVectorValues;
this.spaceType = spaceType;
this.docId = bitSetIterator.nextDoc();
}
Expand All @@ -58,7 +55,7 @@ public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}
int doc = binaryDocValues.advance(docId);
int doc = knnFloatVectorValues.advance(docId);
currentScore = computeScore();
docId = bitSetIterator.nextDoc();
return doc;
Expand All @@ -70,9 +67,7 @@ public float score() {
}

protected float computeScore() throws IOException {
final BytesRef value = binaryDocValues.binaryValue();
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(value);
final float[] vector = vectorSerializer.byteToFloatArray(value);
final float[] vector = knnFloatVectorValues.getVector();
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.IOException;

Expand All @@ -22,11 +22,11 @@ public class NestedFilteredIdsKNNByteIterator extends FilteredIdsKNNByteIterator
public NestedFilteredIdsKNNByteIterator(
final BitSet filterIdsArray,
final byte[] queryVector,
final BinaryDocValues values,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) {
super(filterIdsArray, queryVector, values, spaceType);
super(filterIdsArray, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

Expand All @@ -47,7 +47,7 @@ public int nextDoc() throws IOException {
int bestChild = -1;

while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
binaryDocValues.advance(docId);
binaryVectorValues.advance(docId);
float score = computeScore();
if (score > currentScore) {
bestChild = docId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;

import java.io.IOException;

Expand All @@ -22,11 +22,11 @@ public class NestedFilteredIdsKNNIterator extends FilteredIdsKNNIterator {
public NestedFilteredIdsKNNIterator(
final BitSet filterIdsArray,
final float[] queryVector,
final BinaryDocValues values,
final KNNFloatVectorValues knnFloatVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) {
super(filterIdsArray, queryVector, values, spaceType);
super(filterIdsArray, queryVector, knnFloatVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

Expand All @@ -47,7 +47,7 @@ public int nextDoc() throws IOException {
int bestChild = -1;

while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
binaryDocValues.advance(docId);
knnFloatVectorValues.advance(docId);
float score = computeScore();
if (score > currentScore) {
bestChild = docId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

package org.opensearch.knn.index.vectorvalues;

import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.VectorDataType;

import java.io.IOException;
import java.util.Map;

/**
Expand All @@ -21,7 +27,7 @@ public final class KNNVectorValuesFactory {
*
* @param vectorDataType {@link VectorDataType}
* @param docIdSetIterator {@link DocIdSetIterator}
* @return {@link KNNVectorValues} of type float[]
* @return {@link KNNVectorValues}
*/
public static <T> KNNVectorValues<T> getVectorValues(final VectorDataType vectorDataType, final DocIdSetIterator docIdSetIterator) {
return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator));
Expand All @@ -32,7 +38,7 @@ public static <T> KNNVectorValues<T> getVectorValues(final VectorDataType vector
*
* @param vectorDataType {@link VectorDataType}
* @param docIdWithFieldSet {@link DocsWithFieldSet}
* @return {@link KNNVectorValues} of type float[]
* @return {@link KNNVectorValues}
*/
public static <T> KNNVectorValues<T> getVectorValues(
final VectorDataType vectorDataType,
Expand All @@ -42,6 +48,30 @@ public static <T> KNNVectorValues<T> getVectorValues(
return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues<T>(docIdWithFieldSet, vectors));
}

/**
* Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader}
*
* @param fieldInfo {@link FieldInfo}
* @param leafReader {@link LeafReader}
* @return {@link KNNVectorValues}
*/
public static <T> KNNVectorValues<T> getVectorValues(final FieldInfo fieldInfo, final LeafReader leafReader) throws IOException {
final DocIdSetIterator docIdSetIterator;
if (fieldInfo.hasVectorValues()) {
if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) {
docIdSetIterator = leafReader.getByteVectorValues(fieldInfo.getName());
} else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) {
docIdSetIterator = leafReader.getFloatVectorValues(fieldInfo.getName());
} else {
throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues");
}
} else {
docIdSetIterator = DocValues.getBinary(leafReader, fieldInfo.getName());
}
final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator);
return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator);
}

@SuppressWarnings("unchecked")
private static <T> KNNVectorValues<T> getVectorValues(
final VectorDataType vectorDataType,
Expand Down
Loading
Loading