From 428fdb529117e107d5fa225d8ec23360e1225c02 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 10 Jul 2024 10:28:48 -0400 Subject: [PATCH] Reduce heap usage for knn index writers (#13538) * Reduce heap usage for knn index writers * iter * fixing heap usage & adding changes * javadocs --- lucene/CHANGES.txt | 2 + .../codecs/hnsw/FlatFieldVectorsWriter.java | 26 ++-- .../lucene/codecs/hnsw/FlatVectorsWriter.java | 12 +- .../lucene99/Lucene99FlatVectorsWriter.java | 58 +++++--- .../lucene99/Lucene99HnswVectorsWriter.java | 79 +++++++---- .../Lucene99ScalarQuantizedVectorsWriter.java | 132 ++++++++++-------- 6 files changed, 182 insertions(+), 127 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index b75eaa73ebf9..26a7c06e4833 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -280,6 +280,8 @@ Optimizations * GITHUB#13175: Stop double-checking priority queue inserts in some FacetCount classes. (Jakub Slowinski) +* GITHUB#13538: Slightly reduce heap usage for HNSW and scalar quantized vector writers. (Ben Trent) + Changes in runtime behavior --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatFieldVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatFieldVectorsWriter.java index 313ccccd4eb8..fc71bb729db6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatFieldVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatFieldVectorsWriter.java @@ -17,7 +17,10 @@ package org.apache.lucene.codecs.hnsw; +import java.io.IOException; +import java.util.List; import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; /** * Vectors' writer for a field @@ -26,20 +29,25 @@ * @lucene.experimental */ public abstract class FlatFieldVectorsWriter extends KnnFieldVectorsWriter { + /** + * @return a list of vectors to be written + */ + public abstract List getVectors(); /** - * The delegate to write to, can be null When non-null, all vectors seen should be written to the - * delegate along with being written to the flat vectors. + * @return the docsWithFieldSet for the field writer */ - protected final KnnFieldVectorsWriter indexingDelegate; + public abstract DocsWithFieldSet getDocsWithFieldSet(); /** - * Sole constructor that expects some indexingDelegate. All vectors seen should be written to the - * delegate along with being written to the flat vectors. + * indicates that this writer is done and no new vectors are allowed to be added * - * @param indexingDelegate the delegate to write to, can be null + * @throws IOException if an I/O error occurs + */ + public abstract void finish() throws IOException; + + /** + * @return true if the writer is done and no new vectors are allowed to be added */ - protected FlatFieldVectorsWriter(KnnFieldVectorsWriter indexingDelegate) { - this.indexingDelegate = indexingDelegate; - } + public abstract boolean isFinished(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsWriter.java index 3a7803011aad..37c4f546bab9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsWriter.java @@ -18,7 +18,6 @@ package org.apache.lucene.codecs.hnsw; import java.io.IOException; -import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.MergeState; @@ -46,21 +45,14 @@ public FlatVectorsScorer getFlatVectorScorer() { } /** - * Add a new field for indexing, allowing the user to provide a writer that the flat vectors - * writer can delegate to if additional indexing logic is required. + * Add a new field for indexing * * @param fieldInfo fieldInfo of the field to add - * @param indexWriter the writer to delegate to, can be null * @return a writer for the field * @throws IOException if an I/O error occurs when adding the field */ - public abstract FlatFieldVectorsWriter addField( - FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) throws IOException; - @Override - public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { - return addField(fieldInfo, null); - } + public abstract FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException; /** * Write the field for merging, providing a scorer over the newly merged flat vectors. This way diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index b80c9f4d7f5c..5643752796c2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -27,7 +27,6 @@ import java.util.ArrayList; import java.util.List; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; @@ -111,18 +110,12 @@ public Lucene99FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scor } @Override - public FlatFieldVectorsWriter addField( - FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) throws IOException { - FieldWriter newField = FieldWriter.create(fieldInfo, indexWriter); + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FieldWriter newField = FieldWriter.create(fieldInfo); fields.add(newField); return newField; } - @Override - public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { - return addField(fieldInfo, null); - } - @Override public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { for (FieldWriter field : fields) { @@ -131,6 +124,7 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { } else { writeSortingField(field, maxDoc, sortMap); } + field.finish(); } } @@ -403,22 +397,20 @@ private abstract static class FieldWriter extends FlatFieldVectorsWriter { private final int dim; private final DocsWithFieldSet docsWithField; private final List vectors; + private boolean finished; private int lastDocID = -1; - @SuppressWarnings("unchecked") - static FieldWriter create(FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) { + static FieldWriter create(FieldInfo fieldInfo) { int dim = fieldInfo.getVectorDimension(); return switch (fieldInfo.getVectorEncoding()) { - case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<>( - fieldInfo, (KnnFieldVectorsWriter) indexWriter) { + case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter(fieldInfo) { @Override public byte[] copyValue(byte[] value) { return ArrayUtil.copyOfSubArray(value, 0, dim); } }; - case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<>( - fieldInfo, (KnnFieldVectorsWriter) indexWriter) { + case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter(fieldInfo) { @Override public float[] copyValue(float[] value) { return ArrayUtil.copyOfSubArray(value, 0, dim); @@ -427,8 +419,8 @@ public float[] copyValue(float[] value) { }; } - FieldWriter(FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) { - super(indexWriter); + FieldWriter(FieldInfo fieldInfo) { + super(); this.fieldInfo = fieldInfo; this.dim = fieldInfo.getVectorDimension(); this.docsWithField = new DocsWithFieldSet(); @@ -437,6 +429,9 @@ public float[] copyValue(float[] value) { @Override public void addValue(int docID, T vectorValue) throws IOException { + if (finished) { + throw new IllegalStateException("already finished, cannot add more values"); + } if (docID == lastDocID) { throw new IllegalArgumentException( "VectorValuesField \"" @@ -448,17 +443,11 @@ public void addValue(int docID, T vectorValue) throws IOException { docsWithField.add(docID); vectors.add(copy); lastDocID = docID; - if (indexingDelegate != null) { - indexingDelegate.addValue(docID, copy); - } } @Override public long ramBytesUsed() { long size = SHALLOW_RAM_BYTES_USED; - if (indexingDelegate != null) { - size += indexingDelegate.ramBytesUsed(); - } if (vectors.size() == 0) return size; return size + docsWithField.ramBytesUsed() @@ -468,6 +457,29 @@ public long ramBytesUsed() { * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize; } + + @Override + public List getVectors() { + return vectors; + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return docsWithField; + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + this.finished = true; + } + + @Override + public boolean isFinished() { + return finished; + } } static final class FlatCloseableRandomVectorScorerSupplier diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index 949507848bf0..bf97426738b6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -24,9 +24,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Objects; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.DocsWithFieldSet; @@ -130,12 +132,13 @@ public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException FieldWriter newField = FieldWriter.create( flatVectorWriter.getFlatVectorScorer(), + flatVectorWriter.addField(fieldInfo), fieldInfo, M, beamWidth, segmentWriteState.infoStream); fields.add(newField); - return flatVectorWriter.addField(fieldInfo, newField); + return newField; } @Override @@ -171,8 +174,10 @@ public void finish() throws IOException { @Override public long ramBytesUsed() { long total = SHALLOW_RAM_BYTES_USED; - // The vector delegate will also account for this writer's KnnFieldVectorsWriter objects - total += flatVectorWriter.ramBytesUsed(); + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } return total; } @@ -187,17 +192,19 @@ private void writeField(FieldWriter fieldData) throws IOException { fieldData.fieldInfo, vectorIndexOffset, vectorIndexLength, - fieldData.docsWithField.cardinality(), + fieldData.getDocsWithFieldSet().cardinality(), graph, graphLevelNodeOffsets); } private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) throws IOException { - final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord - final int[] oldOrdMap = new int[fieldData.docsWithField.cardinality()]; // old ord to new ord + final int[] ordMap = + new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + final int[] oldOrdMap = + new int[fieldData.getDocsWithFieldSet().cardinality()]; // old ord to new ord - mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, oldOrdMap, ordMap, null); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, oldOrdMap, ordMap, null); // write graph long vectorIndexOffset = vectorIndex.getFilePointer(); OnHeapHnswGraph graph = fieldData.getGraph(); @@ -209,7 +216,7 @@ private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) fieldData.fieldInfo, vectorIndexOffset, vectorIndexLength, - fieldData.docsWithField.cardinality(), + fieldData.getDocsWithFieldSet().cardinality(), mockGraph, graphLevelNodeOffsets); } @@ -521,42 +528,65 @@ private static class FieldWriter extends KnnFieldVectorsWriter { RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); private final FieldInfo fieldInfo; - private final DocsWithFieldSet docsWithField; - private final List vectors; private final HnswGraphBuilder hnswGraphBuilder; private int lastDocID = -1; private int node = 0; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + @SuppressWarnings("unchecked") static FieldWriter create( - FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) + FlatVectorsScorer scorer, + FlatFieldVectorsWriter flatFieldVectorsWriter, + FieldInfo fieldInfo, + int M, + int beamWidth, + InfoStream infoStream) throws IOException { return switch (fieldInfo.getVectorEncoding()) { - case BYTE -> new FieldWriter(scorer, fieldInfo, M, beamWidth, infoStream); - case FLOAT32 -> new FieldWriter(scorer, fieldInfo, M, beamWidth, infoStream); + case BYTE -> new FieldWriter<>( + scorer, + (FlatFieldVectorsWriter) flatFieldVectorsWriter, + fieldInfo, + M, + beamWidth, + infoStream); + case FLOAT32 -> new FieldWriter<>( + scorer, + (FlatFieldVectorsWriter) flatFieldVectorsWriter, + fieldInfo, + M, + beamWidth, + infoStream); }; } @SuppressWarnings("unchecked") FieldWriter( - FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) + FlatVectorsScorer scorer, + FlatFieldVectorsWriter flatFieldVectorsWriter, + FieldInfo fieldInfo, + int M, + int beamWidth, + InfoStream infoStream) throws IOException { this.fieldInfo = fieldInfo; - this.docsWithField = new DocsWithFieldSet(); - vectors = new ArrayList<>(); RandomVectorScorerSupplier scorerSupplier = switch (fieldInfo.getVectorEncoding()) { case BYTE -> scorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), RandomAccessVectorValues.fromBytes( - (List) vectors, fieldInfo.getVectorDimension())); + (List) flatFieldVectorsWriter.getVectors(), + fieldInfo.getVectorDimension())); case FLOAT32 -> scorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), RandomAccessVectorValues.fromFloats( - (List) vectors, fieldInfo.getVectorDimension())); + (List) flatFieldVectorsWriter.getVectors(), + fieldInfo.getVectorDimension())); }; hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(infoStream); + this.flatFieldVectorsWriter = Objects.requireNonNull(flatFieldVectorsWriter); } @Override @@ -567,20 +597,23 @@ public void addValue(int docID, T vectorValue) throws IOException { + fieldInfo.name + "\" appears more than once in this document (only one value is allowed per field)"); } - assert docID > lastDocID; - vectors.add(vectorValue); - docsWithField.add(docID); + flatFieldVectorsWriter.addValue(docID, vectorValue); hnswGraphBuilder.addGraphNode(node); node++; lastDocID = docID; } + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + @Override public T copyValue(T vectorValue) { throw new UnsupportedOperationException(); } OnHeapHnswGraph getGraph() { + assert flatFieldVectorsWriter.isFinished(); if (node > 0) { return hnswGraphBuilder.getGraph(); } else { @@ -591,9 +624,7 @@ OnHeapHnswGraph getGraph() { @Override public long ramBytesUsed() { return SHALLOW_SIZE - + docsWithField.ramBytesUsed() - + (long) vectors.size() - * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + + flatFieldVectorsWriter.ramBytesUsed() + hnswGraphBuilder.getGraph().ramBytesUsed(); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index beb1af19ca1c..311f2df435e8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -30,8 +30,8 @@ import java.nio.ByteOrder; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; @@ -56,7 +56,6 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.InfoStream; -import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.RandomVectorScorer; @@ -195,8 +194,8 @@ private Lucene99ScalarQuantizedVectorsWriter( } @Override - public FlatFieldVectorsWriter addField( - FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) throws IOException { + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { if (bits <= 4 && fieldInfo.getVectorDimension() % 2 != 0) { throw new IllegalArgumentException( @@ -205,6 +204,7 @@ public FlatFieldVectorsWriter addField( + " is not supported for odd vector dimensions; vector dimension=" + fieldInfo.getVectorDimension()); } + @SuppressWarnings("unchecked") FieldWriter quantizedWriter = new FieldWriter( confidenceInterval, @@ -212,11 +212,11 @@ public FlatFieldVectorsWriter addField( compress, fieldInfo, segmentWriteState.infoStream, - indexWriter); + (FlatFieldVectorsWriter) rawVectorDelegate); fields.add(quantizedWriter); - indexWriter = quantizedWriter; + return quantizedWriter; } - return rawVectorDelegate.addField(fieldInfo, indexWriter); + return rawVectorDelegate; } @Override @@ -270,12 +270,13 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { rawVectorDelegate.flush(maxDoc, sortMap); for (FieldWriter field : fields) { - field.finish(); + ScalarQuantizer quantizer = field.createQuantizer(); if (sortMap == null) { - writeField(field, maxDoc); + writeField(field, maxDoc, quantizer); } else { - writeSortingField(field, maxDoc, sortMap); + writeSortingField(field, maxDoc, sortMap, quantizer); } + field.finish(); } } @@ -299,15 +300,18 @@ public void finish() throws IOException { @Override public long ramBytesUsed() { long total = SHALLOW_RAM_BYTES_USED; - // The vector delegate will also account for this writer's KnnFieldVectorsWriter objects - total += rawVectorDelegate.ramBytesUsed(); + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } return total; } - private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { + private void writeField(FieldWriter fieldData, int maxDoc, ScalarQuantizer scalarQuantizer) + throws IOException { // write vector values long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); - writeQuantizedVectors(fieldData); + writeQuantizedVectors(fieldData, scalarQuantizer); long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; writeMeta( @@ -318,9 +322,9 @@ private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { confidenceInterval, bits, compress, - fieldData.minQuantile, - fieldData.maxQuantile, - fieldData.docsWithField); + scalarQuantizer.getLowerQuantile(), + scalarQuantizer.getUpperQuantile(), + fieldData.getDocsWithFieldSet()); } private void writeMeta( @@ -365,8 +369,8 @@ private void writeMeta( DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField); } - private void writeQuantizedVectors(FieldWriter fieldData) throws IOException { - ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); + private void writeQuantizedVectors(FieldWriter fieldData, ScalarQuantizer scalarQuantizer) + throws IOException { byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; byte[] compressedVector = fieldData.compress @@ -375,7 +379,8 @@ private void writeQuantizedVectors(FieldWriter fieldData) throws IOException { : null; final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null; - for (float[] v : fieldData.floatVectors) { + assert fieldData.getVectors().isEmpty() || scalarQuantizer != null; + for (float[] v : fieldData.getVectors()) { if (fieldData.normalize) { System.arraycopy(v, 0, copy, 0, copy.length); VectorUtil.l2normalize(copy); @@ -396,16 +401,18 @@ private void writeQuantizedVectors(FieldWriter fieldData) throws IOException { } } - private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) + private void writeSortingField( + FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap, ScalarQuantizer scalarQuantizer) throws IOException { - final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord + final int[] ordMap = + new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); - mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); // write vector values long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); - writeSortedQuantizedVectors(fieldData, ordMap); + writeSortedQuantizedVectors(fieldData, ordMap, scalarQuantizer); long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset; writeMeta( fieldData.fieldInfo, @@ -415,13 +422,13 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap confidenceInterval, bits, compress, - fieldData.minQuantile, - fieldData.maxQuantile, + scalarQuantizer.getLowerQuantile(), + scalarQuantizer.getUpperQuantile(), newDocsWithField); } - private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) throws IOException { - ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); + private void writeSortedQuantizedVectors( + FieldWriter fieldData, int[] ordMap, ScalarQuantizer scalarQuantizer) throws IOException { byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; byte[] compressedVector = fieldData.compress @@ -431,7 +438,7 @@ private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) th final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null; for (int ordinal : ordMap) { - float[] v = fieldData.floatVectors.get(ordinal); + float[] v = fieldData.getVectors().get(ordinal); if (fieldData.normalize) { System.arraycopy(v, 0, copy, 0, copy.length); VectorUtil.l2normalize(copy); @@ -744,44 +751,51 @@ public void close() throws IOException { static class FieldWriter extends FlatFieldVectorsWriter { private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); - private final List floatVectors; private final FieldInfo fieldInfo; private final Float confidenceInterval; private final byte bits; private final boolean compress; private final InfoStream infoStream; private final boolean normalize; - private float minQuantile = Float.POSITIVE_INFINITY; - private float maxQuantile = Float.NEGATIVE_INFINITY; private boolean finished; - private final DocsWithFieldSet docsWithField; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; - @SuppressWarnings("unchecked") FieldWriter( Float confidenceInterval, byte bits, boolean compress, FieldInfo fieldInfo, InfoStream infoStream, - KnnFieldVectorsWriter indexWriter) { - super((KnnFieldVectorsWriter) indexWriter); + FlatFieldVectorsWriter indexWriter) { + super(); this.confidenceInterval = confidenceInterval; this.bits = bits; this.fieldInfo = fieldInfo; this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE; - this.floatVectors = new ArrayList<>(); this.infoStream = infoStream; - this.docsWithField = new DocsWithFieldSet(); this.compress = compress; + this.flatFieldVectorsWriter = Objects.requireNonNull(indexWriter); } - void finish() throws IOException { + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void finish() throws IOException { if (finished) { return; } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + ScalarQuantizer createQuantizer() throws IOException { + assert flatFieldVectorsWriter.isFinished(); + List floatVectors = flatFieldVectorsWriter.getVectors(); if (floatVectors.size() == 0) { - finished = true; - return; + return new ScalarQuantizer(0, 0, bits); } FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize); ScalarQuantizer quantizer = @@ -791,8 +805,6 @@ void finish() throws IOException { fieldInfo.getVectorSimilarityFunction(), confidenceInterval, bits); - minQuantile = quantizer.getLowerQuantile(); - maxQuantile = quantizer.getUpperQuantile(); if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { infoStream.message( QUANTIZED_VECTOR_COMPONENT, @@ -802,41 +814,39 @@ void finish() throws IOException { + " bits=" + bits + " minQuantile=" - + minQuantile + + quantizer.getLowerQuantile() + " maxQuantile=" - + maxQuantile); + + quantizer.getUpperQuantile()); } - finished = true; - } - - ScalarQuantizer createQuantizer() { - assert finished; - return new ScalarQuantizer(minQuantile, maxQuantile, bits); + return quantizer; } @Override public long ramBytesUsed() { long size = SHALLOW_SIZE; - if (indexingDelegate != null) { - size += indexingDelegate.ramBytesUsed(); - } - if (floatVectors.size() == 0) return size; - return size + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF; + size += flatFieldVectorsWriter.ramBytesUsed(); + return size; } @Override public void addValue(int docID, float[] vectorValue) throws IOException { - docsWithField.add(docID); - floatVectors.add(vectorValue); - if (indexingDelegate != null) { - indexingDelegate.addValue(docID, vectorValue); - } + flatFieldVectorsWriter.addValue(docID, vectorValue); } @Override public float[] copyValue(float[] vectorValue) { throw new UnsupportedOperationException(); } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } } static class FloatVectorWrapper extends FloatVectorValues {