diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 04e30b17a..5782a3871 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -35,6 +35,7 @@ import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder; +import org.opensearch.knn.index.store.partial_loading.DistanceMaxHeap; import org.opensearch.knn.index.store.partial_loading.FaissHNSW; import org.opensearch.knn.index.store.partial_loading.FlatL2DistanceComputer; import org.opensearch.knn.index.store.partial_loading.KdyHNSW; @@ -377,13 +378,14 @@ private KNNQueryResult[] kdySearch(KdyHNSW kdyHNSW, IndexInput indexInput, float FlatL2DistanceComputer l2Computer = new FlatL2DistanceComputer(queryVector, kdyHNSW.indexFlatL2.codes, kdyHNSW.indexFlatL2.oneVectorByteSize); - PriorityQueue resultsMaxHeap = + DistanceMaxHeap resultsMaxHeap = kdyHNSW.hnswFlatIndex.hnsw.hnswSearch(indexInput, searchParametersHNSW, l2Computer); KNNQueryResult[] results = new KNNQueryResult[resultsMaxHeap.size()]; - int i = 0; - while (!resultsMaxHeap.isEmpty()) { - FaissHNSW.IdAndDistance element = resultsMaxHeap.poll(); - results[i++] = new KNNQueryResult(element.id, element.distance); + int i = resultsMaxHeap.size() - 1; + while (i >= 0) { + FaissHNSW.IdAndDistance element = resultsMaxHeap.top(); + results[i--] = new KNNQueryResult(element.id, element.distance); + resultsMaxHeap.pop(); } return results; } diff --git a/src/main/java/org/opensearch/knn/index/store/partial_loading/DistanceMaxHeap.java b/src/main/java/org/opensearch/knn/index/store/partial_loading/DistanceMaxHeap.java new file mode 100644 index 000000000..55de2a249 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/store/partial_loading/DistanceMaxHeap.java @@ -0,0 +1,203 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.store.partial_loading; + +import java.util.Iterator; +import java.util.NoSuchElementException; + +public class DistanceMaxHeap implements Iterable { + private int k = 0; // Pointing the next last leaf element. + private int numValidElems = 0; + private final int maxSize; + private final FaissHNSW.IdAndDistance[] heap; + private int[] invalidIndices; + private int invalidIndicesUpto; + + public DistanceMaxHeap(int maxSize) { + final int heapSize; + + if (maxSize == 0) { + // We allocate 1 extra to avoid if statement in top() + heapSize = 2; + } else { + // NOTE: we add +1 because all access to heap is + // 1-based not 0-based. heap[0] is unused. + heapSize = maxSize + 1; + } + + // T is an unbounded type, so this unchecked cast works always. + this.heap = new FaissHNSW.IdAndDistance[heapSize]; + this.maxSize = maxSize; + + for (int i = 1; i < heapSize; i++) { + heap[i] = new FaissHNSW.IdAndDistance(0, Float.MAX_VALUE); + } + + invalidIndices = new int[2]; + invalidIndicesUpto = 0; + } + + private FaissHNSW.IdAndDistance add(int id, float distance) { + // don't modify size until we know heap access didn't throw AIOOB. + int index = k + 1; + heap[index].id = id; + heap[index].distance = distance; + k = index; + upHeap(index); + return heap[1]; + } + + private int findLastValidIndex() { + float minDistance = Float.MAX_VALUE; + int minIdx = -1; + for (int i = k; i > 0; --i) { + if (heap[i].id != -1 && heap[i].distance < minDistance) { + minIdx = i; + minDistance = heap[i].distance; + } + } + + return minIdx; + } + + public final void popMin(FaissHNSW.IdAndDistance minIad) { + final int minIdx = findLastValidIndex(); + if (invalidIndicesUpto >= invalidIndices.length) { + int[] newInvalidIndices = new int[2 * invalidIndices.length]; + System.arraycopy(invalidIndices, 0, newInvalidIndices, 0, invalidIndices.length); + invalidIndices = newInvalidIndices; + } + invalidIndices[invalidIndicesUpto++] = minIdx; + minIad.id = heap[minIdx].id; + minIad.distance = heap[minIdx].distance; + // Mark it invalid. + heap[minIdx].id = -1; + heap[minIdx].distance = Float.MIN_VALUE; + --numValidElems; + } + + public void insertWithOverflow(int id, float distance) { + if (numValidElems < maxSize) { + if (invalidIndicesUpto <= 0) { + add(id, distance); + } else { + // Find minimum invalid index. + int minIdxIdx = 0; + int minIdx = Integer.MAX_VALUE; + for (int i = 0; i < invalidIndicesUpto; ++i) { + if (invalidIndices[i] < minIdx) { + minIdx = invalidIndices[i]; + minIdxIdx = i; + } + } + if (minIdxIdx != (invalidIndicesUpto - 1)) { + for (int i = minIdxIdx + 1; i < invalidIndicesUpto; ++i) { + invalidIndices[i - 1] = invalidIndices[i]; + } + } + --invalidIndicesUpto; + + // System.out.println("minIdx=" + minIdx + ", invalidIndicesUpto=" + invalidIndicesUpto); + + heap[minIdx].id = id; + heap[minIdx].distance = distance; + upHeap(minIdx); + } + ++numValidElems; + } else if (greaterThan(heap[1].distance, distance)) { + heap[1].id = id; + heap[1].distance = distance; + updateTop(); + } + } + + private static boolean greaterThan(float a, float b) { + return a > b; + } + + public final FaissHNSW.IdAndDistance top() { + return heap[1]; + } + + public final FaissHNSW.IdAndDistance pop() { + if (k > 0) { + FaissHNSW.IdAndDistance result = heap[1]; // save first value + heap[1] = heap[k]; // move last to first + k--; + downHeap(1); // adjust heap + --numValidElems; + return result; + } else { + return null; + } + } + + private FaissHNSW.IdAndDistance updateTop() { + downHeap(1); + return heap[1]; + } + + public final int size() { + return numValidElems; + } + + public boolean isEmpty() { + return numValidElems <= 0; + } + + private boolean upHeap(int origPos) { + int i = origPos; + FaissHNSW.IdAndDistance node = heap[i]; // save bottom node + int j = i >>> 1; + while (j > 0 && greaterThan(node.distance, heap[j].distance)) { + heap[i] = heap[j]; // shift parents down + i = j; + j = j >>> 1; + } + heap[i] = node; // install saved node + return i != origPos; + } + + private void downHeap(int i) { + FaissHNSW.IdAndDistance node = heap[i]; // save top node + int j = i << 1; // find smaller child + int k = (i << 1) + 1; + if (k <= this.k && greaterThan(heap[k].distance, heap[j].distance)) { + j = k; + } + while (j <= this.k && greaterThan(heap[j].distance, node.distance)) { + heap[i] = heap[j]; // shift up child + i = j; + j = i << 1; + k = j + 1; + if (k <= this.k && greaterThan(heap[k].distance, heap[j].distance)) { + j = k; + } + } + heap[i] = node; // install saved node + } + + public FaissHNSW.IdAndDistance[] getHeapArray() { + return heap; + } + + @Override public Iterator iterator() { + return new Iterator<>() { + int i = 1; + + @Override public boolean hasNext() { + return i <= k; + } + + @Override public FaissHNSW.IdAndDistance next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return heap[i++]; + } + }; + } +} diff --git a/src/main/java/org/opensearch/knn/index/store/partial_loading/FaissHNSW.java b/src/main/java/org/opensearch/knn/index/store/partial_loading/FaissHNSW.java index f30961bf9..a9f34b884 100644 --- a/src/main/java/org/opensearch/knn/index/store/partial_loading/FaissHNSW.java +++ b/src/main/java/org/opensearch/knn/index/store/partial_loading/FaissHNSW.java @@ -6,11 +6,11 @@ package org.opensearch.knn.index.store.partial_loading; import lombok.AllArgsConstructor; +import lombok.ToString; import org.apache.lucene.store.IndexInput; import java.io.IOException; import java.util.HashSet; -import java.util.PriorityQueue; import java.util.Set; public class FaissHNSW { @@ -32,12 +32,12 @@ public FaissHNSW(int M) { // ??? set_default_probas(M, 1.0 / log(M)); } - @AllArgsConstructor public static class IdAndDistance { + @AllArgsConstructor @ToString public static class IdAndDistance { public int id; public float distance; } - public PriorityQueue hnswSearch( + public DistanceMaxHeap hnswSearch( IndexInput indexInput, SearchParametersHNSW parametersHNSW, DistanceComputer distanceComputer ) throws IOException { IdAndDistance nearest = new IdAndDistance(entryPoint, distanceComputer.compute(indexInput, entryPoint)); @@ -48,61 +48,36 @@ public PriorityQueue hnswSearch( } final int ef = Math.max(parametersHNSW.efSearch, parametersHNSW.k); - // System.out.println(" +++++++++++++++++ hnswSearch, ef=" - // + ef + ", k=" + parametersHNSW.k + ", efSearch=" + parametersHNSW.efSearch - // + ", nearest.id=" + nearest.id - // + ", nearest.distance=" + nearest.distance); - PriorityQueue resultMaxHeap = new PriorityQueue<>((a, b) -> Float.compare(b.distance, a.distance)); - PriorityQueue candidates = new PriorityQueue<>((a, b) -> Float.compare(b.distance, a.distance)); - candidates.add(nearest); - searchFromCandidates(indexInput, distanceComputer, resultMaxHeap, candidates, parametersHNSW.k, ef, 0); + // // System.out.println(" +++++++++++++++++ hnswSearch, ef=" + // // + ef + ", k=" + parametersHNSW.k + ", efSearch=" + parametersHNSW.efSearch + // // + ", nearest.id=" + nearest.id + // // + ", nearest.distance=" + nearest.distance); + DistanceMaxHeap resultMaxHeap = new DistanceMaxHeap(parametersHNSW.k); + DistanceMaxHeap candidates = new DistanceMaxHeap(ef); + candidates.insertWithOverflow(nearest.id, nearest.distance); + searchFromCandidates(indexInput, distanceComputer, resultMaxHeap, candidates, 0); return resultMaxHeap; } - private void addToBoundedMaxHeap(IdAndDistance idAndDistance, PriorityQueue maxHeap, int maxLength) { - maxHeap.add(idAndDistance); - while (maxHeap.size() > maxLength) { - maxHeap.poll(); - } - } - - private float addToResultMaxHeap(IdAndDistance idAndDistance, PriorityQueue resultMaxHeap, int maxLength) { - if (resultMaxHeap.size() < maxLength) { - resultMaxHeap.add(idAndDistance); - } else { - final float threshold = resultMaxHeap.isEmpty() ? Float.MAX_VALUE : resultMaxHeap.peek().distance; - if (idAndDistance.distance < threshold) { - resultMaxHeap.add(idAndDistance); - while (resultMaxHeap.size() > maxLength) { - resultMaxHeap.poll(); - } - } - } - return resultMaxHeap.isEmpty() ? Float.MAX_VALUE : resultMaxHeap.peek().distance; - } - - private void addToMaxHeaps( - int id, float distance, PriorityQueue resultMaxHeap, int k, PriorityQueue candidates, int ef + private static float addToMaxHeaps( + int id, float distance, float threshold, DistanceMaxHeap resultMaxHeap, DistanceMaxHeap candidates ) { - final IdAndDistance idAndDistance = new IdAndDistance(id, distance); - addToResultMaxHeap(idAndDistance, resultMaxHeap, k); - addToBoundedMaxHeap(idAndDistance, candidates, ef); + if (distance < threshold) { + resultMaxHeap.insertWithOverflow(id, distance); + } + candidates.insertWithOverflow(id, distance); + return resultMaxHeap.isEmpty() ? Float.MAX_VALUE : resultMaxHeap.top().distance; } private void searchFromCandidates( - IndexInput indexInput, - DistanceComputer distanceComputer, - PriorityQueue resultMaxHeap, - PriorityQueue candidates, - int k, - int ef, - int level + IndexInput indexInput, DistanceComputer distanceComputer, DistanceMaxHeap resultMaxHeap, DistanceMaxHeap candidates, int level ) throws IOException { final Set idSet = new HashSet<>(); - float threshold = resultMaxHeap.isEmpty() ? Float.MAX_VALUE : resultMaxHeap.peek().distance; + float threshold = resultMaxHeap.isEmpty() ? Float.MAX_VALUE : resultMaxHeap.top().distance; for (final IdAndDistance candidate : candidates) { if (candidate.distance < threshold) { - threshold = addToResultMaxHeap(candidate, resultMaxHeap, k); + resultMaxHeap.insertWithOverflow(candidate.id, candidate.distance); + threshold = resultMaxHeap.top().distance; } idSet.add(candidate.id); } @@ -111,9 +86,10 @@ private void searchFromCandidates( int[] savedNeighborIds = new int[4]; float[] distances = new float[4]; + final IdAndDistance currMin = new IdAndDistance(0, 0); while (!candidates.isEmpty()) { - final IdAndDistance currMin = candidates.poll(); + candidates.popMin(currMin); long begin, end; final long o = offsets.readLong(indexInput, currMin.id); @@ -147,10 +123,10 @@ private void searchFromCandidates( if (count == 4) { distanceComputer.computeBatch4(indexInput, savedNeighborIds, distances); - addToMaxHeaps(savedNeighborIds[0], distances[0], resultMaxHeap, k, candidates, ef); - addToMaxHeaps(savedNeighborIds[1], distances[1], resultMaxHeap, k, candidates, ef); - addToMaxHeaps(savedNeighborIds[2], distances[2], resultMaxHeap, k, candidates, ef); - addToMaxHeaps(savedNeighborIds[3], distances[3], resultMaxHeap, k, candidates, ef); + threshold = addToMaxHeaps(savedNeighborIds[0], distances[0], threshold, resultMaxHeap, candidates); + threshold = addToMaxHeaps(savedNeighborIds[1], distances[1], threshold, resultMaxHeap, candidates); + threshold = addToMaxHeaps(savedNeighborIds[2], distances[2], threshold, resultMaxHeap, candidates); + threshold = addToMaxHeaps(savedNeighborIds[3], distances[3], threshold, resultMaxHeap, candidates); count = 0; } // End if @@ -158,7 +134,7 @@ private void searchFromCandidates( for (int i = 0; i < count; ++i) { final float distance = distanceComputer.compute(indexInput, savedNeighborIds[i]); - addToMaxHeaps(savedNeighborIds[i], distance, resultMaxHeap, k, candidates, ef); + threshold = addToMaxHeaps(savedNeighborIds[i], distance, threshold, resultMaxHeap, candidates); } } // End while } diff --git a/src/main/java/org/opensearch/knn/index/store/partial_loading/FlatL2DistanceComputer.java b/src/main/java/org/opensearch/knn/index/store/partial_loading/FlatL2DistanceComputer.java index 54e287880..89959bee3 100644 --- a/src/main/java/org/opensearch/knn/index/store/partial_loading/FlatL2DistanceComputer.java +++ b/src/main/java/org/opensearch/knn/index/store/partial_loading/FlatL2DistanceComputer.java @@ -13,77 +13,36 @@ public class FlatL2DistanceComputer extends DistanceComputer { private float[] queryVector; private int dimension; private float[] floatBuffer1; - private float[] floatBuffer2; - private float[] floatBuffer3; - private float[] floatBuffer4; private Storage codes; private long oneVectorByteSize; - private byte[] bytesBuffer; public FlatL2DistanceComputer(float[] queryVector, Storage codes, long oneVectorByteSize) { this.queryVector = queryVector; this.dimension = queryVector.length; this.floatBuffer1 = new float[dimension]; - this.floatBuffer2 = new float[dimension]; - this.floatBuffer3 = new float[dimension]; - this.floatBuffer4 = new float[dimension]; - this.bytesBuffer = new byte[4 * dimension]; this.codes = codes; this.oneVectorByteSize = oneVectorByteSize; } @Override public float compute(IndexInput indexInput, long index) throws IOException { populateFloats(indexInput, index, floatBuffer1); - float result = 0; - for (int i = 0; i < dimension; ++i) { - final float delta = queryVector[i] - floatBuffer1[i]; - // System.out.println("queryVector[i]=" + queryVector[i] - // + ", floatBuffer1[i]=" + floatBuffer1[i] - // + ", delta=" + delta); - result += delta * delta; - } - return result; + return LuceneVectorUtilSupportProxy.squareDistance(queryVector, floatBuffer1); } private void populateFloats(IndexInput indexInput, long index, float[] floats) throws IOException { // System.out.println("populateFloats, index=" + index + ", oneVectorByteSize=" + oneVectorByteSize); - codes.readBytes(indexInput, index * oneVectorByteSize, bytesBuffer); - for (int i = 0, j = 0; i < bytesBuffer.length ; i += 4, ++j) { - final int intBits = ((255 & bytesBuffer[i])) | ((255 & bytesBuffer[i + 1]) << 8) - | ((255 & bytesBuffer[i + 2]) << 16) | ((255 & bytesBuffer[i + 3]) << 24); - // System.out.println("intBits=" + intBits); - // System.out.println("++++++++ populateFloats, " - // + "b[0]=" + bytesBuffer[0] + ", b[1]=" + bytesBuffer[1] - // + ", b[2]=" + bytesBuffer[2] + ", b[3]=" + bytesBuffer[3]); - floats[j] = Float.intBitsToFloat(intBits); - } + indexInput.seek(codes.baseOffset + index * oneVectorByteSize); + indexInput.readFloats(floats, 0, floats.length); } - @Override - public void computeBatch4(IndexInput indexInput, int[] ids, float[] distances) throws IOException { + @Override public void computeBatch4(IndexInput indexInput, int[] ids, float[] distances) throws IOException { populateFloats(indexInput, ids[0], floatBuffer1); - populateFloats(indexInput, ids[1], floatBuffer2); - populateFloats(indexInput, ids[2], floatBuffer3); - populateFloats(indexInput, ids[3], floatBuffer4); - - float d0 = 0; - float d1 = 0; - float d2 = 0; - float d3 = 0; - for (int i = 0; i < dimension; i++) { - float q0 = queryVector[i] - floatBuffer1[i]; - float q1 = queryVector[i] - floatBuffer2[i]; - float q2 = queryVector[i] - floatBuffer3[i]; - float q3 = queryVector[i] - floatBuffer4[i]; - d0 += q0 * q0; - d1 += q1 * q1; - d2 += q2 * q2; - d3 += q3 * q3; - } - - distances[0] = d0; - distances[1] = d1; - distances[2] = d2; - distances[3] = d3; + distances[0] = LuceneVectorUtilSupportProxy.squareDistance(queryVector, floatBuffer1); + populateFloats(indexInput, ids[1], floatBuffer1); + distances[1] = LuceneVectorUtilSupportProxy.squareDistance(queryVector, floatBuffer1); + populateFloats(indexInput, ids[2], floatBuffer1); + distances[2] = LuceneVectorUtilSupportProxy.squareDistance(queryVector, floatBuffer1); + populateFloats(indexInput, ids[3], floatBuffer1); + distances[3] = LuceneVectorUtilSupportProxy.squareDistance(queryVector, floatBuffer1); } } diff --git a/src/main/java/org/opensearch/knn/index/store/partial_loading/LuceneVectorUtilSupportProxy.java b/src/main/java/org/opensearch/knn/index/store/partial_loading/LuceneVectorUtilSupportProxy.java new file mode 100644 index 000000000..bd690ff99 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/store/partial_loading/LuceneVectorUtilSupportProxy.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.store.partial_loading; + +import org.apache.lucene.util.Constants; + +public class LuceneVectorUtilSupportProxy { + + private static float fma(float a, float b, float c) { + return Constants.HAS_FAST_SCALAR_FMA ? Math.fma(a, b, c) : a * b + c; + } + + public static float squareDistance(float[] a, float[] b) { + float res = 0.0F; + int i = 0; + float acc1; + if (a.length > 32) { + acc1 = 0.0F; + float acc2 = 0.0F; + float acc3 = 0.0F; + float acc4 = 0.0F; + + for (int upperBound = a.length & -4; i < upperBound; i += 4) { + float diff1 = a[i] - b[i]; + acc1 = fma(diff1, diff1, acc1); + float diff2 = a[i + 1] - b[i + 1]; + acc2 = fma(diff2, diff2, acc2); + float diff3 = a[i + 2] - b[i + 2]; + acc3 = fma(diff3, diff3, acc3); + float diff4 = a[i + 3] - b[i + 3]; + acc4 = fma(diff4, diff4, acc4); + } + + res += acc1 + acc2 + acc3 + acc4; + } + + while (i < a.length) { + acc1 = a[i] - b[i]; + res = fma(acc1, acc1, res); + ++i; + } + + return res; + } +}