diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java index a370197ec..e629a1aa3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -9,11 +9,22 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.StoredFieldsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; +import org.opensearch.knn.index.codec.KNN990Codec.SyntheticSourceStoredFieldsFormat; import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.io.IOException; +import java.util.function.Function; /** * KNN Codec that wraps the Lucene Codec which is part of Lucene 9.12 @@ -58,4 +69,24 @@ public CompoundFormat compoundFormat() { public KnnVectorsFormat knnVectorsFormat() { return perFieldKnnVectorsFormat; } + + @Override + public StoredFieldsFormat storedFieldsFormat() { + Function>> vectorValuesSupplierByField = (segmentReadState) -> { + try { + KnnVectorsReader knnVectorsReader = knnVectorsFormat().fieldsReader(segmentReadState); + DocValuesProducer docValuesProducer = docValuesFormat().fieldsProducer(segmentReadState); + return fieldInfo -> { + try { + return KNNVectorValuesFactory.getVectorValues(fieldInfo, docValuesProducer, knnVectorsReader); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + return new SyntheticSourceStoredFieldsFormat(delegate.storedFieldsFormat(), vectorValuesSupplierByField); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticSourceStoredFieldVisitor.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticSourceStoredFieldVisitor.java new file mode 100644 index 000000000..4b68c3bdf --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticSourceStoredFieldVisitor.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.AllArgsConstructor; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.StoredFieldVisitor; + +import java.io.IOException; +import java.util.function.Function; + +@AllArgsConstructor +public class SyntheticSourceStoredFieldVisitor extends StoredFieldVisitor { + + private final StoredFieldVisitor delegate; + private final Function sourceModifier; + + @Override + public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException { + if (fieldInfo.name.equals("_source")) { + delegate.binaryField(fieldInfo, sourceModifier.apply(value)); + return; + } + delegate.binaryField(fieldInfo, value); + } + + @Override + public Status needsField(FieldInfo fieldInfo) throws IOException { + return delegate.needsField(fieldInfo); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticSourceStoredFieldsFormat.java new file mode 100644 index 000000000..0ccacc157 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticSourceStoredFieldsFormat.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.AllArgsConstructor; +import org.apache.lucene.codecs.StoredFieldsFormat; +import org.apache.lucene.codecs.StoredFieldsReader; +import org.apache.lucene.codecs.StoredFieldsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; + +@AllArgsConstructor +public class SyntheticSourceStoredFieldsFormat extends StoredFieldsFormat { + + private final StoredFieldsFormat delegate; + private final Function>> vectorValuesSupplierByFieldSupplier; + + @Override + public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentInfo, FieldInfos fieldInfos, IOContext ioContext) + throws IOException { + // If any field has this value set, than get its supplier + Function> vectorValuesSupplierByField = vectorValuesSupplierByFieldSupplier.apply( + new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext) + ); + Map>> vectorValuesSuppliers = new HashMap<>(); + for (FieldInfo fieldInfo : fieldInfos) { + if (Boolean.parseBoolean(fieldInfo.attributes().get(KNNVectorFieldMapper.KNN_FIELD))) { + vectorValuesSuppliers.put(fieldInfo.name, () -> vectorValuesSupplierByField.apply(fieldInfo)); + } + } + + // Build the processor and create the reader + SyntheticVectorInjectionConsumer syntheticVectorInjectionConsumer = new SyntheticVectorInjectionConsumer(vectorValuesSuppliers); + return new SyntheticSourceStoredFieldsReader( + delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext), + syntheticVectorInjectionConsumer + ); + } + + @Override + public StoredFieldsWriter fieldsWriter(Directory directory, SegmentInfo segmentInfo, IOContext ioContext) throws IOException { + return delegate.fieldsWriter(directory, segmentInfo, ioContext); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticSourceStoredFieldsReader.java new file mode 100644 index 000000000..450979948 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticSourceStoredFieldsReader.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.AllArgsConstructor; +import org.apache.lucene.codecs.StoredFieldsReader; +import org.apache.lucene.index.StoredFieldVisitor; + +import java.io.IOException; +import java.util.function.BiFunction; + +@AllArgsConstructor +public class SyntheticSourceStoredFieldsReader extends StoredFieldsReader { + private final StoredFieldsReader delegate; + // Given docId and source, process source + private final BiFunction sourceModifier; + + @Override + public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException { + delegate.document(docId, new SyntheticSourceStoredFieldVisitor(storedFieldVisitor, bytes -> sourceModifier.apply(docId, bytes))); + } + + @Override + public StoredFieldsReader clone() { + return new SyntheticSourceStoredFieldsReader(delegate.clone(), sourceModifier); + } + + @Override + public void checkIntegrity() throws IOException { + delegate.checkIntegrity(); + } + + @Override + public void close() throws IOException { + delegate.close(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticVectorInjectionConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticVectorInjectionConsumer.java new file mode 100644 index 000000000..53f20d437 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticVectorInjectionConsumer.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +@Log4j2 +@AllArgsConstructor +public class SyntheticVectorInjectionConsumer implements BiFunction { + + private final Map>> vectorValuesSuppliers; + + @Override + public byte[] apply(Integer integer, byte[] bytes) { + try { + Tuple> mapTuple = XContentHelper.convertToMap( + BytesReference.fromByteBuffer(ByteBuffer.wrap(bytes)), + true, + MediaTypeRegistry.getDefaultMediaType() + ); + Map sourceAsMap = new HashMap<>(mapTuple.v2()); + for (Map.Entry>> entry : vectorValuesSuppliers.entrySet()) { + log.info("Injecting vector values for field: " + entry.getKey()); + KNNVectorValues vectorValues = entry.getValue().get(); + vectorValues.advance(integer); + sourceAsMap.put(entry.getKey(), vectorValues.getVector()); + } + BytesStreamOutput bStream = new BytesStreamOutput(1024); + MediaType actualContentType = mapTuple.v1(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(actualContentType, bStream).map(sourceAsMap); + builder.close(); + log.info("Built the following source: " + builder); + return BytesReference.toBytes(BytesReference.bytes(builder)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java index 41408e217..699d62843 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -5,6 +5,8 @@ package org.opensearch.knn.index.vectorvalues; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -72,6 +74,35 @@ public static KNNVectorValues getVectorValues(final FieldInfo fieldInfo, return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); } + /** + * Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader} + * + * @param fieldInfo {@link FieldInfo} + * @param docValuesProducer {@link DocValuesProducer} + * @param knnVectorsReader {@link KnnVectorsReader} + * @return {@link KNNVectorValues} + */ + public static KNNVectorValues getVectorValues( + final FieldInfo fieldInfo, + final DocValuesProducer docValuesProducer, + final KnnVectorsReader knnVectorsReader + ) throws IOException { + final DocIdSetIterator docIdSetIterator; + if (fieldInfo.hasVectorValues()) { + if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) { + docIdSetIterator = knnVectorsReader.getByteVectorValues(fieldInfo.getName()); + } else if (fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32) { + docIdSetIterator = knnVectorsReader.getFloatVectorValues(fieldInfo.getName()); + } else { + throw new IllegalArgumentException("Invalid Vector encoding provided, hence cannot return VectorValues"); + } + } else { + docIdSetIterator = docValuesProducer.getBinary(fieldInfo); + } + final KNNVectorValuesIterator vectorValuesIterator = new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator); + return getVectorValues(FieldInfoExtractor.extractVectorDataType(fieldInfo), vectorValuesIterator); + } + @SuppressWarnings("unchecked") private static KNNVectorValues getVectorValues( final VectorDataType vectorDataType, diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticVectorInjectionConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticVectorInjectionConsumerTests.java new file mode 100644 index 000000000..2f4f731a4 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/SyntheticVectorInjectionConsumerTests.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FloatVectorValues; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; + +public class SyntheticVectorInjectionConsumerTests extends KNNTestCase { + + @SneakyThrows + public void testVectorInjection() { + FloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + List.of(new float[] { 1.0f, 2.0f }, new float[] { 2.0f, 3.0f }, new float[] { 3.0f, 4.0f }, new float[] { 4.0f, 5.0f }) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + final XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("test_text", "text-field"); + builder.endObject(); + + BytesReference bytesReference = BytesReference.bytes(builder); + toMap(bytesReference); + + SyntheticVectorInjectionConsumer consumer = new SyntheticVectorInjectionConsumer(Map.of("test_vector", () -> knnVectorValues)); + logger.info(bytesReference.length()); + byte[] modifiedBytes = consumer.apply(0, bytesReference.toBytesRef().bytes); + BytesReference modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes)); + toMap(modifiedBytesReference); + + modifiedBytes = consumer.apply(1, bytesReference.toBytesRef().bytes); + modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes)); + toMap(modifiedBytesReference); + + modifiedBytes = consumer.apply(0, bytesReference.toBytesRef().bytes); + modifiedBytesReference = BytesReference.fromByteBuffer(ByteBuffer.wrap(modifiedBytes)); + toMap(modifiedBytesReference); + + fail("On purpose"); + } + + private void toMap(BytesReference source) { + Tuple> mapTuple = XContentHelper.convertToMap(source, true, MediaTypeRegistry.JSON); + logger.info(mapTuple.v2().toString()); + } + +}