diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index d547d7bdc656..7ba4e1321bc8 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -43,7 +43,12 @@ API Changes New Features --------------------- -(No changes) + +* GITHUB#13651: New binary quantized vector formats `Lucene101HnswBinaryQuantizedVectorsFormat` and + `Lucene101BinaryQuantizedVectorsFormat`. This results in a 32x reduction in memory requirements for fast vector search + while achieving nice recall properties only requiring about 5x oversampling with rescoring on larger dimensional vectors. + The format is based on the RaBitQ algorithm & paper: https://arxiv.org/abs/2405.12497. + (John Wagster, Mayya Sharipova, Chris Hegarty, Tom Veasey, Ben Trent) Improvements --------------------- diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index 85aff5722498..9a8df45312c6 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -76,7 +76,9 @@ provides org.apache.lucene.codecs.KnnVectorsFormat with org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat, org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat, - org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat; + org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat, + org.apache.lucene.codecs.lucene101.Lucene101BinaryQuantizedVectorsFormat, + org.apache.lucene.codecs.lucene101.Lucene101HnswBinaryQuantizedVectorsFormat; provides org.apache.lucene.codecs.PostingsFormat with org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat; provides org.apache.lucene.index.SortFieldProvider with diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/BinarizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/BinarizedByteVectorValues.java new file mode 100644 index 000000000000..7121432d9dfd --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/BinarizedByteVectorValues.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene101; + +import static org.apache.lucene.util.quantization.BQSpaceUtils.constSqrt; + +import java.io.IOException; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.BQSpaceUtils; +import org.apache.lucene.util.quantization.BinaryQuantizer; + +/** + * A version of {@link ByteVectorValues}, but additionally retrieving score correction values offset + * for binarization quantization scores. + * + * @lucene.experimental + */ +public abstract class BinarizedByteVectorValues extends ByteVectorValues { + + /** + * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of + * distances, the corrective terms are, in order + * + * + * + * For euclidean: + * + * + * + * @param vectorOrd the vector ordinal + * @return the corrective terms + * @throws IOException if an I/O error occurs + */ + public abstract float[] getCorrectiveTerms(int vectorOrd) throws IOException; + + /** + * @return the quantizer used to quantize the vectors + */ + public abstract BinaryQuantizer getQuantizer(); + + public abstract float[] getCentroid() throws IOException; + + int discretizedDimensions() { + return BQSpaceUtils.discretize(dimension(), 64); + } + + float sqrtDimensions() { + return (float) constSqrt(dimension()); + } + + float maxX1() { + return (float) (1.9 / constSqrt(discretizedDimensions() - 1.0)); + } + + /** + * Return a {@link VectorScorer} for the given query vector. + * + * @param query the query vector + * @return a {@link VectorScorer} instance or null + */ + public abstract VectorScorer scorer(float[] query) throws IOException; + + @Override + public abstract BinarizedByteVectorValues copy() throws IOException; + + float getCentroidDP() throws IOException { + // this only gets executed on-merge + float[] centroid = getCentroid(); + return VectorUtil.dotProduct(centroid, centroid); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryFlatVectorsScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryFlatVectorsScorer.java new file mode 100644 index 000000000000..4453e5c3d915 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryFlatVectorsScorer.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene101; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.BQSpaceUtils; +import org.apache.lucene.util.quantization.BinaryQuantizer; + +/** Vector scorer over binarized vector values */ +public class Lucene101BinaryFlatVectorsScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + + public Lucene101BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues) { + throw new UnsupportedOperationException( + "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format"); + } + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) + throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) { + BinaryQuantizer quantizer = binarizedVectors.getQuantizer(); + float[] centroid = binarizedVectors.getCentroid(); + if (similarityFunction == COSINE) { + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + VectorUtil.l2normalize(copy); + target = copy; + } + byte[] quantized = + new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8]; + float[] queryCorrections = quantizer.quantizeForQuery(target, quantized, centroid); + BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections); + return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction); + } + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) + throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + Lucene101BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, + BinarizedByteVectorValues targetVectors) { + return new BinarizedRandomVectorScorerSupplier( + scoringVectors, targetVectors, similarityFunction); + } + + @Override + public String toString() { + return "Lucene101BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + } + + /** Vector scorer supplier over binarized vector values */ + static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + private final Lucene101BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues + queryVectors; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + BinarizedRandomVectorScorerSupplier( + Lucene101BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction) { + this.queryVectors = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] vector = queryVectors.vectorValue(ord); + float[] correctiveTerms = queryVectors.getCorrectiveTerms(ord); + BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms); + return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinarizedRandomVectorScorerSupplier( + queryVectors.copy(), targetVectors.copy(), similarityFunction); + } + } + + /** A binarized query representing its quantized form along with factors */ + public record BinaryQueryVector(byte[] vector, float[] factors) {} + + /** Vector scorer over binarized vector values */ + public static class BinarizedRandomVectorScorer + extends RandomVectorScorer.AbstractRandomVectorScorer { + private final BinaryQueryVector queryVector; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + private final float sqrtDimensions; + private final float maxX1; + + public BinarizedRandomVectorScorer( + BinaryQueryVector queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction) { + super(targetVectors); + this.queryVector = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + this.sqrtDimensions = targetVectors.sqrtDimensions(); + this.maxX1 = targetVectors.maxX1(); + } + + @Override + public float score(int targetOrd) throws IOException { + byte[] quantizedQuery = queryVector.vector(); + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + float qcDist = VectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + float xbSum = (float) VectorUtil.popCount(binaryCode); + float[] correctiveTerms = targetVectors.getCorrectiveTerms(targetOrd); + if (similarityFunction == EUCLIDEAN) { + return euclideanScore(xbSum, qcDist, correctiveTerms, queryVector.factors); + } + return dotProductScore( + xbSum, qcDist, targetVectors.getCentroidDP(), correctiveTerms, queryVector.factors); + } + + private float dotProductScore( + float xbSum, + float qcDist, + float cDotC, + float[] vectorCorrectiveTerms, + float[] queryCorrectiveTerms) { + assert vectorCorrectiveTerms.length == 3; + assert queryCorrectiveTerms.length == 5; + float lower = queryCorrectiveTerms[0] / sqrtDimensions; + float width = queryCorrectiveTerms[1] / sqrtDimensions; + float vmC = queryCorrectiveTerms[2]; + float vDotC = queryCorrectiveTerms[3]; + float quantizedSum = queryCorrectiveTerms[4]; + float ooq = vectorCorrectiveTerms[0]; + float vmcNormOC = vectorCorrectiveTerms[1] * vmC; + float oDotC = vectorCorrectiveTerms[2]; + + final float dist; + // If ||o-c|| == 0, so, it's ok to throw the rest of the equation away + // and simply use `oDotC + vDotC - cDotC` as centroid == doc vector + if (vmcNormOC == 0 || ooq == 0) { + dist = oDotC + vDotC - cDotC; + } else { + // If ||o-c|| != 0, we should assume that `ooq` is finite + assert Float.isFinite(ooq); + float estimatedDot = + (2 * width * qcDist + + 2 * lower * xbSum + - width * quantizedSum + - targetVectors.dimension() * lower) + / ooq; + dist = vmcNormOC * estimatedDot + oDotC + vDotC - cDotC; + } + assert Float.isFinite(dist); + + float ooqSqr = ooq * ooq; + float errorBound = (float) (vmcNormOC * (maxX1 * Math.sqrt((1 - ooqSqr) / ooqSqr))); + float score = Float.isFinite(errorBound) ? dist - errorBound : dist; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + + private float euclideanScore( + float xbSum, float qcDist, float[] vectorCorrectiveTerms, float[] queryCorrectiveTerms) { + assert vectorCorrectiveTerms.length == 2; + assert queryCorrectiveTerms.length == 4; + float distanceToCentroid = queryCorrectiveTerms[0]; + float lower = queryCorrectiveTerms[1]; + float width = queryCorrectiveTerms[2]; + float quantizedSum = queryCorrectiveTerms[3]; + + float targetDistToC = vectorCorrectiveTerms[0]; + float x0 = vectorCorrectiveTerms[1]; + float sqrX = targetDistToC * targetDistToC; + double xX0 = targetDistToC / x0; + + float factorPPC = + (float) (-2.0 / sqrtDimensions * xX0 * (xbSum * 2.0 - targetVectors.dimension())); + float factorIP = (float) (-2.0 / sqrtDimensions * xX0); + + float score = + sqrX + + distanceToCentroid + + factorPPC * lower + + (qcDist * 2 - quantizedSum) * factorIP * width; + float projectionDist = (float) Math.sqrt(xX0 * xX0 - targetDistToC * targetDistToC); + float error = 2.0f * maxX1 * projectionDist; + float y = (float) Math.sqrt(distanceToCentroid); + float errorBound = y * error; + if (Float.isFinite(errorBound)) { + score = score + errorBound; + } + return Math.max(1 / (1f + score), 0); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsFormat.java new file mode 100644 index 000000000000..33cc988670a9 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsFormat.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene101; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * Codec for encoding/decoding binary quantized vectors The binary quantization format used here + * reflects RaBitQ. Also see {@link + * org.apache.lucene.util.quantization.BinaryQuantizer}. Some of key features of RabitQ are: + * + * + * + * The format is stored in two files: + * + *

.veb (vector data) file

+ * + *

Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's + * corrective factors. At the end of the file, additional information is stored for vector ordinal + * to centroid ordinal mapping and sparse vector information. + * + *

+ * + *

.vemb (vector metadata) file

+ * + *

Stores the metadata for the vectors. This includes the number of vectors, the number of + * dimensions, and file offset information. + * + *

+ */ +public class Lucene101BinaryQuantizedVectorsFormat extends FlatVectorsFormat { + + public static final String BINARIZED_VECTOR_COMPONENT = "BVEC"; + public static final String NAME = "Lucene101BinaryQuantizedVectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "Lucene101BinaryQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "Lucene101BinaryQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemb"; + static final String VECTOR_DATA_EXTENSION = "veb"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + private static final FlatVectorsFormat rawVectorFormat = + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + private final Lucene101BinaryFlatVectorsScorer scorer = + new Lucene101BinaryFlatVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + /** Creates a new instance with the default number of vectors per cluster. */ + public Lucene101BinaryQuantizedVectorsFormat() { + super(NAME); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene101BinaryQuantizedVectorsWriter( + scorer, rawVectorFormat.fieldsWriter(state), state); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene101BinaryQuantizedVectorsReader( + state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public String toString() { + return "Lucene101BinaryQuantizedVectorsFormat(name=" + + NAME + + ", flatVectorScorer=" + + scorer + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsReader.java new file mode 100644 index 000000000000..db54b857f202 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsReader.java @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene101; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.quantization.BQSpaceUtils; +import org.apache.lucene.util.quantization.BinaryQuantizer; + +/** + * Reads raw and binarized vectors from the index segments for KNN search. + * + * @lucene.experimental + */ +public class Lucene101BinaryQuantizedVectorsReader extends FlatVectorsReader { + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(Lucene101BinaryQuantizedVectorsReader.class); + + private final Map fields = new HashMap<>(); + private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; + private final Lucene101BinaryFlatVectorsScorer vectorScorer; + + public Lucene101BinaryQuantizedVectorsReader( + SegmentReadState state, + FlatVectorsReader rawVectorsReader, + Lucene101BinaryFlatVectorsScorer vectorsScorer) + throws IOException { + super(vectorsScorer); + this.vectorScorer = vectorsScorer; + this.rawVectorsReader = rawVectorsReader; + int versionMeta = -1; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene101BinaryQuantizedVectorsFormat.META_EXTENSION); + boolean success = false; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = + CodecUtil.checkIndexHeader( + meta, + Lucene101BinaryQuantizedVectorsFormat.META_CODEC_NAME, + Lucene101BinaryQuantizedVectorsFormat.VERSION_START, + Lucene101BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + quantizedVectorData = + openDataInput( + state, + versionMeta, + Lucene101BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION, + Lucene101BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + // Quantized vectors are accessed randomly from their node ID stored in the HNSW + // graph. + state.context.withReadAdvice(ReadAdvice.RANDOM)); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta, info); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + int dimension = info.getVectorDimension(); + if (dimension != fieldEntry.dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + + info.name + + "\"; " + + dimension + + " != " + + fieldEntry.dimension); + } + + int binaryDims = BQSpaceUtils.discretize(dimension, 64) / 8; + int correctionsCount = + fieldEntry.similarityFunction != VectorSimilarityFunction.EUCLIDEAN ? 3 : 2; + long numQuantizedVectorBytes = + Math.multiplyExact((binaryDims + (Float.BYTES * correctionsCount)), (long) fieldEntry.size); + if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { + throw new IllegalStateException( + "Binarized vector data length " + + fieldEntry.vectorDataLength + + " not matching size = " + + fieldEntry.size + + " * (binaryBytes=" + + binaryDims + + " + 8" + + ") = " + + numQuantizedVectorBytes); + } + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + return vectorScorer.getRandomVectorScorer( + fi.similarityFunction, + OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new BinaryQuantizer(fi.dimension, fi.descritizedDimension, fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData), + target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(quantizedVectorData); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fi.vectorEncoding + + " expected: " + + VectorEncoding.FLOAT32); + } + OffHeapBinarizedVectorValues bvv = + OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new BinaryQuantizer(fi.dimension, fi.descritizedDimension, fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData); + return new BinarizedVectorValues(rawVectorsReader.getFloatVectorValues(field), bvv); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + throws IOException { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + throws IOException { + if (knnCollector.k() == 0) return; + final RandomVectorScorer scorer = getRandomVectorScorer(field, target); + if (scorer == null) return; + OrdinalTranslatedKnnCollector collector = + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + collector.collect(i, scorer.score(i)); + collector.incVisitedCount(1); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += + RamUsageEstimator.sizeOfMap( + fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + size += rawVectorsReader.ramBytesUsed(); + return size; + } + + public float[] getCentroid(String field) { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry != null) { + return fieldEntry.centroid; + } + return null; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context) + throws IOException { + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + int versionVectorData = + CodecUtil.checkIndexHeader( + in, + codecName, + Lucene101BinaryQuantizedVectorsFormat.VERSION_START, + Lucene101BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + + versionMeta + + ", " + + codecName + + "=" + + versionVectorData, + in); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction()); + } + return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction()); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + int dimension, + int descritizedDimension, + long vectorDataOffset, + long vectorDataLength, + int size, + float[] centroid, + float centroidDP, + OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) { + + static FieldEntry create( + IndexInput input, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction) + throws IOException { + int dimension = input.readVInt(); + long vectorDataOffset = input.readVLong(); + long vectorDataLength = input.readVLong(); + int size = input.readVInt(); + final float[] centroid; + float centroidDP = 0; + if (size > 0) { + centroid = new float[dimension]; + input.readFloats(centroid, 0, dimension); + centroidDP = Float.intBitsToFloat(input.readInt()); + } else { + centroid = null; + } + OrdToDocDISIReaderConfiguration conf = + OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry( + similarityFunction, + vectorEncoding, + dimension, + BQSpaceUtils.discretize(dimension, 64), + vectorDataOffset, + vectorDataLength, + size, + centroid, + centroidDP, + conf); + } + } + + /** Binarized vector values holding row and quantized vector values */ + protected static final class BinarizedVectorValues extends FloatVectorValues { + private final FloatVectorValues rawVectorValues; + private final BinarizedByteVectorValues quantizedVectorValues; + + BinarizedVectorValues( + FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) { + this.rawVectorValues = rawVectorValues; + this.quantizedVectorValues = quantizedVectorValues; + } + + @Override + public int dimension() { + return rawVectorValues.dimension(); + } + + @Override + public int size() { + return rawVectorValues.size(); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); + } + + @Override + public BinarizedVectorValues copy() throws IOException { + return new BinarizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return rawVectorValues.getAcceptOrds(acceptDocs); + } + + @Override + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return rawVectorValues.iterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.scorer(query); + } + + BinarizedByteVectorValues getQuantizedVectorValues() { + return quantizedVectorValues; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsWriter.java new file mode 100644 index 000000000000..396bb78b94f9 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsWriter.java @@ -0,0 +1,966 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene101; + +import static org.apache.lucene.codecs.lucene101.Lucene101BinaryQuantizedVectorsFormat.BINARIZED_VECTOR_COMPONENT; +import static org.apache.lucene.codecs.lucene101.Lucene101BinaryQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.FloatArrayList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.BQSpaceUtils; +import org.apache.lucene.util.quantization.BinaryQuantizer; + +/** + * Writes raw and binarized vector values to index segments for KNN search. + * + * @lucene.experimental + */ +public class Lucene101BinaryQuantizedVectorsWriter extends FlatVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = + shallowSizeOfInstance(Lucene101BinaryQuantizedVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final List fields = new ArrayList<>(); + private final IndexOutput meta, binarizedVectorData; + private final FlatVectorsWriter rawVectorDelegate; + private final Lucene101BinaryFlatVectorsScorer vectorsScorer; + private boolean finished; + + /** + * Sole constructor + * + * @param vectorsScorer the scorer to use for scoring vectors + */ + protected Lucene101BinaryQuantizedVectorsWriter( + Lucene101BinaryFlatVectorsScorer vectorsScorer, + FlatVectorsWriter rawVectorDelegate, + SegmentWriteState state) + throws IOException { + super(vectorsScorer); + this.vectorsScorer = vectorsScorer; + this.segmentWriteState = state; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene101BinaryQuantizedVectorsFormat.META_EXTENSION); + + String binarizedVectorDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene101BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION); + this.rawVectorDelegate = rawVectorDelegate; + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + binarizedVectorData = + state.directory.createOutput(binarizedVectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + Lucene101BinaryQuantizedVectorsFormat.META_CODEC_NAME, + Lucene101BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + CodecUtil.writeIndexHeader( + binarizedVectorData, + Lucene101BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + Lucene101BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + FieldWriter fieldWriter = + new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); + fields.add(fieldWriter); + return fieldWriter; + } + return rawVectorDelegate; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + // after raw vectors are written, normalize vectors for clustering and quantization + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + field.normalizeVectors(); + } + + final float[] clusterCenter; + int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); + } + } + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + int descritizedDimension = BQSpaceUtils.discretize(field.fieldInfo.getVectorDimension(), 64); + BinaryQuantizer quantizer = + new BinaryQuantizer( + field.fieldInfo.getVectorDimension(), + descritizedDimension, + field.fieldInfo.getVectorSimilarityFunction()); + if (sortMap == null) { + writeField(field, clusterCenter, maxDoc, quantizer); + } else { + writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer); + } + field.finish(); + } + } + + private void writeField( + FieldWriter fieldData, float[] clusterCenter, int maxDoc, BinaryQuantizer quantizer) + throws IOException { + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeBinarizedVectors(fieldData, clusterCenter, quantizer); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = + fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + clusterCenter, + centroidDp, + fieldData.getDocsWithFieldSet()); + } + + private void writeBinarizedVectors( + FieldWriter fieldData, float[] clusterCenter, BinaryQuantizer scalarQuantizer) + throws IOException { + byte[] vector = + new byte[BQSpaceUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64) / 8]; + int correctionsCount = scalarQuantizer.getSimilarity() != EUCLIDEAN ? 3 : 2; + final ByteBuffer correctionsBuffer = + ByteBuffer.allocate(Float.BYTES * correctionsCount).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < fieldData.getVectors().size(); i++) { + float[] v = fieldData.getVectors().get(i); + float[] corrections = scalarQuantizer.quantizeForIndex(v, vector, clusterCenter); + binarizedVectorData.writeBytes(vector, vector.length); + for (float correction : corrections) { + correctionsBuffer.putFloat(correction); + } + binarizedVectorData.writeBytes(correctionsBuffer.array(), correctionsBuffer.array().length); + correctionsBuffer.rewind(); + } + } + + private void writeSortingField( + FieldWriter fieldData, + float[] clusterCenter, + int maxDoc, + Sorter.DocMap sortMap, + BinaryQuantizer scalarQuantizer) + throws IOException { + final int[] ordMap = + new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); + long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + + float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + quantizedVectorLength, + clusterCenter, + centroidDp, + newDocsWithField); + } + + private void writeSortedBinarizedVectors( + FieldWriter fieldData, float[] clusterCenter, int[] ordMap, BinaryQuantizer scalarQuantizer) + throws IOException { + byte[] vector = + new byte[BQSpaceUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64) / 8]; + int correctionsCount = scalarQuantizer.getSimilarity() != EUCLIDEAN ? 3 : 2; + final ByteBuffer correctionsBuffer = + ByteBuffer.allocate(Float.BYTES * correctionsCount).order(ByteOrder.LITTLE_ENDIAN); + for (int ordinal : ordMap) { + float[] v = fieldData.getVectors().get(ordinal); + float[] corrections = scalarQuantizer.quantizeForIndex(v, vector, clusterCenter); + binarizedVectorData.writeBytes(vector, vector.length); + for (float correction : corrections) { + correctionsBuffer.putFloat(correction); + } + binarizedVectorData.writeBytes(correctionsBuffer.array(), correctionsBuffer.array().length); + correctionsBuffer.rewind(); + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + float[] clusterCenter, + float centroidDp, + DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVInt(field.getVectorDimension()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + int count = docsWithField.cardinality(); + meta.writeVInt(count); + if (count > 0) { + final ByteBuffer buffer = + ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(clusterCenter); + meta.writeBytes(buffer.array(), buffer.array().length); + meta.writeInt(Float.floatToIntBits(centroidDp)); + } + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, meta, binarizedVectorData, count, maxDoc, docsWithField); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (binarizedVectorData != null) { + CodecUtil.writeFooter(binarizedVectorData); + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + int descritizedDimension = BQSpaceUtils.discretize(fieldInfo.getVectorDimension(), 64); + FloatVectorValues floatVectorValues = + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + BinarizedFloatVectorValues binarizedVectorValues = + new BinarizedFloatVectorValues( + floatVectorValues, + new BinaryQuantizer( + fieldInfo.getVectorDimension(), + descritizedDimension, + fieldInfo.getVectorSimilarityFunction()), + centroid); + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = + writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = + docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + centroidDp, + docsWithField); + } else { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + } + } + + static DocsWithFieldSet writeBinarizedVectorAndQueryData( + IndexOutput binarizedVectorData, + IndexOutput binarizedQueryData, + FloatVectorValues floatVectorValues, + float[] centroid, + BinaryQuantizer binaryQuantizer) + throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + byte[] toIndex = new byte[BQSpaceUtils.discretize(floatVectorValues.dimension(), 64) / 8]; + byte[] toQuery = + new byte + [(BQSpaceUtils.discretize(floatVectorValues.dimension(), 64) / 8) + * BQSpaceUtils.B_QUERY]; + int queryCorrectionCount = binaryQuantizer.getSimilarity() != EUCLIDEAN ? 4 : 3; + final ByteBuffer queryCorrectionsBuffer = + ByteBuffer.allocate(Float.BYTES * queryCorrectionCount + Short.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write index vector + BinaryQuantizer.QueryAndIndexResults r = + binaryQuantizer.quantizeQueryAndIndex( + floatVectorValues.vectorValue(iterator.index()), toIndex, toQuery, centroid); + binarizedVectorData.writeBytes(toIndex, toIndex.length); + float[] corrections = r.indexFeatures(); + for (float correction : corrections) { + binarizedVectorData.writeInt(Float.floatToIntBits(correction)); + } + docsWithField.add(docV); + + // write query vector + binarizedQueryData.writeBytes(toQuery, toQuery.length); + assert r.queryFeatures().length == queryCorrectionCount + 1; + float[] queryCorrections = r.queryFeatures(); + for (int i = 0; i < queryCorrectionCount; i++) { + queryCorrectionsBuffer.putFloat(queryCorrections[i]); + } + assert queryCorrections[queryCorrectionCount] >= 0 + && queryCorrections[queryCorrectionCount] <= 0xffff; + queryCorrectionsBuffer.putShort((short) queryCorrections[queryCorrectionCount]); + + binarizedQueryData.writeBytes( + queryCorrectionsBuffer.array(), queryCorrectionsBuffer.array().length); + queryCorrectionsBuffer.rewind(); + } + return docsWithField; + } + + static DocsWithFieldSet writeBinarizedVectorData( + IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write vector + byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index()); + output.writeBytes(binaryValue, binaryValue.length); + float[] corrections = binarizedByteVectorValues.getCorrectiveTerms(iterator.index()); + for (float correction : corrections) { + output.writeInt(Float.floatToIntBits(correction)); + } + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float cDotC; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC); + } + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + float[] centroid, + float cDotC) + throws IOException { + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + final IndexOutput tempQuantizedVectorData = + segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), "temp", segmentWriteState.context); + final IndexOutput tempScoreQuantizedVectorData = + segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), "score_temp", segmentWriteState.context); + IndexInput binarizedDataInput = null; + IndexInput binarizedScoreDataInput = null; + boolean success = false; + int descritizedDimension = BQSpaceUtils.discretize(fieldInfo.getVectorDimension(), 64); + BinaryQuantizer quantizer = + new BinaryQuantizer( + fieldInfo.getVectorDimension(), + descritizedDimension, + fieldInfo.getVectorSimilarityFunction()); + try { + FloatVectorValues floatVectorValues = + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + DocsWithFieldSet docsWithField = + writeBinarizedVectorAndQueryData( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + floatVectorValues, + centroid, + quantizer); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + binarizedDataInput = + segmentWriteState.directory.openInput( + tempQuantizedVectorData.getName(), segmentWriteState.context); + binarizedVectorData.copyBytes( + binarizedDataInput, binarizedDataInput.length() - CodecUtil.footerLength()); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + CodecUtil.retrieveChecksum(binarizedDataInput); + CodecUtil.writeFooter(tempScoreQuantizedVectorData); + IOUtils.close(tempScoreQuantizedVectorData); + binarizedScoreDataInput = + segmentWriteState.directory.openInput( + tempScoreQuantizedVectorData.getName(), segmentWriteState.context); + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + cDotC, + docsWithField); + success = true; + final IndexInput finalBinarizedDataInput = binarizedDataInput; + final IndexInput finalBinarizedScoreDataInput = binarizedScoreDataInput; + OffHeapBinarizedVectorValues vectorValues = + new OffHeapBinarizedVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + centroid, + cDotC, + quantizer, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, + finalBinarizedDataInput); + RandomVectorScorerSupplier scorerSupplier = + vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapBinarizedQueryVectorValues( + finalBinarizedScoreDataInput, + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + fieldInfo.getVectorSimilarityFunction()), + vectorValues); + return new BinarizedCloseableRandomVectorScorerSupplier( + scorerSupplier, + vectorValues, + () -> { + IOUtils.close(finalBinarizedDataInput, finalBinarizedScoreDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorData.getName(), + tempScoreQuantizedVectorData.getName()); + }); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + binarizedDataInput, + binarizedScoreDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorData.getName(), + tempScoreQuantizedVectorData.getName()); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, binarizedVectorData, rawVectorDelegate); + } + + static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof Lucene101BinaryQuantizedVectorsReader reader) { + return reader.getCentroid(fieldName); + } + return null; + } + + static int mergeAndRecalculateCentroids( + MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { + boolean recalculate = false; + int totalVectorCount = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null + || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name); + int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size(); + if (vectorCount == 0) { + continue; + } + totalVectorCount += vectorCount; + // If there aren't centroids, or previously clustered with more than one cluster + // or if there are deleted docs, we must recalculate the centroid + if (centroid == null || mergeState.liveDocs[i] != null) { + recalculate = true; + break; + } + for (int j = 0; j < centroid.length; j++) { + mergedCentroid[j] += centroid[j] * vectorCount; + } + } + if (recalculate) { + return calculateCentroid(mergeState, fieldInfo, mergedCentroid); + } else { + for (int j = 0; j < mergedCentroid.length; j++) { + mergedCentroid[j] = mergedCentroid[j] / totalVectorCount; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(mergedCentroid); + } + return totalVectorCount; + } + } + + static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) + throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + // clear out the centroid + Arrays.fill(centroid, 0); + int count = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null) continue; + FloatVectorValues vectorValues = + mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int doc = iterator.nextDoc(); + doc != DocIdSetIterator.NO_MORE_DOCS; + doc = iterator.nextDoc()) { + float[] vector = vectorValues.vectorValue(iterator.index()); + // TODO Panama sum + for (int j = 0; j < vector.length; j++) { + centroid[j] += vector[j]; + } + } + count += vectorValues.size(); + } + if (count == 0) { + return count; + } + // TODO Panama div + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= count; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(centroid); + } + return count; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private boolean finished; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + private final float[] dimensionSums; + private final FloatArrayList magnitudes = new FloatArrayList(); + + FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { + this.fieldInfo = fieldInfo; + this.flatFieldVectorsWriter = flatFieldVectorsWriter; + this.dimensionSums = new float[fieldInfo.getVectorDimension()]; + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + public void normalizeVectors() { + for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) { + float[] vector = flatFieldVectorsWriter.getVectors().get(i); + float magnitude = magnitudes.get(i); + for (int j = 0; j < vector.length; j++) { + vector[j] /= magnitude; + } + } + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float divisor = (float) Math.sqrt(dp); + magnitudes.add(divisor); + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); + } + } else { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += vectorValue[i]; + } + } + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += flatFieldVectorsWriter.ramBytesUsed(); + size += RamUsageEstimator.sizeOf(dimensionSums); + size += magnitudes.ramBytesUsed(); + return size; + } + } + + // When accessing vectorValue method, targerOrd here means a row ordinal. + static class OffHeapBinarizedQueryVectorValues { + private final IndexInput slice; + private final int dimension; + private final int size; + protected final byte[] binaryValue; + protected final ByteBuffer byteBuffer; + private final int byteSize; + protected final float[] correctiveValues; + private int lastOrd = -1; + private final int correctiveValuesSize; + private final VectorSimilarityFunction vectorSimilarityFunction; + + OffHeapBinarizedQueryVectorValues( + IndexInput data, + int dimension, + int size, + VectorSimilarityFunction vectorSimilarityFunction) { + this.slice = data; + this.dimension = dimension; + this.size = size; + this.vectorSimilarityFunction = vectorSimilarityFunction; + this.correctiveValuesSize = vectorSimilarityFunction != EUCLIDEAN ? 4 : 3; + // 4x the quantized binary dimensions + int binaryDimensions = (BQSpaceUtils.discretize(dimension, 64) / 8) * BQSpaceUtils.B_QUERY; + this.byteBuffer = ByteBuffer.allocate(binaryDimensions); + this.binaryValue = byteBuffer.array(); + // + 1 for the quantized sum + this.correctiveValues = new float[correctiveValuesSize + 1]; + this.byteSize = binaryDimensions + Float.BYTES * correctiveValuesSize + Short.BYTES; + } + + public float[] getCorrectiveTerms(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return correctiveValues; + } + vectorValue(targetOrd); + return correctiveValues; + } + + public int size() { + return size; + } + + public int dimension() { + return dimension; + } + + public OffHeapBinarizedQueryVectorValues copy() throws IOException { + return new OffHeapBinarizedQueryVectorValues( + slice.clone(), dimension, size, vectorSimilarityFunction); + } + + public IndexInput getSlice() { + return slice; + } + + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(binaryValue, 0, binaryValue.length); + slice.readFloats(correctiveValues, 0, correctiveValuesSize); + correctiveValues[correctiveValuesSize] = Short.toUnsignedInt(slice.readShort()); + lastOrd = targetOrd; + return binaryValue; + } + } + + static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { + private float[] corrections; + private final byte[] binarized; + private final float[] centroid; + private final FloatVectorValues values; + private final BinaryQuantizer quantizer; + private int lastOrd = -1; + + BinarizedFloatVectorValues( + FloatVectorValues delegate, BinaryQuantizer quantizer, float[] centroid) { + this.values = delegate; + this.quantizer = quantizer; + this.binarized = new byte[BQSpaceUtils.discretize(delegate.dimension(), 64) / 8]; + this.centroid = centroid; + } + + @Override + public float[] getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + + ord + + " than the quantization was done for: " + + lastOrd); + } + return corrections; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + binarize(ord); + lastOrd = ord; + } + return binarized; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public BinaryQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public int size() { + return values.size(); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public BinarizedByteVectorValues copy() throws IOException { + return new BinarizedFloatVectorValues(values.copy(), quantizer, centroid); + } + + private void binarize(int ord) throws IOException { + corrections = quantizer.quantizeForIndex(values.vectorValue(ord), binarized, centroid); + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + } + + static class BinarizedCloseableRandomVectorScorerSupplier + implements CloseableRandomVectorScorerSupplier { + private final RandomVectorScorerSupplier supplier; + private final KnnVectorValues vectorValues; + private final Closeable onClose; + + BinarizedCloseableRandomVectorScorerSupplier( + RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) { + this.supplier = supplier; + this.onClose = onClose; + this.vectorValues = vectorValues; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + return supplier.scorer(ord); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return vectorValues.size(); + } + } + + static final class NormalizedFloatVectorValues extends FloatVectorValues { + private final FloatVectorValues values; + private final float[] normalizedVector; + + NormalizedFloatVectorValues(FloatVectorValues values) { + this.values = values; + this.normalizedVector = new float[values.dimension()]; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101HnswBinaryQuantizedVectorsFormat.java new file mode 100644 index 000000000000..debad93e5bad --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101HnswBinaryQuantizedVectorsFormat.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene101; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; + +/** + * Lucene 10.1 vector format, which encodes numeric vector values into an associated graph + * connecting the documents having values. The graph is used to power HNSW search. The format + * consists of two files, and uses {@link Lucene101BinaryQuantizedVectorsFormat} to store the actual + * vectors: For details on graph storage and file extensions, see {@link Lucene99HnswVectorsFormat}. + * + * @lucene.experimental + */ +public class Lucene101HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "Lucene101HnswBinaryQuantizedVectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private final FlatVectorsFormat flatVectorsFormat; + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public Lucene101HnswBinaryQuantizedVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public Lucene101HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public Lucene101HnswBinaryQuantizedVectorsFormat( + int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + + MAXIMUM_MAX_CONN + + "; maxConn=" + + maxConn); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + + MAXIMUM_BEAM_WIDTH + + "; beamWidth=" + + beamWidth); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException( + "No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + this.flatVectorsFormat = new Lucene101BinaryQuantizedVectorsFormat(); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "Lucene101HnswBinaryQuantizedVectorsFormat(name=Lucene101HnswBinaryQuantizedVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/OffHeapBinarizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/OffHeapBinarizedVectorValues.java new file mode 100644 index 000000000000..411136fbf959 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/OffHeapBinarizedVectorValues.java @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene101; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.util.quantization.BQSpaceUtils.constSqrt; + +import java.io.IOException; +import java.nio.ByteBuffer; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.apache.lucene.util.quantization.BQSpaceUtils; +import org.apache.lucene.util.quantization.BinaryQuantizer; + +/** Binarized vector values loaded from off-heap */ +public abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues { + + protected final int dimension; + protected final int size; + protected final int numBytes; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer vectorsScorer; + + protected final IndexInput slice; + protected final byte[] binaryValue; + protected final ByteBuffer byteBuffer; + protected final int byteSize; + private int lastOrd = -1; + protected final float[] correctiveValues; + protected final BinaryQuantizer binaryQuantizer; + protected final float[] centroid; + protected final float centroidDp; + private final int discretizedDimensions; + private final float maxX1; + private final float sqrtDimensions; + private final int correctionsCount; + + OffHeapBinarizedVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + BinaryQuantizer quantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.numBytes = BQSpaceUtils.discretize(dimension, 64) / 8; + this.correctionsCount = similarityFunction != EUCLIDEAN ? 3 : 2; + this.correctiveValues = new float[this.correctionsCount]; + this.byteSize = numBytes + (Float.BYTES * correctionsCount); + this.byteBuffer = ByteBuffer.allocate(numBytes); + this.binaryValue = byteBuffer.array(); + this.binaryQuantizer = quantizer; + this.discretizedDimensions = BQSpaceUtils.discretize(dimension, 64); + this.sqrtDimensions = (float) constSqrt(dimension); + this.maxX1 = (float) (1.9 / constSqrt(discretizedDimensions - 1.0)); + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes); + slice.readFloats(correctiveValues, 0, correctionsCount); + lastOrd = targetOrd; + return binaryValue; + } + + @Override + public int discretizedDimensions() { + return discretizedDimensions; + } + + @Override + public float sqrtDimensions() { + return sqrtDimensions; + } + + @Override + public float maxX1() { + return maxX1; + } + + @Override + public float getCentroidDP() { + return centroidDp; + } + + @Override + public float[] getCorrectiveTerms(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return correctiveValues; + } + slice.seek(((long) targetOrd * byteSize) + numBytes); + slice.readFloats(correctiveValues, 0, correctionsCount); + return correctiveValues; + } + + @Override + public BinaryQuantizer getQuantizer() { + return binaryQuantizer; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public int getVectorByteLength() { + return numBytes; + } + + public static OffHeapBinarizedVectorValues load( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + BinaryQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + float[] centroid, + float centroidDp, + long quantizedVectorDataOffset, + long quantizedVectorDataLength, + IndexInput vectorData) + throws IOException { + if (configuration.isEmpty()) { + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); + } + assert centroid != null; + IndexInput bytesSlice = + vectorData.slice( + "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); + if (configuration.isDense()) { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + bytesSlice); + } else { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice); + } + } + + /** Dense off-heap binarized vector values */ + public static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + public DenseOffHeapVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + BinaryQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) { + super( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + /** Sparse off-heap binarized vector values */ + private static class SparseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + float[] centroid, + float centroidDp, + BinaryQuantizer binaryQuantizer, + IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice) + throws IOException { + super( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone()); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = + vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapBinarizedVectorValues { + EmptyOffHeapVectorValues( + int dimension, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer) { + super(dimension, 0, null, Float.NaN, null, similarityFunction, vectorsScorer, null); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public DenseOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] target) { + return null; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java index e582f12c3185..f2bec47e18e5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - /** * Lucene 10.1 file format. * diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java index 184403cf48b7..585b8dbb47aa 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java @@ -17,6 +17,9 @@ package org.apache.lucene.internal.vectorization; +import static org.apache.lucene.util.VectorUtil.B_QUERY; + +import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.Constants; import org.apache.lucene.util.SuppressForbidden; @@ -198,6 +201,31 @@ public int squareDistance(byte[] a, byte[] b) { return squareSum; } + @Override + public long ipByteBinByte(byte[] q, byte[] d) { + return ipByteBinByteImpl(q, d); + } + + public static long ipByteBinByteImpl(byte[] q, byte[] d) { + long ret = 0; + int size = d.length; + for (int i = 0; i < B_QUERY; i++) { + int r = 0; + long subRet = 0; + for (final int upperBound = d.length & -Integer.BYTES; r < upperBound; r += Integer.BYTES) { + subRet += + Integer.bitCount( + (int) BitUtil.VH_NATIVE_INT.get(q, i * size + r) + & (int) BitUtil.VH_NATIVE_INT.get(d, r)); + } + for (; r < d.length; r++) { + subRet += Integer.bitCount((q[i * size + r] & d[r]) & 0xFF); + } + ret += subRet << i; + } + return ret; + } + @Override public int findNextGEQ(int[] buffer, int target, int from, int to) { for (int i = from; i < to; ++i) { diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java index fb94b0e31736..613c1d8c5be4 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java @@ -45,6 +45,9 @@ public interface VectorUtilSupport { /** Returns the sum of squared differences of the two byte vectors. */ int squareDistance(byte[] a, byte[] b); + /** This does a bit-wise dot-product between two particularly formatted byte arrays. */ + long ipByteBinByte(byte[] q, byte[] d); + /** * Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code to} * exclusive, find the first array index whose value is greater than or equal to {@code target}. diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 250c65448703..e08d3a9f0317 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -124,6 +124,16 @@ public static float[] l2normalize(float[] v) { return v; } + /** + * Return the l2Norm of the vector. + * + * @param v the vector + * @return the l2Norm of the vector + */ + public static float l2Norm(float[] v) { + return (float) Math.sqrt(IMPL.dotProduct(v, v)); + } + public static boolean isUnitVector(float[] v) { double l1norm = IMPL.dotProduct(v, v); return Math.abs(l1norm - 1.0d) <= EPSILON; @@ -169,6 +179,30 @@ public static void add(float[] u, float[] v) { } } + /** + * Subtracts the second argument from the first + * + * @param u the destination + * @param v the vector to subtract from the destination + */ + public static void subtract(float[] u, float[] v) { + for (int i = 0; i < u.length; i++) { + u[i] -= v[i]; + } + } + + /** + * Divides the first argument by the second + * + * @param u the destination + * @param v to divide the destination by + */ + public static void divide(float[] u, float v) { + for (int i = 0; i < u.length; i++) { + u[i] /= v; + } + } + /** * Dot product computed over signed bytes. * @@ -269,6 +303,44 @@ static int xorBitCountLong(byte[] a, byte[] b) { return distance; } + /** + * The popCount for the given byte array. + * + * @param v the byte array + * @return the number of set bits in the byte array + */ + public static int popCount(byte[] v) { + if (XOR_BIT_COUNT_STRIDE_AS_INT) { + return popCountInt(v); + } else { + return popCountLong(v); + } + } + + static int popCountInt(byte[] d) { + int r = 0; + int cnt = 0; + for (final int upperBound = d.length & -Integer.BYTES; r < upperBound; r += Integer.BYTES) { + cnt += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(d, r)); + } + for (; r < d.length; r++) { + cnt += Integer.bitCount(d[r] & 0xFF); + } + return cnt; + } + + static int popCountLong(byte[] d) { + int r = 0; + int cnt = 0; + for (final int upperBound = d.length & -Long.BYTES; r < upperBound; r += Long.BYTES) { + cnt += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(d, r)); + } + for (; r < d.length; r++) { + cnt += Integer.bitCount(d[r] & 0xFF); + } + return cnt; + } + /** * Dot product score computed over signed bytes, scaled to be in [0, 1]. * @@ -309,6 +381,28 @@ public static float[] checkFinite(float[] v) { return v; } + public static final short B_QUERY = 4; + + /** + * This does a dot-product between two particularly formatted byte arrays. It is assumed that q is + * 4 times the size of d and bits for each individual dimension are packed by their order. An + * example encoding for q would be for values 0, 12, 7, 5 which have the binary values of 0000, + * 1100, 0111, 0101 the bits would actually be packed as 0011, 0010, 0111, 0100 or the values 3, + * 2, 7, 4. + * + * @param q an int4 encoded byte array, but where the lower level bits are collected are first + * with higher order bits following later + * @param d a bit encoded byte array + * @return the dot product + */ + public static long ipByteBinByte(byte[] q, byte[] d) { + if (q.length != d.length * B_QUERY) { + throw new IllegalArgumentException( + "vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length); + } + return IMPL.ipByteBinByte(q, d); + } + /** * Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code to} * exclusive, find the first array index whose value is greater than or equal to {@code target}. diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/BQSpaceUtils.java b/lucene/core/src/java/org/apache/lucene/util/quantization/BQSpaceUtils.java new file mode 100644 index 000000000000..4f89801684e3 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/BQSpaceUtils.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util.quantization; + +import org.apache.lucene.util.ArrayUtil; + +/** Utility class for quantization calculations */ +public class BQSpaceUtils { + + public static final short B_QUERY = 4; + // the first four bits masked + private static final int B_QUERY_MASK = 15; + + public static double sqrtNewtonRaphson(double x, double curr, double prev) { + return (curr == prev) ? curr : sqrtNewtonRaphson(x, 0.5 * (curr + x / curr), curr); + } + + public static double constSqrt(double x) { + return x >= 0 && Double.isInfinite(x) == false ? sqrtNewtonRaphson(x, x, 0) : Double.NaN; + } + + public static int discretize(int value, int bucket) { + return ((value + (bucket - 1)) / bucket) * bucket; + } + + public static float[] pad(float[] vector, int dimensions) { + if (vector.length >= dimensions) { + return vector; + } + return ArrayUtil.growExact(vector, dimensions); + } + + public static byte[] pad(byte[] vector, int dimensions) { + if (vector.length >= dimensions) { + return vector; + } + return ArrayUtil.growExact(vector, dimensions); + } + + /** + * Transpose the query vector into a byte array allowing for efficient bitwise operations with the + * index bit vectors. The idea here is to organize the query vector bits such that the first bit + * of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second, + * third, and fourth bits are in the second, third, and fourth set of dimensions bits, + * respectively. This allows for direct bitwise comparisons with the stored index vectors through + * summing the bitwise results with the relative required bit shifts. + * + * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15 + * @param dimensions the number of dimensions in the query vector + * @param quantQueryByte the byte array to store the transposed query vector + */ + public static void transposeBin(byte[] q, int dimensions, byte[] quantQueryByte) { + int qOffset = 0; + final byte[] v1 = new byte[4]; + final byte[] v = new byte[32]; + for (int i = 0; i < dimensions; i += 32) { + // for every four bytes we shift left (with remainder across those bytes) + for (int j = 0; j < v.length; j += 4) { + v[j] = (byte) (q[qOffset + j] << B_QUERY | ((q[qOffset + j] >>> B_QUERY) & B_QUERY_MASK)); + v[j + 1] = + (byte) + (q[qOffset + j + 1] << B_QUERY | ((q[qOffset + j + 1] >>> B_QUERY) & B_QUERY_MASK)); + v[j + 2] = + (byte) + (q[qOffset + j + 2] << B_QUERY | ((q[qOffset + j + 2] >>> B_QUERY) & B_QUERY_MASK)); + v[j + 3] = + (byte) + (q[qOffset + j + 3] << B_QUERY | ((q[qOffset + j + 3] >>> B_QUERY) & B_QUERY_MASK)); + } + for (int j = 0; j < B_QUERY; j++) { + moveMaskEpi8Byte(v, v1); + for (int k = 0; k < 4; k++) { + quantQueryByte[(B_QUERY - j - 1) * (dimensions / 8) + i / 8 + k] = v1[k]; + v1[k] = 0; + } + for (int k = 0; k < v.length; k += 4) { + v[k] = (byte) (v[k] + v[k]); + v[k + 1] = (byte) (v[k + 1] + v[k + 1]); + v[k + 2] = (byte) (v[k + 2] + v[k + 2]); + v[k + 3] = (byte) (v[k + 3] + v[k + 3]); + } + } + qOffset += 32; + } + } + + private static void moveMaskEpi8Byte(byte[] v, byte[] v1b) { + int m = 0; + for (int k = 0; k < v.length; k++) { + if ((v[k] & 0b10000000) == 0b10000000) { + v1b[m] |= 0b00000001; + } + if (k % 8 == 7) { + m++; + } else { + v1b[m] <<= 1; + } + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/BinaryQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/BinaryQuantizer.java new file mode 100644 index 000000000000..35856104c266 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/BinaryQuantizer.java @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util.quantization; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; + +/** + * Quantized that quantizes raw vector values to binary. The algorithm is based on the paper RaBitQ. + */ +public class BinaryQuantizer { + private final int discretizedDimensions; + + private final VectorSimilarityFunction similarityFunction; + private final float sqrtDimensions; + + public BinaryQuantizer( + int dimensions, int discretizedDimensions, VectorSimilarityFunction similarityFunction) { + if (dimensions <= 0) { + throw new IllegalArgumentException("dimensions must be > 0 but was: " + dimensions); + } + assert discretizedDimensions % 64 == 0 + : "discretizedDimensions must be a multiple of 64 but was: " + discretizedDimensions; + this.discretizedDimensions = discretizedDimensions; + this.similarityFunction = similarityFunction; + this.sqrtDimensions = (float) Math.sqrt(dimensions); + } + + BinaryQuantizer(int dimensions, VectorSimilarityFunction similarityFunction) { + this(dimensions, dimensions, similarityFunction); + } + + private static void removeSignAndDivide(float[] a, float divisor) { + for (int i = 0; i < a.length; i++) { + a[i] = Math.abs(a[i]) / divisor; + } + } + + private static float sumAndNormalize(float[] a, float norm) { + float aDivided = 0f; + + for (int i = 0; i < a.length; i++) { + aDivided += a[i]; + } + + aDivided = aDivided / norm; + if (!Float.isFinite(aDivided)) { + aDivided = 0.8f; // can be anything + } + + return aDivided; + } + + private static void packAsBinary(float[] vector, byte[] packedVector) { + for (int h = 0; h < vector.length; h += 8) { + byte result = 0; + int q = 0; + for (int i = 7; i >= 0; i--) { + if (vector[h + i] > 0) { + result |= (byte) (1 << q); + } + q++; + } + packedVector[h / 8] = result; + } + } + + public VectorSimilarityFunction getSimilarity() { + return this.similarityFunction; + } + + private record SubspaceOutput(float projection) {} + + private SubspaceOutput generateSubSpace( + float[] vector, float[] centroid, byte[] quantizedVector) { + // typically no-op if dimensions/64 + float[] paddedCentroid = BQSpaceUtils.pad(centroid, discretizedDimensions); + float[] paddedVector = BQSpaceUtils.pad(vector, discretizedDimensions); + + VectorUtil.subtract(paddedVector, paddedCentroid); + + // The inner product between the data vector and the quantized data vector + float norm = VectorUtil.l2Norm(paddedVector); + + packAsBinary(paddedVector, quantizedVector); + + removeSignAndDivide(paddedVector, sqrtDimensions); + float projection = sumAndNormalize(paddedVector, norm); + + return new SubspaceOutput(projection); + } + + record SubspaceOutputMIP(float OOQ, float normOC, float oDotC) {} + + private SubspaceOutputMIP generateSubSpaceMIP( + float[] vector, float[] centroid, byte[] quantizedVector) { + + // typically no-op if dimensions/64 + float[] paddedCentroid = BQSpaceUtils.pad(centroid, discretizedDimensions); + float[] paddedVector = BQSpaceUtils.pad(vector, discretizedDimensions); + + float oDotC = VectorUtil.dotProduct(paddedVector, paddedCentroid); + VectorUtil.subtract(paddedVector, paddedCentroid); + + float normOC = VectorUtil.l2Norm(paddedVector); + packAsBinary(paddedVector, quantizedVector); + VectorUtil.divide(paddedVector, normOC); // OmC / norm(OmC) + + float OOQ = computerOOQ(vector.length, paddedVector, quantizedVector); + + return new SubspaceOutputMIP(OOQ, normOC, oDotC); + } + + private float computerOOQ(int originalLength, float[] normOMinusC, byte[] packedBinaryVector) { + float OOQ = 0f; + for (int j = 0; j < originalLength / 8; j++) { + for (int r = 0; r < 8; r++) { + int sign = ((packedBinaryVector[j] >> (7 - r)) & 0b00000001); + OOQ += (normOMinusC[j * 8 + r] * (2 * sign - 1)); + } + } + OOQ = OOQ / sqrtDimensions; + return OOQ; + } + + private static float[] range(float[] q) { + float vl = 1e20f; + float vr = -1e20f; + for (int i = 0; i < q.length; i++) { + if (q[i] < vl) { + vl = q[i]; + } + if (q[i] > vr) { + vr = q[i]; + } + } + + return new float[] {vl, vr}; + } + + /** Results of quantizing a vector for both querying and indexing */ + public record QueryAndIndexResults(float[] indexFeatures, float[] queryFeatures) {} + + /** + * Quantizes the given vector to both single bits and int4 precision. Also containes two distinct + * float arrays of corrective factors. For details see the individual methods {@link + * #quantizeForIndex(float[], byte[], float[])} and {@link #quantizeForQuery(float[], byte[], + * float[])}. + * + * @param vector the vector to quantize + * @param indexDestination the destination byte array to store the quantized bit vector + * @param queryDestination the destination byte array to store the quantized int4 vector + * @param centroid the centroid to use for quantization + * @return the corrective factors used for scoring error correction. + */ + public QueryAndIndexResults quantizeQueryAndIndex( + float[] vector, byte[] indexDestination, byte[] queryDestination, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert this.discretizedDimensions == BQSpaceUtils.discretize(vector.length, 64); + + if (this.discretizedDimensions != indexDestination.length * 8) { + throw new IllegalArgumentException( + "vector and quantized vector destination must be compatible dimensions: " + + BQSpaceUtils.discretize(vector.length, 64) + + " [ " + + this.discretizedDimensions + + " ]" + + "!= " + + indexDestination.length + + " * 8"); + } + + if (this.discretizedDimensions != (queryDestination.length * 8) / BQSpaceUtils.B_QUERY) { + throw new IllegalArgumentException( + "vector and quantized vector destination must be compatible dimensions: " + + vector.length + + " [ " + + this.discretizedDimensions + + " ]" + + "!= (" + + queryDestination.length + + " * 8) / " + + BQSpaceUtils.B_QUERY); + } + + if (vector.length != centroid.length) { + throw new IllegalArgumentException( + "vector and centroid dimensions must be the same: " + + vector.length + + "!= " + + centroid.length); + } + vector = ArrayUtil.copyArray(vector); + // only need distToC for euclidean + float distToC = + similarityFunction == EUCLIDEAN ? VectorUtil.squareDistance(vector, centroid) : 0f; + // only need vdotc for dot-products similarity, but not for euclidean + float vDotC = similarityFunction != EUCLIDEAN ? VectorUtil.dotProduct(vector, centroid) : 0f; + VectorUtil.subtract(vector, centroid); + // both euclidean and dot-product need the norm of the vector, just at different times + float normVmC = VectorUtil.l2Norm(vector); + // quantize for index + packAsBinary(BQSpaceUtils.pad(vector, discretizedDimensions), indexDestination); + if (similarityFunction != EUCLIDEAN) { + VectorUtil.divide(vector, normVmC); + } + + // Quantize for query + float[] range = range(vector); + float lower = range[0]; + float upper = range[1]; + // Δ := (𝑣𝑟 − 𝑣𝑙)/(2𝐵𝑞 − 1) + float width = (upper - lower) / ((1 << BQSpaceUtils.B_QUERY) - 1); + + QuantResult quantResult = quantize(vector, lower, width); + byte[] byteQuery = quantResult.result(); + + // q¯ = Δ · q¯𝑢 + 𝑣𝑙 · 1𝐷 + // q¯ is an approximation of q′ (scalar quantized approximation) + // FIXME: vectors need to be padded but that's expensive; update transponseBin to deal + byteQuery = BQSpaceUtils.pad(byteQuery, discretizedDimensions); + BQSpaceUtils.transposeBin(byteQuery, discretizedDimensions, queryDestination); + final float[] indexCorrections; + final float[] queryCorrections; + if (similarityFunction == EUCLIDEAN) { + indexCorrections = new float[2]; + indexCorrections[0] = (float) Math.sqrt(distToC); + removeSignAndDivide(vector, sqrtDimensions); + indexCorrections[1] = sumAndNormalize(vector, normVmC); + queryCorrections = new float[] {distToC, lower, width, quantResult.quantizedSum}; + } else { + indexCorrections = new float[3]; + indexCorrections[0] = computerOOQ(vector.length, vector, indexDestination); + indexCorrections[1] = normVmC; + indexCorrections[2] = vDotC; + queryCorrections = new float[] {lower, width, normVmC, vDotC, quantResult.quantizedSum}; + } + return new QueryAndIndexResults(indexCorrections, queryCorrections); + } + + /** + * Quantizes the given vector to single bits and returns an array of corrective factors. For the + * dot-product family of distances, the corrective terms are, in order + * + *
    + *
  • the dot-product of the normalized, centered vector with its binarized self + *
  • the norm of the centered vector + *
  • the dot-product of the vector with the centroid + *
+ * + * For euclidean: + * + *
    + *
  • The euclidean distance to the centroid + *
  • The sum of the dimensions divided by the vector norm + *
+ * + * @param vector the vector to quantize + * @param destination the destination byte array to store the quantized vector + * @param centroid the centroid to use for quantization + * @return the corrective factors used for scoring error correction. + */ + public float[] quantizeForIndex(float[] vector, byte[] destination, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert this.discretizedDimensions == BQSpaceUtils.discretize(vector.length, 64); + + if (this.discretizedDimensions != destination.length * 8) { + throw new IllegalArgumentException( + "vector and quantized vector destination must be compatible dimensions: " + + BQSpaceUtils.discretize(vector.length, 64) + + " [ " + + this.discretizedDimensions + + " ]" + + "!= " + + destination.length + + " * 8"); + } + + if (vector.length != centroid.length) { + throw new IllegalArgumentException( + "vector and centroid dimensions must be the same: " + + vector.length + + "!= " + + centroid.length); + } + + float[] corrections; + + vector = ArrayUtil.copyArray(vector); + + switch (similarityFunction) { + case EUCLIDEAN: + float distToCentroid = (float) Math.sqrt(VectorUtil.squareDistance(vector, centroid)); + + SubspaceOutput subspaceOutput = generateSubSpace(vector, centroid, destination); + corrections = new float[2]; + corrections[0] = distToCentroid; + corrections[1] = subspaceOutput.projection(); + break; + case MAXIMUM_INNER_PRODUCT: + case COSINE: + case DOT_PRODUCT: + SubspaceOutputMIP subspaceOutputMIP = generateSubSpaceMIP(vector, centroid, destination); + corrections = new float[3]; + corrections[0] = subspaceOutputMIP.OOQ(); + corrections[1] = subspaceOutputMIP.normOC(); + corrections[2] = subspaceOutputMIP.oDotC(); + break; + default: + throw new UnsupportedOperationException( + "Unsupported similarity function: " + similarityFunction); + } + + return corrections; + } + + private record QuantResult(byte[] result, int quantizedSum) {} + + private static QuantResult quantize(float[] vector, float lower, float width) { + // FIXME: speed up with panama? and/or use existing scalar quantization utils in Lucene? + byte[] result = new byte[vector.length]; + float oneOverWidth = 1.0f / width; + int sumQ = 0; + for (int i = 0; i < vector.length; i++) { + byte res = (byte) ((vector[i] - lower) * oneOverWidth); + result[i] = res; + sumQ += res; + } + + return new QuantResult(result, sumQ); + } + + /** + * Quantizes the given vector to int4 precision and returns an array of corrective factors. + * + *

Corrective factors are used for scoring error correction. For the dot-product family of + * + *

    + *
  • The lower bound for the int4 quantized vector + *
  • The width for int4 quantized vector + *
  • The norm of the centroid centered vector + *
  • The dot-product of the vector with the centroid + *
  • The sum of the quantized dimensions + *
+ * + * For euclidean: + * + *
    + *
  • The euclidean distance to the centroid + *
  • The lower bound for the int4 quantized vector + *
  • The width for int4 quantized vector + *
  • The sum of the quantized dimensions + *
+ * + * @param vector the vector to quantize + * @param destination the destination byte array to store the quantized vector + * @param centroid the centroid to use for quantization + * @return the corrective factors used for scoring error correction. + */ + public float[] quantizeForQuery(float[] vector, byte[] destination, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert this.discretizedDimensions == BQSpaceUtils.discretize(vector.length, 64); + + if (this.discretizedDimensions != (destination.length * 8) / BQSpaceUtils.B_QUERY) { + throw new IllegalArgumentException( + "vector and quantized vector destination must be compatible dimensions: " + + vector.length + + " [ " + + this.discretizedDimensions + + " ]" + + "!= (" + + destination.length + + " * 8) / " + + BQSpaceUtils.B_QUERY); + } + + if (vector.length != centroid.length) { + throw new IllegalArgumentException( + "vector and centroid dimensions must be the same: " + + vector.length + + "!= " + + centroid.length); + } + + // FIXME: make a copy of vector so we don't overwrite it here? + // ... (could subtractInPlace but the passed vector is modified) <<--- + float[] vmC = ArrayUtil.copyArray(vector); + VectorUtil.subtract(vmC, centroid); + + // FIXME: should other similarity functions behave like MIP on query like COSINE + float normVmC = 0f; + if (similarityFunction != EUCLIDEAN) { + normVmC = VectorUtil.l2Norm(vmC); + VectorUtil.divide(vmC, normVmC); + } + float[] range = range(vmC); + float lower = range[0]; + float upper = range[1]; + // Δ := (𝑣𝑟 − 𝑣𝑙)/(2𝐵𝑞 − 1) + float width = (upper - lower) / ((1 << BQSpaceUtils.B_QUERY) - 1); + + QuantResult quantResult = quantize(vmC, lower, width); + byte[] byteQuery = quantResult.result(); + + // q¯ = Δ · q¯𝑢 + 𝑣𝑙 · 1𝐷 + // q¯ is an approximation of q′ (scalar quantized approximation) + // FIXME: vectors need to be padded but that's expensive; update transponseBin to deal + byteQuery = BQSpaceUtils.pad(byteQuery, discretizedDimensions); + BQSpaceUtils.transposeBin(byteQuery, discretizedDimensions, destination); + + final float[] corrections; + if (similarityFunction != EUCLIDEAN) { + float vDotC = VectorUtil.dotProduct(vector, centroid); + corrections = new float[] {lower, width, normVmC, vDotC, quantResult.quantizedSum}; + } else { + float distToCentroid = (float) Math.sqrt(VectorUtil.squareDistance(vector, centroid)); + corrections = new float[] {distToCentroid, lower, width, quantResult.quantizedSum}; + } + + return corrections; + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 9273f7c5a813..8f7fa6048b80 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -29,6 +29,7 @@ import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.LongVector; import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.Vector; import jdk.incubator.vector.VectorMask; @@ -764,6 +765,122 @@ private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int l return acc1.add(acc2).reduceLanes(ADD); } + @Override + public long ipByteBinByte(byte[] q, byte[] d) { + // 128 / 8 == 16 + if (d.length >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { + if (VECTOR_BITSIZE >= 256) { + return ipByteBin256(q, d); + } else if (VECTOR_BITSIZE == 128) { + return ipByteBin128(q, d); + } + } + return DefaultVectorUtilSupport.ipByteBinByteImpl(q, d); + } + + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; + private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; + + static long ipByteBin256(byte[] q, byte[] d) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + + if (d.length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(d.length); + var sum0 = LongVector.zero(LongVector.SPECIES_256); + var sum1 = LongVector.zero(LongVector.SPECIES_256); + var sum2 = LongVector.zero(LongVector.SPECIES_256); + var sum3 = LongVector.zero(LongVector.SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 3).reinterpretAsLongs(); + var vd = ByteVector.fromArray(BYTE_SPECIES_256, d, i).reinterpretAsLongs(); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (d.length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LongVector.SPECIES_128); + var sum1 = LongVector.zero(LongVector.SPECIES_128); + var sum2 = LongVector.zero(LongVector.SPECIES_128); + var sum3 = LongVector.zero(LongVector.SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(d.length); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsLongs(); + var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsLongs(); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // tail as bytes + for (; i < d.length; i++) { + subRet0 += Integer.bitCount((q[i] & d[i]) & 0xFF); + subRet1 += Integer.bitCount((q[i + d.length] & d[i]) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * d.length] & d[i]) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * d.length] & d[i]) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + public static long ipByteBin128(byte[] q, byte[] d) { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + + var sum0 = IntVector.zero(IntVector.SPECIES_128); + var sum1 = IntVector.zero(IntVector.SPECIES_128); + var sum2 = IntVector.zero(IntVector.SPECIES_128); + var sum3 = IntVector.zero(IntVector.SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(d.length); + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsInts(); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // tail as bytes + for (; i < d.length; i++) { + int dValue = d[i]; + subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); + subRet1 += Integer.bitCount((dValue & q[i + d.length]) & 0xFF); + subRet2 += Integer.bitCount((dValue & q[i + 2 * d.length]) & 0xFF); + subRet3 += Integer.bitCount((dValue & q[i + 3 * d.length]) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + // Experiments suggest that we need at least 8 lanes so that the overhead of going with the vector // approach and counting trues on vector masks pays off. private static final boolean ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO = INT_SPECIES.length() >= 8; diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index cb5fee62aeec..bd0c97688f54 100644 --- a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -16,3 +16,5 @@ org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat +org.apache.lucene.codecs.lucene101.Lucene101HnswBinaryQuantizedVectorsFormat +org.apache.lucene.codecs.lucene101.Lucene101BinaryQuantizedVectorsFormat diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryFlatVectorsScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryFlatVectorsScorer.java new file mode 100644 index 000000000000..3c0495102fd9 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryFlatVectorsScorer.java @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene101; + +import java.io.IOException; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.BQSpaceUtils; +import org.apache.lucene.util.quantization.BinaryQuantizer; + +public class TestLucene101BinaryFlatVectorsScorer extends LuceneTestCase { + + public void testScore() throws IOException { + int dimensions = random().nextInt(1, 4097); + int discretizedDimensions = BQSpaceUtils.discretize(dimensions, 64); + + int randIdx = random().nextInt(VectorSimilarityFunction.values().length); + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx]; + + float[] centroid = new float[dimensions]; + for (int j = 0; j < dimensions; j++) { + centroid[j] = random().nextFloat(-50f, 50f); + } + if (similarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(centroid); + } + + byte[] vector = new byte[discretizedDimensions / 8 * BQSpaceUtils.B_QUERY]; + random().nextBytes(vector); + float distanceToCentroid = random().nextFloat(0f, 10_000.0f); + float vl = random().nextFloat(-1000f, 1000f); + float width = random().nextFloat(0f, 1000f); + short quantizedSum = (short) random().nextInt(0, 4097); + float normVmC = random().nextFloat(-1000f, 1000f); + float vDotC = random().nextFloat(-1000f, 1000f); + final float[] queryCorrections = + similarityFunction != VectorSimilarityFunction.EUCLIDEAN + ? new float[] {vl, width, normVmC, vDotC, quantizedSum} + : new float[] {distanceToCentroid, vl, width, quantizedSum}; + Lucene101BinaryFlatVectorsScorer.BinaryQueryVector queryVector = + new Lucene101BinaryFlatVectorsScorer.BinaryQueryVector(vector, queryCorrections); + + BinarizedByteVectorValues targetVectors = + new BinarizedByteVectorValues() { + + @Override + public BinaryQuantizer getQuantizer() { + int dimensions = 128; + return new BinaryQuantizer(dimensions, dimensions, VectorSimilarityFunction.EUCLIDEAN); + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public BinarizedByteVectorValues copy() { + return null; + } + + @Override + public byte[] vectorValue(int targetOrd) { + byte[] vectorBytes = new byte[discretizedDimensions / 8]; + random().nextBytes(vectorBytes); + return vectorBytes; + } + + @Override + public int size() { + return 1; + } + + @Override + public int dimension() { + return dimensions; + } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } + + @Override + public float[] getCorrectiveTerms(int vectorOrd) { + if (similarityFunction == VectorSimilarityFunction.EUCLIDEAN) { + return new float[] {random().nextFloat(0f, 1000f), random().nextFloat(0f, 100f)}; + } + return new float[] { + random().nextFloat(-1000f, 1000f), + random().nextFloat(-1000f, 1000f), + random().nextFloat(-1000f, 1000f) + }; + } + }; + + Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = + new Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + queryVector, targetVectors, similarityFunction); + + float score = scorer.score(0); + + assertTrue(score >= 0f); + } + + public void testScoreEuclidean() throws IOException { + int dimensions = 128; + + byte[] vector = + new byte[] { + -8, 10, -27, 112, -83, 36, -36, -122, -114, 82, 55, 33, -33, 120, 55, -99, -93, -86, -55, + 21, -121, 30, 111, 30, 0, 82, 21, 38, -120, -127, 40, -32, 78, -37, 42, -43, 122, 115, 30, + 115, 123, 108, -13, -65, 123, 124, -33, -68, 49, 5, 20, 58, 0, 12, 30, 30, 4, 97, 10, 66, + 4, 35, 1, 67 + }; + float distanceToCentroid = 157799.12f; + float vl = -57.883f; + float width = 9.972266f; + short quantizedSum = 795; + float[] queryCorrections = new float[] {distanceToCentroid, vl, width, quantizedSum}; + Lucene101BinaryFlatVectorsScorer.BinaryQueryVector queryVector = + new Lucene101BinaryFlatVectorsScorer.BinaryQueryVector(vector, queryCorrections); + + BinarizedByteVectorValues targetVectors = + new BinarizedByteVectorValues() { + @Override + public BinaryQuantizer getQuantizer() { + int dimensions = 128; + return new BinaryQuantizer(dimensions, dimensions, VectorSimilarityFunction.EUCLIDEAN); + } + + @Override + public float[] getCentroid() { + return new float[] { + 26.7f, 16.2f, 10.913f, 10.314f, 12.12f, 14.045f, 15.887f, 16.864f, 32.232f, 31.567f, + 34.922f, 21.624f, 16.349f, 29.625f, 31.994f, 22.044f, 37.847f, 24.622f, 36.299f, + 27.966f, 14.368f, 19.248f, 30.778f, 35.927f, 27.019f, 16.381f, 17.325f, 16.517f, + 13.272f, 9.154f, 9.242f, 17.995f, 53.777f, 23.011f, 12.929f, 16.128f, 22.16f, 28.643f, + 25.861f, 27.197f, 59.883f, 40.878f, 34.153f, 22.795f, 24.402f, 37.427f, 34.19f, + 29.288f, 61.812f, 26.355f, 39.071f, 37.789f, 23.33f, 22.299f, 28.64f, 47.828f, + 52.457f, 21.442f, 24.039f, 29.781f, 27.707f, 19.484f, 14.642f, 28.757f, 54.567f, + 20.936f, 25.112f, 25.521f, 22.077f, 18.272f, 14.526f, 29.054f, 61.803f, 24.509f, + 37.517f, 35.906f, 24.106f, 22.64f, 32.1f, 48.788f, 60.102f, 39.625f, 34.766f, 22.497f, + 24.397f, 41.599f, 38.419f, 30.99f, 55.647f, 25.115f, 14.96f, 18.882f, 26.918f, + 32.442f, 26.231f, 27.107f, 26.828f, 15.968f, 18.668f, 14.071f, 10.906f, 8.989f, + 9.721f, 17.294f, 36.32f, 21.854f, 35.509f, 27.106f, 14.067f, 19.82f, 33.582f, 35.997f, + 33.528f, 30.369f, 36.955f, 21.23f, 15.2f, 30.252f, 34.56f, 22.295f, 29.413f, 16.576f, + 11.226f, 10.754f, 12.936f, 15.525f, 15.868f, 16.43f + }; + } + + @Override + public BinarizedByteVectorValues copy() { + return null; + } + + @Override + public byte[] vectorValue(int targetOrd) { + return new byte[] { + 44, 108, 120, -15, -61, -32, 124, 25, -63, -57, 6, 24, 1, -61, 1, 14 + }; + } + + @Override + public int size() { + return 1; + } + + @Override + public int dimension() { + return dimensions; + } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } + + @Override + public float[] getCorrectiveTerms(int vectorOrd) { + return new float[] {355.78073f, 0.7636705f}; + } + }; + + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + + Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = + new Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + queryVector, targetVectors, similarityFunction); + + assertEquals(1f / (1f + 245482.47f), scorer.score(0), 0.1f); + } + + public void testScoreMIP() throws IOException { + int dimensions = 768; + + byte[] vector = + new byte[] { + -76, 44, 81, 31, 30, -59, 56, -118, -36, 45, -11, 8, -61, 95, -100, 18, -91, -98, -46, 31, + -8, 82, -42, 121, 75, -61, 125, -21, -82, 16, 21, 40, -1, 12, -92, -22, -49, -92, -19, + -32, -56, -34, 60, -100, 69, 13, 60, -51, 90, 4, -77, 63, 124, 69, 88, 73, -72, 29, -96, + 44, 69, -123, -59, -94, 84, 80, -61, 27, -37, -92, -51, -86, 19, -55, -36, -2, 68, -37, + -128, 59, -47, 119, -53, 56, -12, 37, 27, 119, -37, 125, 78, 19, 15, -9, 94, 100, -72, 55, + 86, -48, 26, 10, -112, 28, -15, -64, -34, 55, -42, -31, -96, -18, 60, -44, 69, 106, -20, + 15, 47, 49, -122, -45, 119, 101, 22, 77, 108, -15, -71, -28, -43, -68, -127, -86, -118, + -51, 121, -65, -10, -49, 115, -6, -61, -98, 21, 41, 56, 29, -16, -82, 4, 72, -77, 23, 23, + -32, -98, 112, 27, -4, 91, -69, 102, -114, 16, -20, -76, -124, 43, 12, 3, -30, 42, -44, + -88, -72, -76, -94, -73, 46, -17, 4, -74, -44, 53, -11, -117, -105, -113, -37, -43, -128, + -70, 56, -68, -100, 56, -20, 77, 12, 17, -119, -17, 59, -10, -26, 29, 42, -59, -28, -28, + 60, -34, 60, -24, 80, -81, 24, 122, 127, 62, 124, -5, -11, 59, -52, 74, -29, -116, 3, -40, + -99, -24, 11, -10, 95, 21, -38, 59, -52, 29, 58, 112, 100, -106, -90, 71, 72, 57, 95, 98, + 96, -41, -16, 50, -18, 123, -36, 74, -101, 17, 50, 48, 96, 57, 7, 81, -16, -32, -102, -24, + -71, -10, 37, -22, 94, -36, -52, -71, -47, 47, -1, -31, -10, -126, -15, -123, -59, 71, + -49, 67, 99, -57, 21, -93, -13, -18, 54, -112, -60, 9, 25, -30, -47, 26, 27, 26, -63, 1, + -63, 18, -114, 80, 110, -123, 0, -63, -126, -128, 10, -60, 51, -71, 28, 114, -4, 53, 10, + 23, -96, 9, 32, -22, 5, -108, 33, 98, -59, -106, -126, 73, 72, -72, -73, -60, -96, -99, + 31, 40, 15, -19, 17, -128, 33, -75, 96, -18, -47, 75, 27, -60, -16, -82, 13, 21, 37, 23, + 70, 9, -39, 16, -127, 35, -78, 64, 99, -46, 1, 28, 65, 125, 14, 42, 26 + }; + float vl = -0.10079563f; + float width = 0.014609014f; + short quantizedSum = 5306; + float normVmC = 9.766797f; + float vDotC = 133.56123f; + float cDotC = 132.20227f; + float[] queryCorrections = new float[] {vl, width, normVmC, vDotC, quantizedSum}; + Lucene101BinaryFlatVectorsScorer.BinaryQueryVector queryVector = + new Lucene101BinaryFlatVectorsScorer.BinaryQueryVector(vector, queryCorrections); + + BinarizedByteVectorValues targetVectors = + new BinarizedByteVectorValues() { + + @Override + public float getCentroidDP() { + return cDotC; + } + + @Override + public BinaryQuantizer getQuantizer() { + int dimensions = 768; + return new BinaryQuantizer( + dimensions, dimensions, VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT); + } + + @Override + public float[] getCentroid() { + return new float[] { + 0.16672021f, 0.11700719f, 0.013227397f, 0.09305186f, -0.029422699f, 0.17622353f, + 0.4267106f, -0.297038f, 0.13915674f, 0.38441318f, -0.486725f, -0.15987667f, + -0.19712289f, 0.1349074f, -0.19016947f, -0.026179956f, 0.4129807f, 0.14325741f, + -0.09106042f, 0.06876218f, -0.19389102f, 0.4467732f, 0.03169017f, -0.066950575f, + -0.044301506f, -0.0059755715f, -0.33196586f, 0.18213534f, -0.25065416f, 0.30251458f, + 0.3448419f, -0.14900115f, -0.07782894f, 0.3568707f, -0.46595258f, 0.37295088f, + -0.088741764f, 0.17248306f, -0.0072736046f, 0.32928637f, 0.13216197f, 0.032092985f, + 0.21553043f, 0.016091486f, 0.31958902f, 0.0133126f, 0.1579258f, 0.018537233f, + 0.046248164f, -0.0048194043f, -0.2184672f, -0.26273906f, -0.110678785f, -0.04542999f, + -0.41625032f, 0.46025568f, -0.16116948f, 0.4091706f, 0.18427321f, 0.004736977f, + 0.16289745f, -0.05330932f, -0.2694863f, -0.14762327f, 0.17744702f, 0.2445075f, + 0.14377175f, 0.37390858f, 0.16165806f, 0.17177118f, 0.097307935f, 0.36326465f, + 0.23221572f, 0.15579978f, -0.065486655f, -0.29006517f, -0.009194494f, 0.009019374f, + 0.32154799f, -0.23186184f, 0.46485493f, -0.110756285f, -0.18604982f, 0.35027295f, + 0.19815539f, 0.47386464f, -0.031379268f, 0.124035835f, 0.11556784f, 0.4304302f, + -0.24455063f, 0.1816723f, 0.034300473f, -0.034347706f, 0.040140998f, 0.1389901f, + 0.22840638f, -0.19911191f, 0.07563166f, -0.2744902f, 0.13114859f, -0.23862572f, + -0.31404558f, 0.41355187f, 0.12970817f, -0.35403475f, -0.2714075f, 0.07231573f, + 0.043893218f, 0.30324167f, 0.38928393f, -0.1567055f, -0.0083288215f, 0.0487653f, + 0.12073729f, -0.01582117f, 0.13381198f, -0.084824145f, -0.15329859f, -1.120622f, + 0.3972598f, 0.36022213f, -0.29826534f, -0.09468781f, 0.03550699f, -0.21630692f, + 0.55655843f, -0.14842057f, 0.5924833f, 0.38791573f, 0.1502777f, 0.111737385f, + 0.1926823f, 0.66021144f, 0.25601995f, 0.28220543f, 0.10194068f, 0.013066262f, + -0.09348819f, -0.24085014f, -0.17843121f, -0.012598432f, 0.18757571f, 0.48543528f, + -0.059388146f, 0.1548026f, 0.041945867f, 0.3322589f, 0.012830887f, 0.16621992f, + 0.22606649f, 0.13959105f, -0.16688728f, 0.47194278f, -0.12767595f, 0.037815034f, + 0.441938f, 0.07875027f, 0.08625042f, 0.053454693f, 0.74093896f, 0.34662113f, + 0.009829135f, -0.033400282f, 0.030965377f, 0.17645596f, 0.083803624f, 0.32578796f, + 0.49538168f, -0.13212465f, -0.39596975f, 0.109529115f, 0.2815771f, -0.051440604f, + 0.21889819f, 0.25598505f, 0.012208843f, -0.012405662f, 0.3248759f, 0.00997502f, + 0.05999008f, 0.03562817f, 0.19007418f, 0.24805716f, 0.5926766f, 0.26937613f, + 0.25856f, -0.05798439f, -0.29168302f, 0.14050555f, 0.084851265f, -0.03763504f, + 0.8265359f, -0.23383066f, -0.042164285f, 0.19120507f, -0.12189065f, 0.3864055f, + -0.19823311f, 0.30280992f, 0.10814344f, -0.164514f, -0.22905481f, 0.13680641f, + 0.4513772f, -0.514546f, -0.061746247f, 0.11598224f, -0.23093395f, -0.09735358f, + 0.02767051f, 0.11594536f, 0.17106244f, 0.21301728f, -0.048222974f, 0.2212131f, + -0.018857865f, -0.09783516f, 0.42156664f, -0.14032331f, -0.103861615f, 0.4190284f, + 0.068923555f, -0.015083771f, 0.083590426f, -0.15759592f, -0.19096768f, -0.4275228f, + 0.12626286f, 0.12192557f, 0.4157616f, 0.048780657f, 0.008426048f, -0.0869124f, + 0.054927208f, 0.28417027f, 0.29765493f, 0.09203619f, -0.14446871f, -0.117514975f, + 0.30662632f, 0.24904715f, -0.19551662f, -0.0045785015f, 0.4217626f, -0.31457824f, + 0.23381722f, 0.089111514f, -0.27170828f, -0.06662652f, 0.10011391f, -0.090274535f, + 0.101849966f, 0.26554734f, -0.1722843f, 0.23296228f, 0.25112453f, -0.16790418f, + 0.010348314f, 0.05061285f, 0.38003662f, 0.0804625f, 0.3450673f, 0.364368f, + -0.2529952f, -0.034065288f, 0.22796603f, 0.5457553f, 0.11120353f, 0.24596325f, + 0.42822433f, -0.19215727f, -0.06974534f, 0.19388479f, -0.17598474f, -0.08769705f, + 0.12769659f, 0.1371616f, -0.4636819f, 0.16870509f, 0.14217548f, 0.04412187f, + -0.20930687f, 0.0075530168f, 0.10065227f, 0.45334083f, -0.1097471f, -0.11139921f, + -0.31835595f, -0.057386875f, 0.16285825f, 0.5088513f, -0.06318843f, -0.34759882f, + 0.21132466f, 0.33609292f, 0.04858872f, -0.058759f, 0.22845529f, -0.07641319f, + 0.5452827f, -0.5050389f, 0.1788054f, 0.37428045f, 0.066334985f, -0.28162515f, + -0.15629752f, 0.33783385f, -0.0832242f, 0.29144394f, 0.47892854f, -0.47006592f, + -0.07867588f, 0.3872869f, 0.28053126f, 0.52399015f, 0.21979983f, 0.076880336f, + 0.47866163f, 0.252952f, -0.1323851f, -0.22225754f, -0.38585815f, 0.12967427f, + 0.20340872f, -0.326928f, 0.09636557f, -0.35929212f, 0.5413311f, 0.019960884f, + 0.33512768f, 0.15133342f, -0.14124066f, -0.1868793f, -0.07862198f, 0.22739467f, + 0.19598985f, 0.34314656f, -0.05071516f, -0.21107961f, 0.19934991f, 0.04822684f, + 0.15060754f, 0.26586458f, -0.15528078f, 0.123646654f, 0.14450715f, -0.12574252f, + 0.30608323f, 0.018549249f, 0.36323825f, 0.06762097f, 0.08562406f, -0.07863075f, + 0.15975896f, 0.008347004f, 0.37931192f, 0.22957338f, 0.33606857f, -0.25204057f, + 0.18126069f, 0.41903302f, 0.20244692f, -0.053850617f, 0.23088565f, 0.16085246f, + 0.1077502f, -0.12445943f, 0.115779735f, 0.124704875f, 0.13076028f, -0.11628619f, + -0.12580182f, 0.065204754f, -0.26290357f, -0.23539798f, -0.1855292f, 0.39872098f, + 0.44495568f, 0.05491784f, 0.05135692f, 0.624011f, 0.22839564f, 0.0022447354f, + -0.27169296f, -0.1694988f, -0.19106841f, 0.0110123325f, 0.15464798f, -0.16269256f, + 0.04033836f, -0.11792753f, 0.17172396f, -0.08912173f, -0.30929542f, -0.03446989f, + -0.21738084f, 0.39657044f, 0.33550346f, -0.06839139f, 0.053675443f, 0.33783767f, + 0.22576828f, 0.38280004f, 4.1448855f, 0.14225426f, 0.24038498f, 0.072373435f, + -0.09465926f, -0.016144043f, 0.40864578f, -0.2583055f, 0.031816103f, 0.062555805f, + 0.06068663f, 0.25858644f, -0.10598804f, 0.18201788f, -0.00090025424f, 0.085680895f, + 0.4304161f, 0.028686283f, 0.027298616f, 0.27473378f, -0.3888415f, 0.44825438f, + 0.3600378f, 0.038944595f, 0.49292335f, 0.18556066f, 0.15779617f, 0.29989767f, + 0.39233804f, 0.39759228f, 0.3850708f, -0.0526475f, 0.18572918f, 0.09667526f, + -0.36111078f, 0.3439669f, 0.1724522f, 0.14074509f, 0.26097745f, 0.16626832f, + -0.3062964f, -0.054877423f, 0.21702516f, 0.4736452f, 0.2298038f, -0.2983771f, + 0.118479654f, 0.35940516f, 0.12212727f, 0.17234904f, 0.30632678f, 0.09207966f, + -0.14084268f, -0.19737118f, 0.12442629f, 0.52454203f, 0.1266684f, 0.3062802f, + 0.121598125f, -0.09156268f, 0.11491686f, -0.105715364f, 0.19831072f, 0.061421417f, + -0.41778997f, 0.14488487f, 0.023310646f, 0.27257463f, 0.16821945f, -0.16702746f, + 0.263203f, 0.33512688f, 0.35117313f, -0.31740817f, -0.14203706f, 0.061256267f, + -0.19764185f, 0.04822579f, -0.0016218472f, -0.025792575f, 0.4885193f, -0.16942391f, + -0.04156327f, 0.15908112f, -0.06998626f, 0.53907114f, 0.10317832f, -0.365468f, + 0.4729886f, 0.14291425f, 0.32812154f, -0.0273262f, 0.31760117f, 0.16925456f, + 0.21820979f, 0.085142255f, 0.16118735f, -3.7089362f, 0.251577f, 0.18394576f, + 0.027926167f, 0.15720351f, 0.13084261f, 0.16240814f, 0.23045056f, -0.3966458f, + 0.22822891f, -0.061541352f, 0.028320132f, -0.14736478f, 0.184569f, 0.084853746f, + 0.15172474f, 0.08277542f, 0.27751622f, 0.23450488f, -0.15349835f, 0.29665688f, + 0.32045734f, 0.20012043f, -0.2749372f, 0.011832386f, 0.05976605f, 0.018300122f, + -0.07855043f, -0.075900674f, 0.0384252f, -0.15101928f, 0.10922137f, 0.47396383f, + -0.1771141f, 0.2203417f, 0.33174303f, 0.36640546f, 0.10906258f, 0.13765177f, + 0.2488032f, -0.061588854f, 0.20347528f, 0.2574979f, 0.22369152f, 0.18777567f, + -0.0772263f, -0.1353299f, 0.087077625f, -0.05409276f, 0.027534787f, 0.08053508f, + 0.3403908f, -0.15362988f, 0.07499862f, 0.54367846f, -0.045938436f, 0.12206868f, + 0.031069376f, 0.2972343f, 0.3235321f, -0.053970363f, -0.0042564687f, 0.21447177f, + 0.023565233f, -0.1286087f, -0.047359955f, 0.23021339f, 0.059837278f, 0.19709614f, + -0.17340347f, 0.11572943f, 0.21720429f, 0.29375625f, -0.045433592f, 0.033339307f, + 0.24594454f, -0.021661613f, -0.12823369f, 0.41809165f, 0.093840264f, -0.007481906f, + 0.22441079f, -0.45719734f, 0.2292629f, 2.675806f, 0.3690025f, 2.1311781f, + 0.07818368f, -0.17055893f, 0.3162922f, -0.2983149f, 0.21211359f, 0.037087034f, + 0.021580033f, 0.086415835f, 0.13541797f, -0.12453424f, 0.04563163f, -0.082379065f, + -0.15938349f, 0.38595748f, -0.8796574f, -0.080991246f, 0.078572094f, 0.20274459f, + 0.009252143f, -0.12719384f, 0.105845824f, 0.1592398f, -0.08656061f, -0.053054806f, + 0.090986334f, -0.02223379f, -0.18215932f, -0.018316114f, 0.1806707f, 0.24788831f, + -0.041049056f, 0.01839475f, 0.19160001f, -0.04827654f, 4.4070687f, 0.12640671f, + -0.11171499f, -0.015480781f, 0.14313947f, 0.10024215f, 0.4129662f, 0.038836367f, + -0.030228542f, 0.2948598f, 0.32946473f, 0.2237934f, 0.14260699f, -0.044821896f, + 0.23791742f, 0.079720296f, 0.27059034f, 0.32129505f, 0.2725177f, 0.06883333f, + 0.1478041f, 0.07598411f, 0.27230525f, -0.04704308f, 0.045167264f, 0.215413f, + 0.20359069f, -0.092178136f, -0.09523752f, 0.21427691f, 0.10512272f, 5.1295033f, + 0.040909242f, 0.007160441f, -0.192866f, -0.102640584f, 0.21103396f, -0.006780398f, + -0.049653083f, -0.29426834f, -0.0038102255f, -0.13842082f, 0.06620181f, -0.3196518f, + 0.33279592f, 0.13845938f, 0.16162738f, -0.24798508f, -0.06672485f, 0.195944f, + -0.11957207f, 0.44237947f, -0.07617347f, 0.13575341f, -0.35074243f, -0.093798876f, + 0.072853446f, -0.20490398f, 0.26504788f, -0.046076056f, 0.16488416f, 0.36007464f, + 0.20955376f, -0.3082038f, 0.46533757f, -0.27326992f, -0.14167665f, 0.25017953f, + 0.062622115f, 0.14057694f, -0.102370486f, 0.33898357f, 0.36456722f, -0.10120469f, + -0.27838466f, -0.11779602f, 0.18517569f, -0.05942488f, 0.076405466f, 0.007960496f, + 0.0443746f, 0.098998964f, -0.01897129f, 0.8059487f, 0.06991939f, 0.26562217f, + 0.26942885f, 0.11432197f, -0.0055776504f, 0.054493718f, -0.13086213f, 0.6841702f, + 0.121975765f, 0.02787146f, 0.29039973f, 0.30943078f, 0.21762547f, 0.28751117f, + 0.027524523f, 0.5315654f, -0.22451901f, -0.13782433f, 0.08228316f, 0.07808882f, + 0.17445615f, -0.042489477f, 0.13232234f, 0.2756272f, -0.18824948f, 0.14326479f, + -0.119312495f, 0.011788091f, -0.22103515f, -0.2477118f, -0.10513839f, 0.034028634f, + 0.10693818f, 0.03057979f, 0.04634646f, 0.2289361f, 0.09981585f, 0.26901972f, + 0.1561221f, -0.10639886f, 0.36466748f, 0.06350991f, 0.027927283f, 0.11919768f, + 0.23290513f, -0.03417105f, 0.16698854f, -0.19243467f, 0.28430334f, 0.03754995f, + -0.08697018f, 0.20413163f, -0.27218238f, 0.13707504f, -0.082289375f, 0.03479585f, + 0.2298305f, 0.4983682f, 0.34522808f, -0.05711886f, -0.10568684f, -0.07771385f + }; + } + + @Override + public BinarizedByteVectorValues copy() { + return null; + } + + @Override + public byte[] vectorValue(int targetOrd) { + return new byte[] { + -88, -3, 60, -75, -38, 79, 84, -53, -116, -126, 19, -19, -21, -80, 69, 101, -71, 53, + 101, -124, -24, -76, 92, -45, 108, -107, -18, 102, 23, -80, -47, 116, 87, -50, 27, + -31, -10, -13, 117, -88, -27, -93, -98, -39, 30, -109, -114, 5, -15, 98, -82, 81, 83, + 118, 30, -118, -12, -95, 121, 125, -13, -88, 75, -85, -56, -126, 82, -59, 48, -81, 67, + -63, 81, 24, -83, 95, -44, 103, 3, -40, -13, -41, -29, -60, 1, 65, -4, -110, -40, 34, + 118, 51, -76, 75, 70, -51 + }; + } + + @Override + public int size() { + return 1; + } + + @Override + public int dimension() { + return dimensions; + } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } + + @Override + public float[] getCorrectiveTerms(int vectorOrd) { + return new float[] {0.7882396f, 5.0889387f, 131.485660f}; + } + }; + + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + + Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = + new Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + queryVector, targetVectors, similarityFunction); + + assertEquals(129.64046f, scorer.score(0), 0.0001f); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryQuantizedVectorsFormat.java new file mode 100644 index 000000000000..011b42d9c664 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryQuantizedVectorsFormat.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene101; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.util.quantization.BQSpaceUtils; +import org.apache.lucene.util.quantization.BinaryQuantizer; + +public class TestLucene101BinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + + @Override + protected Codec getCodec() { + return new Lucene101Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new Lucene101BinaryQuantizedVectorsFormat(); + } + }; + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene101BinaryQuantizedVectorsFormat(); + } + }; + String expectedPattern = + "Lucene101BinaryQuantizedVectorsFormat(name=Lucene101BinaryQuantizedVectorsFormat, flatVectorScorer=Lucene101BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = + format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + public void testQuantizedVectorsWriteAndRead() throws IOException { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + if (i % 101 == 0) { + w.commit(); + } + } + w.commit(); + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); + assertEquals(vectorValues.size(), numVectors); + BinarizedByteVectorValues qvectorValues = + ((Lucene101BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues) + .getQuantizedVectorValues(); + float[] centroid = qvectorValues.getCentroid(); + assertEquals(centroid.length, dims); + + int descritizedDimension = BQSpaceUtils.discretize(dims, 64); + BinaryQuantizer quantizer = + new BinaryQuantizer(dims, descritizedDimension, similarityFunction); + byte[] expectedVector = new byte[BQSpaceUtils.discretize(dims, 64) / 8]; + if (similarityFunction == VectorSimilarityFunction.COSINE) { + vectorValues = + new Lucene101BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); + } + + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + float[] corrections = + quantizer.quantizeForIndex( + vectorValues.vectorValue(docIndexIterator.index()), expectedVector, centroid); + assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); + assertEquals( + corrections.length, + qvectorValues.getCorrectiveTerms(docIndexIterator.index()).length); + for (int i = 0; i < corrections.length; i++) { + assertEquals( + corrections[i], + qvectorValues.getCorrectiveTerms(docIndexIterator.index())[i], + 0.00001f); + } + } + } + } + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101HnswBinaryQuantizedVectorsFormat.java new file mode 100644 index 000000000000..a3f7f1827520 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101HnswBinaryQuantizedVectorsFormat.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene101; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +import java.util.Arrays; +import java.util.Locale; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.util.SameThreadExecutorService; + +public class TestLucene101HnswBinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + + @Override + protected Codec getCodec() { + return new Lucene101Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new Lucene101HnswBinaryQuantizedVectorsFormat(); + } + }; + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new Lucene101HnswBinaryQuantizedVectorsFormat(10, 20, 1, null); + } + }; + String expectedPattern = + "Lucene101HnswBinaryQuantizedVectorsFormat(name=Lucene101HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=Lucene101BinaryQuantizedVectorsFormat(name=Lucene101BinaryQuantizedVectorsFormat," + + " flatVectorScorer=Lucene101BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))"; + + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = + format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assert (vectorValues.size() == 1); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); + } + TopDocs td = + r.searchNearestVectors("f", randomVector(vector.length), 1, null, Integer.MAX_VALUE); + assertEquals(1, td.totalHits.value()); + assertTrue(td.scoreDocs[0].score >= 0); + } + } + } + } + + public void testLimits() { + expectThrows( + IllegalArgumentException.class, + () -> new Lucene101HnswBinaryQuantizedVectorsFormat(-1, 20)); + expectThrows( + IllegalArgumentException.class, () -> new Lucene101HnswBinaryQuantizedVectorsFormat(0, 20)); + expectThrows( + IllegalArgumentException.class, () -> new Lucene101HnswBinaryQuantizedVectorsFormat(20, 0)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene101HnswBinaryQuantizedVectorsFormat(20, -1)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene101HnswBinaryQuantizedVectorsFormat(512 + 1, 20)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene101HnswBinaryQuantizedVectorsFormat(20, 3201)); + expectThrows( + IllegalArgumentException.class, + () -> + new Lucene101HnswBinaryQuantizedVectorsFormat( + 20, 100, 1, new SameThreadExecutorService())); + } + + // Ensures that all expected vector similarity functions are translatable in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java index 34a0e5230022..ea6c183c7e7a 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java @@ -21,8 +21,8 @@ public abstract class BaseVectorizationTestCase extends LuceneTestCase { - protected static final VectorizationProvider LUCENE_PROVIDER = new DefaultVectorizationProvider(); - protected static final VectorizationProvider PANAMA_PROVIDER = VectorizationProvider.lookup(true); + protected static final VectorizationProvider LUCENE_PROVIDER = defaultProvider(); + protected static final VectorizationProvider PANAMA_PROVIDER = maybePanamaProvider(); @BeforeClass public static void beforeClass() throws Exception { @@ -30,4 +30,12 @@ public static void beforeClass() throws Exception { "Test only works when JDK's vector incubator module is enabled.", PANAMA_PROVIDER.getClass() != LUCENE_PROVIDER.getClass()); } + + public static VectorizationProvider defaultProvider() { + return new DefaultVectorizationProvider(); + } + + public static VectorizationProvider maybePanamaProvider() { + return VectorizationProvider.lookup(true); + } } diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java index 7064955cb5f3..3355962d7f1a 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java @@ -16,10 +16,13 @@ */ package org.apache.lucene.internal.vectorization; +import static org.apache.lucene.util.VectorUtil.B_QUERY; + import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import java.util.Arrays; import java.util.function.ToDoubleFunction; import java.util.function.ToIntFunction; +import java.util.function.ToLongFunction; import java.util.stream.IntStream; public class TestVectorUtilSupport extends BaseVectorizationTestCase { @@ -142,6 +145,27 @@ static byte[] pack(byte[] unpacked) { return packed; } + public void testIpByteBin() { + var d = new byte[size]; + var q = new byte[size * B_QUERY]; + random().nextBytes(d); + random().nextBytes(q); + assertLongReturningProviders(p -> p.ipByteBinByte(q, d)); + } + + public void testIpByteBinBoundaries() { + var d = new byte[size]; + var q = new byte[size * B_QUERY]; + + Arrays.fill(d, Byte.MAX_VALUE); + Arrays.fill(q, Byte.MAX_VALUE); + assertLongReturningProviders(p -> p.ipByteBinByte(q, d)); + + Arrays.fill(d, Byte.MIN_VALUE); + Arrays.fill(q, Byte.MIN_VALUE); + assertLongReturningProviders(p -> p.ipByteBinByte(q, d)); + } + private void assertFloatReturningProviders(ToDoubleFunction func) { assertEquals( func.applyAsDouble(LUCENE_PROVIDER.getVectorUtilSupport()), @@ -154,4 +178,10 @@ private void assertIntReturningProviders(ToIntFunction func) func.applyAsInt(LUCENE_PROVIDER.getVectorUtilSupport()), func.applyAsInt(PANAMA_PROVIDER.getVectorUtilSupport())); } + + private void assertLongReturningProviders(ToLongFunction func) { + assertEquals( + func.applyAsLong(LUCENE_PROVIDER.getVectorUtilSupport()), + func.applyAsLong(PANAMA_PROVIDER.getVectorUtilSupport())); + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java index 6e449a550028..7b56ef5ebedf 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java @@ -16,8 +16,14 @@ */ package org.apache.lucene.util; +import static com.carrotsearch.randomizedtesting.generators.RandomNumbers.randomIntBetween; +import static org.apache.lucene.util.VectorUtil.B_QUERY; + +import java.util.Arrays; import java.util.Random; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.vectorization.BaseVectorizationTestCase; +import org.apache.lucene.internal.vectorization.VectorizationProvider; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; @@ -129,6 +135,45 @@ public void testExtremeNumerics() { } } + public void testPopCount() { + assertEquals(0, VectorUtil.popCount(new byte[] {})); + assertEquals(1, VectorUtil.popCount(new byte[] {1})); + assertEquals(2, VectorUtil.popCount(new byte[] {2, 1})); + assertEquals(2, VectorUtil.popCount(new byte[] {8, 0, 1})); + assertEquals(4, VectorUtil.popCount(new byte[] {7, 1})); + + int iterations = atLeast(50); + for (int i = 0; i < iterations; i++) { + int size = random().nextInt(5000); + var a = new byte[size]; + random().nextBytes(a); + assertEquals(popcount(a, 0, a, size), VectorUtil.popCount(a)); + } + } + + public void testNorm() { + assertEquals(3.0f, VectorUtil.l2Norm(new float[] {3}), DELTA); + assertEquals(5.0f, VectorUtil.l2Norm(new float[] {5}), DELTA); + assertEquals(4.0f, VectorUtil.l2Norm(new float[] {2, 2, 2, 2}), DELTA); + assertEquals(9.0f, VectorUtil.l2Norm(new float[] {3, 3, 3, 3, 3, 3, 3, 3, 3}), DELTA); + } + + public void testSubtract() { + var a = new float[] {3}; + VectorUtil.subtract(a, new float[] {2}); + assertArrayEquals(new float[] {1}, a, (float) DELTA); + a = new float[] {3, 3, 3}; + VectorUtil.subtract(a, new float[] {1, 2, 3}); + assertArrayEquals(new float[] {2, 1, 0}, a, (float) DELTA); + } + + public void testL2Norm() { + assertEquals(3.0f, VectorUtil.l2Norm(new float[] {3}), DELTA); + assertEquals(5.0f, VectorUtil.l2Norm(new float[] {5}), DELTA); + assertEquals(4.0f, VectorUtil.l2Norm(new float[] {2, 2, 2, 2}), DELTA); + assertEquals(9.0f, VectorUtil.l2Norm(new float[] {3, 3, 3, 3, 3, 3, 3, 3, 3}), DELTA); + } + private static float l2(float[] v) { float l2 = 0; for (float x : v) { @@ -354,6 +399,129 @@ private static int xorBitCount(byte[] a, byte[] b) { return res; } + public void testIpByteBinInvariants() { + int iterations = atLeast(10); + for (int i = 0; i < iterations; i++) { + int size = randomIntBetween(random(), 1, 10); + var d = new byte[size]; + var q = new byte[size * B_QUERY - 1]; + expectThrows(IllegalArgumentException.class, () -> VectorUtil.ipByteBinByte(q, d)); + } + } + + static final VectorizationProvider defaultedProvider = + BaseVectorizationTestCase.defaultProvider(); + static final VectorizationProvider defOrPanamaProvider = + BaseVectorizationTestCase.maybePanamaProvider(); + + public void testBasicIpByteBin() { + testBasicIpByteBinImpl(VectorUtil::ipByteBinByte); + testBasicIpByteBinImpl(defaultedProvider.getVectorUtilSupport()::ipByteBinByte); + testBasicIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte); + } + + interface IpByteBin { + long apply(byte[] q, byte[] d); + } + + void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) { + assertEquals(15L, ipByteBinFunc.apply(new byte[] {1, 1, 1, 1}, new byte[] {1})); + assertEquals(30L, ipByteBinFunc.apply(new byte[] {1, 2, 1, 2, 1, 2, 1, 2}, new byte[] {1, 2})); + + var d = new byte[] {1, 2, 3}; + var q = new byte[] {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}; + assert scalarIpByteBin(q, d) == 60L; // 4 + 8 + 16 + 32 + assertEquals(60L, ipByteBinFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4}; + q = new byte[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}; + assert scalarIpByteBin(q, d) == 75L; // 5 + 10 + 20 + 40 + assertEquals(75L, ipByteBinFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4, 5}; + q = new byte[] {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}; + assert scalarIpByteBin(q, d) == 105L; // 7 + 14 + 28 + 56 + assertEquals(105L, ipByteBinFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4, 5, 6}; + q = new byte[] {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}; + assert scalarIpByteBin(q, d) == 135L; // 9 + 18 + 36 + 72 + assertEquals(135L, ipByteBinFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4, 5, 6, 7}; + q = + new byte[] { + 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7 + }; + assert scalarIpByteBin(q, d) == 180L; // 12 + 24 + 48 + 96 + assertEquals(180L, ipByteBinFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8}; + q = + new byte[] { + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, + 7, 8 + }; + assert scalarIpByteBin(q, d) == 195L; // 13 + 26 + 52 + 104 + assertEquals(195L, ipByteBinFunc.apply(q, d)); + + d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9}; + q = + new byte[] { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, + 4, 5, 6, 7, 8, 9 + }; + assert scalarIpByteBin(q, d) == 225L; // 15 + 30 + 60 + 120 + assertEquals(225L, ipByteBinFunc.apply(q, d)); + } + + public void testIpByteBin() { + testIpByteBinImpl(VectorUtil::ipByteBinByte); + testIpByteBinImpl(defaultedProvider.getVectorUtilSupport()::ipByteBinByte); + testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte); + } + + void testIpByteBinImpl(IpByteBin ipByteBinFunc) { + int iterations = atLeast(50); + for (int i = 0; i < iterations; i++) { + int size = random().nextInt(5000); + var d = new byte[size]; + var q = new byte[size * B_QUERY]; + random().nextBytes(d); + random().nextBytes(q); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + + Arrays.fill(d, Byte.MAX_VALUE); + Arrays.fill(q, Byte.MAX_VALUE); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + + Arrays.fill(d, Byte.MIN_VALUE); + Arrays.fill(q, Byte.MIN_VALUE); + assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d)); + } + } + + static int scalarIpByteBin(byte[] q, byte[] d) { + int res = 0; + for (int i = 0; i < B_QUERY; i++) { + res += (popcount(q, i * d.length, d, d.length) << i); + } + return res; + } + + public static int popcount(byte[] a, int aOffset, byte[] b, int length) { + int res = 0; + for (int j = 0; j < length; j++) { + int value = (a[aOffset + j] & b[j]) & 0xFF; + for (int k = 0; k < Byte.SIZE; k++) { + if ((value & (1 << k)) != 0) { + ++res; + } + } + } + return res; + } + public void testFindNextGEQ() { int padding = TestUtil.nextInt(random(), 0, 5); int[] values = new int[128 + padding]; diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestBQSpaceUtils.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestBQSpaceUtils.java new file mode 100644 index 000000000000..8dda4146a9a8 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestBQSpaceUtils.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.quantization; + +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestBQSpaceUtils extends LuceneTestCase { + + private static float DELTA = Float.MIN_VALUE; + + public void testPadFloat() { + assertArrayEquals( + new float[] {1, 2, 3, 4}, BQSpaceUtils.pad(new float[] {1, 2, 3, 4}, 4), DELTA); + assertArrayEquals( + new float[] {1, 2, 3, 4}, BQSpaceUtils.pad(new float[] {1, 2, 3, 4}, 3), DELTA); + assertArrayEquals( + new float[] {1, 2, 3, 4, 0}, BQSpaceUtils.pad(new float[] {1, 2, 3, 4}, 5), DELTA); + } + + public void testPadByte() { + assertArrayEquals(new byte[] {1, 2, 3, 4}, BQSpaceUtils.pad(new byte[] {1, 2, 3, 4}, 4)); + assertArrayEquals(new byte[] {1, 2, 3, 4}, BQSpaceUtils.pad(new byte[] {1, 2, 3, 4}, 3)); + assertArrayEquals(new byte[] {1, 2, 3, 4, 0}, BQSpaceUtils.pad(new byte[] {1, 2, 3, 4}, 5)); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestBinaryQuantization.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestBinaryQuantization.java new file mode 100644 index 000000000000..5a16e60b5f4f --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestBinaryQuantization.java @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.quantization; + +import java.util.Random; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.VectorUtil; + +public class TestBinaryQuantization extends LuceneTestCase { + + public void testQuantizeForIndex() { + int dimensions = random().nextInt(1, 4097); + int discretizedDimensions = BQSpaceUtils.discretize(dimensions, 64); + + int randIdx = random().nextInt(VectorSimilarityFunction.values().length); + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx]; + + BinaryQuantizer quantizer = new BinaryQuantizer(discretizedDimensions, similarityFunction); + + float[] centroid = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + centroid[i] = random().nextFloat(-50f, 50f); + } + + float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = random().nextFloat(-50f, 50f); + } + if (similarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + VectorUtil.l2normalize(centroid); + } + + byte[] destination = new byte[discretizedDimensions / 8]; + float[] corrections = quantizer.quantizeForIndex(vector, destination, centroid); + + for (float correction : corrections) { + assertFalse(Float.isNaN(correction)); + } + + if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) { + assertEquals(3, corrections.length); + assertTrue(corrections[0] >= 0); + assertTrue(corrections[1] > 0); + } else { + assertEquals(2, corrections.length); + assertTrue(corrections[0] > 0); + assertTrue(corrections[1] > 0); + } + } + + public void testQuantizeForQuery() { + int dimensions = random().nextInt(1, 4097); + int discretizedDimensions = BQSpaceUtils.discretize(dimensions, 64); + + int randIdx = random().nextInt(VectorSimilarityFunction.values().length); + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx]; + + BinaryQuantizer quantizer = new BinaryQuantizer(discretizedDimensions, similarityFunction); + + float[] centroid = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + centroid[i] = random().nextFloat(-50f, 50f); + } + + float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = random().nextFloat(-50f, 50f); + } + if (similarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + VectorUtil.l2normalize(centroid); + } + float cDotC = VectorUtil.dotProduct(centroid, centroid); + byte[] destination = new byte[discretizedDimensions / 8 * BQSpaceUtils.B_QUERY]; + float[] corrections = quantizer.quantizeForQuery(vector, destination, centroid); + + if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) { + float lower = corrections[0]; + float width = corrections[1]; + float normVmC = corrections[2]; + float vDotC = corrections[3]; + float sumQ = corrections[4]; + assertTrue(sumQ >= 0); + assertFalse(Float.isNaN(lower)); + assertTrue(width >= 0); + assertTrue(normVmC >= 0); + assertFalse(Float.isNaN(vDotC)); + assertTrue(cDotC >= 0); + } else { + float distToC = corrections[0]; + float lower = corrections[1]; + float width = corrections[2]; + float sumQ = corrections[3]; + assertTrue(sumQ >= 0); + assertTrue(distToC >= 0); + assertFalse(Float.isNaN(lower)); + assertTrue(width >= 0); + } + } + + public void testQuantizeForIndexEuclidean() { + int dimensions = 128; + + BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, VectorSimilarityFunction.EUCLIDEAN); + float[] vector = + new float[] { + 0f, 0.0f, 16.0f, 35.0f, 5.0f, 32.0f, 31.0f, 14.0f, 10.0f, 11.0f, 78.0f, 55.0f, 10.0f, + 45.0f, 83.0f, 11.0f, 6.0f, 14.0f, 57.0f, 102.0f, 75.0f, 20.0f, 8.0f, 3.0f, 5.0f, 67.0f, + 17.0f, 19.0f, 26.0f, 5.0f, 0.0f, 1.0f, 22.0f, 60.0f, 26.0f, 7.0f, 1.0f, 18.0f, 22.0f, + 84.0f, 53.0f, 85.0f, 119.0f, 119.0f, 4.0f, 24.0f, 18.0f, 7.0f, 7.0f, 1.0f, 81.0f, 106.0f, + 102.0f, 72.0f, 30.0f, 6.0f, 0.0f, 9.0f, 1.0f, 9.0f, 119.0f, 72.0f, 1.0f, 4.0f, 33.0f, + 119.0f, 29.0f, 6.0f, 1.0f, 0.0f, 1.0f, 14.0f, 52.0f, 119.0f, 30.0f, 3.0f, 0.0f, 0.0f, + 55.0f, 92.0f, 111.0f, 2.0f, 5.0f, 4.0f, 9.0f, 22.0f, 89.0f, 96.0f, 14.0f, 1.0f, 0.0f, + 1.0f, 82.0f, 59.0f, 16.0f, 20.0f, 5.0f, 25.0f, 14.0f, 11.0f, 4.0f, 0.0f, 0.0f, 1.0f, + 26.0f, 47.0f, 23.0f, 4.0f, 0.0f, 0.0f, 4.0f, 38.0f, 83.0f, 30.0f, 14.0f, 9.0f, 4.0f, 9.0f, + 17.0f, 23.0f, 41.0f, 0.0f, 0.0f, 2.0f, 8.0f, 19.0f, 25.0f, 23.0f + }; + byte[] destination = new byte[dimensions / 8]; + float[] centroid = + new float[] { + 27.054054f, 22.252253f, 25.027027f, 23.55856f, 31.099098f, 28.765766f, 31.64865f, + 30.981981f, 24.675676f, 21.81982f, 26.72973f, 25.486486f, 30.504505f, 35.216217f, + 28.306307f, 24.486486f, 29.675676f, 26.153152f, 31.315315f, 25.225225f, 29.234234f, + 30.855856f, 24.495495f, 29.828829f, 31.54955f, 24.36937f, 25.108109f, 24.873875f, + 22.918919f, 24.918919f, 29.027027f, 25.513514f, 27.64865f, 28.405405f, 23.603603f, + 17.900902f, 22.522522f, 24.855856f, 31.396397f, 32.585587f, 26.297297f, 27.468468f, + 19.675676f, 19.018019f, 24.801802f, 30.27928f, 27.945946f, 25.324324f, 29.918919f, + 27.864864f, 28.081081f, 23.45946f, 28.828829f, 28.387388f, 25.387388f, 27.90991f, + 25.621622f, 21.585585f, 26.378378f, 24.144144f, 21.666666f, 22.72973f, 26.837837f, + 22.747747f, 29.0f, 28.414415f, 24.612612f, 21.594595f, 19.117117f, 24.045046f, + 30.612612f, 27.55856f, 25.117117f, 27.783783f, 21.639639f, 19.36937f, 21.252253f, + 29.153152f, 29.216217f, 24.747747f, 28.252253f, 25.288288f, 25.738739f, 23.44144f, + 24.423424f, 23.693693f, 26.306307f, 29.162163f, 28.684685f, 34.648647f, 25.576576f, + 25.288288f, 29.63063f, 20.225225f, 25.72973f, 29.009008f, 28.666666f, 29.243244f, + 26.36937f, 25.864864f, 21.522522f, 21.414415f, 25.963964f, 26.054054f, 25.099098f, + 30.477478f, 29.55856f, 24.837837f, 24.801802f, 21.18018f, 24.027027f, 26.360361f, + 33.153152f, 29.135136f, 30.486486f, 28.639639f, 27.576576f, 24.486486f, 26.297297f, + 21.774775f, 25.936937f, 35.36937f, 25.171171f, 30.405405f, 31.522522f, 29.765766f, + 22.324324f, 26.09009f + }; + float[] corrections = quantizer.quantizeForIndex(vector, destination, centroid); + + assertEquals(2, corrections.length); + float distToCentroid = corrections[0]; + float magnitude = corrections[1]; + + assertEquals(387.90204f, distToCentroid, 0.0003f); + assertEquals(0.75916624f, magnitude, 0.0000001f); + assertArrayEquals( + new byte[] {20, 54, 56, 72, 97, -16, 62, 12, -32, -29, -125, 12, 0, -63, -63, -126}, + destination); + } + + public void testQuantizeForQueryEuclidean() { + int dimensions = 128; + + BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, VectorSimilarityFunction.EUCLIDEAN); + float[] vector = + new float[] { + 0.0f, 8.0f, 69.0f, 45.0f, 2.0f, 0f, 16.0f, 52.0f, 32.0f, 13.0f, 2.0f, 6.0f, 34.0f, 49.0f, + 45.0f, 83.0f, 6.0f, 2.0f, 26.0f, 57.0f, 14.0f, 46.0f, 19.0f, 9.0f, 4.0f, 13.0f, 53.0f, + 104.0f, 33.0f, 11.0f, 25.0f, 19.0f, 30.0f, 10.0f, 7.0f, 2.0f, 8.0f, 7.0f, 25.0f, 1.0f, + 2.0f, 25.0f, 24.0f, 28.0f, 61.0f, 83.0f, 41.0f, 9.0f, 14.0f, 3.0f, 7.0f, 114.0f, 114.0f, + 114.0f, 114.0f, 5.0f, 5.0f, 1.0f, 5.0f, 114.0f, 73.0f, 75.0f, 106.0f, 3.0f, 5.0f, 6.0f, + 6.0f, 8.0f, 15.0f, 45.0f, 2.0f, 15.0f, 7.0f, 114.0f, 103.0f, 6.0f, 5.0f, 4.0f, 9.0f, + 67.0f, 47.0f, 22.0f, 32.0f, 27.0f, 41.0f, 10.0f, 114.0f, 36.0f, 43.0f, 42.0f, 23.0f, 9.0f, + 7.0f, 30.0f, 114.0f, 19.0f, 7.0f, 5.0f, 6.0f, 6.0f, 21.0f, 48.0f, 2.0f, 1.0f, 0.0f, 8.0f, + 114.0f, 13.0f, 0.0f, 1.0f, 53.0f, 83.0f, 14.0f, 8.0f, 16.0f, 12.0f, 16.0f, 20.0f, 27.0f, + 87.0f, 45.0f, 50.0f, 15.0f, 5.0f, 5.0f, 6.0f, 32.0f, 49.0f + }; + byte[] destination = new byte[dimensions / 8 * BQSpaceUtils.B_QUERY]; + float[] centroid = + new float[] { + 26.7f, 16.2f, 10.913f, 10.314f, 12.12f, 14.045f, 15.887f, 16.864f, 32.232f, 31.567f, + 34.922f, 21.624f, 16.349f, 29.625f, 31.994f, 22.044f, 37.847f, 24.622f, 36.299f, 27.966f, + 14.368f, 19.248f, 30.778f, 35.927f, 27.019f, 16.381f, 17.325f, 16.517f, 13.272f, 9.154f, + 9.242f, 17.995f, 53.777f, 23.011f, 12.929f, 16.128f, 22.16f, 28.643f, 25.861f, 27.197f, + 59.883f, 40.878f, 34.153f, 22.795f, 24.402f, 37.427f, 34.19f, 29.288f, 61.812f, 26.355f, + 39.071f, 37.789f, 23.33f, 22.299f, 28.64f, 47.828f, 52.457f, 21.442f, 24.039f, 29.781f, + 27.707f, 19.484f, 14.642f, 28.757f, 54.567f, 20.936f, 25.112f, 25.521f, 22.077f, 18.272f, + 14.526f, 29.054f, 61.803f, 24.509f, 37.517f, 35.906f, 24.106f, 22.64f, 32.1f, 48.788f, + 60.102f, 39.625f, 34.766f, 22.497f, 24.397f, 41.599f, 38.419f, 30.99f, 55.647f, 25.115f, + 14.96f, 18.882f, 26.918f, 32.442f, 26.231f, 27.107f, 26.828f, 15.968f, 18.668f, 14.071f, + 10.906f, 8.989f, 9.721f, 17.294f, 36.32f, 21.854f, 35.509f, 27.106f, 14.067f, 19.82f, + 33.582f, 35.997f, 33.528f, 30.369f, 36.955f, 21.23f, 15.2f, 30.252f, 34.56f, 22.295f, + 29.413f, 16.576f, 11.226f, 10.754f, 12.936f, 15.525f, 15.868f, 16.43f + }; + float[] corrections = quantizer.quantizeForQuery(vector, destination, centroid); + + float lower = corrections[1]; + float width = corrections[2]; + float sumQ = corrections[3]; + + assertEquals(729f, sumQ, 0); + assertEquals(-57.883f, lower, 0.001f); + assertEquals(9.972266f, width, 0.000001f); + assertArrayEquals( + new byte[] { + -77, -49, 73, -17, -89, 9, -43, -27, 40, 15, 42, 76, -122, 38, -22, -37, -96, 111, -63, + -102, -123, 23, 110, 127, 32, 95, 29, 106, -120, -121, -32, -94, 78, -98, 42, 95, 122, + 114, 30, 18, 91, 97, -5, -9, 123, 122, 31, -66, 49, 1, 20, 48, 0, 12, 30, 30, 4, 96, 2, 2, + 4, 33, 1, 65 + }, + destination); + } + + private float[] generateRandomFloatArray( + Random random, int dimensions, float lowerBoundInclusive, float upperBoundExclusive) { + float[] data = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + data[i] = random.nextFloat(lowerBoundInclusive, upperBoundExclusive); + } + return data; + } + + public void testQuantizeForIndexMIP() { + int dimensions = 768; + + // we want fixed values for these arrays so define our own random generation here to track + // quantization changes + Random random = new Random(42); + + float[] mipVectorToIndex = generateRandomFloatArray(random, dimensions, -1f, 1f); + float[] mipCentroid = generateRandomFloatArray(random, dimensions, -1f, 1f); + + VectorSimilarityFunction[] similarityFunctionsActingLikeEucllidean = + new VectorSimilarityFunction[] { + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, VectorSimilarityFunction.DOT_PRODUCT + }; + int randIdx = random().nextInt(similarityFunctionsActingLikeEucllidean.length); + VectorSimilarityFunction similarityFunction = similarityFunctionsActingLikeEucllidean[randIdx]; + + BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, similarityFunction); + float[] vector = mipVectorToIndex; + byte[] destination = new byte[dimensions / 8]; + float[] centroid = mipCentroid; + float[] corrections = quantizer.quantizeForIndex(vector, destination, centroid); + + assertEquals(3, corrections.length); + float ooq = corrections[0]; + float normOC = corrections[1]; + float oDotC = corrections[2]; + + assertEquals(0.8141399f, ooq, 0.0000001f); + assertEquals(21.847124f, normOC, 0.00001f); + assertEquals(6.4300356f, oDotC, 0.0001f); + assertArrayEquals( + new byte[] { + -83, -91, -71, 97, 32, -96, 89, -80, -19, -108, 3, 113, -111, 12, -86, 32, -43, 76, 122, + -106, -83, -37, -122, 118, 84, -72, 34, 20, 57, -29, 119, -8, -10, -100, -109, 62, -54, + 53, -44, 8, -16, 80, 58, 50, 105, -25, 47, 115, -106, -92, -122, -44, 8, 18, -23, 24, -15, + 62, 58, 111, 99, -116, -111, -5, 101, -69, -32, -74, -105, 113, -89, 44, 100, -93, -80, + 82, -64, 91, -87, -95, 115, 6, 76, 110, 101, 39, 108, 72, 2, 112, -63, -43, 105, -42, 9, + -128 + }, + destination); + } + + public void testQuantizeForQueryMIP() { + int dimensions = 768; + + // we want fixed values for these arrays so define our own random generation here to track + // quantization changes + Random random = new Random(42); + + float[] mipVectorToQuery = generateRandomFloatArray(random, dimensions, -1f, 1f); + float[] mipCentroid = generateRandomFloatArray(random, dimensions, -1f, 1f); + + VectorSimilarityFunction[] similarityFunctionsActingLikeEucllidean = + new VectorSimilarityFunction[] { + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, VectorSimilarityFunction.DOT_PRODUCT + }; + int randIdx = random().nextInt(similarityFunctionsActingLikeEucllidean.length); + VectorSimilarityFunction similarityFunction = similarityFunctionsActingLikeEucllidean[randIdx]; + + BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, similarityFunction); + float[] vector = mipVectorToQuery; + byte[] destination = new byte[dimensions / 8 * BQSpaceUtils.B_QUERY]; + float[] centroid = mipCentroid; + float cDotC = VectorUtil.dotProduct(centroid, centroid); + float[] corrections = quantizer.quantizeForQuery(vector, destination, centroid); + + float lower = corrections[0]; + float width = corrections[1]; + float normVmC = corrections[2]; + float vDotC = corrections[3]; + float sumQ = corrections[4]; + + assertEquals(5272, sumQ, 0); + assertEquals(-0.08603752f, lower, 0.00000001f); + assertEquals(0.011431276f, width, 0.00000001f); + assertEquals(21.847124f, normVmC, 0.00001f); + assertEquals(6.4300356f, vDotC, 0.0001f); + assertEquals(252.37146f, cDotC, 0.0001f); + assertArrayEquals( + new byte[] { + -81, 19, 67, 33, 112, 8, 40, -5, -19, 115, -87, -63, -59, 12, -2, -127, -23, 43, 24, 16, + -69, 112, -22, 75, -81, -50, 100, -41, 3, -120, -93, -4, 4, 125, 34, -57, -109, 89, -63, + -35, -116, 4, 35, 93, -26, -88, -55, -86, 63, -46, -122, -96, -26, 124, -64, 21, 96, 46, + 98, 97, 88, -98, -83, 121, 16, -14, -89, -118, 65, -39, -111, -35, 113, 108, 111, 86, 17, + -69, -47, 72, 1, 36, 17, 113, -87, -5, -46, -37, -2, 93, -123, 118, 4, -12, -33, 95, 32, + -63, -97, -109, 27, 111, 42, -57, -87, -41, -73, -106, 27, -31, 32, -1, 9, -88, -35, -11, + -103, 5, 27, -127, 108, 127, -119, 58, 38, 18, -103, -27, -63, 56, 77, -13, 3, -40, -127, + 37, 82, -87, -26, -45, -14, 18, -50, 76, 25, 37, -12, 106, 17, 115, 0, 23, -109, 26, -110, + 17, -35, 111, 4, 60, 58, -64, -104, -125, 23, -58, 89, -117, 104, -71, 3, -89, -26, 46, + 15, 82, -83, -75, -72, -69, 20, -38, -47, 109, -66, -66, -89, 108, -122, -3, -69, -85, 18, + 59, 85, -97, -114, 95, 2, -84, -77, 121, -6, 10, 110, -13, -123, -34, 106, -71, -107, 123, + 67, -111, 58, 52, -53, 87, -113, -21, -44, 26, 10, -62, 56, 111, 36, -126, 26, 94, -88, + -13, -113, -50, -9, -115, 84, 8, -32, -102, -4, 89, 29, 75, -73, -19, 22, -90, 76, -61, 4, + -48, -100, -11, 107, 20, -39, -98, 123, 77, 104, 9, 9, 91, -105, -40, -106, -87, 38, 48, + 60, 29, -68, 124, -78, -63, -101, -115, 67, -17, 101, -53, 121, 44, -78, -12, 110, 91, + -83, -92, -72, 96, 32, -96, 89, 48, 76, -124, 3, 113, -111, 12, -86, 32, -43, 68, 106, + -122, -84, -37, -124, 118, 84, -72, 34, 20, 57, -29, 119, 56, -10, -108, -109, 60, -56, + 37, 84, 8, -16, 80, 24, 50, 41, -25, 47, 115, -122, -92, -126, -44, 8, 18, -23, 24, -15, + 60, 58, 111, 99, -120, -111, -21, 101, 59, -32, -74, -105, 113, -90, 36, 100, -93, -80, + 82, -64, 91, -87, -95, 115, 6, 76, 110, 101, 39, 44, 0, 2, 112, -64, -47, 105, 2, 1, -128 + }, + destination); + } + + public void testQuantizeForIndexCosine() { + int dimensions = 768; + + // we want fixed values for these arrays so define our own random generation here to track + // quantization changes + Random random = new Random(42); + + float[] mipVectorToIndex = generateRandomFloatArray(random, dimensions, -1f, 1f); + float[] mipCentroid = generateRandomFloatArray(random, dimensions, -1f, 1f); + + mipVectorToIndex = VectorUtil.l2normalize(mipVectorToIndex); + mipCentroid = VectorUtil.l2normalize(mipCentroid); + + BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, VectorSimilarityFunction.COSINE); + float[] vector = mipVectorToIndex; + byte[] destination = new byte[dimensions / 8]; + float[] centroid = mipCentroid; + float[] corrections = quantizer.quantizeForIndex(vector, destination, centroid); + + assertEquals(3, corrections.length); + float ooq = corrections[0]; + float normOC = corrections[1]; + float oDotC = corrections[2]; + + assertEquals(0.8145253f, ooq, 0.000001f); + assertEquals(1.3955297f, normOC, 0.00001f); + assertEquals(0.026248248f, oDotC, 0.0001f); + assertArrayEquals( + new byte[] { + -83, -91, -71, 97, 32, -96, 89, -80, -20, -108, 3, 113, -111, 12, -86, 32, -43, 76, 122, + -106, -83, -37, -122, 118, 84, -72, 34, 20, 57, -29, 119, -72, -10, -100, -109, 62, -54, + 117, -44, 8, -16, 80, 58, 50, 41, -25, 47, 115, -106, -92, -122, -44, 8, 18, -23, 24, -15, + 62, 58, 111, 99, -116, -111, -21, 101, -69, -32, -74, -105, 113, -90, 44, 100, -93, -80, + 82, -64, 91, -87, -95, 115, 6, 76, 110, 101, 39, 44, 72, 2, 112, -63, -43, 105, -42, 9, + -126 + }, + destination); + } + + public void testQuantizeForQueryCosine() { + int dimensions = 768; + + // we want fixed values for these arrays so define our own random generation here to track + // quantization changes + Random random = new Random(42); + + float[] mipVectorToQuery = generateRandomFloatArray(random, dimensions, -1f, 1f); + float[] mipCentroid = generateRandomFloatArray(random, dimensions, -1f, 1f); + + mipVectorToQuery = VectorUtil.l2normalize(mipVectorToQuery); + mipCentroid = VectorUtil.l2normalize(mipCentroid); + + BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, VectorSimilarityFunction.COSINE); + float[] vector = mipVectorToQuery; + byte[] destination = new byte[dimensions / 8 * BQSpaceUtils.B_QUERY]; + float[] centroid = mipCentroid; + float cDotC = VectorUtil.dotProduct(centroid, centroid); + float[] corrections = quantizer.quantizeForQuery(vector, destination, centroid); + + float lower = corrections[0]; + float width = corrections[1]; + float normVmC = corrections[2]; + float vDotC = corrections[3]; + float sumQ = corrections[4]; + + assertEquals(5277, sumQ, 0); + assertEquals(-0.086002514f, lower, 0.00000001f); + assertEquals(0.011431345f, width, 0.00000001f); + assertEquals(1.3955297f, normVmC, 0.00001f); + assertEquals(0.026248248f, vDotC, 0.0001f); + assertEquals(1.0f, cDotC, 0.0001f); + assertArrayEquals( + new byte[] { + -83, 18, 67, 37, 80, 8, 40, -1, -19, 115, -87, -63, -59, 12, -2, -63, -19, 43, -104, 16, + -69, 80, -22, 75, -81, -50, 100, -41, 7, -88, -93, -4, 4, 117, 34, -57, -109, 89, -63, + -35, -116, 4, 35, 93, -26, -88, -56, -82, 63, -46, -122, -96, -26, 124, -64, 21, 96, 46, + 114, 101, 92, -98, -83, 121, 48, -14, -89, -118, 65, -47, -79, -35, 113, 110, 111, 70, 17, + -69, -47, 64, 1, 102, 19, 113, -87, -5, -46, -34, -2, 93, -123, 102, 4, -12, 127, 95, 32, + -64, -97, -105, 59, 111, 42, -57, -87, -41, -73, -106, 27, -31, 32, -65, 9, -88, 93, -11, + -103, 37, 27, -127, 108, 127, -119, 58, 38, 18, -103, -27, -63, 48, 77, -13, 3, -40, -127, + 37, 82, -87, -26, -45, -14, 18, -49, 76, 25, 37, -12, 106, 17, 115, 0, 23, -109, 26, -126, + 21, -35, 111, 4, 60, 58, -64, -104, -125, 23, -58, 121, -117, 104, -69, 3, -89, -26, 46, + 15, 90, -83, -73, -72, -69, 20, -38, -47, 109, -66, -66, -89, 108, -122, -3, 59, -85, 18, + 58, 85, -101, -114, 95, 2, -84, -77, 121, -6, 10, 110, -13, -123, -34, 106, -71, -107, + 123, 67, -111, 58, 52, -53, 87, -113, -21, -44, 26, 10, -62, 56, 103, 36, -126, 26, 94, + -88, -13, -113, -50, -9, -115, 84, 8, -32, -102, -4, 89, 29, 75, -73, -19, 22, -90, 76, + -61, 4, -44, -100, -11, 107, 20, -39, -98, 123, 77, 104, 9, 41, 91, -105, -38, -106, -87, + 38, 48, 60, 29, -68, 126, -78, -63, -101, -115, 67, -17, 101, -53, 121, 44, -78, -12, -18, + 91, -83, -91, -72, 96, 32, -96, 89, 48, 76, -124, 3, 113, -111, 12, -86, 32, -43, 68, 106, + -122, -84, -37, -124, 118, 84, -72, 34, 20, 57, -29, 119, 56, -10, -100, -109, 60, -56, + 37, 84, 8, -16, 80, 24, 50, 41, -25, 47, 115, -122, -92, -126, -44, 8, 18, -23, 24, -15, + 60, 58, 107, 99, -120, -111, -21, 101, 59, -32, -74, -105, 113, -122, 36, 100, -95, -80, + 82, -64, 91, -87, -95, 115, 4, 76, 110, 101, 39, 44, 0, 2, 112, -64, -47, 105, 2, 1, -128 + }, + destination); + } +}