Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert 'Support script score when doc value is disabled' #1662

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Support filter and nested field in faiss engine radial search [#1652](https://github.com/opensearch-project/k-NN/pull/1652)
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
* Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices [#1604](https://github.com/opensearch-project/k-NN/pull/1604)
* Remove unnecessary toString conversion of vector field and added some minor optimization in KNNCodec [1613](https://github.com/opensearch-project/k-NN/pull/1613)
* Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

package org.opensearch.knn.index;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.index.fielddata.LeafFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
Expand Down Expand Up @@ -40,29 +39,10 @@ public long ramBytesUsed() {
@Override
public ScriptDocValues<float[]> getScriptValues() {
try {
FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName);
if (fieldInfo == null) {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
}

DocIdSetIterator values;
if (fieldInfo.hasVectorValues()) {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
values = reader.getFloatVectorValues(fieldName);
break;
case BYTE:
values = reader.getByteVectorValues(fieldName);
break;
default:
throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding());
}
} else {
values = DocValues.getBinary(reader, fieldName);
}
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
BinaryDocValues values = DocValues.getBinary(reader, fieldName);
return new KNNVectorScriptDocValues(values, fieldName, vectorDataType);
} catch (IOException e) {
throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e);
throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e);
}
}

Expand Down
108 changes: 11 additions & 97 deletions src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,30 @@
package org.opensearch.knn.index;

import java.io.IOException;
import java.util.Objects;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;

@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {
import java.io.IOException;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final DocIdSetIterator vectorValues;
private final BinaryDocValues binaryDocValues;
private final String fieldName;
@Getter
private final VectorDataType vectorDataType;
private boolean docExists = false;

@Override
public void setNextDocId(int docId) throws IOException {
docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId;
if (binaryDocValues.advanceExact(docId)) {
docExists = true;
return;
}
docExists = false;
}

public float[] getValue() {
Expand All @@ -43,14 +44,12 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
return doGetValue();
return vectorDataType.getVectorFromBytesRef(binaryDocValues.binaryValue());
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}

protected abstract float[] doGetValue() throws IOException;

@Override
public int size() {
return docExists ? 1 : 0;
Expand All @@ -60,89 +59,4 @@ public int size() {
public float[] get(int i) {
throw new UnsupportedOperationException("knn vector does not support this operation");
}

/**
* Creates a KNNVectorScriptDocValues object based on the provided parameters.
*
* @param values The DocIdSetIterator representing the vector values.
* @param fieldName The name of the field.
* @param vectorDataType The data type of the vector.
* @return A KNNVectorScriptDocValues object based on the type of the values.
* @throws IllegalArgumentException If the type of values is unsupported.
*/
public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
Objects.requireNonNull(values, "values must not be null");
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {
return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof BinaryDocValues) {
return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType);
} else {
throw new IllegalArgumentException("Unsupported values type: " + values.getClass());
}
}

private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues {
private final ByteVectorValues values;

KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
byte[] bytes = values.vectorValue();
float[] value = new float[bytes.length];
for (int i = 0; i < bytes.length; i++) {
value[i] = (float) bytes[i];
}
return value;
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
private final FloatVectorValues values;

KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return values.vectorValue();
}
}

private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues {
private final BinaryDocValues values;

KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromBytesRef(values.binaryValue());
}
}

/**
* Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type.
*
* @param fieldName The name of the field.
* @param type The data type of the vector.
* @return An empty KNNVectorScriptDocValues object.
*/
public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) {
return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,7 @@

package org.opensearch.knn.index;

import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.knn.KNNTestCase;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.document.BinaryDocValuesField;
Expand Down Expand Up @@ -41,39 +33,26 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase {
public void setUp() throws Exception {
super.setUp();
directory = newDirectory();
Class<? extends DocIdSetIterator> valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class);
createKNNVectorDocument(directory, valuesClass);
createKNNVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReader leafReader = reader.getContext().leaves().get(0).reader();
DocIdSetIterator vectorValues;
if (BinaryDocValues.class.equals(valuesClass)) {
vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME);
} else if (ByteVectorValues.class.equals(valuesClass)) {
vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME);
} else {
vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME);
}

scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME),
MOCK_INDEX_FIELD_NAME,
VectorDataType.FLOAT
);
}

private void createKNNVectorDocument(Directory directory, Class<? extends DocIdSetIterator> valuesClass) throws IOException {
private void createKNNVectorDocument(Directory directory) throws IOException {
IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random()));
IndexWriter writer = new IndexWriter(directory, conf);
Document knnDocument = new Document();
Field field;
if (BinaryDocValues.class.equals(valuesClass)) {
field = new BinaryDocValuesField(
knnDocument.add(
new BinaryDocValuesField(
MOCK_INDEX_FIELD_NAME,
new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue()
);
} else if (ByteVectorValues.class.equals(valuesClass)) {
field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA);
} else {
field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA);
}

knnDocument.add(field);
)
);
writer.addDocument(knnDocument);
writer.commit();
writer.close();
Expand Down Expand Up @@ -105,18 +84,4 @@ public void testSize() throws IOException {
public void testGet() throws IOException {
expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
}

public void testUnsupportedValues() throws IOException {
expectThrows(
IllegalArgumentException.class,
() -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT)
);
}

public void testEmptyValues() throws IOException {
KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
assertEquals(0, values.size());
scriptDocValues.setNextDocId(0);
assertEquals(0, values.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() {
createKNNFloatVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
return KNNVectorScriptDocValues.create(
return new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME),
VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME,
VectorDataType.FLOAT
Expand All @@ -70,7 +70,7 @@ private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() {
createKNNByteVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
return KNNVectorScriptDocValues.create(
return new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME),
VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME,
VectorDataType.BYTE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx
if (scriptDocValues == null) {
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = KNNVectorScriptDocValues.create(
scriptDocValues = new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(fieldName),
fieldName,
VectorDataType.FLOAT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,16 +612,7 @@ private List<String> createMappers(int dimensions) throws Exception {
dimensions,
KNNConstants.METHOD_HNSW,
KNNEngine.LUCENE.getName(),
SpaceType.DEFAULT.getValue(),
true
),
createKnnIndexMapping(
FIELD_NAME,
dimensions,
KNNConstants.METHOD_HNSW,
KNNEngine.LUCENE.getName(),
SpaceType.DEFAULT.getValue(),
false
SpaceType.DEFAULT.getValue()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,7 @@ public void testL2ScriptingWithLuceneBackedIndex() throws Exception {
new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())
);
properties.add(
new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2")
.knnMethodContext(knnMethodContext)
.docValues(randomBoolean())
new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").knnMethodContext(knnMethodContext)
);

String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
Expand Down
Loading