Skip to content

Commit

Permalink
Synthetic vector source via stored fields
Browse files Browse the repository at this point in the history
Synthetically generates the vector source in the source field from the
KnnVectorsFormat or BVD. It does this by adding StoredFieldsFormat to
our existing custom codec.

Work is still WIP but confirmed reindex and search work if field is
excluded.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Oct 25, 2024
1 parent f3f8e25 commit ac5e3f8
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,22 @@
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.SegmentReadState;
import org.opensearch.knn.index.codec.KNN990Codec.SyntheticSourceStoredFieldsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;

import java.io.IOException;
import java.util.function.Function;

/**
* KNN Codec that wraps the Lucene Codec which is part of Lucene 9.12
Expand Down Expand Up @@ -58,4 +69,24 @@ public CompoundFormat compoundFormat() {
public KnnVectorsFormat knnVectorsFormat() {
return perFieldKnnVectorsFormat;
}

@Override
public StoredFieldsFormat storedFieldsFormat() {
Function<SegmentReadState, Function<FieldInfo, KNNVectorValues<?>>> vectorValuesSupplierByField = (segmentReadState) -> {
try {
KnnVectorsReader knnVectorsReader = knnVectorsFormat().fieldsReader(segmentReadState);
DocValuesProducer docValuesProducer = docValuesFormat().fieldsProducer(segmentReadState);
return fieldInfo -> {
try {
return KNNVectorValuesFactory.getVectorValues(fieldInfo, docValuesProducer, knnVectorsReader);
} catch (IOException e) {
throw new RuntimeException(e);
}
};
} catch (IOException e) {
throw new RuntimeException(e);
}
};
return new SyntheticSourceStoredFieldsFormat(delegate.storedFieldsFormat(), vectorValuesSupplierByField);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.AllArgsConstructor;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.StoredFieldVisitor;

import java.io.IOException;
import java.util.function.Function;

@AllArgsConstructor
public class SyntheticSourceStoredFieldVisitor extends StoredFieldVisitor {

private final StoredFieldVisitor delegate;
private final Function<byte[], byte[]> sourceModifier;

@Override
public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException {
if (fieldInfo.name.equals("_source")) {
delegate.binaryField(fieldInfo, sourceModifier.apply(value));
return;
}
delegate.binaryField(fieldInfo, value);
}

@Override
public Status needsField(FieldInfo fieldInfo) throws IOException {
return delegate.needsField(fieldInfo);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.AllArgsConstructor;
import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.StoredFieldsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;

@AllArgsConstructor
public class SyntheticSourceStoredFieldsFormat extends StoredFieldsFormat {

private final StoredFieldsFormat delegate;
private final Function<SegmentReadState, Function<FieldInfo, KNNVectorValues<?>>> vectorValuesSupplierByFieldSupplier;

@Override
public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentInfo, FieldInfos fieldInfos, IOContext ioContext)
throws IOException {
// If any field has this value set, than get its supplier
Function<FieldInfo, KNNVectorValues<?>> vectorValuesSupplierByField = vectorValuesSupplierByFieldSupplier.apply(
new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext)
);
Map<String, Supplier<KNNVectorValues<?>>> vectorValuesSuppliers = new HashMap<>();
for (FieldInfo fieldInfo : fieldInfos) {
if (Boolean.parseBoolean(fieldInfo.attributes().get(KNNVectorFieldMapper.KNN_FIELD))) {
vectorValuesSuppliers.put(fieldInfo.name, () -> vectorValuesSupplierByField.apply(fieldInfo));
}
}

// Build the processor and create the reader
SyntheticVectorInjectionConsumer syntheticVectorInjectionConsumer = new SyntheticVectorInjectionConsumer(vectorValuesSuppliers);
return new SyntheticSourceStoredFieldsReader(
delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext),
syntheticVectorInjectionConsumer
);
}

@Override
public StoredFieldsWriter fieldsWriter(Directory directory, SegmentInfo segmentInfo, IOContext ioContext) throws IOException {
return delegate.fieldsWriter(directory, segmentInfo, ioContext);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.AllArgsConstructor;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.index.StoredFieldVisitor;

import java.io.IOException;
import java.util.function.BiFunction;

@AllArgsConstructor
public class SyntheticSourceStoredFieldsReader extends StoredFieldsReader {
private final StoredFieldsReader delegate;
// Given docId and source, process source
private final BiFunction<Integer, byte[], byte[]> sourceModifier;

@Override
public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException {
delegate.document(docId, new SyntheticSourceStoredFieldVisitor(storedFieldVisitor, bytes -> sourceModifier.apply(docId, bytes)));
}

@Override
public StoredFieldsReader clone() {
return new SyntheticSourceStoredFieldsReader(delegate.clone(), sourceModifier);
}

@Override
public void checkIntegrity() throws IOException {
delegate.checkIntegrity();
}

@Override
public void close() throws IOException {
delegate.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Supplier;

@Log4j2
@AllArgsConstructor
public class SyntheticVectorInjectionConsumer implements BiFunction<Integer, byte[], byte[]> {

private final Map<String, Supplier<KNNVectorValues<?>>> vectorValuesSuppliers;

@Override
public byte[] apply(Integer integer, byte[] bytes) {
try {
Tuple<? extends MediaType, Map<String, Object>> mapTuple = XContentHelper.convertToMap(
BytesReference.fromByteBuffer(ByteBuffer.wrap(bytes)),
true,
MediaTypeRegistry.getDefaultMediaType()
);
Map<String, Object> sourceAsMap = new HashMap<>(mapTuple.v2());
for (Map.Entry<String, Supplier<KNNVectorValues<?>>> entry : vectorValuesSuppliers.entrySet()) {
log.info("Injecting vector values for field: " + entry.getKey());
KNNVectorValues<?> vectorValues = entry.getValue().get();
vectorValues.advance(integer);
sourceAsMap.put(entry.getKey(), vectorValues.getVector());
}
BytesStreamOutput bStream = new BytesStreamOutput(1024);
MediaType actualContentType = mapTuple.v1();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(actualContentType, bStream).map(sourceAsMap);
builder.close();
log.info("Built the following source: " + builder);
return BytesReference.toBytes(BytesReference.bytes(builder));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.index.vectorvalues;

import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
Expand Down Expand Up @@ -72,6 +74,35 @@ public static <T> KNNVectorValues<T> getVectorValues(final FieldInfo fieldInfo,
return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator);
}

/**
* Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader}
*
* @param fieldInfo {@link FieldInfo}
* @param docValuesProducer {@link DocValuesProducer}
* @param knnVectorsReader {@link KnnVectorsReader}
* @return {@link KNNVectorValues}
*/
public static <T> KNNVectorValues<T> getVectorValues(
final FieldInfo fieldInfo,
final DocValuesProducer docValuesProducer,
final KnnVectorsReader knnVectorsReader
) throws IOException {
final DocIdSetIterator docIdSetIterator;
if (fieldInfo.hasVectorValues()) {
if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) {
docIdSetIterator = knnVectorsReader.getByteVectorValues(fieldInfo.getName());
} else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) {
docIdSetIterator = knnVectorsReader.getFloatVectorValues(fieldInfo.getName());
} else {
throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues");
}
} else {
docIdSetIterator = docValuesProducer.getBinary(fieldInfo);
}
final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator);
return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator);
}

@SuppressWarnings("unchecked")
private static <T> KNNVectorValues<T> getVectorValues(
final VectorDataType vectorDataType,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.SneakyThrows;
import org.apache.lucene.index.FloatVectorValues;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.index.vectorvalues.TestVectorValues;

import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;

public class SyntheticVectorInjectionConsumerTests extends KNNTestCase {

@SneakyThrows
public void testVectorInjection() {
FloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
List.of(new float[] { 1.0f, 2.0f }, new float[] { 2.0f, 3.0f }, new float[] { 3.0f, 4.0f }, new float[] { 4.0f, 5.0f })
);
final KNNVectorValues<float[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues);

final XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
builder.field("test_text", "text-field");
builder.endObject();

BytesReference bytesReference = BytesReference.bytes(builder);
toMap(bytesReference);

SyntheticVectorInjectionConsumer consumer = new SyntheticVectorInjectionConsumer(Map.of("test_vector", () -> knnVectorValues));
logger.info(bytesReference.length());
byte[] modifiedBytes = consumer.apply(0, bytesReference.toBytesRef().bytes);
BytesReference modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes));
toMap(modifiedBytesReference);

modifiedBytes = consumer.apply(1, bytesReference.toBytesRef().bytes);
modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes));
toMap(modifiedBytesReference);

modifiedBytes = consumer.apply(0, bytesReference.toBytesRef().bytes);
modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes));
toMap(modifiedBytesReference);

fail("On purpose");
}

private void toMap(BytesReference source) {
Tuple<? extends MediaType, Map<String, Object>> mapTuple = XContentHelper.convertToMap(source, true, MediaTypeRegistry.JSON);
logger.info(mapTuple.v2().toString());
}

}

0 comments on commit ac5e3f8

Please sign in to comment.