diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index d547d7bdc656..7ba4e1321bc8 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -43,7 +43,12 @@ API Changes
New Features
---------------------
-(No changes)
+
+* GITHUB#13651: New binary quantized vector formats `Lucene101HnswBinaryQuantizedVectorsFormat` and
+ `Lucene101BinaryQuantizedVectorsFormat`. This results in a 32x reduction in memory requirements for fast vector search
+ while achieving nice recall properties only requiring about 5x oversampling with rescoring on larger dimensional vectors.
+ The format is based on the RaBitQ algorithm & paper: https://arxiv.org/abs/2405.12497.
+ (John Wagster, Mayya Sharipova, Chris Hegarty, Tom Veasey, Ben Trent)
Improvements
---------------------
diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java
index 85aff5722498..9a8df45312c6 100644
--- a/lucene/core/src/java/module-info.java
+++ b/lucene/core/src/java/module-info.java
@@ -76,7 +76,9 @@
provides org.apache.lucene.codecs.KnnVectorsFormat with
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat,
- org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
+ org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat,
+ org.apache.lucene.codecs.lucene101.Lucene101BinaryQuantizedVectorsFormat,
+ org.apache.lucene.codecs.lucene101.Lucene101HnswBinaryQuantizedVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.lucene101.Lucene101PostingsFormat;
provides org.apache.lucene.index.SortFieldProvider with
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/BinarizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/BinarizedByteVectorValues.java
new file mode 100644
index 000000000000..7121432d9dfd
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/BinarizedByteVectorValues.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene101;
+
+import static org.apache.lucene.util.quantization.BQSpaceUtils.constSqrt;
+
+import java.io.IOException;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.quantization.BQSpaceUtils;
+import org.apache.lucene.util.quantization.BinaryQuantizer;
+
+/**
+ * A version of {@link ByteVectorValues}, but additionally retrieving score correction values offset
+ * for binarization quantization scores.
+ *
+ * @lucene.experimental
+ */
+public abstract class BinarizedByteVectorValues extends ByteVectorValues {
+
+ /**
+ * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of
+ * distances, the corrective terms are, in order
+ *
+ *
+ * - the dot-product of the normalized, centered vector with its binarized self
+ *
- the norm of the centered vector
+ *
- the dot-product of the vector with the centroid
+ *
+ *
+ * For euclidean:
+ *
+ *
+ * - The euclidean distance to the centroid
+ *
- The sum of the dimensions divided by the vector norm
+ *
+ *
+ * @param vectorOrd the vector ordinal
+ * @return the corrective terms
+ * @throws IOException if an I/O error occurs
+ */
+ public abstract float[] getCorrectiveTerms(int vectorOrd) throws IOException;
+
+ /**
+ * @return the quantizer used to quantize the vectors
+ */
+ public abstract BinaryQuantizer getQuantizer();
+
+ public abstract float[] getCentroid() throws IOException;
+
+ int discretizedDimensions() {
+ return BQSpaceUtils.discretize(dimension(), 64);
+ }
+
+ float sqrtDimensions() {
+ return (float) constSqrt(dimension());
+ }
+
+ float maxX1() {
+ return (float) (1.9 / constSqrt(discretizedDimensions() - 1.0));
+ }
+
+ /**
+ * Return a {@link VectorScorer} for the given query vector.
+ *
+ * @param query the query vector
+ * @return a {@link VectorScorer} instance or null
+ */
+ public abstract VectorScorer scorer(float[] query) throws IOException;
+
+ @Override
+ public abstract BinarizedByteVectorValues copy() throws IOException;
+
+ float getCentroidDP() throws IOException {
+ // this only gets executed on-merge
+ float[] centroid = getCentroid();
+ return VectorUtil.dotProduct(centroid, centroid);
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryFlatVectorsScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryFlatVectorsScorer.java
new file mode 100644
index 000000000000..4453e5c3d915
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryFlatVectorsScorer.java
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene101;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+import java.io.IOException;
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.quantization.BQSpaceUtils;
+import org.apache.lucene.util.quantization.BinaryQuantizer;
+
+/** Vector scorer over binarized vector values */
+public class Lucene101BinaryFlatVectorsScorer implements FlatVectorsScorer {
+ private final FlatVectorsScorer nonQuantizedDelegate;
+
+ public Lucene101BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) {
+ this.nonQuantizedDelegate = nonQuantizedDelegate;
+ }
+
+ @Override
+ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
+ throws IOException {
+ if (vectorValues instanceof BinarizedByteVectorValues) {
+ throw new UnsupportedOperationException(
+ "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format");
+ }
+ return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
+ throws IOException {
+ if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) {
+ BinaryQuantizer quantizer = binarizedVectors.getQuantizer();
+ float[] centroid = binarizedVectors.getCentroid();
+ if (similarityFunction == COSINE) {
+ float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
+ VectorUtil.l2normalize(copy);
+ target = copy;
+ }
+ byte[] quantized =
+ new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8];
+ float[] queryCorrections = quantizer.quantizeForQuery(target, quantized, centroid);
+ BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections);
+ return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction);
+ }
+ return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
+ throws IOException {
+ return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction,
+ Lucene101BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors,
+ BinarizedByteVectorValues targetVectors) {
+ return new BinarizedRandomVectorScorerSupplier(
+ scoringVectors, targetVectors, similarityFunction);
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene101BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")";
+ }
+
+ /** Vector scorer supplier over binarized vector values */
+ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
+ private final Lucene101BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues
+ queryVectors;
+ private final BinarizedByteVectorValues targetVectors;
+ private final VectorSimilarityFunction similarityFunction;
+
+ BinarizedRandomVectorScorerSupplier(
+ Lucene101BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors,
+ BinarizedByteVectorValues targetVectors,
+ VectorSimilarityFunction similarityFunction) {
+ this.queryVectors = queryVectors;
+ this.targetVectors = targetVectors;
+ this.similarityFunction = similarityFunction;
+ }
+
+ @Override
+ public RandomVectorScorer scorer(int ord) throws IOException {
+ byte[] vector = queryVectors.vectorValue(ord);
+ float[] correctiveTerms = queryVectors.getCorrectiveTerms(ord);
+ BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms);
+ return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction);
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return new BinarizedRandomVectorScorerSupplier(
+ queryVectors.copy(), targetVectors.copy(), similarityFunction);
+ }
+ }
+
+ /** A binarized query representing its quantized form along with factors */
+ public record BinaryQueryVector(byte[] vector, float[] factors) {}
+
+ /** Vector scorer over binarized vector values */
+ public static class BinarizedRandomVectorScorer
+ extends RandomVectorScorer.AbstractRandomVectorScorer {
+ private final BinaryQueryVector queryVector;
+ private final BinarizedByteVectorValues targetVectors;
+ private final VectorSimilarityFunction similarityFunction;
+
+ private final float sqrtDimensions;
+ private final float maxX1;
+
+ public BinarizedRandomVectorScorer(
+ BinaryQueryVector queryVectors,
+ BinarizedByteVectorValues targetVectors,
+ VectorSimilarityFunction similarityFunction) {
+ super(targetVectors);
+ this.queryVector = queryVectors;
+ this.targetVectors = targetVectors;
+ this.similarityFunction = similarityFunction;
+ this.sqrtDimensions = targetVectors.sqrtDimensions();
+ this.maxX1 = targetVectors.maxX1();
+ }
+
+ @Override
+ public float score(int targetOrd) throws IOException {
+ byte[] quantizedQuery = queryVector.vector();
+ byte[] binaryCode = targetVectors.vectorValue(targetOrd);
+ float qcDist = VectorUtil.ipByteBinByte(quantizedQuery, binaryCode);
+ float xbSum = (float) VectorUtil.popCount(binaryCode);
+ float[] correctiveTerms = targetVectors.getCorrectiveTerms(targetOrd);
+ if (similarityFunction == EUCLIDEAN) {
+ return euclideanScore(xbSum, qcDist, correctiveTerms, queryVector.factors);
+ }
+ return dotProductScore(
+ xbSum, qcDist, targetVectors.getCentroidDP(), correctiveTerms, queryVector.factors);
+ }
+
+ private float dotProductScore(
+ float xbSum,
+ float qcDist,
+ float cDotC,
+ float[] vectorCorrectiveTerms,
+ float[] queryCorrectiveTerms) {
+ assert vectorCorrectiveTerms.length == 3;
+ assert queryCorrectiveTerms.length == 5;
+ float lower = queryCorrectiveTerms[0] / sqrtDimensions;
+ float width = queryCorrectiveTerms[1] / sqrtDimensions;
+ float vmC = queryCorrectiveTerms[2];
+ float vDotC = queryCorrectiveTerms[3];
+ float quantizedSum = queryCorrectiveTerms[4];
+ float ooq = vectorCorrectiveTerms[0];
+ float vmcNormOC = vectorCorrectiveTerms[1] * vmC;
+ float oDotC = vectorCorrectiveTerms[2];
+
+ final float dist;
+ // If ||o-c|| == 0, so, it's ok to throw the rest of the equation away
+ // and simply use `oDotC + vDotC - cDotC` as centroid == doc vector
+ if (vmcNormOC == 0 || ooq == 0) {
+ dist = oDotC + vDotC - cDotC;
+ } else {
+ // If ||o-c|| != 0, we should assume that `ooq` is finite
+ assert Float.isFinite(ooq);
+ float estimatedDot =
+ (2 * width * qcDist
+ + 2 * lower * xbSum
+ - width * quantizedSum
+ - targetVectors.dimension() * lower)
+ / ooq;
+ dist = vmcNormOC * estimatedDot + oDotC + vDotC - cDotC;
+ }
+ assert Float.isFinite(dist);
+
+ float ooqSqr = ooq * ooq;
+ float errorBound = (float) (vmcNormOC * (maxX1 * Math.sqrt((1 - ooqSqr) / ooqSqr)));
+ float score = Float.isFinite(errorBound) ? dist - errorBound : dist;
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ return VectorUtil.scaleMaxInnerProductScore(score);
+ }
+ return Math.max((1f + score) / 2f, 0);
+ }
+
+ private float euclideanScore(
+ float xbSum, float qcDist, float[] vectorCorrectiveTerms, float[] queryCorrectiveTerms) {
+ assert vectorCorrectiveTerms.length == 2;
+ assert queryCorrectiveTerms.length == 4;
+ float distanceToCentroid = queryCorrectiveTerms[0];
+ float lower = queryCorrectiveTerms[1];
+ float width = queryCorrectiveTerms[2];
+ float quantizedSum = queryCorrectiveTerms[3];
+
+ float targetDistToC = vectorCorrectiveTerms[0];
+ float x0 = vectorCorrectiveTerms[1];
+ float sqrX = targetDistToC * targetDistToC;
+ double xX0 = targetDistToC / x0;
+
+ float factorPPC =
+ (float) (-2.0 / sqrtDimensions * xX0 * (xbSum * 2.0 - targetVectors.dimension()));
+ float factorIP = (float) (-2.0 / sqrtDimensions * xX0);
+
+ float score =
+ sqrX
+ + distanceToCentroid
+ + factorPPC * lower
+ + (qcDist * 2 - quantizedSum) * factorIP * width;
+ float projectionDist = (float) Math.sqrt(xX0 * xX0 - targetDistToC * targetDistToC);
+ float error = 2.0f * maxX1 * projectionDist;
+ float y = (float) Math.sqrt(distanceToCentroid);
+ float errorBound = y * error;
+ if (Float.isFinite(errorBound)) {
+ score = score + errorBound;
+ }
+ return Math.max(1 / (1f + score), 0);
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..33cc988670a9
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsFormat.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene101;
+
+import java.io.IOException;
+import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
+import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+
+/**
+ * Codec for encoding/decoding binary quantized vectors The binary quantization format used here
+ * reflects RaBitQ. Also see {@link
+ * org.apache.lucene.util.quantization.BinaryQuantizer}. Some of key features of RabitQ are:
+ *
+ *
+ * - Estimating the distance between two vectors using their centroid normalized distance. This
+ * requires some additional corrective factors, but allows for centroid normalization to occur
+ * and thus enabling binary quantization.
+ *
- Binary quantization of centroid normalized vectors.
+ *
- Asymmetric quantization of vectors, where query vectors are quantized to half-byte
+ * precision (normalized to the centroid) and then compared directly against the single bit
+ * quantized vectors in the index.
+ *
- Transforming the half-byte quantized query vectors in such a way that the comparison with
+ * single bit vectors can be done with bit arithmetic.
+ *
- Utilizing an error bias calculation enabled by the centroid normalization. This allows for
+ * dynamic rescoring of vectors that fall outside a certain error threshold.
+ *
+ *
+ * The format is stored in two files:
+ *
+ * .veb (vector data) file
+ *
+ * Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's
+ * corrective factors. At the end of the file, additional information is stored for vector ordinal
+ * to centroid ordinal mapping and sparse vector information.
+ *
+ *
+ * - For each vector:
+ *
+ * - [byte] the binary quantized values, each byte holds 8 bits.
+ *
- [float] the corrective values. Two floats for Euclidean distance. Three floats
+ * for the dot-product family of distances.
+ *
+ * - After the vectors, sparse vector information keeping track of monotonic blocks.
+ *
+ *
+ * .vemb (vector metadata) file
+ *
+ * Stores the metadata for the vectors. This includes the number of vectors, the number of
+ * dimensions, and file offset information.
+ *
+ *
+ * - int the field number
+ *
- int the vector encoding ordinal
+ *
- int the vector similarity ordinal
+ *
- vint the vector dimensions
+ *
- vlong the offset to the vector data in the .veb file
+ *
- vlong the length of the vector data in the .veb file
+ *
- vint the number of vectors
+ *
- The sparse vector information, if required, mapping vector ordinal to doc ID
+ *
+ */
+public class Lucene101BinaryQuantizedVectorsFormat extends FlatVectorsFormat {
+
+ public static final String BINARIZED_VECTOR_COMPONENT = "BVEC";
+ public static final String NAME = "Lucene101BinaryQuantizedVectorsFormat";
+
+ static final int VERSION_START = 0;
+ static final int VERSION_CURRENT = VERSION_START;
+ static final String META_CODEC_NAME = "Lucene101BinaryQuantizedVectorsFormatMeta";
+ static final String VECTOR_DATA_CODEC_NAME = "Lucene101BinaryQuantizedVectorsFormatData";
+ static final String META_EXTENSION = "vemb";
+ static final String VECTOR_DATA_EXTENSION = "veb";
+ static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
+
+ private static final FlatVectorsFormat rawVectorFormat =
+ new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
+
+ private final Lucene101BinaryFlatVectorsScorer scorer =
+ new Lucene101BinaryFlatVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
+
+ /** Creates a new instance with the default number of vectors per cluster. */
+ public Lucene101BinaryQuantizedVectorsFormat() {
+ super(NAME);
+ }
+
+ @Override
+ public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new Lucene101BinaryQuantizedVectorsWriter(
+ scorer, rawVectorFormat.fieldsWriter(state), state);
+ }
+
+ @Override
+ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
+ return new Lucene101BinaryQuantizedVectorsReader(
+ state, rawVectorFormat.fieldsReader(state), scorer);
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene101BinaryQuantizedVectorsFormat(name="
+ + NAME
+ + ", flatVectorScorer="
+ + scorer
+ + ")";
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsReader.java
new file mode 100644
index 000000000000..db54b857f202
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsReader.java
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene101;
+
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.lucene.codecs.CodecUtil;
+import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.index.CorruptIndexException;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FieldInfos;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexFileNames;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.ChecksumIndexInput;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.ReadAdvice;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.RamUsageEstimator;
+import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.quantization.BQSpaceUtils;
+import org.apache.lucene.util.quantization.BinaryQuantizer;
+
+/**
+ * Reads raw and binarized vectors from the index segments for KNN search.
+ *
+ * @lucene.experimental
+ */
+public class Lucene101BinaryQuantizedVectorsReader extends FlatVectorsReader {
+ private static final long SHALLOW_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(Lucene101BinaryQuantizedVectorsReader.class);
+
+ private final Map fields = new HashMap<>();
+ private final IndexInput quantizedVectorData;
+ private final FlatVectorsReader rawVectorsReader;
+ private final Lucene101BinaryFlatVectorsScorer vectorScorer;
+
+ public Lucene101BinaryQuantizedVectorsReader(
+ SegmentReadState state,
+ FlatVectorsReader rawVectorsReader,
+ Lucene101BinaryFlatVectorsScorer vectorsScorer)
+ throws IOException {
+ super(vectorsScorer);
+ this.vectorScorer = vectorsScorer;
+ this.rawVectorsReader = rawVectorsReader;
+ int versionMeta = -1;
+ String metaFileName =
+ IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ Lucene101BinaryQuantizedVectorsFormat.META_EXTENSION);
+ boolean success = false;
+ try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
+ Throwable priorE = null;
+ try {
+ versionMeta =
+ CodecUtil.checkIndexHeader(
+ meta,
+ Lucene101BinaryQuantizedVectorsFormat.META_CODEC_NAME,
+ Lucene101BinaryQuantizedVectorsFormat.VERSION_START,
+ Lucene101BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ readFields(meta, state.fieldInfos);
+ } catch (Throwable exception) {
+ priorE = exception;
+ } finally {
+ CodecUtil.checkFooter(meta, priorE);
+ }
+ quantizedVectorData =
+ openDataInput(
+ state,
+ versionMeta,
+ Lucene101BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION,
+ Lucene101BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
+ // Quantized vectors are accessed randomly from their node ID stored in the HNSW
+ // graph.
+ state.context.withReadAdvice(ReadAdvice.RANDOM));
+ success = true;
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(this);
+ }
+ }
+ }
+
+ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
+ for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
+ FieldInfo info = infos.fieldInfo(fieldNumber);
+ if (info == null) {
+ throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
+ }
+ FieldEntry fieldEntry = readField(meta, info);
+ validateFieldEntry(info, fieldEntry);
+ fields.put(info.name, fieldEntry);
+ }
+ }
+
+ static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
+ int dimension = info.getVectorDimension();
+ if (dimension != fieldEntry.dimension) {
+ throw new IllegalStateException(
+ "Inconsistent vector dimension for field=\""
+ + info.name
+ + "\"; "
+ + dimension
+ + " != "
+ + fieldEntry.dimension);
+ }
+
+ int binaryDims = BQSpaceUtils.discretize(dimension, 64) / 8;
+ int correctionsCount =
+ fieldEntry.similarityFunction != VectorSimilarityFunction.EUCLIDEAN ? 3 : 2;
+ long numQuantizedVectorBytes =
+ Math.multiplyExact((binaryDims + (Float.BYTES * correctionsCount)), (long) fieldEntry.size);
+ if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
+ throw new IllegalStateException(
+ "Binarized vector data length "
+ + fieldEntry.vectorDataLength
+ + " not matching size = "
+ + fieldEntry.size
+ + " * (binaryBytes="
+ + binaryDims
+ + " + 8"
+ + ") = "
+ + numQuantizedVectorBytes);
+ }
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
+ FieldEntry fi = fields.get(field);
+ if (fi == null) {
+ return null;
+ }
+ return vectorScorer.getRandomVectorScorer(
+ fi.similarityFunction,
+ OffHeapBinarizedVectorValues.load(
+ fi.ordToDocDISIReaderConfiguration,
+ fi.dimension,
+ fi.size,
+ new BinaryQuantizer(fi.dimension, fi.descritizedDimension, fi.similarityFunction),
+ fi.similarityFunction,
+ vectorScorer,
+ fi.centroid,
+ fi.centroidDP,
+ fi.vectorDataOffset,
+ fi.vectorDataLength,
+ quantizedVectorData),
+ target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
+ return rawVectorsReader.getRandomVectorScorer(field, target);
+ }
+
+ @Override
+ public void checkIntegrity() throws IOException {
+ rawVectorsReader.checkIntegrity();
+ CodecUtil.checksumEntireFile(quantizedVectorData);
+ }
+
+ @Override
+ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
+ FieldEntry fi = fields.get(field);
+ if (fi == null) {
+ return null;
+ }
+ if (fi.vectorEncoding != VectorEncoding.FLOAT32) {
+ throw new IllegalArgumentException(
+ "field=\""
+ + field
+ + "\" is encoded as: "
+ + fi.vectorEncoding
+ + " expected: "
+ + VectorEncoding.FLOAT32);
+ }
+ OffHeapBinarizedVectorValues bvv =
+ OffHeapBinarizedVectorValues.load(
+ fi.ordToDocDISIReaderConfiguration,
+ fi.dimension,
+ fi.size,
+ new BinaryQuantizer(fi.dimension, fi.descritizedDimension, fi.similarityFunction),
+ fi.similarityFunction,
+ vectorScorer,
+ fi.centroid,
+ fi.centroidDP,
+ fi.vectorDataOffset,
+ fi.vectorDataLength,
+ quantizedVectorData);
+ return new BinarizedVectorValues(rawVectorsReader.getFloatVectorValues(field), bvv);
+ }
+
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ return rawVectorsReader.getByteVectorValues(field);
+ }
+
+ @Override
+ public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
+ throws IOException {
+ rawVectorsReader.search(field, target, knnCollector, acceptDocs);
+ }
+
+ @Override
+ public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
+ throws IOException {
+ if (knnCollector.k() == 0) return;
+ final RandomVectorScorer scorer = getRandomVectorScorer(field, target);
+ if (scorer == null) return;
+ OrdinalTranslatedKnnCollector collector =
+ new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
+ Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
+ for (int i = 0; i < scorer.maxOrd(); i++) {
+ if (acceptedOrds == null || acceptedOrds.get(i)) {
+ collector.collect(i, scorer.score(i));
+ collector.incVisitedCount(1);
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(quantizedVectorData, rawVectorsReader);
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long size = SHALLOW_SIZE;
+ size +=
+ RamUsageEstimator.sizeOfMap(
+ fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
+ size += rawVectorsReader.ramBytesUsed();
+ return size;
+ }
+
+ public float[] getCentroid(String field) {
+ FieldEntry fieldEntry = fields.get(field);
+ if (fieldEntry != null) {
+ return fieldEntry.centroid;
+ }
+ return null;
+ }
+
+ private static IndexInput openDataInput(
+ SegmentReadState state,
+ int versionMeta,
+ String fileExtension,
+ String codecName,
+ IOContext context)
+ throws IOException {
+ String fileName =
+ IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
+ IndexInput in = state.directory.openInput(fileName, context);
+ boolean success = false;
+ try {
+ int versionVectorData =
+ CodecUtil.checkIndexHeader(
+ in,
+ codecName,
+ Lucene101BinaryQuantizedVectorsFormat.VERSION_START,
+ Lucene101BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ if (versionMeta != versionVectorData) {
+ throw new CorruptIndexException(
+ "Format versions mismatch: meta="
+ + versionMeta
+ + ", "
+ + codecName
+ + "="
+ + versionVectorData,
+ in);
+ }
+ CodecUtil.retrieveChecksum(in);
+ success = true;
+ return in;
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(in);
+ }
+ }
+ }
+
+ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
+ VectorEncoding vectorEncoding = readVectorEncoding(input);
+ VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
+ if (similarityFunction != info.getVectorSimilarityFunction()) {
+ throw new IllegalStateException(
+ "Inconsistent vector similarity function for field=\""
+ + info.name
+ + "\"; "
+ + similarityFunction
+ + " != "
+ + info.getVectorSimilarityFunction());
+ }
+ return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction());
+ }
+
+ private record FieldEntry(
+ VectorSimilarityFunction similarityFunction,
+ VectorEncoding vectorEncoding,
+ int dimension,
+ int descritizedDimension,
+ long vectorDataOffset,
+ long vectorDataLength,
+ int size,
+ float[] centroid,
+ float centroidDP,
+ OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration) {
+
+ static FieldEntry create(
+ IndexInput input,
+ VectorEncoding vectorEncoding,
+ VectorSimilarityFunction similarityFunction)
+ throws IOException {
+ int dimension = input.readVInt();
+ long vectorDataOffset = input.readVLong();
+ long vectorDataLength = input.readVLong();
+ int size = input.readVInt();
+ final float[] centroid;
+ float centroidDP = 0;
+ if (size > 0) {
+ centroid = new float[dimension];
+ input.readFloats(centroid, 0, dimension);
+ centroidDP = Float.intBitsToFloat(input.readInt());
+ } else {
+ centroid = null;
+ }
+ OrdToDocDISIReaderConfiguration conf =
+ OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
+ return new FieldEntry(
+ similarityFunction,
+ vectorEncoding,
+ dimension,
+ BQSpaceUtils.discretize(dimension, 64),
+ vectorDataOffset,
+ vectorDataLength,
+ size,
+ centroid,
+ centroidDP,
+ conf);
+ }
+ }
+
+ /** Binarized vector values holding row and quantized vector values */
+ protected static final class BinarizedVectorValues extends FloatVectorValues {
+ private final FloatVectorValues rawVectorValues;
+ private final BinarizedByteVectorValues quantizedVectorValues;
+
+ BinarizedVectorValues(
+ FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) {
+ this.rawVectorValues = rawVectorValues;
+ this.quantizedVectorValues = quantizedVectorValues;
+ }
+
+ @Override
+ public int dimension() {
+ return rawVectorValues.dimension();
+ }
+
+ @Override
+ public int size() {
+ return rawVectorValues.size();
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ return rawVectorValues.vectorValue(ord);
+ }
+
+ @Override
+ public BinarizedVectorValues copy() throws IOException {
+ return new BinarizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy());
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return rawVectorValues.getAcceptOrds(acceptDocs);
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return rawVectorValues.ordToDoc(ord);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return rawVectorValues.iterator();
+ }
+
+ @Override
+ public VectorScorer scorer(float[] query) throws IOException {
+ return quantizedVectorValues.scorer(query);
+ }
+
+ BinarizedByteVectorValues getQuantizedVectorValues() {
+ return quantizedVectorValues;
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsWriter.java
new file mode 100644
index 000000000000..396bb78b94f9
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101BinaryQuantizedVectorsWriter.java
@@ -0,0 +1,966 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene101;
+
+import static org.apache.lucene.codecs.lucene101.Lucene101BinaryQuantizedVectorsFormat.BINARIZED_VECTOR_COMPONENT;
+import static org.apache.lucene.codecs.lucene101.Lucene101BinaryQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.lucene.codecs.CodecUtil;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
+import org.apache.lucene.index.DocsWithFieldSet;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexFileNames;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.MergeState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.internal.hppc.FloatArrayList;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.RamUsageEstimator;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.quantization.BQSpaceUtils;
+import org.apache.lucene.util.quantization.BinaryQuantizer;
+
+/**
+ * Writes raw and binarized vector values to index segments for KNN search.
+ *
+ * @lucene.experimental
+ */
+public class Lucene101BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
+ private static final long SHALLOW_RAM_BYTES_USED =
+ shallowSizeOfInstance(Lucene101BinaryQuantizedVectorsWriter.class);
+
+ private final SegmentWriteState segmentWriteState;
+ private final List fields = new ArrayList<>();
+ private final IndexOutput meta, binarizedVectorData;
+ private final FlatVectorsWriter rawVectorDelegate;
+ private final Lucene101BinaryFlatVectorsScorer vectorsScorer;
+ private boolean finished;
+
+ /**
+ * Sole constructor
+ *
+ * @param vectorsScorer the scorer to use for scoring vectors
+ */
+ protected Lucene101BinaryQuantizedVectorsWriter(
+ Lucene101BinaryFlatVectorsScorer vectorsScorer,
+ FlatVectorsWriter rawVectorDelegate,
+ SegmentWriteState state)
+ throws IOException {
+ super(vectorsScorer);
+ this.vectorsScorer = vectorsScorer;
+ this.segmentWriteState = state;
+ String metaFileName =
+ IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ Lucene101BinaryQuantizedVectorsFormat.META_EXTENSION);
+
+ String binarizedVectorDataFileName =
+ IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ Lucene101BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION);
+ this.rawVectorDelegate = rawVectorDelegate;
+ boolean success = false;
+ try {
+ meta = state.directory.createOutput(metaFileName, state.context);
+ binarizedVectorData =
+ state.directory.createOutput(binarizedVectorDataFileName, state.context);
+
+ CodecUtil.writeIndexHeader(
+ meta,
+ Lucene101BinaryQuantizedVectorsFormat.META_CODEC_NAME,
+ Lucene101BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ CodecUtil.writeIndexHeader(
+ binarizedVectorData,
+ Lucene101BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
+ Lucene101BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix);
+ success = true;
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(this);
+ }
+ }
+ }
+
+ @Override
+ public FlatFieldVectorsWriter> addField(FieldInfo fieldInfo) throws IOException {
+ FlatFieldVectorsWriter> rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo);
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ @SuppressWarnings("unchecked")
+ FieldWriter fieldWriter =
+ new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate);
+ fields.add(fieldWriter);
+ return fieldWriter;
+ }
+ return rawVectorDelegate;
+ }
+
+ @Override
+ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
+ rawVectorDelegate.flush(maxDoc, sortMap);
+ for (FieldWriter field : fields) {
+ // after raw vectors are written, normalize vectors for clustering and quantization
+ if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) {
+ field.normalizeVectors();
+ }
+
+ final float[] clusterCenter;
+ int vectorCount = field.flatFieldVectorsWriter.getVectors().size();
+ clusterCenter = new float[field.dimensionSums.length];
+ if (vectorCount > 0) {
+ for (int i = 0; i < field.dimensionSums.length; i++) {
+ clusterCenter[i] = field.dimensionSums[i] / vectorCount;
+ }
+ if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) {
+ VectorUtil.l2normalize(clusterCenter);
+ }
+ }
+ if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(
+ BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ int descritizedDimension = BQSpaceUtils.discretize(field.fieldInfo.getVectorDimension(), 64);
+ BinaryQuantizer quantizer =
+ new BinaryQuantizer(
+ field.fieldInfo.getVectorDimension(),
+ descritizedDimension,
+ field.fieldInfo.getVectorSimilarityFunction());
+ if (sortMap == null) {
+ writeField(field, clusterCenter, maxDoc, quantizer);
+ } else {
+ writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer);
+ }
+ field.finish();
+ }
+ }
+
+ private void writeField(
+ FieldWriter fieldData, float[] clusterCenter, int maxDoc, BinaryQuantizer quantizer)
+ throws IOException {
+ // write vector values
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ writeBinarizedVectors(fieldData, clusterCenter, quantizer);
+ long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+ float centroidDp =
+ fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0;
+
+ writeMeta(
+ fieldData.fieldInfo,
+ maxDoc,
+ vectorDataOffset,
+ vectorDataLength,
+ clusterCenter,
+ centroidDp,
+ fieldData.getDocsWithFieldSet());
+ }
+
+ private void writeBinarizedVectors(
+ FieldWriter fieldData, float[] clusterCenter, BinaryQuantizer scalarQuantizer)
+ throws IOException {
+ byte[] vector =
+ new byte[BQSpaceUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64) / 8];
+ int correctionsCount = scalarQuantizer.getSimilarity() != EUCLIDEAN ? 3 : 2;
+ final ByteBuffer correctionsBuffer =
+ ByteBuffer.allocate(Float.BYTES * correctionsCount).order(ByteOrder.LITTLE_ENDIAN);
+ for (int i = 0; i < fieldData.getVectors().size(); i++) {
+ float[] v = fieldData.getVectors().get(i);
+ float[] corrections = scalarQuantizer.quantizeForIndex(v, vector, clusterCenter);
+ binarizedVectorData.writeBytes(vector, vector.length);
+ for (float correction : corrections) {
+ correctionsBuffer.putFloat(correction);
+ }
+ binarizedVectorData.writeBytes(correctionsBuffer.array(), correctionsBuffer.array().length);
+ correctionsBuffer.rewind();
+ }
+ }
+
+ private void writeSortingField(
+ FieldWriter fieldData,
+ float[] clusterCenter,
+ int maxDoc,
+ Sorter.DocMap sortMap,
+ BinaryQuantizer scalarQuantizer)
+ throws IOException {
+ final int[] ordMap =
+ new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord
+
+ DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
+ mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField);
+
+ // write vector values
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer);
+ long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+
+ float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter);
+ writeMeta(
+ fieldData.fieldInfo,
+ maxDoc,
+ vectorDataOffset,
+ quantizedVectorLength,
+ clusterCenter,
+ centroidDp,
+ newDocsWithField);
+ }
+
+ private void writeSortedBinarizedVectors(
+ FieldWriter fieldData, float[] clusterCenter, int[] ordMap, BinaryQuantizer scalarQuantizer)
+ throws IOException {
+ byte[] vector =
+ new byte[BQSpaceUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64) / 8];
+ int correctionsCount = scalarQuantizer.getSimilarity() != EUCLIDEAN ? 3 : 2;
+ final ByteBuffer correctionsBuffer =
+ ByteBuffer.allocate(Float.BYTES * correctionsCount).order(ByteOrder.LITTLE_ENDIAN);
+ for (int ordinal : ordMap) {
+ float[] v = fieldData.getVectors().get(ordinal);
+ float[] corrections = scalarQuantizer.quantizeForIndex(v, vector, clusterCenter);
+ binarizedVectorData.writeBytes(vector, vector.length);
+ for (float correction : corrections) {
+ correctionsBuffer.putFloat(correction);
+ }
+ binarizedVectorData.writeBytes(correctionsBuffer.array(), correctionsBuffer.array().length);
+ correctionsBuffer.rewind();
+ }
+ }
+
+ private void writeMeta(
+ FieldInfo field,
+ int maxDoc,
+ long vectorDataOffset,
+ long vectorDataLength,
+ float[] clusterCenter,
+ float centroidDp,
+ DocsWithFieldSet docsWithField)
+ throws IOException {
+ meta.writeInt(field.number);
+ meta.writeInt(field.getVectorEncoding().ordinal());
+ meta.writeInt(field.getVectorSimilarityFunction().ordinal());
+ meta.writeVInt(field.getVectorDimension());
+ meta.writeVLong(vectorDataOffset);
+ meta.writeVLong(vectorDataLength);
+ int count = docsWithField.cardinality();
+ meta.writeVInt(count);
+ if (count > 0) {
+ final ByteBuffer buffer =
+ ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES)
+ .order(ByteOrder.LITTLE_ENDIAN);
+ buffer.asFloatBuffer().put(clusterCenter);
+ meta.writeBytes(buffer.array(), buffer.array().length);
+ meta.writeInt(Float.floatToIntBits(centroidDp));
+ }
+ OrdToDocDISIReaderConfiguration.writeStoredMeta(
+ DIRECT_MONOTONIC_BLOCK_SHIFT, meta, binarizedVectorData, count, maxDoc, docsWithField);
+ }
+
+ @Override
+ public void finish() throws IOException {
+ if (finished) {
+ throw new IllegalStateException("already finished");
+ }
+ finished = true;
+ rawVectorDelegate.finish();
+ if (meta != null) {
+ // write end of fields marker
+ meta.writeInt(-1);
+ CodecUtil.writeFooter(meta);
+ }
+ if (binarizedVectorData != null) {
+ CodecUtil.writeFooter(binarizedVectorData);
+ }
+ }
+
+ @Override
+ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ final float[] centroid;
+ final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()];
+ int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid);
+ // Don't need access to the random vectors, we can just use the merged
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ centroid = mergedCentroid;
+ if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(
+ BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ int descritizedDimension = BQSpaceUtils.discretize(fieldInfo.getVectorDimension(), 64);
+ FloatVectorValues floatVectorValues =
+ KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
+ }
+ BinarizedFloatVectorValues binarizedVectorValues =
+ new BinarizedFloatVectorValues(
+ floatVectorValues,
+ new BinaryQuantizer(
+ fieldInfo.getVectorDimension(),
+ descritizedDimension,
+ fieldInfo.getVectorSimilarityFunction()),
+ centroid);
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ DocsWithFieldSet docsWithField =
+ writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues);
+ long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+ float centroidDp =
+ docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
+ writeMeta(
+ fieldInfo,
+ segmentWriteState.segmentInfo.maxDoc(),
+ vectorDataOffset,
+ vectorDataLength,
+ centroid,
+ centroidDp,
+ docsWithField);
+ } else {
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ }
+ }
+
+ static DocsWithFieldSet writeBinarizedVectorAndQueryData(
+ IndexOutput binarizedVectorData,
+ IndexOutput binarizedQueryData,
+ FloatVectorValues floatVectorValues,
+ float[] centroid,
+ BinaryQuantizer binaryQuantizer)
+ throws IOException {
+ DocsWithFieldSet docsWithField = new DocsWithFieldSet();
+ byte[] toIndex = new byte[BQSpaceUtils.discretize(floatVectorValues.dimension(), 64) / 8];
+ byte[] toQuery =
+ new byte
+ [(BQSpaceUtils.discretize(floatVectorValues.dimension(), 64) / 8)
+ * BQSpaceUtils.B_QUERY];
+ int queryCorrectionCount = binaryQuantizer.getSimilarity() != EUCLIDEAN ? 4 : 3;
+ final ByteBuffer queryCorrectionsBuffer =
+ ByteBuffer.allocate(Float.BYTES * queryCorrectionCount + Short.BYTES)
+ .order(ByteOrder.LITTLE_ENDIAN);
+ KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
+ for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
+ // write index vector
+ BinaryQuantizer.QueryAndIndexResults r =
+ binaryQuantizer.quantizeQueryAndIndex(
+ floatVectorValues.vectorValue(iterator.index()), toIndex, toQuery, centroid);
+ binarizedVectorData.writeBytes(toIndex, toIndex.length);
+ float[] corrections = r.indexFeatures();
+ for (float correction : corrections) {
+ binarizedVectorData.writeInt(Float.floatToIntBits(correction));
+ }
+ docsWithField.add(docV);
+
+ // write query vector
+ binarizedQueryData.writeBytes(toQuery, toQuery.length);
+ assert r.queryFeatures().length == queryCorrectionCount + 1;
+ float[] queryCorrections = r.queryFeatures();
+ for (int i = 0; i < queryCorrectionCount; i++) {
+ queryCorrectionsBuffer.putFloat(queryCorrections[i]);
+ }
+ assert queryCorrections[queryCorrectionCount] >= 0
+ && queryCorrections[queryCorrectionCount] <= 0xffff;
+ queryCorrectionsBuffer.putShort((short) queryCorrections[queryCorrectionCount]);
+
+ binarizedQueryData.writeBytes(
+ queryCorrectionsBuffer.array(), queryCorrectionsBuffer.array().length);
+ queryCorrectionsBuffer.rewind();
+ }
+ return docsWithField;
+ }
+
+ static DocsWithFieldSet writeBinarizedVectorData(
+ IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues) throws IOException {
+ DocsWithFieldSet docsWithField = new DocsWithFieldSet();
+ KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator();
+ for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
+ // write vector
+ byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index());
+ output.writeBytes(binaryValue, binaryValue.length);
+ float[] corrections = binarizedByteVectorValues.getCorrectiveTerms(iterator.index());
+ for (float correction : corrections) {
+ output.writeInt(Float.floatToIntBits(correction));
+ }
+ docsWithField.add(docV);
+ }
+ return docsWithField;
+ }
+
+ @Override
+ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
+ FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ final float[] centroid;
+ final float cDotC;
+ final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()];
+ int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid);
+
+ // Don't need access to the random vectors, we can just use the merged
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ centroid = mergedCentroid;
+ cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
+ if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(
+ BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC);
+ }
+ return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState);
+ }
+
+ private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
+ SegmentWriteState segmentWriteState,
+ FieldInfo fieldInfo,
+ MergeState mergeState,
+ float[] centroid,
+ float cDotC)
+ throws IOException {
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ final IndexOutput tempQuantizedVectorData =
+ segmentWriteState.directory.createTempOutput(
+ binarizedVectorData.getName(), "temp", segmentWriteState.context);
+ final IndexOutput tempScoreQuantizedVectorData =
+ segmentWriteState.directory.createTempOutput(
+ binarizedVectorData.getName(), "score_temp", segmentWriteState.context);
+ IndexInput binarizedDataInput = null;
+ IndexInput binarizedScoreDataInput = null;
+ boolean success = false;
+ int descritizedDimension = BQSpaceUtils.discretize(fieldInfo.getVectorDimension(), 64);
+ BinaryQuantizer quantizer =
+ new BinaryQuantizer(
+ fieldInfo.getVectorDimension(),
+ descritizedDimension,
+ fieldInfo.getVectorSimilarityFunction());
+ try {
+ FloatVectorValues floatVectorValues =
+ KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
+ }
+ DocsWithFieldSet docsWithField =
+ writeBinarizedVectorAndQueryData(
+ tempQuantizedVectorData,
+ tempScoreQuantizedVectorData,
+ floatVectorValues,
+ centroid,
+ quantizer);
+ CodecUtil.writeFooter(tempQuantizedVectorData);
+ IOUtils.close(tempQuantizedVectorData);
+ binarizedDataInput =
+ segmentWriteState.directory.openInput(
+ tempQuantizedVectorData.getName(), segmentWriteState.context);
+ binarizedVectorData.copyBytes(
+ binarizedDataInput, binarizedDataInput.length() - CodecUtil.footerLength());
+ long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+ CodecUtil.retrieveChecksum(binarizedDataInput);
+ CodecUtil.writeFooter(tempScoreQuantizedVectorData);
+ IOUtils.close(tempScoreQuantizedVectorData);
+ binarizedScoreDataInput =
+ segmentWriteState.directory.openInput(
+ tempScoreQuantizedVectorData.getName(), segmentWriteState.context);
+ writeMeta(
+ fieldInfo,
+ segmentWriteState.segmentInfo.maxDoc(),
+ vectorDataOffset,
+ vectorDataLength,
+ centroid,
+ cDotC,
+ docsWithField);
+ success = true;
+ final IndexInput finalBinarizedDataInput = binarizedDataInput;
+ final IndexInput finalBinarizedScoreDataInput = binarizedScoreDataInput;
+ OffHeapBinarizedVectorValues vectorValues =
+ new OffHeapBinarizedVectorValues.DenseOffHeapVectorValues(
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality(),
+ centroid,
+ cDotC,
+ quantizer,
+ fieldInfo.getVectorSimilarityFunction(),
+ vectorsScorer,
+ finalBinarizedDataInput);
+ RandomVectorScorerSupplier scorerSupplier =
+ vectorsScorer.getRandomVectorScorerSupplier(
+ fieldInfo.getVectorSimilarityFunction(),
+ new OffHeapBinarizedQueryVectorValues(
+ finalBinarizedScoreDataInput,
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality(),
+ fieldInfo.getVectorSimilarityFunction()),
+ vectorValues);
+ return new BinarizedCloseableRandomVectorScorerSupplier(
+ scorerSupplier,
+ vectorValues,
+ () -> {
+ IOUtils.close(finalBinarizedDataInput, finalBinarizedScoreDataInput);
+ IOUtils.deleteFilesIgnoringExceptions(
+ segmentWriteState.directory,
+ tempQuantizedVectorData.getName(),
+ tempScoreQuantizedVectorData.getName());
+ });
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(
+ tempQuantizedVectorData,
+ tempScoreQuantizedVectorData,
+ binarizedDataInput,
+ binarizedScoreDataInput);
+ IOUtils.deleteFilesIgnoringExceptions(
+ segmentWriteState.directory,
+ tempQuantizedVectorData.getName(),
+ tempScoreQuantizedVectorData.getName());
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(meta, binarizedVectorData, rawVectorDelegate);
+ }
+
+ static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) {
+ if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
+ vectorsReader = candidateReader.getFieldReader(fieldName);
+ }
+ if (vectorsReader instanceof Lucene101BinaryQuantizedVectorsReader reader) {
+ return reader.getCentroid(fieldName);
+ }
+ return null;
+ }
+
+ static int mergeAndRecalculateCentroids(
+ MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException {
+ boolean recalculate = false;
+ int totalVectorCount = 0;
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+ if (knnVectorsReader == null
+ || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) {
+ continue;
+ }
+ float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name);
+ int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size();
+ if (vectorCount == 0) {
+ continue;
+ }
+ totalVectorCount += vectorCount;
+ // If there aren't centroids, or previously clustered with more than one cluster
+ // or if there are deleted docs, we must recalculate the centroid
+ if (centroid == null || mergeState.liveDocs[i] != null) {
+ recalculate = true;
+ break;
+ }
+ for (int j = 0; j < centroid.length; j++) {
+ mergedCentroid[j] += centroid[j] * vectorCount;
+ }
+ }
+ if (recalculate) {
+ return calculateCentroid(mergeState, fieldInfo, mergedCentroid);
+ } else {
+ for (int j = 0; j < mergedCentroid.length; j++) {
+ mergedCentroid[j] = mergedCentroid[j] / totalVectorCount;
+ }
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ VectorUtil.l2normalize(mergedCentroid);
+ }
+ return totalVectorCount;
+ }
+ }
+
+ static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid)
+ throws IOException {
+ assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32);
+ // clear out the centroid
+ Arrays.fill(centroid, 0);
+ int count = 0;
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+ if (knnVectorsReader == null) continue;
+ FloatVectorValues vectorValues =
+ mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name);
+ if (vectorValues == null) {
+ continue;
+ }
+ KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
+ for (int doc = iterator.nextDoc();
+ doc != DocIdSetIterator.NO_MORE_DOCS;
+ doc = iterator.nextDoc()) {
+ float[] vector = vectorValues.vectorValue(iterator.index());
+ // TODO Panama sum
+ for (int j = 0; j < vector.length; j++) {
+ centroid[j] += vector[j];
+ }
+ }
+ count += vectorValues.size();
+ }
+ if (count == 0) {
+ return count;
+ }
+ // TODO Panama div
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] /= count;
+ }
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ VectorUtil.l2normalize(centroid);
+ }
+ return count;
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long total = SHALLOW_RAM_BYTES_USED;
+ for (FieldWriter field : fields) {
+ // the field tracks the delegate field usage
+ total += field.ramBytesUsed();
+ }
+ return total;
+ }
+
+ static class FieldWriter extends FlatFieldVectorsWriter {
+ private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
+ private final FieldInfo fieldInfo;
+ private boolean finished;
+ private final FlatFieldVectorsWriter flatFieldVectorsWriter;
+ private final float[] dimensionSums;
+ private final FloatArrayList magnitudes = new FloatArrayList();
+
+ FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) {
+ this.fieldInfo = fieldInfo;
+ this.flatFieldVectorsWriter = flatFieldVectorsWriter;
+ this.dimensionSums = new float[fieldInfo.getVectorDimension()];
+ }
+
+ @Override
+ public List getVectors() {
+ return flatFieldVectorsWriter.getVectors();
+ }
+
+ public void normalizeVectors() {
+ for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) {
+ float[] vector = flatFieldVectorsWriter.getVectors().get(i);
+ float magnitude = magnitudes.get(i);
+ for (int j = 0; j < vector.length; j++) {
+ vector[j] /= magnitude;
+ }
+ }
+ }
+
+ @Override
+ public DocsWithFieldSet getDocsWithFieldSet() {
+ return flatFieldVectorsWriter.getDocsWithFieldSet();
+ }
+
+ @Override
+ public void finish() throws IOException {
+ if (finished) {
+ return;
+ }
+ assert flatFieldVectorsWriter.isFinished();
+ finished = true;
+ }
+
+ @Override
+ public boolean isFinished() {
+ return finished && flatFieldVectorsWriter.isFinished();
+ }
+
+ @Override
+ public void addValue(int docID, float[] vectorValue) throws IOException {
+ flatFieldVectorsWriter.addValue(docID, vectorValue);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ float dp = VectorUtil.dotProduct(vectorValue, vectorValue);
+ float divisor = (float) Math.sqrt(dp);
+ magnitudes.add(divisor);
+ for (int i = 0; i < vectorValue.length; i++) {
+ dimensionSums[i] += (vectorValue[i] / divisor);
+ }
+ } else {
+ for (int i = 0; i < vectorValue.length; i++) {
+ dimensionSums[i] += vectorValue[i];
+ }
+ }
+ }
+
+ @Override
+ public float[] copyValue(float[] vectorValue) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long size = SHALLOW_SIZE;
+ size += flatFieldVectorsWriter.ramBytesUsed();
+ size += RamUsageEstimator.sizeOf(dimensionSums);
+ size += magnitudes.ramBytesUsed();
+ return size;
+ }
+ }
+
+ // When accessing vectorValue method, targerOrd here means a row ordinal.
+ static class OffHeapBinarizedQueryVectorValues {
+ private final IndexInput slice;
+ private final int dimension;
+ private final int size;
+ protected final byte[] binaryValue;
+ protected final ByteBuffer byteBuffer;
+ private final int byteSize;
+ protected final float[] correctiveValues;
+ private int lastOrd = -1;
+ private final int correctiveValuesSize;
+ private final VectorSimilarityFunction vectorSimilarityFunction;
+
+ OffHeapBinarizedQueryVectorValues(
+ IndexInput data,
+ int dimension,
+ int size,
+ VectorSimilarityFunction vectorSimilarityFunction) {
+ this.slice = data;
+ this.dimension = dimension;
+ this.size = size;
+ this.vectorSimilarityFunction = vectorSimilarityFunction;
+ this.correctiveValuesSize = vectorSimilarityFunction != EUCLIDEAN ? 4 : 3;
+ // 4x the quantized binary dimensions
+ int binaryDimensions = (BQSpaceUtils.discretize(dimension, 64) / 8) * BQSpaceUtils.B_QUERY;
+ this.byteBuffer = ByteBuffer.allocate(binaryDimensions);
+ this.binaryValue = byteBuffer.array();
+ // + 1 for the quantized sum
+ this.correctiveValues = new float[correctiveValuesSize + 1];
+ this.byteSize = binaryDimensions + Float.BYTES * correctiveValuesSize + Short.BYTES;
+ }
+
+ public float[] getCorrectiveTerms(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return correctiveValues;
+ }
+ vectorValue(targetOrd);
+ return correctiveValues;
+ }
+
+ public int size() {
+ return size;
+ }
+
+ public int dimension() {
+ return dimension;
+ }
+
+ public OffHeapBinarizedQueryVectorValues copy() throws IOException {
+ return new OffHeapBinarizedQueryVectorValues(
+ slice.clone(), dimension, size, vectorSimilarityFunction);
+ }
+
+ public IndexInput getSlice() {
+ return slice;
+ }
+
+ public byte[] vectorValue(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return binaryValue;
+ }
+ slice.seek((long) targetOrd * byteSize);
+ slice.readBytes(binaryValue, 0, binaryValue.length);
+ slice.readFloats(correctiveValues, 0, correctiveValuesSize);
+ correctiveValues[correctiveValuesSize] = Short.toUnsignedInt(slice.readShort());
+ lastOrd = targetOrd;
+ return binaryValue;
+ }
+ }
+
+ static class BinarizedFloatVectorValues extends BinarizedByteVectorValues {
+ private float[] corrections;
+ private final byte[] binarized;
+ private final float[] centroid;
+ private final FloatVectorValues values;
+ private final BinaryQuantizer quantizer;
+ private int lastOrd = -1;
+
+ BinarizedFloatVectorValues(
+ FloatVectorValues delegate, BinaryQuantizer quantizer, float[] centroid) {
+ this.values = delegate;
+ this.quantizer = quantizer;
+ this.binarized = new byte[BQSpaceUtils.discretize(delegate.dimension(), 64) / 8];
+ this.centroid = centroid;
+ }
+
+ @Override
+ public float[] getCorrectiveTerms(int ord) {
+ if (ord != lastOrd) {
+ throw new IllegalStateException(
+ "attempt to retrieve corrective terms for different ord "
+ + ord
+ + " than the quantization was done for: "
+ + lastOrd);
+ }
+ return corrections;
+ }
+
+ @Override
+ public byte[] vectorValue(int ord) throws IOException {
+ if (ord != lastOrd) {
+ binarize(ord);
+ lastOrd = ord;
+ }
+ return binarized;
+ }
+
+ @Override
+ public int dimension() {
+ return values.dimension();
+ }
+
+ @Override
+ public BinaryQuantizer getQuantizer() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float[] getCentroid() throws IOException {
+ return centroid;
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public BinarizedByteVectorValues copy() throws IOException {
+ return new BinarizedFloatVectorValues(values.copy(), quantizer, centroid);
+ }
+
+ private void binarize(int ord) throws IOException {
+ corrections = quantizer.quantizeForIndex(values.vectorValue(ord), binarized, centroid);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return values.iterator();
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return values.ordToDoc(ord);
+ }
+ }
+
+ static class BinarizedCloseableRandomVectorScorerSupplier
+ implements CloseableRandomVectorScorerSupplier {
+ private final RandomVectorScorerSupplier supplier;
+ private final KnnVectorValues vectorValues;
+ private final Closeable onClose;
+
+ BinarizedCloseableRandomVectorScorerSupplier(
+ RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) {
+ this.supplier = supplier;
+ this.onClose = onClose;
+ this.vectorValues = vectorValues;
+ }
+
+ @Override
+ public RandomVectorScorer scorer(int ord) throws IOException {
+ return supplier.scorer(ord);
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return supplier.copy();
+ }
+
+ @Override
+ public void close() throws IOException {
+ onClose.close();
+ }
+
+ @Override
+ public int totalVectorCount() {
+ return vectorValues.size();
+ }
+ }
+
+ static final class NormalizedFloatVectorValues extends FloatVectorValues {
+ private final FloatVectorValues values;
+ private final float[] normalizedVector;
+
+ NormalizedFloatVectorValues(FloatVectorValues values) {
+ this.values = values;
+ this.normalizedVector = new float[values.dimension()];
+ }
+
+ @Override
+ public int dimension() {
+ return values.dimension();
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return values.ordToDoc(ord);
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length);
+ VectorUtil.l2normalize(normalizedVector);
+ return normalizedVector;
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return values.iterator();
+ }
+
+ @Override
+ public NormalizedFloatVectorValues copy() throws IOException {
+ return new NormalizedFloatVectorValues(values.copy());
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101HnswBinaryQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..debad93e5bad
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101HnswBinaryQuantizedVectorsFormat.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.codecs.lucene101;
+
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.search.TaskExecutor;
+import org.apache.lucene.util.hnsw.HnswGraph;
+
+/**
+ * Lucene 10.1 vector format, which encodes numeric vector values into an associated graph
+ * connecting the documents having values. The graph is used to power HNSW search. The format
+ * consists of two files, and uses {@link Lucene101BinaryQuantizedVectorsFormat} to store the actual
+ * vectors: For details on graph storage and file extensions, see {@link Lucene99HnswVectorsFormat}.
+ *
+ * @lucene.experimental
+ */
+public class Lucene101HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat {
+
+ public static final String NAME = "Lucene101HnswBinaryQuantizedVectorsFormat";
+
+ /**
+ * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
+ * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
+ */
+ private final int maxConn;
+
+ /**
+ * The number of candidate neighbors to track while searching the graph for each newly inserted
+ * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
+ * for details.
+ */
+ private final int beamWidth;
+
+ /** The format for storing, reading, merging vectors on disk */
+ private final FlatVectorsFormat flatVectorsFormat;
+
+ private final int numMergeWorkers;
+ private final TaskExecutor mergeExec;
+
+ /** Constructs a format using default graph construction parameters */
+ public Lucene101HnswBinaryQuantizedVectorsFormat() {
+ this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
+ }
+
+ /**
+ * Constructs a format using the given graph construction parameters.
+ *
+ * @param maxConn the maximum number of connections to a node in the HNSW graph
+ * @param beamWidth the size of the queue maintained during graph construction.
+ */
+ public Lucene101HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) {
+ this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
+ }
+
+ /**
+ * Constructs a format using the given graph construction parameters and scalar quantization.
+ *
+ * @param maxConn the maximum number of connections to a node in the HNSW graph
+ * @param beamWidth the size of the queue maintained during graph construction.
+ * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
+ * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
+ * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
+ * generated by this format to do the merge
+ */
+ public Lucene101HnswBinaryQuantizedVectorsFormat(
+ int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
+ super(NAME);
+ if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
+ throw new IllegalArgumentException(
+ "maxConn must be positive and less than or equal to "
+ + MAXIMUM_MAX_CONN
+ + "; maxConn="
+ + maxConn);
+ }
+ if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
+ throw new IllegalArgumentException(
+ "beamWidth must be positive and less than or equal to "
+ + MAXIMUM_BEAM_WIDTH
+ + "; beamWidth="
+ + beamWidth);
+ }
+ this.maxConn = maxConn;
+ this.beamWidth = beamWidth;
+ if (numMergeWorkers == 1 && mergeExec != null) {
+ throw new IllegalArgumentException(
+ "No executor service is needed as we'll use single thread to merge");
+ }
+ this.numMergeWorkers = numMergeWorkers;
+ if (mergeExec != null) {
+ this.mergeExec = new TaskExecutor(mergeExec);
+ } else {
+ this.mergeExec = null;
+ }
+ this.flatVectorsFormat = new Lucene101BinaryQuantizedVectorsFormat();
+ }
+
+ @Override
+ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new Lucene99HnswVectorsWriter(
+ state,
+ maxConn,
+ beamWidth,
+ flatVectorsFormat.fieldsWriter(state),
+ numMergeWorkers,
+ mergeExec);
+ }
+
+ @Override
+ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
+ return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
+ }
+
+ @Override
+ public int getMaxDimensions(String fieldName) {
+ return 1024;
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene101HnswBinaryQuantizedVectorsFormat(name=Lucene101HnswBinaryQuantizedVectorsFormat, maxConn="
+ + maxConn
+ + ", beamWidth="
+ + beamWidth
+ + ", flatVectorFormat="
+ + flatVectorsFormat
+ + ")";
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/OffHeapBinarizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/OffHeapBinarizedVectorValues.java
new file mode 100644
index 000000000000..411136fbf959
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/OffHeapBinarizedVectorValues.java
@@ -0,0 +1,391 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene101;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.util.quantization.BQSpaceUtils.constSqrt;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.codecs.lucene90.IndexedDISI;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.packed.DirectMonotonicReader;
+import org.apache.lucene.util.quantization.BQSpaceUtils;
+import org.apache.lucene.util.quantization.BinaryQuantizer;
+
+/** Binarized vector values loaded from off-heap */
+public abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues {
+
+ protected final int dimension;
+ protected final int size;
+ protected final int numBytes;
+ protected final VectorSimilarityFunction similarityFunction;
+ protected final FlatVectorsScorer vectorsScorer;
+
+ protected final IndexInput slice;
+ protected final byte[] binaryValue;
+ protected final ByteBuffer byteBuffer;
+ protected final int byteSize;
+ private int lastOrd = -1;
+ protected final float[] correctiveValues;
+ protected final BinaryQuantizer binaryQuantizer;
+ protected final float[] centroid;
+ protected final float centroidDp;
+ private final int discretizedDimensions;
+ private final float maxX1;
+ private final float sqrtDimensions;
+ private final int correctionsCount;
+
+ OffHeapBinarizedVectorValues(
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ BinaryQuantizer quantizer,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice) {
+ this.dimension = dimension;
+ this.size = size;
+ this.similarityFunction = similarityFunction;
+ this.vectorsScorer = vectorsScorer;
+ this.slice = slice;
+ this.centroid = centroid;
+ this.centroidDp = centroidDp;
+ this.numBytes = BQSpaceUtils.discretize(dimension, 64) / 8;
+ this.correctionsCount = similarityFunction != EUCLIDEAN ? 3 : 2;
+ this.correctiveValues = new float[this.correctionsCount];
+ this.byteSize = numBytes + (Float.BYTES * correctionsCount);
+ this.byteBuffer = ByteBuffer.allocate(numBytes);
+ this.binaryValue = byteBuffer.array();
+ this.binaryQuantizer = quantizer;
+ this.discretizedDimensions = BQSpaceUtils.discretize(dimension, 64);
+ this.sqrtDimensions = (float) constSqrt(dimension);
+ this.maxX1 = (float) (1.9 / constSqrt(discretizedDimensions - 1.0));
+ }
+
+ @Override
+ public int dimension() {
+ return dimension;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public byte[] vectorValue(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return binaryValue;
+ }
+ slice.seek((long) targetOrd * byteSize);
+ slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes);
+ slice.readFloats(correctiveValues, 0, correctionsCount);
+ lastOrd = targetOrd;
+ return binaryValue;
+ }
+
+ @Override
+ public int discretizedDimensions() {
+ return discretizedDimensions;
+ }
+
+ @Override
+ public float sqrtDimensions() {
+ return sqrtDimensions;
+ }
+
+ @Override
+ public float maxX1() {
+ return maxX1;
+ }
+
+ @Override
+ public float getCentroidDP() {
+ return centroidDp;
+ }
+
+ @Override
+ public float[] getCorrectiveTerms(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return correctiveValues;
+ }
+ slice.seek(((long) targetOrd * byteSize) + numBytes);
+ slice.readFloats(correctiveValues, 0, correctionsCount);
+ return correctiveValues;
+ }
+
+ @Override
+ public BinaryQuantizer getQuantizer() {
+ return binaryQuantizer;
+ }
+
+ @Override
+ public float[] getCentroid() {
+ return centroid;
+ }
+
+ @Override
+ public int getVectorByteLength() {
+ return numBytes;
+ }
+
+ public static OffHeapBinarizedVectorValues load(
+ OrdToDocDISIReaderConfiguration configuration,
+ int dimension,
+ int size,
+ BinaryQuantizer binaryQuantizer,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ float[] centroid,
+ float centroidDp,
+ long quantizedVectorDataOffset,
+ long quantizedVectorDataLength,
+ IndexInput vectorData)
+ throws IOException {
+ if (configuration.isEmpty()) {
+ return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer);
+ }
+ assert centroid != null;
+ IndexInput bytesSlice =
+ vectorData.slice(
+ "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
+ if (configuration.isDense()) {
+ return new DenseOffHeapVectorValues(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ similarityFunction,
+ vectorsScorer,
+ bytesSlice);
+ } else {
+ return new SparseOffHeapVectorValues(
+ configuration,
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ vectorData,
+ similarityFunction,
+ vectorsScorer,
+ bytesSlice);
+ }
+ }
+
+ /** Dense off-heap binarized vector values */
+ public static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues {
+ public DenseOffHeapVectorValues(
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ BinaryQuantizer binaryQuantizer,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice) {
+ super(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ similarityFunction,
+ vectorsScorer,
+ slice);
+ }
+
+ @Override
+ public DenseOffHeapVectorValues copy() throws IOException {
+ return new DenseOffHeapVectorValues(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ similarityFunction,
+ vectorsScorer,
+ slice.clone());
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return acceptDocs;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ DenseOffHeapVectorValues copy = copy();
+ DocIndexIterator iterator = copy.iterator();
+ RandomVectorScorer scorer =
+ vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
+ return new VectorScorer() {
+ @Override
+ public float score() throws IOException {
+ return scorer.score(iterator.index());
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return iterator;
+ }
+ };
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return createDenseIterator();
+ }
+ }
+
+ /** Sparse off-heap binarized vector values */
+ private static class SparseOffHeapVectorValues extends OffHeapBinarizedVectorValues {
+ private final DirectMonotonicReader ordToDoc;
+ private final IndexedDISI disi;
+ // dataIn was used to init a new IndexedDIS for #randomAccess()
+ private final IndexInput dataIn;
+ private final OrdToDocDISIReaderConfiguration configuration;
+
+ SparseOffHeapVectorValues(
+ OrdToDocDISIReaderConfiguration configuration,
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ BinaryQuantizer binaryQuantizer,
+ IndexInput dataIn,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice)
+ throws IOException {
+ super(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ similarityFunction,
+ vectorsScorer,
+ slice);
+ this.configuration = configuration;
+ this.dataIn = dataIn;
+ this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
+ this.disi = configuration.getIndexedDISI(dataIn);
+ }
+
+ @Override
+ public SparseOffHeapVectorValues copy() throws IOException {
+ return new SparseOffHeapVectorValues(
+ configuration,
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ dataIn,
+ similarityFunction,
+ vectorsScorer,
+ slice.clone());
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return (int) ordToDoc.get(ord);
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ if (acceptDocs == null) {
+ return null;
+ }
+ return new Bits() {
+ @Override
+ public boolean get(int index) {
+ return acceptDocs.get(ordToDoc(index));
+ }
+
+ @Override
+ public int length() {
+ return size;
+ }
+ };
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return IndexedDISI.asDocIndexIterator(disi);
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ SparseOffHeapVectorValues copy = copy();
+ DocIndexIterator iterator = copy.iterator();
+ RandomVectorScorer scorer =
+ vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
+ return new VectorScorer() {
+ @Override
+ public float score() throws IOException {
+ return scorer.score(iterator.index());
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return iterator;
+ }
+ };
+ }
+ }
+
+ private static class EmptyOffHeapVectorValues extends OffHeapBinarizedVectorValues {
+ EmptyOffHeapVectorValues(
+ int dimension,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer) {
+ super(dimension, 0, null, Float.NaN, null, similarityFunction, vectorsScorer, null);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return createDenseIterator();
+ }
+
+ @Override
+ public DenseOffHeapVectorValues copy() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return null;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) {
+ return null;
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java
index e582f12c3185..f2bec47e18e5 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/package-info.java
@@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
/**
* Lucene 10.1 file format.
*
diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java
index 184403cf48b7..585b8dbb47aa 100644
--- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java
+++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java
@@ -17,6 +17,9 @@
package org.apache.lucene.internal.vectorization;
+import static org.apache.lucene.util.VectorUtil.B_QUERY;
+
+import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SuppressForbidden;
@@ -198,6 +201,31 @@ public int squareDistance(byte[] a, byte[] b) {
return squareSum;
}
+ @Override
+ public long ipByteBinByte(byte[] q, byte[] d) {
+ return ipByteBinByteImpl(q, d);
+ }
+
+ public static long ipByteBinByteImpl(byte[] q, byte[] d) {
+ long ret = 0;
+ int size = d.length;
+ for (int i = 0; i < B_QUERY; i++) {
+ int r = 0;
+ long subRet = 0;
+ for (final int upperBound = d.length & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
+ subRet +=
+ Integer.bitCount(
+ (int) BitUtil.VH_NATIVE_INT.get(q, i * size + r)
+ & (int) BitUtil.VH_NATIVE_INT.get(d, r));
+ }
+ for (; r < d.length; r++) {
+ subRet += Integer.bitCount((q[i * size + r] & d[r]) & 0xFF);
+ }
+ ret += subRet << i;
+ }
+ return ret;
+ }
+
@Override
public int findNextGEQ(int[] buffer, int target, int from, int to) {
for (int i = from; i < to; ++i) {
diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java
index fb94b0e31736..613c1d8c5be4 100644
--- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java
+++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java
@@ -45,6 +45,9 @@ public interface VectorUtilSupport {
/** Returns the sum of squared differences of the two byte vectors. */
int squareDistance(byte[] a, byte[] b);
+ /** This does a bit-wise dot-product between two particularly formatted byte arrays. */
+ long ipByteBinByte(byte[] q, byte[] d);
+
/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code to}
* exclusive, find the first array index whose value is greater than or equal to {@code target}.
diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
index 250c65448703..e08d3a9f0317 100644
--- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
@@ -124,6 +124,16 @@ public static float[] l2normalize(float[] v) {
return v;
}
+ /**
+ * Return the l2Norm of the vector.
+ *
+ * @param v the vector
+ * @return the l2Norm of the vector
+ */
+ public static float l2Norm(float[] v) {
+ return (float) Math.sqrt(IMPL.dotProduct(v, v));
+ }
+
public static boolean isUnitVector(float[] v) {
double l1norm = IMPL.dotProduct(v, v);
return Math.abs(l1norm - 1.0d) <= EPSILON;
@@ -169,6 +179,30 @@ public static void add(float[] u, float[] v) {
}
}
+ /**
+ * Subtracts the second argument from the first
+ *
+ * @param u the destination
+ * @param v the vector to subtract from the destination
+ */
+ public static void subtract(float[] u, float[] v) {
+ for (int i = 0; i < u.length; i++) {
+ u[i] -= v[i];
+ }
+ }
+
+ /**
+ * Divides the first argument by the second
+ *
+ * @param u the destination
+ * @param v to divide the destination by
+ */
+ public static void divide(float[] u, float v) {
+ for (int i = 0; i < u.length; i++) {
+ u[i] /= v;
+ }
+ }
+
/**
* Dot product computed over signed bytes.
*
@@ -269,6 +303,44 @@ static int xorBitCountLong(byte[] a, byte[] b) {
return distance;
}
+ /**
+ * The popCount for the given byte array.
+ *
+ * @param v the byte array
+ * @return the number of set bits in the byte array
+ */
+ public static int popCount(byte[] v) {
+ if (XOR_BIT_COUNT_STRIDE_AS_INT) {
+ return popCountInt(v);
+ } else {
+ return popCountLong(v);
+ }
+ }
+
+ static int popCountInt(byte[] d) {
+ int r = 0;
+ int cnt = 0;
+ for (final int upperBound = d.length & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
+ cnt += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(d, r));
+ }
+ for (; r < d.length; r++) {
+ cnt += Integer.bitCount(d[r] & 0xFF);
+ }
+ return cnt;
+ }
+
+ static int popCountLong(byte[] d) {
+ int r = 0;
+ int cnt = 0;
+ for (final int upperBound = d.length & -Long.BYTES; r < upperBound; r += Long.BYTES) {
+ cnt += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(d, r));
+ }
+ for (; r < d.length; r++) {
+ cnt += Integer.bitCount(d[r] & 0xFF);
+ }
+ return cnt;
+ }
+
/**
* Dot product score computed over signed bytes, scaled to be in [0, 1].
*
@@ -309,6 +381,28 @@ public static float[] checkFinite(float[] v) {
return v;
}
+ public static final short B_QUERY = 4;
+
+ /**
+ * This does a dot-product between two particularly formatted byte arrays. It is assumed that q is
+ * 4 times the size of d and bits for each individual dimension are packed by their order. An
+ * example encoding for q would be for values 0, 12, 7, 5 which have the binary values of 0000,
+ * 1100, 0111, 0101 the bits would actually be packed as 0011, 0010, 0111, 0100 or the values 3,
+ * 2, 7, 4.
+ *
+ * @param q an int4 encoded byte array, but where the lower level bits are collected are first
+ * with higher order bits following later
+ * @param d a bit encoded byte array
+ * @return the dot product
+ */
+ public static long ipByteBinByte(byte[] q, byte[] d) {
+ if (q.length != d.length * B_QUERY) {
+ throw new IllegalArgumentException(
+ "vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
+ }
+ return IMPL.ipByteBinByte(q, d);
+ }
+
/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code to}
* exclusive, find the first array index whose value is greater than or equal to {@code target}.
diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/BQSpaceUtils.java b/lucene/core/src/java/org/apache/lucene/util/quantization/BQSpaceUtils.java
new file mode 100644
index 000000000000..4f89801684e3
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/quantization/BQSpaceUtils.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util.quantization;
+
+import org.apache.lucene.util.ArrayUtil;
+
+/** Utility class for quantization calculations */
+public class BQSpaceUtils {
+
+ public static final short B_QUERY = 4;
+ // the first four bits masked
+ private static final int B_QUERY_MASK = 15;
+
+ public static double sqrtNewtonRaphson(double x, double curr, double prev) {
+ return (curr == prev) ? curr : sqrtNewtonRaphson(x, 0.5 * (curr + x / curr), curr);
+ }
+
+ public static double constSqrt(double x) {
+ return x >= 0 && Double.isInfinite(x) == false ? sqrtNewtonRaphson(x, x, 0) : Double.NaN;
+ }
+
+ public static int discretize(int value, int bucket) {
+ return ((value + (bucket - 1)) / bucket) * bucket;
+ }
+
+ public static float[] pad(float[] vector, int dimensions) {
+ if (vector.length >= dimensions) {
+ return vector;
+ }
+ return ArrayUtil.growExact(vector, dimensions);
+ }
+
+ public static byte[] pad(byte[] vector, int dimensions) {
+ if (vector.length >= dimensions) {
+ return vector;
+ }
+ return ArrayUtil.growExact(vector, dimensions);
+ }
+
+ /**
+ * Transpose the query vector into a byte array allowing for efficient bitwise operations with the
+ * index bit vectors. The idea here is to organize the query vector bits such that the first bit
+ * of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second,
+ * third, and fourth bits are in the second, third, and fourth set of dimensions bits,
+ * respectively. This allows for direct bitwise comparisons with the stored index vectors through
+ * summing the bitwise results with the relative required bit shifts.
+ *
+ * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
+ * @param dimensions the number of dimensions in the query vector
+ * @param quantQueryByte the byte array to store the transposed query vector
+ */
+ public static void transposeBin(byte[] q, int dimensions, byte[] quantQueryByte) {
+ int qOffset = 0;
+ final byte[] v1 = new byte[4];
+ final byte[] v = new byte[32];
+ for (int i = 0; i < dimensions; i += 32) {
+ // for every four bytes we shift left (with remainder across those bytes)
+ for (int j = 0; j < v.length; j += 4) {
+ v[j] = (byte) (q[qOffset + j] << B_QUERY | ((q[qOffset + j] >>> B_QUERY) & B_QUERY_MASK));
+ v[j + 1] =
+ (byte)
+ (q[qOffset + j + 1] << B_QUERY | ((q[qOffset + j + 1] >>> B_QUERY) & B_QUERY_MASK));
+ v[j + 2] =
+ (byte)
+ (q[qOffset + j + 2] << B_QUERY | ((q[qOffset + j + 2] >>> B_QUERY) & B_QUERY_MASK));
+ v[j + 3] =
+ (byte)
+ (q[qOffset + j + 3] << B_QUERY | ((q[qOffset + j + 3] >>> B_QUERY) & B_QUERY_MASK));
+ }
+ for (int j = 0; j < B_QUERY; j++) {
+ moveMaskEpi8Byte(v, v1);
+ for (int k = 0; k < 4; k++) {
+ quantQueryByte[(B_QUERY - j - 1) * (dimensions / 8) + i / 8 + k] = v1[k];
+ v1[k] = 0;
+ }
+ for (int k = 0; k < v.length; k += 4) {
+ v[k] = (byte) (v[k] + v[k]);
+ v[k + 1] = (byte) (v[k + 1] + v[k + 1]);
+ v[k + 2] = (byte) (v[k + 2] + v[k + 2]);
+ v[k + 3] = (byte) (v[k + 3] + v[k + 3]);
+ }
+ }
+ qOffset += 32;
+ }
+ }
+
+ private static void moveMaskEpi8Byte(byte[] v, byte[] v1b) {
+ int m = 0;
+ for (int k = 0; k < v.length; k++) {
+ if ((v[k] & 0b10000000) == 0b10000000) {
+ v1b[m] |= 0b00000001;
+ }
+ if (k % 8 == 7) {
+ m++;
+ } else {
+ v1b[m] <<= 1;
+ }
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/BinaryQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/BinaryQuantizer.java
new file mode 100644
index 000000000000..35856104c266
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/quantization/BinaryQuantizer.java
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util.quantization;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.VectorUtil;
+
+/**
+ * Quantized that quantizes raw vector values to binary. The algorithm is based on the paper RaBitQ.
+ */
+public class BinaryQuantizer {
+ private final int discretizedDimensions;
+
+ private final VectorSimilarityFunction similarityFunction;
+ private final float sqrtDimensions;
+
+ public BinaryQuantizer(
+ int dimensions, int discretizedDimensions, VectorSimilarityFunction similarityFunction) {
+ if (dimensions <= 0) {
+ throw new IllegalArgumentException("dimensions must be > 0 but was: " + dimensions);
+ }
+ assert discretizedDimensions % 64 == 0
+ : "discretizedDimensions must be a multiple of 64 but was: " + discretizedDimensions;
+ this.discretizedDimensions = discretizedDimensions;
+ this.similarityFunction = similarityFunction;
+ this.sqrtDimensions = (float) Math.sqrt(dimensions);
+ }
+
+ BinaryQuantizer(int dimensions, VectorSimilarityFunction similarityFunction) {
+ this(dimensions, dimensions, similarityFunction);
+ }
+
+ private static void removeSignAndDivide(float[] a, float divisor) {
+ for (int i = 0; i < a.length; i++) {
+ a[i] = Math.abs(a[i]) / divisor;
+ }
+ }
+
+ private static float sumAndNormalize(float[] a, float norm) {
+ float aDivided = 0f;
+
+ for (int i = 0; i < a.length; i++) {
+ aDivided += a[i];
+ }
+
+ aDivided = aDivided / norm;
+ if (!Float.isFinite(aDivided)) {
+ aDivided = 0.8f; // can be anything
+ }
+
+ return aDivided;
+ }
+
+ private static void packAsBinary(float[] vector, byte[] packedVector) {
+ for (int h = 0; h < vector.length; h += 8) {
+ byte result = 0;
+ int q = 0;
+ for (int i = 7; i >= 0; i--) {
+ if (vector[h + i] > 0) {
+ result |= (byte) (1 << q);
+ }
+ q++;
+ }
+ packedVector[h / 8] = result;
+ }
+ }
+
+ public VectorSimilarityFunction getSimilarity() {
+ return this.similarityFunction;
+ }
+
+ private record SubspaceOutput(float projection) {}
+
+ private SubspaceOutput generateSubSpace(
+ float[] vector, float[] centroid, byte[] quantizedVector) {
+ // typically no-op if dimensions/64
+ float[] paddedCentroid = BQSpaceUtils.pad(centroid, discretizedDimensions);
+ float[] paddedVector = BQSpaceUtils.pad(vector, discretizedDimensions);
+
+ VectorUtil.subtract(paddedVector, paddedCentroid);
+
+ // The inner product between the data vector and the quantized data vector
+ float norm = VectorUtil.l2Norm(paddedVector);
+
+ packAsBinary(paddedVector, quantizedVector);
+
+ removeSignAndDivide(paddedVector, sqrtDimensions);
+ float projection = sumAndNormalize(paddedVector, norm);
+
+ return new SubspaceOutput(projection);
+ }
+
+ record SubspaceOutputMIP(float OOQ, float normOC, float oDotC) {}
+
+ private SubspaceOutputMIP generateSubSpaceMIP(
+ float[] vector, float[] centroid, byte[] quantizedVector) {
+
+ // typically no-op if dimensions/64
+ float[] paddedCentroid = BQSpaceUtils.pad(centroid, discretizedDimensions);
+ float[] paddedVector = BQSpaceUtils.pad(vector, discretizedDimensions);
+
+ float oDotC = VectorUtil.dotProduct(paddedVector, paddedCentroid);
+ VectorUtil.subtract(paddedVector, paddedCentroid);
+
+ float normOC = VectorUtil.l2Norm(paddedVector);
+ packAsBinary(paddedVector, quantizedVector);
+ VectorUtil.divide(paddedVector, normOC); // OmC / norm(OmC)
+
+ float OOQ = computerOOQ(vector.length, paddedVector, quantizedVector);
+
+ return new SubspaceOutputMIP(OOQ, normOC, oDotC);
+ }
+
+ private float computerOOQ(int originalLength, float[] normOMinusC, byte[] packedBinaryVector) {
+ float OOQ = 0f;
+ for (int j = 0; j < originalLength / 8; j++) {
+ for (int r = 0; r < 8; r++) {
+ int sign = ((packedBinaryVector[j] >> (7 - r)) & 0b00000001);
+ OOQ += (normOMinusC[j * 8 + r] * (2 * sign - 1));
+ }
+ }
+ OOQ = OOQ / sqrtDimensions;
+ return OOQ;
+ }
+
+ private static float[] range(float[] q) {
+ float vl = 1e20f;
+ float vr = -1e20f;
+ for (int i = 0; i < q.length; i++) {
+ if (q[i] < vl) {
+ vl = q[i];
+ }
+ if (q[i] > vr) {
+ vr = q[i];
+ }
+ }
+
+ return new float[] {vl, vr};
+ }
+
+ /** Results of quantizing a vector for both querying and indexing */
+ public record QueryAndIndexResults(float[] indexFeatures, float[] queryFeatures) {}
+
+ /**
+ * Quantizes the given vector to both single bits and int4 precision. Also containes two distinct
+ * float arrays of corrective factors. For details see the individual methods {@link
+ * #quantizeForIndex(float[], byte[], float[])} and {@link #quantizeForQuery(float[], byte[],
+ * float[])}.
+ *
+ * @param vector the vector to quantize
+ * @param indexDestination the destination byte array to store the quantized bit vector
+ * @param queryDestination the destination byte array to store the quantized int4 vector
+ * @param centroid the centroid to use for quantization
+ * @return the corrective factors used for scoring error correction.
+ */
+ public QueryAndIndexResults quantizeQueryAndIndex(
+ float[] vector, byte[] indexDestination, byte[] queryDestination, float[] centroid) {
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
+ assert this.discretizedDimensions == BQSpaceUtils.discretize(vector.length, 64);
+
+ if (this.discretizedDimensions != indexDestination.length * 8) {
+ throw new IllegalArgumentException(
+ "vector and quantized vector destination must be compatible dimensions: "
+ + BQSpaceUtils.discretize(vector.length, 64)
+ + " [ "
+ + this.discretizedDimensions
+ + " ]"
+ + "!= "
+ + indexDestination.length
+ + " * 8");
+ }
+
+ if (this.discretizedDimensions != (queryDestination.length * 8) / BQSpaceUtils.B_QUERY) {
+ throw new IllegalArgumentException(
+ "vector and quantized vector destination must be compatible dimensions: "
+ + vector.length
+ + " [ "
+ + this.discretizedDimensions
+ + " ]"
+ + "!= ("
+ + queryDestination.length
+ + " * 8) / "
+ + BQSpaceUtils.B_QUERY);
+ }
+
+ if (vector.length != centroid.length) {
+ throw new IllegalArgumentException(
+ "vector and centroid dimensions must be the same: "
+ + vector.length
+ + "!= "
+ + centroid.length);
+ }
+ vector = ArrayUtil.copyArray(vector);
+ // only need distToC for euclidean
+ float distToC =
+ similarityFunction == EUCLIDEAN ? VectorUtil.squareDistance(vector, centroid) : 0f;
+ // only need vdotc for dot-products similarity, but not for euclidean
+ float vDotC = similarityFunction != EUCLIDEAN ? VectorUtil.dotProduct(vector, centroid) : 0f;
+ VectorUtil.subtract(vector, centroid);
+ // both euclidean and dot-product need the norm of the vector, just at different times
+ float normVmC = VectorUtil.l2Norm(vector);
+ // quantize for index
+ packAsBinary(BQSpaceUtils.pad(vector, discretizedDimensions), indexDestination);
+ if (similarityFunction != EUCLIDEAN) {
+ VectorUtil.divide(vector, normVmC);
+ }
+
+ // Quantize for query
+ float[] range = range(vector);
+ float lower = range[0];
+ float upper = range[1];
+ // Δ := (𝑣𝑟 − 𝑣𝑙)/(2𝐵𝑞 − 1)
+ float width = (upper - lower) / ((1 << BQSpaceUtils.B_QUERY) - 1);
+
+ QuantResult quantResult = quantize(vector, lower, width);
+ byte[] byteQuery = quantResult.result();
+
+ // q¯ = Δ · q¯𝑢 + 𝑣𝑙 · 1𝐷
+ // q¯ is an approximation of q′ (scalar quantized approximation)
+ // FIXME: vectors need to be padded but that's expensive; update transponseBin to deal
+ byteQuery = BQSpaceUtils.pad(byteQuery, discretizedDimensions);
+ BQSpaceUtils.transposeBin(byteQuery, discretizedDimensions, queryDestination);
+ final float[] indexCorrections;
+ final float[] queryCorrections;
+ if (similarityFunction == EUCLIDEAN) {
+ indexCorrections = new float[2];
+ indexCorrections[0] = (float) Math.sqrt(distToC);
+ removeSignAndDivide(vector, sqrtDimensions);
+ indexCorrections[1] = sumAndNormalize(vector, normVmC);
+ queryCorrections = new float[] {distToC, lower, width, quantResult.quantizedSum};
+ } else {
+ indexCorrections = new float[3];
+ indexCorrections[0] = computerOOQ(vector.length, vector, indexDestination);
+ indexCorrections[1] = normVmC;
+ indexCorrections[2] = vDotC;
+ queryCorrections = new float[] {lower, width, normVmC, vDotC, quantResult.quantizedSum};
+ }
+ return new QueryAndIndexResults(indexCorrections, queryCorrections);
+ }
+
+ /**
+ * Quantizes the given vector to single bits and returns an array of corrective factors. For the
+ * dot-product family of distances, the corrective terms are, in order
+ *
+ *
+ * - the dot-product of the normalized, centered vector with its binarized self
+ *
- the norm of the centered vector
+ *
- the dot-product of the vector with the centroid
+ *
+ *
+ * For euclidean:
+ *
+ *
+ * - The euclidean distance to the centroid
+ *
- The sum of the dimensions divided by the vector norm
+ *
+ *
+ * @param vector the vector to quantize
+ * @param destination the destination byte array to store the quantized vector
+ * @param centroid the centroid to use for quantization
+ * @return the corrective factors used for scoring error correction.
+ */
+ public float[] quantizeForIndex(float[] vector, byte[] destination, float[] centroid) {
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
+ assert this.discretizedDimensions == BQSpaceUtils.discretize(vector.length, 64);
+
+ if (this.discretizedDimensions != destination.length * 8) {
+ throw new IllegalArgumentException(
+ "vector and quantized vector destination must be compatible dimensions: "
+ + BQSpaceUtils.discretize(vector.length, 64)
+ + " [ "
+ + this.discretizedDimensions
+ + " ]"
+ + "!= "
+ + destination.length
+ + " * 8");
+ }
+
+ if (vector.length != centroid.length) {
+ throw new IllegalArgumentException(
+ "vector and centroid dimensions must be the same: "
+ + vector.length
+ + "!= "
+ + centroid.length);
+ }
+
+ float[] corrections;
+
+ vector = ArrayUtil.copyArray(vector);
+
+ switch (similarityFunction) {
+ case EUCLIDEAN:
+ float distToCentroid = (float) Math.sqrt(VectorUtil.squareDistance(vector, centroid));
+
+ SubspaceOutput subspaceOutput = generateSubSpace(vector, centroid, destination);
+ corrections = new float[2];
+ corrections[0] = distToCentroid;
+ corrections[1] = subspaceOutput.projection();
+ break;
+ case MAXIMUM_INNER_PRODUCT:
+ case COSINE:
+ case DOT_PRODUCT:
+ SubspaceOutputMIP subspaceOutputMIP = generateSubSpaceMIP(vector, centroid, destination);
+ corrections = new float[3];
+ corrections[0] = subspaceOutputMIP.OOQ();
+ corrections[1] = subspaceOutputMIP.normOC();
+ corrections[2] = subspaceOutputMIP.oDotC();
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ "Unsupported similarity function: " + similarityFunction);
+ }
+
+ return corrections;
+ }
+
+ private record QuantResult(byte[] result, int quantizedSum) {}
+
+ private static QuantResult quantize(float[] vector, float lower, float width) {
+ // FIXME: speed up with panama? and/or use existing scalar quantization utils in Lucene?
+ byte[] result = new byte[vector.length];
+ float oneOverWidth = 1.0f / width;
+ int sumQ = 0;
+ for (int i = 0; i < vector.length; i++) {
+ byte res = (byte) ((vector[i] - lower) * oneOverWidth);
+ result[i] = res;
+ sumQ += res;
+ }
+
+ return new QuantResult(result, sumQ);
+ }
+
+ /**
+ * Quantizes the given vector to int4 precision and returns an array of corrective factors.
+ *
+ * Corrective factors are used for scoring error correction. For the dot-product family of
+ *
+ *
+ * - The lower bound for the int4 quantized vector
+ *
- The width for int4 quantized vector
+ *
- The norm of the centroid centered vector
+ *
- The dot-product of the vector with the centroid
+ *
- The sum of the quantized dimensions
+ *
+ *
+ * For euclidean:
+ *
+ *
+ * - The euclidean distance to the centroid
+ *
- The lower bound for the int4 quantized vector
+ *
- The width for int4 quantized vector
+ *
- The sum of the quantized dimensions
+ *
+ *
+ * @param vector the vector to quantize
+ * @param destination the destination byte array to store the quantized vector
+ * @param centroid the centroid to use for quantization
+ * @return the corrective factors used for scoring error correction.
+ */
+ public float[] quantizeForQuery(float[] vector, byte[] destination, float[] centroid) {
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
+ assert this.discretizedDimensions == BQSpaceUtils.discretize(vector.length, 64);
+
+ if (this.discretizedDimensions != (destination.length * 8) / BQSpaceUtils.B_QUERY) {
+ throw new IllegalArgumentException(
+ "vector and quantized vector destination must be compatible dimensions: "
+ + vector.length
+ + " [ "
+ + this.discretizedDimensions
+ + " ]"
+ + "!= ("
+ + destination.length
+ + " * 8) / "
+ + BQSpaceUtils.B_QUERY);
+ }
+
+ if (vector.length != centroid.length) {
+ throw new IllegalArgumentException(
+ "vector and centroid dimensions must be the same: "
+ + vector.length
+ + "!= "
+ + centroid.length);
+ }
+
+ // FIXME: make a copy of vector so we don't overwrite it here?
+ // ... (could subtractInPlace but the passed vector is modified) <<---
+ float[] vmC = ArrayUtil.copyArray(vector);
+ VectorUtil.subtract(vmC, centroid);
+
+ // FIXME: should other similarity functions behave like MIP on query like COSINE
+ float normVmC = 0f;
+ if (similarityFunction != EUCLIDEAN) {
+ normVmC = VectorUtil.l2Norm(vmC);
+ VectorUtil.divide(vmC, normVmC);
+ }
+ float[] range = range(vmC);
+ float lower = range[0];
+ float upper = range[1];
+ // Δ := (𝑣𝑟 − 𝑣𝑙)/(2𝐵𝑞 − 1)
+ float width = (upper - lower) / ((1 << BQSpaceUtils.B_QUERY) - 1);
+
+ QuantResult quantResult = quantize(vmC, lower, width);
+ byte[] byteQuery = quantResult.result();
+
+ // q¯ = Δ · q¯𝑢 + 𝑣𝑙 · 1𝐷
+ // q¯ is an approximation of q′ (scalar quantized approximation)
+ // FIXME: vectors need to be padded but that's expensive; update transponseBin to deal
+ byteQuery = BQSpaceUtils.pad(byteQuery, discretizedDimensions);
+ BQSpaceUtils.transposeBin(byteQuery, discretizedDimensions, destination);
+
+ final float[] corrections;
+ if (similarityFunction != EUCLIDEAN) {
+ float vDotC = VectorUtil.dotProduct(vector, centroid);
+ corrections = new float[] {lower, width, normVmC, vDotC, quantResult.quantizedSum};
+ } else {
+ float distToCentroid = (float) Math.sqrt(VectorUtil.squareDistance(vector, centroid));
+ corrections = new float[] {distToCentroid, lower, width, quantResult.quantizedSum};
+ }
+
+ return corrections;
+ }
+}
diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
index 9273f7c5a813..8f7fa6048b80 100644
--- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
+++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
@@ -29,6 +29,7 @@
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
+import jdk.incubator.vector.LongVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMask;
@@ -764,6 +765,122 @@ private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int l
return acc1.add(acc2).reduceLanes(ADD);
}
+ @Override
+ public long ipByteBinByte(byte[] q, byte[] d) {
+ // 128 / 8 == 16
+ if (d.length >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) {
+ if (VECTOR_BITSIZE >= 256) {
+ return ipByteBin256(q, d);
+ } else if (VECTOR_BITSIZE == 128) {
+ return ipByteBin128(q, d);
+ }
+ }
+ return DefaultVectorUtilSupport.ipByteBinByteImpl(q, d);
+ }
+
+ private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128;
+ private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256;
+
+ static long ipByteBin256(byte[] q, byte[] d) {
+ long subRet0 = 0;
+ long subRet1 = 0;
+ long subRet2 = 0;
+ long subRet3 = 0;
+ int i = 0;
+
+ if (d.length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
+ int limit = ByteVector.SPECIES_256.loopBound(d.length);
+ var sum0 = LongVector.zero(LongVector.SPECIES_256);
+ var sum1 = LongVector.zero(LongVector.SPECIES_256);
+ var sum2 = LongVector.zero(LongVector.SPECIES_256);
+ var sum3 = LongVector.zero(LongVector.SPECIES_256);
+ for (; i < limit; i += ByteVector.SPECIES_256.length()) {
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length).reinterpretAsLongs();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 2).reinterpretAsLongs();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 3).reinterpretAsLongs();
+ var vd = ByteVector.fromArray(BYTE_SPECIES_256, d, i).reinterpretAsLongs();
+ sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ }
+
+ if (d.length - i >= ByteVector.SPECIES_128.vectorByteSize()) {
+ var sum0 = LongVector.zero(LongVector.SPECIES_128);
+ var sum1 = LongVector.zero(LongVector.SPECIES_128);
+ var sum2 = LongVector.zero(LongVector.SPECIES_128);
+ var sum3 = LongVector.zero(LongVector.SPECIES_128);
+ int limit = ByteVector.SPECIES_128.loopBound(d.length);
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsLongs();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsLongs();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsLongs();
+ var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsLongs();
+ sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ }
+ // tail as bytes
+ for (; i < d.length; i++) {
+ subRet0 += Integer.bitCount((q[i] & d[i]) & 0xFF);
+ subRet1 += Integer.bitCount((q[i + d.length] & d[i]) & 0xFF);
+ subRet2 += Integer.bitCount((q[i + 2 * d.length] & d[i]) & 0xFF);
+ subRet3 += Integer.bitCount((q[i + 3 * d.length] & d[i]) & 0xFF);
+ }
+ return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+ }
+
+ public static long ipByteBin128(byte[] q, byte[] d) {
+ long subRet0 = 0;
+ long subRet1 = 0;
+ long subRet2 = 0;
+ long subRet3 = 0;
+ int i = 0;
+
+ var sum0 = IntVector.zero(IntVector.SPECIES_128);
+ var sum1 = IntVector.zero(IntVector.SPECIES_128);
+ var sum2 = IntVector.zero(IntVector.SPECIES_128);
+ var sum3 = IntVector.zero(IntVector.SPECIES_128);
+ int limit = ByteVector.SPECIES_128.loopBound(d.length);
+ for (; i < limit; i += ByteVector.SPECIES_128.length()) {
+ var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsInts();
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsInts();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsInts();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsInts();
+ sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ // tail as bytes
+ for (; i < d.length; i++) {
+ int dValue = d[i];
+ subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
+ subRet1 += Integer.bitCount((dValue & q[i + d.length]) & 0xFF);
+ subRet2 += Integer.bitCount((dValue & q[i + 2 * d.length]) & 0xFF);
+ subRet3 += Integer.bitCount((dValue & q[i + 3 * d.length]) & 0xFF);
+ }
+ return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+ }
+
// Experiments suggest that we need at least 8 lanes so that the overhead of going with the vector
// approach and counting trues on vector masks pays off.
private static final boolean ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO = INT_SPECIES.length() >= 8;
diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
index cb5fee62aeec..bd0c97688f54 100644
--- a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
+++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
@@ -16,3 +16,5 @@
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat
+org.apache.lucene.codecs.lucene101.Lucene101HnswBinaryQuantizedVectorsFormat
+org.apache.lucene.codecs.lucene101.Lucene101BinaryQuantizedVectorsFormat
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryFlatVectorsScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryFlatVectorsScorer.java
new file mode 100644
index 000000000000..3c0495102fd9
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryFlatVectorsScorer.java
@@ -0,0 +1,442 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.codecs.lucene101;
+
+import java.io.IOException;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.quantization.BQSpaceUtils;
+import org.apache.lucene.util.quantization.BinaryQuantizer;
+
+public class TestLucene101BinaryFlatVectorsScorer extends LuceneTestCase {
+
+ public void testScore() throws IOException {
+ int dimensions = random().nextInt(1, 4097);
+ int discretizedDimensions = BQSpaceUtils.discretize(dimensions, 64);
+
+ int randIdx = random().nextInt(VectorSimilarityFunction.values().length);
+ VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx];
+
+ float[] centroid = new float[dimensions];
+ for (int j = 0; j < dimensions; j++) {
+ centroid[j] = random().nextFloat(-50f, 50f);
+ }
+ if (similarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(centroid);
+ }
+
+ byte[] vector = new byte[discretizedDimensions / 8 * BQSpaceUtils.B_QUERY];
+ random().nextBytes(vector);
+ float distanceToCentroid = random().nextFloat(0f, 10_000.0f);
+ float vl = random().nextFloat(-1000f, 1000f);
+ float width = random().nextFloat(0f, 1000f);
+ short quantizedSum = (short) random().nextInt(0, 4097);
+ float normVmC = random().nextFloat(-1000f, 1000f);
+ float vDotC = random().nextFloat(-1000f, 1000f);
+ final float[] queryCorrections =
+ similarityFunction != VectorSimilarityFunction.EUCLIDEAN
+ ? new float[] {vl, width, normVmC, vDotC, quantizedSum}
+ : new float[] {distanceToCentroid, vl, width, quantizedSum};
+ Lucene101BinaryFlatVectorsScorer.BinaryQueryVector queryVector =
+ new Lucene101BinaryFlatVectorsScorer.BinaryQueryVector(vector, queryCorrections);
+
+ BinarizedByteVectorValues targetVectors =
+ new BinarizedByteVectorValues() {
+
+ @Override
+ public BinaryQuantizer getQuantizer() {
+ int dimensions = 128;
+ return new BinaryQuantizer(dimensions, dimensions, VectorSimilarityFunction.EUCLIDEAN);
+ }
+
+ @Override
+ public float[] getCentroid() {
+ return centroid;
+ }
+
+ @Override
+ public BinarizedByteVectorValues copy() {
+ return null;
+ }
+
+ @Override
+ public byte[] vectorValue(int targetOrd) {
+ byte[] vectorBytes = new byte[discretizedDimensions / 8];
+ random().nextBytes(vectorBytes);
+ return vectorBytes;
+ }
+
+ @Override
+ public int size() {
+ return 1;
+ }
+
+ @Override
+ public int dimension() {
+ return dimensions;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] query) {
+ return null;
+ }
+
+ @Override
+ public float[] getCorrectiveTerms(int vectorOrd) {
+ if (similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
+ return new float[] {random().nextFloat(0f, 1000f), random().nextFloat(0f, 100f)};
+ }
+ return new float[] {
+ random().nextFloat(-1000f, 1000f),
+ random().nextFloat(-1000f, 1000f),
+ random().nextFloat(-1000f, 1000f)
+ };
+ }
+ };
+
+ Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer =
+ new Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer(
+ queryVector, targetVectors, similarityFunction);
+
+ float score = scorer.score(0);
+
+ assertTrue(score >= 0f);
+ }
+
+ public void testScoreEuclidean() throws IOException {
+ int dimensions = 128;
+
+ byte[] vector =
+ new byte[] {
+ -8, 10, -27, 112, -83, 36, -36, -122, -114, 82, 55, 33, -33, 120, 55, -99, -93, -86, -55,
+ 21, -121, 30, 111, 30, 0, 82, 21, 38, -120, -127, 40, -32, 78, -37, 42, -43, 122, 115, 30,
+ 115, 123, 108, -13, -65, 123, 124, -33, -68, 49, 5, 20, 58, 0, 12, 30, 30, 4, 97, 10, 66,
+ 4, 35, 1, 67
+ };
+ float distanceToCentroid = 157799.12f;
+ float vl = -57.883f;
+ float width = 9.972266f;
+ short quantizedSum = 795;
+ float[] queryCorrections = new float[] {distanceToCentroid, vl, width, quantizedSum};
+ Lucene101BinaryFlatVectorsScorer.BinaryQueryVector queryVector =
+ new Lucene101BinaryFlatVectorsScorer.BinaryQueryVector(vector, queryCorrections);
+
+ BinarizedByteVectorValues targetVectors =
+ new BinarizedByteVectorValues() {
+ @Override
+ public BinaryQuantizer getQuantizer() {
+ int dimensions = 128;
+ return new BinaryQuantizer(dimensions, dimensions, VectorSimilarityFunction.EUCLIDEAN);
+ }
+
+ @Override
+ public float[] getCentroid() {
+ return new float[] {
+ 26.7f, 16.2f, 10.913f, 10.314f, 12.12f, 14.045f, 15.887f, 16.864f, 32.232f, 31.567f,
+ 34.922f, 21.624f, 16.349f, 29.625f, 31.994f, 22.044f, 37.847f, 24.622f, 36.299f,
+ 27.966f, 14.368f, 19.248f, 30.778f, 35.927f, 27.019f, 16.381f, 17.325f, 16.517f,
+ 13.272f, 9.154f, 9.242f, 17.995f, 53.777f, 23.011f, 12.929f, 16.128f, 22.16f, 28.643f,
+ 25.861f, 27.197f, 59.883f, 40.878f, 34.153f, 22.795f, 24.402f, 37.427f, 34.19f,
+ 29.288f, 61.812f, 26.355f, 39.071f, 37.789f, 23.33f, 22.299f, 28.64f, 47.828f,
+ 52.457f, 21.442f, 24.039f, 29.781f, 27.707f, 19.484f, 14.642f, 28.757f, 54.567f,
+ 20.936f, 25.112f, 25.521f, 22.077f, 18.272f, 14.526f, 29.054f, 61.803f, 24.509f,
+ 37.517f, 35.906f, 24.106f, 22.64f, 32.1f, 48.788f, 60.102f, 39.625f, 34.766f, 22.497f,
+ 24.397f, 41.599f, 38.419f, 30.99f, 55.647f, 25.115f, 14.96f, 18.882f, 26.918f,
+ 32.442f, 26.231f, 27.107f, 26.828f, 15.968f, 18.668f, 14.071f, 10.906f, 8.989f,
+ 9.721f, 17.294f, 36.32f, 21.854f, 35.509f, 27.106f, 14.067f, 19.82f, 33.582f, 35.997f,
+ 33.528f, 30.369f, 36.955f, 21.23f, 15.2f, 30.252f, 34.56f, 22.295f, 29.413f, 16.576f,
+ 11.226f, 10.754f, 12.936f, 15.525f, 15.868f, 16.43f
+ };
+ }
+
+ @Override
+ public BinarizedByteVectorValues copy() {
+ return null;
+ }
+
+ @Override
+ public byte[] vectorValue(int targetOrd) {
+ return new byte[] {
+ 44, 108, 120, -15, -61, -32, 124, 25, -63, -57, 6, 24, 1, -61, 1, 14
+ };
+ }
+
+ @Override
+ public int size() {
+ return 1;
+ }
+
+ @Override
+ public int dimension() {
+ return dimensions;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] query) {
+ return null;
+ }
+
+ @Override
+ public float[] getCorrectiveTerms(int vectorOrd) {
+ return new float[] {355.78073f, 0.7636705f};
+ }
+ };
+
+ VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+
+ Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer =
+ new Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer(
+ queryVector, targetVectors, similarityFunction);
+
+ assertEquals(1f / (1f + 245482.47f), scorer.score(0), 0.1f);
+ }
+
+ public void testScoreMIP() throws IOException {
+ int dimensions = 768;
+
+ byte[] vector =
+ new byte[] {
+ -76, 44, 81, 31, 30, -59, 56, -118, -36, 45, -11, 8, -61, 95, -100, 18, -91, -98, -46, 31,
+ -8, 82, -42, 121, 75, -61, 125, -21, -82, 16, 21, 40, -1, 12, -92, -22, -49, -92, -19,
+ -32, -56, -34, 60, -100, 69, 13, 60, -51, 90, 4, -77, 63, 124, 69, 88, 73, -72, 29, -96,
+ 44, 69, -123, -59, -94, 84, 80, -61, 27, -37, -92, -51, -86, 19, -55, -36, -2, 68, -37,
+ -128, 59, -47, 119, -53, 56, -12, 37, 27, 119, -37, 125, 78, 19, 15, -9, 94, 100, -72, 55,
+ 86, -48, 26, 10, -112, 28, -15, -64, -34, 55, -42, -31, -96, -18, 60, -44, 69, 106, -20,
+ 15, 47, 49, -122, -45, 119, 101, 22, 77, 108, -15, -71, -28, -43, -68, -127, -86, -118,
+ -51, 121, -65, -10, -49, 115, -6, -61, -98, 21, 41, 56, 29, -16, -82, 4, 72, -77, 23, 23,
+ -32, -98, 112, 27, -4, 91, -69, 102, -114, 16, -20, -76, -124, 43, 12, 3, -30, 42, -44,
+ -88, -72, -76, -94, -73, 46, -17, 4, -74, -44, 53, -11, -117, -105, -113, -37, -43, -128,
+ -70, 56, -68, -100, 56, -20, 77, 12, 17, -119, -17, 59, -10, -26, 29, 42, -59, -28, -28,
+ 60, -34, 60, -24, 80, -81, 24, 122, 127, 62, 124, -5, -11, 59, -52, 74, -29, -116, 3, -40,
+ -99, -24, 11, -10, 95, 21, -38, 59, -52, 29, 58, 112, 100, -106, -90, 71, 72, 57, 95, 98,
+ 96, -41, -16, 50, -18, 123, -36, 74, -101, 17, 50, 48, 96, 57, 7, 81, -16, -32, -102, -24,
+ -71, -10, 37, -22, 94, -36, -52, -71, -47, 47, -1, -31, -10, -126, -15, -123, -59, 71,
+ -49, 67, 99, -57, 21, -93, -13, -18, 54, -112, -60, 9, 25, -30, -47, 26, 27, 26, -63, 1,
+ -63, 18, -114, 80, 110, -123, 0, -63, -126, -128, 10, -60, 51, -71, 28, 114, -4, 53, 10,
+ 23, -96, 9, 32, -22, 5, -108, 33, 98, -59, -106, -126, 73, 72, -72, -73, -60, -96, -99,
+ 31, 40, 15, -19, 17, -128, 33, -75, 96, -18, -47, 75, 27, -60, -16, -82, 13, 21, 37, 23,
+ 70, 9, -39, 16, -127, 35, -78, 64, 99, -46, 1, 28, 65, 125, 14, 42, 26
+ };
+ float vl = -0.10079563f;
+ float width = 0.014609014f;
+ short quantizedSum = 5306;
+ float normVmC = 9.766797f;
+ float vDotC = 133.56123f;
+ float cDotC = 132.20227f;
+ float[] queryCorrections = new float[] {vl, width, normVmC, vDotC, quantizedSum};
+ Lucene101BinaryFlatVectorsScorer.BinaryQueryVector queryVector =
+ new Lucene101BinaryFlatVectorsScorer.BinaryQueryVector(vector, queryCorrections);
+
+ BinarizedByteVectorValues targetVectors =
+ new BinarizedByteVectorValues() {
+
+ @Override
+ public float getCentroidDP() {
+ return cDotC;
+ }
+
+ @Override
+ public BinaryQuantizer getQuantizer() {
+ int dimensions = 768;
+ return new BinaryQuantizer(
+ dimensions, dimensions, VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);
+ }
+
+ @Override
+ public float[] getCentroid() {
+ return new float[] {
+ 0.16672021f, 0.11700719f, 0.013227397f, 0.09305186f, -0.029422699f, 0.17622353f,
+ 0.4267106f, -0.297038f, 0.13915674f, 0.38441318f, -0.486725f, -0.15987667f,
+ -0.19712289f, 0.1349074f, -0.19016947f, -0.026179956f, 0.4129807f, 0.14325741f,
+ -0.09106042f, 0.06876218f, -0.19389102f, 0.4467732f, 0.03169017f, -0.066950575f,
+ -0.044301506f, -0.0059755715f, -0.33196586f, 0.18213534f, -0.25065416f, 0.30251458f,
+ 0.3448419f, -0.14900115f, -0.07782894f, 0.3568707f, -0.46595258f, 0.37295088f,
+ -0.088741764f, 0.17248306f, -0.0072736046f, 0.32928637f, 0.13216197f, 0.032092985f,
+ 0.21553043f, 0.016091486f, 0.31958902f, 0.0133126f, 0.1579258f, 0.018537233f,
+ 0.046248164f, -0.0048194043f, -0.2184672f, -0.26273906f, -0.110678785f, -0.04542999f,
+ -0.41625032f, 0.46025568f, -0.16116948f, 0.4091706f, 0.18427321f, 0.004736977f,
+ 0.16289745f, -0.05330932f, -0.2694863f, -0.14762327f, 0.17744702f, 0.2445075f,
+ 0.14377175f, 0.37390858f, 0.16165806f, 0.17177118f, 0.097307935f, 0.36326465f,
+ 0.23221572f, 0.15579978f, -0.065486655f, -0.29006517f, -0.009194494f, 0.009019374f,
+ 0.32154799f, -0.23186184f, 0.46485493f, -0.110756285f, -0.18604982f, 0.35027295f,
+ 0.19815539f, 0.47386464f, -0.031379268f, 0.124035835f, 0.11556784f, 0.4304302f,
+ -0.24455063f, 0.1816723f, 0.034300473f, -0.034347706f, 0.040140998f, 0.1389901f,
+ 0.22840638f, -0.19911191f, 0.07563166f, -0.2744902f, 0.13114859f, -0.23862572f,
+ -0.31404558f, 0.41355187f, 0.12970817f, -0.35403475f, -0.2714075f, 0.07231573f,
+ 0.043893218f, 0.30324167f, 0.38928393f, -0.1567055f, -0.0083288215f, 0.0487653f,
+ 0.12073729f, -0.01582117f, 0.13381198f, -0.084824145f, -0.15329859f, -1.120622f,
+ 0.3972598f, 0.36022213f, -0.29826534f, -0.09468781f, 0.03550699f, -0.21630692f,
+ 0.55655843f, -0.14842057f, 0.5924833f, 0.38791573f, 0.1502777f, 0.111737385f,
+ 0.1926823f, 0.66021144f, 0.25601995f, 0.28220543f, 0.10194068f, 0.013066262f,
+ -0.09348819f, -0.24085014f, -0.17843121f, -0.012598432f, 0.18757571f, 0.48543528f,
+ -0.059388146f, 0.1548026f, 0.041945867f, 0.3322589f, 0.012830887f, 0.16621992f,
+ 0.22606649f, 0.13959105f, -0.16688728f, 0.47194278f, -0.12767595f, 0.037815034f,
+ 0.441938f, 0.07875027f, 0.08625042f, 0.053454693f, 0.74093896f, 0.34662113f,
+ 0.009829135f, -0.033400282f, 0.030965377f, 0.17645596f, 0.083803624f, 0.32578796f,
+ 0.49538168f, -0.13212465f, -0.39596975f, 0.109529115f, 0.2815771f, -0.051440604f,
+ 0.21889819f, 0.25598505f, 0.012208843f, -0.012405662f, 0.3248759f, 0.00997502f,
+ 0.05999008f, 0.03562817f, 0.19007418f, 0.24805716f, 0.5926766f, 0.26937613f,
+ 0.25856f, -0.05798439f, -0.29168302f, 0.14050555f, 0.084851265f, -0.03763504f,
+ 0.8265359f, -0.23383066f, -0.042164285f, 0.19120507f, -0.12189065f, 0.3864055f,
+ -0.19823311f, 0.30280992f, 0.10814344f, -0.164514f, -0.22905481f, 0.13680641f,
+ 0.4513772f, -0.514546f, -0.061746247f, 0.11598224f, -0.23093395f, -0.09735358f,
+ 0.02767051f, 0.11594536f, 0.17106244f, 0.21301728f, -0.048222974f, 0.2212131f,
+ -0.018857865f, -0.09783516f, 0.42156664f, -0.14032331f, -0.103861615f, 0.4190284f,
+ 0.068923555f, -0.015083771f, 0.083590426f, -0.15759592f, -0.19096768f, -0.4275228f,
+ 0.12626286f, 0.12192557f, 0.4157616f, 0.048780657f, 0.008426048f, -0.0869124f,
+ 0.054927208f, 0.28417027f, 0.29765493f, 0.09203619f, -0.14446871f, -0.117514975f,
+ 0.30662632f, 0.24904715f, -0.19551662f, -0.0045785015f, 0.4217626f, -0.31457824f,
+ 0.23381722f, 0.089111514f, -0.27170828f, -0.06662652f, 0.10011391f, -0.090274535f,
+ 0.101849966f, 0.26554734f, -0.1722843f, 0.23296228f, 0.25112453f, -0.16790418f,
+ 0.010348314f, 0.05061285f, 0.38003662f, 0.0804625f, 0.3450673f, 0.364368f,
+ -0.2529952f, -0.034065288f, 0.22796603f, 0.5457553f, 0.11120353f, 0.24596325f,
+ 0.42822433f, -0.19215727f, -0.06974534f, 0.19388479f, -0.17598474f, -0.08769705f,
+ 0.12769659f, 0.1371616f, -0.4636819f, 0.16870509f, 0.14217548f, 0.04412187f,
+ -0.20930687f, 0.0075530168f, 0.10065227f, 0.45334083f, -0.1097471f, -0.11139921f,
+ -0.31835595f, -0.057386875f, 0.16285825f, 0.5088513f, -0.06318843f, -0.34759882f,
+ 0.21132466f, 0.33609292f, 0.04858872f, -0.058759f, 0.22845529f, -0.07641319f,
+ 0.5452827f, -0.5050389f, 0.1788054f, 0.37428045f, 0.066334985f, -0.28162515f,
+ -0.15629752f, 0.33783385f, -0.0832242f, 0.29144394f, 0.47892854f, -0.47006592f,
+ -0.07867588f, 0.3872869f, 0.28053126f, 0.52399015f, 0.21979983f, 0.076880336f,
+ 0.47866163f, 0.252952f, -0.1323851f, -0.22225754f, -0.38585815f, 0.12967427f,
+ 0.20340872f, -0.326928f, 0.09636557f, -0.35929212f, 0.5413311f, 0.019960884f,
+ 0.33512768f, 0.15133342f, -0.14124066f, -0.1868793f, -0.07862198f, 0.22739467f,
+ 0.19598985f, 0.34314656f, -0.05071516f, -0.21107961f, 0.19934991f, 0.04822684f,
+ 0.15060754f, 0.26586458f, -0.15528078f, 0.123646654f, 0.14450715f, -0.12574252f,
+ 0.30608323f, 0.018549249f, 0.36323825f, 0.06762097f, 0.08562406f, -0.07863075f,
+ 0.15975896f, 0.008347004f, 0.37931192f, 0.22957338f, 0.33606857f, -0.25204057f,
+ 0.18126069f, 0.41903302f, 0.20244692f, -0.053850617f, 0.23088565f, 0.16085246f,
+ 0.1077502f, -0.12445943f, 0.115779735f, 0.124704875f, 0.13076028f, -0.11628619f,
+ -0.12580182f, 0.065204754f, -0.26290357f, -0.23539798f, -0.1855292f, 0.39872098f,
+ 0.44495568f, 0.05491784f, 0.05135692f, 0.624011f, 0.22839564f, 0.0022447354f,
+ -0.27169296f, -0.1694988f, -0.19106841f, 0.0110123325f, 0.15464798f, -0.16269256f,
+ 0.04033836f, -0.11792753f, 0.17172396f, -0.08912173f, -0.30929542f, -0.03446989f,
+ -0.21738084f, 0.39657044f, 0.33550346f, -0.06839139f, 0.053675443f, 0.33783767f,
+ 0.22576828f, 0.38280004f, 4.1448855f, 0.14225426f, 0.24038498f, 0.072373435f,
+ -0.09465926f, -0.016144043f, 0.40864578f, -0.2583055f, 0.031816103f, 0.062555805f,
+ 0.06068663f, 0.25858644f, -0.10598804f, 0.18201788f, -0.00090025424f, 0.085680895f,
+ 0.4304161f, 0.028686283f, 0.027298616f, 0.27473378f, -0.3888415f, 0.44825438f,
+ 0.3600378f, 0.038944595f, 0.49292335f, 0.18556066f, 0.15779617f, 0.29989767f,
+ 0.39233804f, 0.39759228f, 0.3850708f, -0.0526475f, 0.18572918f, 0.09667526f,
+ -0.36111078f, 0.3439669f, 0.1724522f, 0.14074509f, 0.26097745f, 0.16626832f,
+ -0.3062964f, -0.054877423f, 0.21702516f, 0.4736452f, 0.2298038f, -0.2983771f,
+ 0.118479654f, 0.35940516f, 0.12212727f, 0.17234904f, 0.30632678f, 0.09207966f,
+ -0.14084268f, -0.19737118f, 0.12442629f, 0.52454203f, 0.1266684f, 0.3062802f,
+ 0.121598125f, -0.09156268f, 0.11491686f, -0.105715364f, 0.19831072f, 0.061421417f,
+ -0.41778997f, 0.14488487f, 0.023310646f, 0.27257463f, 0.16821945f, -0.16702746f,
+ 0.263203f, 0.33512688f, 0.35117313f, -0.31740817f, -0.14203706f, 0.061256267f,
+ -0.19764185f, 0.04822579f, -0.0016218472f, -0.025792575f, 0.4885193f, -0.16942391f,
+ -0.04156327f, 0.15908112f, -0.06998626f, 0.53907114f, 0.10317832f, -0.365468f,
+ 0.4729886f, 0.14291425f, 0.32812154f, -0.0273262f, 0.31760117f, 0.16925456f,
+ 0.21820979f, 0.085142255f, 0.16118735f, -3.7089362f, 0.251577f, 0.18394576f,
+ 0.027926167f, 0.15720351f, 0.13084261f, 0.16240814f, 0.23045056f, -0.3966458f,
+ 0.22822891f, -0.061541352f, 0.028320132f, -0.14736478f, 0.184569f, 0.084853746f,
+ 0.15172474f, 0.08277542f, 0.27751622f, 0.23450488f, -0.15349835f, 0.29665688f,
+ 0.32045734f, 0.20012043f, -0.2749372f, 0.011832386f, 0.05976605f, 0.018300122f,
+ -0.07855043f, -0.075900674f, 0.0384252f, -0.15101928f, 0.10922137f, 0.47396383f,
+ -0.1771141f, 0.2203417f, 0.33174303f, 0.36640546f, 0.10906258f, 0.13765177f,
+ 0.2488032f, -0.061588854f, 0.20347528f, 0.2574979f, 0.22369152f, 0.18777567f,
+ -0.0772263f, -0.1353299f, 0.087077625f, -0.05409276f, 0.027534787f, 0.08053508f,
+ 0.3403908f, -0.15362988f, 0.07499862f, 0.54367846f, -0.045938436f, 0.12206868f,
+ 0.031069376f, 0.2972343f, 0.3235321f, -0.053970363f, -0.0042564687f, 0.21447177f,
+ 0.023565233f, -0.1286087f, -0.047359955f, 0.23021339f, 0.059837278f, 0.19709614f,
+ -0.17340347f, 0.11572943f, 0.21720429f, 0.29375625f, -0.045433592f, 0.033339307f,
+ 0.24594454f, -0.021661613f, -0.12823369f, 0.41809165f, 0.093840264f, -0.007481906f,
+ 0.22441079f, -0.45719734f, 0.2292629f, 2.675806f, 0.3690025f, 2.1311781f,
+ 0.07818368f, -0.17055893f, 0.3162922f, -0.2983149f, 0.21211359f, 0.037087034f,
+ 0.021580033f, 0.086415835f, 0.13541797f, -0.12453424f, 0.04563163f, -0.082379065f,
+ -0.15938349f, 0.38595748f, -0.8796574f, -0.080991246f, 0.078572094f, 0.20274459f,
+ 0.009252143f, -0.12719384f, 0.105845824f, 0.1592398f, -0.08656061f, -0.053054806f,
+ 0.090986334f, -0.02223379f, -0.18215932f, -0.018316114f, 0.1806707f, 0.24788831f,
+ -0.041049056f, 0.01839475f, 0.19160001f, -0.04827654f, 4.4070687f, 0.12640671f,
+ -0.11171499f, -0.015480781f, 0.14313947f, 0.10024215f, 0.4129662f, 0.038836367f,
+ -0.030228542f, 0.2948598f, 0.32946473f, 0.2237934f, 0.14260699f, -0.044821896f,
+ 0.23791742f, 0.079720296f, 0.27059034f, 0.32129505f, 0.2725177f, 0.06883333f,
+ 0.1478041f, 0.07598411f, 0.27230525f, -0.04704308f, 0.045167264f, 0.215413f,
+ 0.20359069f, -0.092178136f, -0.09523752f, 0.21427691f, 0.10512272f, 5.1295033f,
+ 0.040909242f, 0.007160441f, -0.192866f, -0.102640584f, 0.21103396f, -0.006780398f,
+ -0.049653083f, -0.29426834f, -0.0038102255f, -0.13842082f, 0.06620181f, -0.3196518f,
+ 0.33279592f, 0.13845938f, 0.16162738f, -0.24798508f, -0.06672485f, 0.195944f,
+ -0.11957207f, 0.44237947f, -0.07617347f, 0.13575341f, -0.35074243f, -0.093798876f,
+ 0.072853446f, -0.20490398f, 0.26504788f, -0.046076056f, 0.16488416f, 0.36007464f,
+ 0.20955376f, -0.3082038f, 0.46533757f, -0.27326992f, -0.14167665f, 0.25017953f,
+ 0.062622115f, 0.14057694f, -0.102370486f, 0.33898357f, 0.36456722f, -0.10120469f,
+ -0.27838466f, -0.11779602f, 0.18517569f, -0.05942488f, 0.076405466f, 0.007960496f,
+ 0.0443746f, 0.098998964f, -0.01897129f, 0.8059487f, 0.06991939f, 0.26562217f,
+ 0.26942885f, 0.11432197f, -0.0055776504f, 0.054493718f, -0.13086213f, 0.6841702f,
+ 0.121975765f, 0.02787146f, 0.29039973f, 0.30943078f, 0.21762547f, 0.28751117f,
+ 0.027524523f, 0.5315654f, -0.22451901f, -0.13782433f, 0.08228316f, 0.07808882f,
+ 0.17445615f, -0.042489477f, 0.13232234f, 0.2756272f, -0.18824948f, 0.14326479f,
+ -0.119312495f, 0.011788091f, -0.22103515f, -0.2477118f, -0.10513839f, 0.034028634f,
+ 0.10693818f, 0.03057979f, 0.04634646f, 0.2289361f, 0.09981585f, 0.26901972f,
+ 0.1561221f, -0.10639886f, 0.36466748f, 0.06350991f, 0.027927283f, 0.11919768f,
+ 0.23290513f, -0.03417105f, 0.16698854f, -0.19243467f, 0.28430334f, 0.03754995f,
+ -0.08697018f, 0.20413163f, -0.27218238f, 0.13707504f, -0.082289375f, 0.03479585f,
+ 0.2298305f, 0.4983682f, 0.34522808f, -0.05711886f, -0.10568684f, -0.07771385f
+ };
+ }
+
+ @Override
+ public BinarizedByteVectorValues copy() {
+ return null;
+ }
+
+ @Override
+ public byte[] vectorValue(int targetOrd) {
+ return new byte[] {
+ -88, -3, 60, -75, -38, 79, 84, -53, -116, -126, 19, -19, -21, -80, 69, 101, -71, 53,
+ 101, -124, -24, -76, 92, -45, 108, -107, -18, 102, 23, -80, -47, 116, 87, -50, 27,
+ -31, -10, -13, 117, -88, -27, -93, -98, -39, 30, -109, -114, 5, -15, 98, -82, 81, 83,
+ 118, 30, -118, -12, -95, 121, 125, -13, -88, 75, -85, -56, -126, 82, -59, 48, -81, 67,
+ -63, 81, 24, -83, 95, -44, 103, 3, -40, -13, -41, -29, -60, 1, 65, -4, -110, -40, 34,
+ 118, 51, -76, 75, 70, -51
+ };
+ }
+
+ @Override
+ public int size() {
+ return 1;
+ }
+
+ @Override
+ public int dimension() {
+ return dimensions;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] query) {
+ return null;
+ }
+
+ @Override
+ public float[] getCorrectiveTerms(int vectorOrd) {
+ return new float[] {0.7882396f, 5.0889387f, 131.485660f};
+ }
+ };
+
+ VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+ Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer =
+ new Lucene101BinaryFlatVectorsScorer.BinarizedRandomVectorScorer(
+ queryVector, targetVectors, similarityFunction);
+
+ assertEquals(129.64046f, scorer.score(0), 0.0001f);
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..011b42d9c664
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101BinaryQuantizedVectorsFormat.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene101;
+
+import static java.lang.String.format;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.oneOf;
+
+import java.io.IOException;
+import java.util.Locale;
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.codecs.FilterCodec;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.util.quantization.BQSpaceUtils;
+import org.apache.lucene.util.quantization.BinaryQuantizer;
+
+public class TestLucene101BinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
+
+ @Override
+ protected Codec getCodec() {
+ return new Lucene101Codec() {
+ @Override
+ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
+ return new Lucene101BinaryQuantizedVectorsFormat();
+ }
+ };
+ }
+
+ public void testSearch() throws Exception {
+ String fieldName = "field";
+ int numVectors = random().nextInt(99, 500);
+ int dims = random().nextInt(4, 65);
+ float[] vector = randomVector(dims);
+ VectorSimilarityFunction similarityFunction = randomSimilarity();
+ KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction);
+ IndexWriterConfig iwc = newIndexWriterConfig();
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(dir, iwc)) {
+ for (int i = 0; i < numVectors; i++) {
+ Document doc = new Document();
+ knnField.setVectorValue(randomVector(dims));
+ doc.add(knnField);
+ w.addDocument(doc);
+ }
+ w.commit();
+
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ final int k = random().nextInt(5, 50);
+ float[] queryVector = randomVector(dims);
+ Query q = new KnnFloatVectorQuery(fieldName, queryVector, k);
+ TopDocs collectedDocs = searcher.search(q, k);
+ assertEquals(k, collectedDocs.totalHits.value());
+ assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation());
+ }
+ }
+ }
+ }
+
+ public void testToString() {
+ FilterCodec customCodec =
+ new FilterCodec("foo", Codec.getDefault()) {
+ @Override
+ public KnnVectorsFormat knnVectorsFormat() {
+ return new Lucene101BinaryQuantizedVectorsFormat();
+ }
+ };
+ String expectedPattern =
+ "Lucene101BinaryQuantizedVectorsFormat(name=Lucene101BinaryQuantizedVectorsFormat, flatVectorScorer=Lucene101BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))";
+ var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
+ var memSegScorer =
+ format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
+ assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
+ }
+
+ @Override
+ public void testRandomWithUpdatesAndGraph() {
+ // graph not supported
+ }
+
+ @Override
+ public void testSearchWithVisitedLimit() {
+ // visited limit is not respected, as it is brute force search
+ }
+
+ public void testQuantizedVectorsWriteAndRead() throws IOException {
+ String fieldName = "field";
+ int numVectors = random().nextInt(99, 500);
+ int dims = random().nextInt(4, 65);
+
+ float[] vector = randomVector(dims);
+ VectorSimilarityFunction similarityFunction = randomSimilarity();
+ KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction);
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ for (int i = 0; i < numVectors; i++) {
+ Document doc = new Document();
+ knnField.setVectorValue(randomVector(dims));
+ doc.add(knnField);
+ w.addDocument(doc);
+ if (i % 101 == 0) {
+ w.commit();
+ }
+ }
+ w.commit();
+ w.forceMerge(1);
+
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ LeafReader r = getOnlyLeafReader(reader);
+ FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
+ assertEquals(vectorValues.size(), numVectors);
+ BinarizedByteVectorValues qvectorValues =
+ ((Lucene101BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues)
+ .getQuantizedVectorValues();
+ float[] centroid = qvectorValues.getCentroid();
+ assertEquals(centroid.length, dims);
+
+ int descritizedDimension = BQSpaceUtils.discretize(dims, 64);
+ BinaryQuantizer quantizer =
+ new BinaryQuantizer(dims, descritizedDimension, similarityFunction);
+ byte[] expectedVector = new byte[BQSpaceUtils.discretize(dims, 64) / 8];
+ if (similarityFunction == VectorSimilarityFunction.COSINE) {
+ vectorValues =
+ new Lucene101BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues);
+ }
+
+ KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator();
+
+ while (docIndexIterator.nextDoc() != NO_MORE_DOCS) {
+ float[] corrections =
+ quantizer.quantizeForIndex(
+ vectorValues.vectorValue(docIndexIterator.index()), expectedVector, centroid);
+ assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index()));
+ assertEquals(
+ corrections.length,
+ qvectorValues.getCorrectiveTerms(docIndexIterator.index()).length);
+ for (int i = 0; i < corrections.length; i++) {
+ assertEquals(
+ corrections[i],
+ qvectorValues.getCorrectiveTerms(docIndexIterator.index())[i],
+ 0.00001f);
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101HnswBinaryQuantizedVectorsFormat.java
new file mode 100644
index 000000000000..a3f7f1827520
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene101/TestLucene101HnswBinaryQuantizedVectorsFormat.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene101;
+
+import static java.lang.String.format;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.oneOf;
+
+import java.util.Arrays;
+import java.util.Locale;
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.codecs.FilterCodec;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.util.SameThreadExecutorService;
+
+public class TestLucene101HnswBinaryQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
+
+ @Override
+ protected Codec getCodec() {
+ return new Lucene101Codec() {
+ @Override
+ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
+ return new Lucene101HnswBinaryQuantizedVectorsFormat();
+ }
+ };
+ }
+
+ public void testToString() {
+ FilterCodec customCodec =
+ new FilterCodec("foo", Codec.getDefault()) {
+ @Override
+ public KnnVectorsFormat knnVectorsFormat() {
+ return new Lucene101HnswBinaryQuantizedVectorsFormat(10, 20, 1, null);
+ }
+ };
+ String expectedPattern =
+ "Lucene101HnswBinaryQuantizedVectorsFormat(name=Lucene101HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20,"
+ + " flatVectorFormat=Lucene101BinaryQuantizedVectorsFormat(name=Lucene101BinaryQuantizedVectorsFormat,"
+ + " flatVectorScorer=Lucene101BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))";
+
+ var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
+ var memSegScorer =
+ format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
+ assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
+ }
+
+ public void testSingleVectorCase() throws Exception {
+ float[] vector = randomVector(random().nextInt(12, 500));
+ for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
+ try (Directory dir = newDirectory();
+ IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ Document doc = new Document();
+ doc.add(new KnnFloatVectorField("f", vector, similarityFunction));
+ w.addDocument(doc);
+ w.commit();
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ LeafReader r = getOnlyLeafReader(reader);
+ FloatVectorValues vectorValues = r.getFloatVectorValues("f");
+ KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator();
+ assert (vectorValues.size() == 1);
+ while (docIndexIterator.nextDoc() != NO_MORE_DOCS) {
+ assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f);
+ }
+ TopDocs td =
+ r.searchNearestVectors("f", randomVector(vector.length), 1, null, Integer.MAX_VALUE);
+ assertEquals(1, td.totalHits.value());
+ assertTrue(td.scoreDocs[0].score >= 0);
+ }
+ }
+ }
+ }
+
+ public void testLimits() {
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene101HnswBinaryQuantizedVectorsFormat(-1, 20));
+ expectThrows(
+ IllegalArgumentException.class, () -> new Lucene101HnswBinaryQuantizedVectorsFormat(0, 20));
+ expectThrows(
+ IllegalArgumentException.class, () -> new Lucene101HnswBinaryQuantizedVectorsFormat(20, 0));
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene101HnswBinaryQuantizedVectorsFormat(20, -1));
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene101HnswBinaryQuantizedVectorsFormat(512 + 1, 20));
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new Lucene101HnswBinaryQuantizedVectorsFormat(20, 3201));
+ expectThrows(
+ IllegalArgumentException.class,
+ () ->
+ new Lucene101HnswBinaryQuantizedVectorsFormat(
+ 20, 100, 1, new SameThreadExecutorService()));
+ }
+
+ // Ensures that all expected vector similarity functions are translatable in the format.
+ public void testVectorSimilarityFuncs() {
+ // This does not necessarily have to be all similarity functions, but
+ // differences should be considered carefully.
+ var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList();
+ assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues);
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java
index 34a0e5230022..ea6c183c7e7a 100644
--- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/BaseVectorizationTestCase.java
@@ -21,8 +21,8 @@
public abstract class BaseVectorizationTestCase extends LuceneTestCase {
- protected static final VectorizationProvider LUCENE_PROVIDER = new DefaultVectorizationProvider();
- protected static final VectorizationProvider PANAMA_PROVIDER = VectorizationProvider.lookup(true);
+ protected static final VectorizationProvider LUCENE_PROVIDER = defaultProvider();
+ protected static final VectorizationProvider PANAMA_PROVIDER = maybePanamaProvider();
@BeforeClass
public static void beforeClass() throws Exception {
@@ -30,4 +30,12 @@ public static void beforeClass() throws Exception {
"Test only works when JDK's vector incubator module is enabled.",
PANAMA_PROVIDER.getClass() != LUCENE_PROVIDER.getClass());
}
+
+ public static VectorizationProvider defaultProvider() {
+ return new DefaultVectorizationProvider();
+ }
+
+ public static VectorizationProvider maybePanamaProvider() {
+ return VectorizationProvider.lookup(true);
+ }
}
diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java
index 7064955cb5f3..3355962d7f1a 100644
--- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java
+++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java
@@ -16,10 +16,13 @@
*/
package org.apache.lucene.internal.vectorization;
+import static org.apache.lucene.util.VectorUtil.B_QUERY;
+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import java.util.Arrays;
import java.util.function.ToDoubleFunction;
import java.util.function.ToIntFunction;
+import java.util.function.ToLongFunction;
import java.util.stream.IntStream;
public class TestVectorUtilSupport extends BaseVectorizationTestCase {
@@ -142,6 +145,27 @@ static byte[] pack(byte[] unpacked) {
return packed;
}
+ public void testIpByteBin() {
+ var d = new byte[size];
+ var q = new byte[size * B_QUERY];
+ random().nextBytes(d);
+ random().nextBytes(q);
+ assertLongReturningProviders(p -> p.ipByteBinByte(q, d));
+ }
+
+ public void testIpByteBinBoundaries() {
+ var d = new byte[size];
+ var q = new byte[size * B_QUERY];
+
+ Arrays.fill(d, Byte.MAX_VALUE);
+ Arrays.fill(q, Byte.MAX_VALUE);
+ assertLongReturningProviders(p -> p.ipByteBinByte(q, d));
+
+ Arrays.fill(d, Byte.MIN_VALUE);
+ Arrays.fill(q, Byte.MIN_VALUE);
+ assertLongReturningProviders(p -> p.ipByteBinByte(q, d));
+ }
+
private void assertFloatReturningProviders(ToDoubleFunction func) {
assertEquals(
func.applyAsDouble(LUCENE_PROVIDER.getVectorUtilSupport()),
@@ -154,4 +178,10 @@ private void assertIntReturningProviders(ToIntFunction func)
func.applyAsInt(LUCENE_PROVIDER.getVectorUtilSupport()),
func.applyAsInt(PANAMA_PROVIDER.getVectorUtilSupport()));
}
+
+ private void assertLongReturningProviders(ToLongFunction func) {
+ assertEquals(
+ func.applyAsLong(LUCENE_PROVIDER.getVectorUtilSupport()),
+ func.applyAsLong(PANAMA_PROVIDER.getVectorUtilSupport()));
+ }
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
index 6e449a550028..7b56ef5ebedf 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
@@ -16,8 +16,14 @@
*/
package org.apache.lucene.util;
+import static com.carrotsearch.randomizedtesting.generators.RandomNumbers.randomIntBetween;
+import static org.apache.lucene.util.VectorUtil.B_QUERY;
+
+import java.util.Arrays;
import java.util.Random;
import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.internal.vectorization.BaseVectorizationTestCase;
+import org.apache.lucene.internal.vectorization.VectorizationProvider;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
@@ -129,6 +135,45 @@ public void testExtremeNumerics() {
}
}
+ public void testPopCount() {
+ assertEquals(0, VectorUtil.popCount(new byte[] {}));
+ assertEquals(1, VectorUtil.popCount(new byte[] {1}));
+ assertEquals(2, VectorUtil.popCount(new byte[] {2, 1}));
+ assertEquals(2, VectorUtil.popCount(new byte[] {8, 0, 1}));
+ assertEquals(4, VectorUtil.popCount(new byte[] {7, 1}));
+
+ int iterations = atLeast(50);
+ for (int i = 0; i < iterations; i++) {
+ int size = random().nextInt(5000);
+ var a = new byte[size];
+ random().nextBytes(a);
+ assertEquals(popcount(a, 0, a, size), VectorUtil.popCount(a));
+ }
+ }
+
+ public void testNorm() {
+ assertEquals(3.0f, VectorUtil.l2Norm(new float[] {3}), DELTA);
+ assertEquals(5.0f, VectorUtil.l2Norm(new float[] {5}), DELTA);
+ assertEquals(4.0f, VectorUtil.l2Norm(new float[] {2, 2, 2, 2}), DELTA);
+ assertEquals(9.0f, VectorUtil.l2Norm(new float[] {3, 3, 3, 3, 3, 3, 3, 3, 3}), DELTA);
+ }
+
+ public void testSubtract() {
+ var a = new float[] {3};
+ VectorUtil.subtract(a, new float[] {2});
+ assertArrayEquals(new float[] {1}, a, (float) DELTA);
+ a = new float[] {3, 3, 3};
+ VectorUtil.subtract(a, new float[] {1, 2, 3});
+ assertArrayEquals(new float[] {2, 1, 0}, a, (float) DELTA);
+ }
+
+ public void testL2Norm() {
+ assertEquals(3.0f, VectorUtil.l2Norm(new float[] {3}), DELTA);
+ assertEquals(5.0f, VectorUtil.l2Norm(new float[] {5}), DELTA);
+ assertEquals(4.0f, VectorUtil.l2Norm(new float[] {2, 2, 2, 2}), DELTA);
+ assertEquals(9.0f, VectorUtil.l2Norm(new float[] {3, 3, 3, 3, 3, 3, 3, 3, 3}), DELTA);
+ }
+
private static float l2(float[] v) {
float l2 = 0;
for (float x : v) {
@@ -354,6 +399,129 @@ private static int xorBitCount(byte[] a, byte[] b) {
return res;
}
+ public void testIpByteBinInvariants() {
+ int iterations = atLeast(10);
+ for (int i = 0; i < iterations; i++) {
+ int size = randomIntBetween(random(), 1, 10);
+ var d = new byte[size];
+ var q = new byte[size * B_QUERY - 1];
+ expectThrows(IllegalArgumentException.class, () -> VectorUtil.ipByteBinByte(q, d));
+ }
+ }
+
+ static final VectorizationProvider defaultedProvider =
+ BaseVectorizationTestCase.defaultProvider();
+ static final VectorizationProvider defOrPanamaProvider =
+ BaseVectorizationTestCase.maybePanamaProvider();
+
+ public void testBasicIpByteBin() {
+ testBasicIpByteBinImpl(VectorUtil::ipByteBinByte);
+ testBasicIpByteBinImpl(defaultedProvider.getVectorUtilSupport()::ipByteBinByte);
+ testBasicIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
+ }
+
+ interface IpByteBin {
+ long apply(byte[] q, byte[] d);
+ }
+
+ void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) {
+ assertEquals(15L, ipByteBinFunc.apply(new byte[] {1, 1, 1, 1}, new byte[] {1}));
+ assertEquals(30L, ipByteBinFunc.apply(new byte[] {1, 2, 1, 2, 1, 2, 1, 2}, new byte[] {1, 2}));
+
+ var d = new byte[] {1, 2, 3};
+ var q = new byte[] {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3};
+ assert scalarIpByteBin(q, d) == 60L; // 4 + 8 + 16 + 32
+ assertEquals(60L, ipByteBinFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4};
+ q = new byte[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4};
+ assert scalarIpByteBin(q, d) == 75L; // 5 + 10 + 20 + 40
+ assertEquals(75L, ipByteBinFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4, 5};
+ q = new byte[] {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
+ assert scalarIpByteBin(q, d) == 105L; // 7 + 14 + 28 + 56
+ assertEquals(105L, ipByteBinFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4, 5, 6};
+ q = new byte[] {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6};
+ assert scalarIpByteBin(q, d) == 135L; // 9 + 18 + 36 + 72
+ assertEquals(135L, ipByteBinFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4, 5, 6, 7};
+ q =
+ new byte[] {
+ 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7
+ };
+ assert scalarIpByteBin(q, d) == 180L; // 12 + 24 + 48 + 96
+ assertEquals(180L, ipByteBinFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8};
+ q =
+ new byte[] {
+ 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
+ 7, 8
+ };
+ assert scalarIpByteBin(q, d) == 195L; // 13 + 26 + 52 + 104
+ assertEquals(195L, ipByteBinFunc.apply(q, d));
+
+ d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ q =
+ new byte[] {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3,
+ 4, 5, 6, 7, 8, 9
+ };
+ assert scalarIpByteBin(q, d) == 225L; // 15 + 30 + 60 + 120
+ assertEquals(225L, ipByteBinFunc.apply(q, d));
+ }
+
+ public void testIpByteBin() {
+ testIpByteBinImpl(VectorUtil::ipByteBinByte);
+ testIpByteBinImpl(defaultedProvider.getVectorUtilSupport()::ipByteBinByte);
+ testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
+ }
+
+ void testIpByteBinImpl(IpByteBin ipByteBinFunc) {
+ int iterations = atLeast(50);
+ for (int i = 0; i < iterations; i++) {
+ int size = random().nextInt(5000);
+ var d = new byte[size];
+ var q = new byte[size * B_QUERY];
+ random().nextBytes(d);
+ random().nextBytes(q);
+ assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
+
+ Arrays.fill(d, Byte.MAX_VALUE);
+ Arrays.fill(q, Byte.MAX_VALUE);
+ assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
+
+ Arrays.fill(d, Byte.MIN_VALUE);
+ Arrays.fill(q, Byte.MIN_VALUE);
+ assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
+ }
+ }
+
+ static int scalarIpByteBin(byte[] q, byte[] d) {
+ int res = 0;
+ for (int i = 0; i < B_QUERY; i++) {
+ res += (popcount(q, i * d.length, d, d.length) << i);
+ }
+ return res;
+ }
+
+ public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
+ int res = 0;
+ for (int j = 0; j < length; j++) {
+ int value = (a[aOffset + j] & b[j]) & 0xFF;
+ for (int k = 0; k < Byte.SIZE; k++) {
+ if ((value & (1 << k)) != 0) {
+ ++res;
+ }
+ }
+ }
+ return res;
+ }
+
public void testFindNextGEQ() {
int padding = TestUtil.nextInt(random(), 0, 5);
int[] values = new int[128 + padding];
diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestBQSpaceUtils.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestBQSpaceUtils.java
new file mode 100644
index 000000000000..8dda4146a9a8
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestBQSpaceUtils.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.quantization;
+
+import org.apache.lucene.tests.util.LuceneTestCase;
+
+public class TestBQSpaceUtils extends LuceneTestCase {
+
+ private static float DELTA = Float.MIN_VALUE;
+
+ public void testPadFloat() {
+ assertArrayEquals(
+ new float[] {1, 2, 3, 4}, BQSpaceUtils.pad(new float[] {1, 2, 3, 4}, 4), DELTA);
+ assertArrayEquals(
+ new float[] {1, 2, 3, 4}, BQSpaceUtils.pad(new float[] {1, 2, 3, 4}, 3), DELTA);
+ assertArrayEquals(
+ new float[] {1, 2, 3, 4, 0}, BQSpaceUtils.pad(new float[] {1, 2, 3, 4}, 5), DELTA);
+ }
+
+ public void testPadByte() {
+ assertArrayEquals(new byte[] {1, 2, 3, 4}, BQSpaceUtils.pad(new byte[] {1, 2, 3, 4}, 4));
+ assertArrayEquals(new byte[] {1, 2, 3, 4}, BQSpaceUtils.pad(new byte[] {1, 2, 3, 4}, 3));
+ assertArrayEquals(new byte[] {1, 2, 3, 4, 0}, BQSpaceUtils.pad(new byte[] {1, 2, 3, 4}, 5));
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestBinaryQuantization.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestBinaryQuantization.java
new file mode 100644
index 000000000000..5a16e60b5f4f
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestBinaryQuantization.java
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.quantization;
+
+import java.util.Random;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.VectorUtil;
+
+public class TestBinaryQuantization extends LuceneTestCase {
+
+ public void testQuantizeForIndex() {
+ int dimensions = random().nextInt(1, 4097);
+ int discretizedDimensions = BQSpaceUtils.discretize(dimensions, 64);
+
+ int randIdx = random().nextInt(VectorSimilarityFunction.values().length);
+ VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx];
+
+ BinaryQuantizer quantizer = new BinaryQuantizer(discretizedDimensions, similarityFunction);
+
+ float[] centroid = new float[dimensions];
+ for (int i = 0; i < dimensions; i++) {
+ centroid[i] = random().nextFloat(-50f, 50f);
+ }
+
+ float[] vector = new float[dimensions];
+ for (int i = 0; i < dimensions; i++) {
+ vector[i] = random().nextFloat(-50f, 50f);
+ }
+ if (similarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(vector);
+ VectorUtil.l2normalize(centroid);
+ }
+
+ byte[] destination = new byte[discretizedDimensions / 8];
+ float[] corrections = quantizer.quantizeForIndex(vector, destination, centroid);
+
+ for (float correction : corrections) {
+ assertFalse(Float.isNaN(correction));
+ }
+
+ if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
+ assertEquals(3, corrections.length);
+ assertTrue(corrections[0] >= 0);
+ assertTrue(corrections[1] > 0);
+ } else {
+ assertEquals(2, corrections.length);
+ assertTrue(corrections[0] > 0);
+ assertTrue(corrections[1] > 0);
+ }
+ }
+
+ public void testQuantizeForQuery() {
+ int dimensions = random().nextInt(1, 4097);
+ int discretizedDimensions = BQSpaceUtils.discretize(dimensions, 64);
+
+ int randIdx = random().nextInt(VectorSimilarityFunction.values().length);
+ VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx];
+
+ BinaryQuantizer quantizer = new BinaryQuantizer(discretizedDimensions, similarityFunction);
+
+ float[] centroid = new float[dimensions];
+ for (int i = 0; i < dimensions; i++) {
+ centroid[i] = random().nextFloat(-50f, 50f);
+ }
+
+ float[] vector = new float[dimensions];
+ for (int i = 0; i < dimensions; i++) {
+ vector[i] = random().nextFloat(-50f, 50f);
+ }
+ if (similarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(vector);
+ VectorUtil.l2normalize(centroid);
+ }
+ float cDotC = VectorUtil.dotProduct(centroid, centroid);
+ byte[] destination = new byte[discretizedDimensions / 8 * BQSpaceUtils.B_QUERY];
+ float[] corrections = quantizer.quantizeForQuery(vector, destination, centroid);
+
+ if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
+ float lower = corrections[0];
+ float width = corrections[1];
+ float normVmC = corrections[2];
+ float vDotC = corrections[3];
+ float sumQ = corrections[4];
+ assertTrue(sumQ >= 0);
+ assertFalse(Float.isNaN(lower));
+ assertTrue(width >= 0);
+ assertTrue(normVmC >= 0);
+ assertFalse(Float.isNaN(vDotC));
+ assertTrue(cDotC >= 0);
+ } else {
+ float distToC = corrections[0];
+ float lower = corrections[1];
+ float width = corrections[2];
+ float sumQ = corrections[3];
+ assertTrue(sumQ >= 0);
+ assertTrue(distToC >= 0);
+ assertFalse(Float.isNaN(lower));
+ assertTrue(width >= 0);
+ }
+ }
+
+ public void testQuantizeForIndexEuclidean() {
+ int dimensions = 128;
+
+ BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, VectorSimilarityFunction.EUCLIDEAN);
+ float[] vector =
+ new float[] {
+ 0f, 0.0f, 16.0f, 35.0f, 5.0f, 32.0f, 31.0f, 14.0f, 10.0f, 11.0f, 78.0f, 55.0f, 10.0f,
+ 45.0f, 83.0f, 11.0f, 6.0f, 14.0f, 57.0f, 102.0f, 75.0f, 20.0f, 8.0f, 3.0f, 5.0f, 67.0f,
+ 17.0f, 19.0f, 26.0f, 5.0f, 0.0f, 1.0f, 22.0f, 60.0f, 26.0f, 7.0f, 1.0f, 18.0f, 22.0f,
+ 84.0f, 53.0f, 85.0f, 119.0f, 119.0f, 4.0f, 24.0f, 18.0f, 7.0f, 7.0f, 1.0f, 81.0f, 106.0f,
+ 102.0f, 72.0f, 30.0f, 6.0f, 0.0f, 9.0f, 1.0f, 9.0f, 119.0f, 72.0f, 1.0f, 4.0f, 33.0f,
+ 119.0f, 29.0f, 6.0f, 1.0f, 0.0f, 1.0f, 14.0f, 52.0f, 119.0f, 30.0f, 3.0f, 0.0f, 0.0f,
+ 55.0f, 92.0f, 111.0f, 2.0f, 5.0f, 4.0f, 9.0f, 22.0f, 89.0f, 96.0f, 14.0f, 1.0f, 0.0f,
+ 1.0f, 82.0f, 59.0f, 16.0f, 20.0f, 5.0f, 25.0f, 14.0f, 11.0f, 4.0f, 0.0f, 0.0f, 1.0f,
+ 26.0f, 47.0f, 23.0f, 4.0f, 0.0f, 0.0f, 4.0f, 38.0f, 83.0f, 30.0f, 14.0f, 9.0f, 4.0f, 9.0f,
+ 17.0f, 23.0f, 41.0f, 0.0f, 0.0f, 2.0f, 8.0f, 19.0f, 25.0f, 23.0f
+ };
+ byte[] destination = new byte[dimensions / 8];
+ float[] centroid =
+ new float[] {
+ 27.054054f, 22.252253f, 25.027027f, 23.55856f, 31.099098f, 28.765766f, 31.64865f,
+ 30.981981f, 24.675676f, 21.81982f, 26.72973f, 25.486486f, 30.504505f, 35.216217f,
+ 28.306307f, 24.486486f, 29.675676f, 26.153152f, 31.315315f, 25.225225f, 29.234234f,
+ 30.855856f, 24.495495f, 29.828829f, 31.54955f, 24.36937f, 25.108109f, 24.873875f,
+ 22.918919f, 24.918919f, 29.027027f, 25.513514f, 27.64865f, 28.405405f, 23.603603f,
+ 17.900902f, 22.522522f, 24.855856f, 31.396397f, 32.585587f, 26.297297f, 27.468468f,
+ 19.675676f, 19.018019f, 24.801802f, 30.27928f, 27.945946f, 25.324324f, 29.918919f,
+ 27.864864f, 28.081081f, 23.45946f, 28.828829f, 28.387388f, 25.387388f, 27.90991f,
+ 25.621622f, 21.585585f, 26.378378f, 24.144144f, 21.666666f, 22.72973f, 26.837837f,
+ 22.747747f, 29.0f, 28.414415f, 24.612612f, 21.594595f, 19.117117f, 24.045046f,
+ 30.612612f, 27.55856f, 25.117117f, 27.783783f, 21.639639f, 19.36937f, 21.252253f,
+ 29.153152f, 29.216217f, 24.747747f, 28.252253f, 25.288288f, 25.738739f, 23.44144f,
+ 24.423424f, 23.693693f, 26.306307f, 29.162163f, 28.684685f, 34.648647f, 25.576576f,
+ 25.288288f, 29.63063f, 20.225225f, 25.72973f, 29.009008f, 28.666666f, 29.243244f,
+ 26.36937f, 25.864864f, 21.522522f, 21.414415f, 25.963964f, 26.054054f, 25.099098f,
+ 30.477478f, 29.55856f, 24.837837f, 24.801802f, 21.18018f, 24.027027f, 26.360361f,
+ 33.153152f, 29.135136f, 30.486486f, 28.639639f, 27.576576f, 24.486486f, 26.297297f,
+ 21.774775f, 25.936937f, 35.36937f, 25.171171f, 30.405405f, 31.522522f, 29.765766f,
+ 22.324324f, 26.09009f
+ };
+ float[] corrections = quantizer.quantizeForIndex(vector, destination, centroid);
+
+ assertEquals(2, corrections.length);
+ float distToCentroid = corrections[0];
+ float magnitude = corrections[1];
+
+ assertEquals(387.90204f, distToCentroid, 0.0003f);
+ assertEquals(0.75916624f, magnitude, 0.0000001f);
+ assertArrayEquals(
+ new byte[] {20, 54, 56, 72, 97, -16, 62, 12, -32, -29, -125, 12, 0, -63, -63, -126},
+ destination);
+ }
+
+ public void testQuantizeForQueryEuclidean() {
+ int dimensions = 128;
+
+ BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, VectorSimilarityFunction.EUCLIDEAN);
+ float[] vector =
+ new float[] {
+ 0.0f, 8.0f, 69.0f, 45.0f, 2.0f, 0f, 16.0f, 52.0f, 32.0f, 13.0f, 2.0f, 6.0f, 34.0f, 49.0f,
+ 45.0f, 83.0f, 6.0f, 2.0f, 26.0f, 57.0f, 14.0f, 46.0f, 19.0f, 9.0f, 4.0f, 13.0f, 53.0f,
+ 104.0f, 33.0f, 11.0f, 25.0f, 19.0f, 30.0f, 10.0f, 7.0f, 2.0f, 8.0f, 7.0f, 25.0f, 1.0f,
+ 2.0f, 25.0f, 24.0f, 28.0f, 61.0f, 83.0f, 41.0f, 9.0f, 14.0f, 3.0f, 7.0f, 114.0f, 114.0f,
+ 114.0f, 114.0f, 5.0f, 5.0f, 1.0f, 5.0f, 114.0f, 73.0f, 75.0f, 106.0f, 3.0f, 5.0f, 6.0f,
+ 6.0f, 8.0f, 15.0f, 45.0f, 2.0f, 15.0f, 7.0f, 114.0f, 103.0f, 6.0f, 5.0f, 4.0f, 9.0f,
+ 67.0f, 47.0f, 22.0f, 32.0f, 27.0f, 41.0f, 10.0f, 114.0f, 36.0f, 43.0f, 42.0f, 23.0f, 9.0f,
+ 7.0f, 30.0f, 114.0f, 19.0f, 7.0f, 5.0f, 6.0f, 6.0f, 21.0f, 48.0f, 2.0f, 1.0f, 0.0f, 8.0f,
+ 114.0f, 13.0f, 0.0f, 1.0f, 53.0f, 83.0f, 14.0f, 8.0f, 16.0f, 12.0f, 16.0f, 20.0f, 27.0f,
+ 87.0f, 45.0f, 50.0f, 15.0f, 5.0f, 5.0f, 6.0f, 32.0f, 49.0f
+ };
+ byte[] destination = new byte[dimensions / 8 * BQSpaceUtils.B_QUERY];
+ float[] centroid =
+ new float[] {
+ 26.7f, 16.2f, 10.913f, 10.314f, 12.12f, 14.045f, 15.887f, 16.864f, 32.232f, 31.567f,
+ 34.922f, 21.624f, 16.349f, 29.625f, 31.994f, 22.044f, 37.847f, 24.622f, 36.299f, 27.966f,
+ 14.368f, 19.248f, 30.778f, 35.927f, 27.019f, 16.381f, 17.325f, 16.517f, 13.272f, 9.154f,
+ 9.242f, 17.995f, 53.777f, 23.011f, 12.929f, 16.128f, 22.16f, 28.643f, 25.861f, 27.197f,
+ 59.883f, 40.878f, 34.153f, 22.795f, 24.402f, 37.427f, 34.19f, 29.288f, 61.812f, 26.355f,
+ 39.071f, 37.789f, 23.33f, 22.299f, 28.64f, 47.828f, 52.457f, 21.442f, 24.039f, 29.781f,
+ 27.707f, 19.484f, 14.642f, 28.757f, 54.567f, 20.936f, 25.112f, 25.521f, 22.077f, 18.272f,
+ 14.526f, 29.054f, 61.803f, 24.509f, 37.517f, 35.906f, 24.106f, 22.64f, 32.1f, 48.788f,
+ 60.102f, 39.625f, 34.766f, 22.497f, 24.397f, 41.599f, 38.419f, 30.99f, 55.647f, 25.115f,
+ 14.96f, 18.882f, 26.918f, 32.442f, 26.231f, 27.107f, 26.828f, 15.968f, 18.668f, 14.071f,
+ 10.906f, 8.989f, 9.721f, 17.294f, 36.32f, 21.854f, 35.509f, 27.106f, 14.067f, 19.82f,
+ 33.582f, 35.997f, 33.528f, 30.369f, 36.955f, 21.23f, 15.2f, 30.252f, 34.56f, 22.295f,
+ 29.413f, 16.576f, 11.226f, 10.754f, 12.936f, 15.525f, 15.868f, 16.43f
+ };
+ float[] corrections = quantizer.quantizeForQuery(vector, destination, centroid);
+
+ float lower = corrections[1];
+ float width = corrections[2];
+ float sumQ = corrections[3];
+
+ assertEquals(729f, sumQ, 0);
+ assertEquals(-57.883f, lower, 0.001f);
+ assertEquals(9.972266f, width, 0.000001f);
+ assertArrayEquals(
+ new byte[] {
+ -77, -49, 73, -17, -89, 9, -43, -27, 40, 15, 42, 76, -122, 38, -22, -37, -96, 111, -63,
+ -102, -123, 23, 110, 127, 32, 95, 29, 106, -120, -121, -32, -94, 78, -98, 42, 95, 122,
+ 114, 30, 18, 91, 97, -5, -9, 123, 122, 31, -66, 49, 1, 20, 48, 0, 12, 30, 30, 4, 96, 2, 2,
+ 4, 33, 1, 65
+ },
+ destination);
+ }
+
+ private float[] generateRandomFloatArray(
+ Random random, int dimensions, float lowerBoundInclusive, float upperBoundExclusive) {
+ float[] data = new float[dimensions];
+ for (int i = 0; i < dimensions; i++) {
+ data[i] = random.nextFloat(lowerBoundInclusive, upperBoundExclusive);
+ }
+ return data;
+ }
+
+ public void testQuantizeForIndexMIP() {
+ int dimensions = 768;
+
+ // we want fixed values for these arrays so define our own random generation here to track
+ // quantization changes
+ Random random = new Random(42);
+
+ float[] mipVectorToIndex = generateRandomFloatArray(random, dimensions, -1f, 1f);
+ float[] mipCentroid = generateRandomFloatArray(random, dimensions, -1f, 1f);
+
+ VectorSimilarityFunction[] similarityFunctionsActingLikeEucllidean =
+ new VectorSimilarityFunction[] {
+ VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, VectorSimilarityFunction.DOT_PRODUCT
+ };
+ int randIdx = random().nextInt(similarityFunctionsActingLikeEucllidean.length);
+ VectorSimilarityFunction similarityFunction = similarityFunctionsActingLikeEucllidean[randIdx];
+
+ BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, similarityFunction);
+ float[] vector = mipVectorToIndex;
+ byte[] destination = new byte[dimensions / 8];
+ float[] centroid = mipCentroid;
+ float[] corrections = quantizer.quantizeForIndex(vector, destination, centroid);
+
+ assertEquals(3, corrections.length);
+ float ooq = corrections[0];
+ float normOC = corrections[1];
+ float oDotC = corrections[2];
+
+ assertEquals(0.8141399f, ooq, 0.0000001f);
+ assertEquals(21.847124f, normOC, 0.00001f);
+ assertEquals(6.4300356f, oDotC, 0.0001f);
+ assertArrayEquals(
+ new byte[] {
+ -83, -91, -71, 97, 32, -96, 89, -80, -19, -108, 3, 113, -111, 12, -86, 32, -43, 76, 122,
+ -106, -83, -37, -122, 118, 84, -72, 34, 20, 57, -29, 119, -8, -10, -100, -109, 62, -54,
+ 53, -44, 8, -16, 80, 58, 50, 105, -25, 47, 115, -106, -92, -122, -44, 8, 18, -23, 24, -15,
+ 62, 58, 111, 99, -116, -111, -5, 101, -69, -32, -74, -105, 113, -89, 44, 100, -93, -80,
+ 82, -64, 91, -87, -95, 115, 6, 76, 110, 101, 39, 108, 72, 2, 112, -63, -43, 105, -42, 9,
+ -128
+ },
+ destination);
+ }
+
+ public void testQuantizeForQueryMIP() {
+ int dimensions = 768;
+
+ // we want fixed values for these arrays so define our own random generation here to track
+ // quantization changes
+ Random random = new Random(42);
+
+ float[] mipVectorToQuery = generateRandomFloatArray(random, dimensions, -1f, 1f);
+ float[] mipCentroid = generateRandomFloatArray(random, dimensions, -1f, 1f);
+
+ VectorSimilarityFunction[] similarityFunctionsActingLikeEucllidean =
+ new VectorSimilarityFunction[] {
+ VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, VectorSimilarityFunction.DOT_PRODUCT
+ };
+ int randIdx = random().nextInt(similarityFunctionsActingLikeEucllidean.length);
+ VectorSimilarityFunction similarityFunction = similarityFunctionsActingLikeEucllidean[randIdx];
+
+ BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, similarityFunction);
+ float[] vector = mipVectorToQuery;
+ byte[] destination = new byte[dimensions / 8 * BQSpaceUtils.B_QUERY];
+ float[] centroid = mipCentroid;
+ float cDotC = VectorUtil.dotProduct(centroid, centroid);
+ float[] corrections = quantizer.quantizeForQuery(vector, destination, centroid);
+
+ float lower = corrections[0];
+ float width = corrections[1];
+ float normVmC = corrections[2];
+ float vDotC = corrections[3];
+ float sumQ = corrections[4];
+
+ assertEquals(5272, sumQ, 0);
+ assertEquals(-0.08603752f, lower, 0.00000001f);
+ assertEquals(0.011431276f, width, 0.00000001f);
+ assertEquals(21.847124f, normVmC, 0.00001f);
+ assertEquals(6.4300356f, vDotC, 0.0001f);
+ assertEquals(252.37146f, cDotC, 0.0001f);
+ assertArrayEquals(
+ new byte[] {
+ -81, 19, 67, 33, 112, 8, 40, -5, -19, 115, -87, -63, -59, 12, -2, -127, -23, 43, 24, 16,
+ -69, 112, -22, 75, -81, -50, 100, -41, 3, -120, -93, -4, 4, 125, 34, -57, -109, 89, -63,
+ -35, -116, 4, 35, 93, -26, -88, -55, -86, 63, -46, -122, -96, -26, 124, -64, 21, 96, 46,
+ 98, 97, 88, -98, -83, 121, 16, -14, -89, -118, 65, -39, -111, -35, 113, 108, 111, 86, 17,
+ -69, -47, 72, 1, 36, 17, 113, -87, -5, -46, -37, -2, 93, -123, 118, 4, -12, -33, 95, 32,
+ -63, -97, -109, 27, 111, 42, -57, -87, -41, -73, -106, 27, -31, 32, -1, 9, -88, -35, -11,
+ -103, 5, 27, -127, 108, 127, -119, 58, 38, 18, -103, -27, -63, 56, 77, -13, 3, -40, -127,
+ 37, 82, -87, -26, -45, -14, 18, -50, 76, 25, 37, -12, 106, 17, 115, 0, 23, -109, 26, -110,
+ 17, -35, 111, 4, 60, 58, -64, -104, -125, 23, -58, 89, -117, 104, -71, 3, -89, -26, 46,
+ 15, 82, -83, -75, -72, -69, 20, -38, -47, 109, -66, -66, -89, 108, -122, -3, -69, -85, 18,
+ 59, 85, -97, -114, 95, 2, -84, -77, 121, -6, 10, 110, -13, -123, -34, 106, -71, -107, 123,
+ 67, -111, 58, 52, -53, 87, -113, -21, -44, 26, 10, -62, 56, 111, 36, -126, 26, 94, -88,
+ -13, -113, -50, -9, -115, 84, 8, -32, -102, -4, 89, 29, 75, -73, -19, 22, -90, 76, -61, 4,
+ -48, -100, -11, 107, 20, -39, -98, 123, 77, 104, 9, 9, 91, -105, -40, -106, -87, 38, 48,
+ 60, 29, -68, 124, -78, -63, -101, -115, 67, -17, 101, -53, 121, 44, -78, -12, 110, 91,
+ -83, -92, -72, 96, 32, -96, 89, 48, 76, -124, 3, 113, -111, 12, -86, 32, -43, 68, 106,
+ -122, -84, -37, -124, 118, 84, -72, 34, 20, 57, -29, 119, 56, -10, -108, -109, 60, -56,
+ 37, 84, 8, -16, 80, 24, 50, 41, -25, 47, 115, -122, -92, -126, -44, 8, 18, -23, 24, -15,
+ 60, 58, 111, 99, -120, -111, -21, 101, 59, -32, -74, -105, 113, -90, 36, 100, -93, -80,
+ 82, -64, 91, -87, -95, 115, 6, 76, 110, 101, 39, 44, 0, 2, 112, -64, -47, 105, 2, 1, -128
+ },
+ destination);
+ }
+
+ public void testQuantizeForIndexCosine() {
+ int dimensions = 768;
+
+ // we want fixed values for these arrays so define our own random generation here to track
+ // quantization changes
+ Random random = new Random(42);
+
+ float[] mipVectorToIndex = generateRandomFloatArray(random, dimensions, -1f, 1f);
+ float[] mipCentroid = generateRandomFloatArray(random, dimensions, -1f, 1f);
+
+ mipVectorToIndex = VectorUtil.l2normalize(mipVectorToIndex);
+ mipCentroid = VectorUtil.l2normalize(mipCentroid);
+
+ BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, VectorSimilarityFunction.COSINE);
+ float[] vector = mipVectorToIndex;
+ byte[] destination = new byte[dimensions / 8];
+ float[] centroid = mipCentroid;
+ float[] corrections = quantizer.quantizeForIndex(vector, destination, centroid);
+
+ assertEquals(3, corrections.length);
+ float ooq = corrections[0];
+ float normOC = corrections[1];
+ float oDotC = corrections[2];
+
+ assertEquals(0.8145253f, ooq, 0.000001f);
+ assertEquals(1.3955297f, normOC, 0.00001f);
+ assertEquals(0.026248248f, oDotC, 0.0001f);
+ assertArrayEquals(
+ new byte[] {
+ -83, -91, -71, 97, 32, -96, 89, -80, -20, -108, 3, 113, -111, 12, -86, 32, -43, 76, 122,
+ -106, -83, -37, -122, 118, 84, -72, 34, 20, 57, -29, 119, -72, -10, -100, -109, 62, -54,
+ 117, -44, 8, -16, 80, 58, 50, 41, -25, 47, 115, -106, -92, -122, -44, 8, 18, -23, 24, -15,
+ 62, 58, 111, 99, -116, -111, -21, 101, -69, -32, -74, -105, 113, -90, 44, 100, -93, -80,
+ 82, -64, 91, -87, -95, 115, 6, 76, 110, 101, 39, 44, 72, 2, 112, -63, -43, 105, -42, 9,
+ -126
+ },
+ destination);
+ }
+
+ public void testQuantizeForQueryCosine() {
+ int dimensions = 768;
+
+ // we want fixed values for these arrays so define our own random generation here to track
+ // quantization changes
+ Random random = new Random(42);
+
+ float[] mipVectorToQuery = generateRandomFloatArray(random, dimensions, -1f, 1f);
+ float[] mipCentroid = generateRandomFloatArray(random, dimensions, -1f, 1f);
+
+ mipVectorToQuery = VectorUtil.l2normalize(mipVectorToQuery);
+ mipCentroid = VectorUtil.l2normalize(mipCentroid);
+
+ BinaryQuantizer quantizer = new BinaryQuantizer(dimensions, VectorSimilarityFunction.COSINE);
+ float[] vector = mipVectorToQuery;
+ byte[] destination = new byte[dimensions / 8 * BQSpaceUtils.B_QUERY];
+ float[] centroid = mipCentroid;
+ float cDotC = VectorUtil.dotProduct(centroid, centroid);
+ float[] corrections = quantizer.quantizeForQuery(vector, destination, centroid);
+
+ float lower = corrections[0];
+ float width = corrections[1];
+ float normVmC = corrections[2];
+ float vDotC = corrections[3];
+ float sumQ = corrections[4];
+
+ assertEquals(5277, sumQ, 0);
+ assertEquals(-0.086002514f, lower, 0.00000001f);
+ assertEquals(0.011431345f, width, 0.00000001f);
+ assertEquals(1.3955297f, normVmC, 0.00001f);
+ assertEquals(0.026248248f, vDotC, 0.0001f);
+ assertEquals(1.0f, cDotC, 0.0001f);
+ assertArrayEquals(
+ new byte[] {
+ -83, 18, 67, 37, 80, 8, 40, -1, -19, 115, -87, -63, -59, 12, -2, -63, -19, 43, -104, 16,
+ -69, 80, -22, 75, -81, -50, 100, -41, 7, -88, -93, -4, 4, 117, 34, -57, -109, 89, -63,
+ -35, -116, 4, 35, 93, -26, -88, -56, -82, 63, -46, -122, -96, -26, 124, -64, 21, 96, 46,
+ 114, 101, 92, -98, -83, 121, 48, -14, -89, -118, 65, -47, -79, -35, 113, 110, 111, 70, 17,
+ -69, -47, 64, 1, 102, 19, 113, -87, -5, -46, -34, -2, 93, -123, 102, 4, -12, 127, 95, 32,
+ -64, -97, -105, 59, 111, 42, -57, -87, -41, -73, -106, 27, -31, 32, -65, 9, -88, 93, -11,
+ -103, 37, 27, -127, 108, 127, -119, 58, 38, 18, -103, -27, -63, 48, 77, -13, 3, -40, -127,
+ 37, 82, -87, -26, -45, -14, 18, -49, 76, 25, 37, -12, 106, 17, 115, 0, 23, -109, 26, -126,
+ 21, -35, 111, 4, 60, 58, -64, -104, -125, 23, -58, 121, -117, 104, -69, 3, -89, -26, 46,
+ 15, 90, -83, -73, -72, -69, 20, -38, -47, 109, -66, -66, -89, 108, -122, -3, 59, -85, 18,
+ 58, 85, -101, -114, 95, 2, -84, -77, 121, -6, 10, 110, -13, -123, -34, 106, -71, -107,
+ 123, 67, -111, 58, 52, -53, 87, -113, -21, -44, 26, 10, -62, 56, 103, 36, -126, 26, 94,
+ -88, -13, -113, -50, -9, -115, 84, 8, -32, -102, -4, 89, 29, 75, -73, -19, 22, -90, 76,
+ -61, 4, -44, -100, -11, 107, 20, -39, -98, 123, 77, 104, 9, 41, 91, -105, -38, -106, -87,
+ 38, 48, 60, 29, -68, 126, -78, -63, -101, -115, 67, -17, 101, -53, 121, 44, -78, -12, -18,
+ 91, -83, -91, -72, 96, 32, -96, 89, 48, 76, -124, 3, 113, -111, 12, -86, 32, -43, 68, 106,
+ -122, -84, -37, -124, 118, 84, -72, 34, 20, 57, -29, 119, 56, -10, -100, -109, 60, -56,
+ 37, 84, 8, -16, 80, 24, 50, 41, -25, 47, 115, -122, -92, -126, -44, 8, 18, -23, 24, -15,
+ 60, 58, 107, 99, -120, -111, -21, 101, 59, -32, -74, -105, 113, -122, 36, 100, -95, -80,
+ 82, -64, 91, -87, -95, 115, 4, 76, 110, 101, 39, 44, 0, 2, 112, -64, -47, 105, 2, 1, -128
+ },
+ destination);
+ }
+}