Skip to content

Commit

Permalink
Fixed heap bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
Dooyong Kim committed Dec 20, 2024
1 parent 7be04f2 commit f85120c
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 110 deletions.
12 changes: 7 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder;
import org.opensearch.knn.index.store.partial_loading.DistanceMaxHeap;
import org.opensearch.knn.index.store.partial_loading.FaissHNSW;
import org.opensearch.knn.index.store.partial_loading.FlatL2DistanceComputer;
import org.opensearch.knn.index.store.partial_loading.KdyHNSW;
Expand Down Expand Up @@ -377,13 +378,14 @@ private KNNQueryResult[] kdySearch(KdyHNSW kdyHNSW, IndexInput indexInput, float

FlatL2DistanceComputer l2Computer =
new FlatL2DistanceComputer(queryVector, kdyHNSW.indexFlatL2.codes, kdyHNSW.indexFlatL2.oneVectorByteSize);
PriorityQueue<FaissHNSW.IdAndDistance> resultsMaxHeap =
DistanceMaxHeap resultsMaxHeap =
kdyHNSW.hnswFlatIndex.hnsw.hnswSearch(indexInput, searchParametersHNSW, l2Computer);
KNNQueryResult[] results = new KNNQueryResult[resultsMaxHeap.size()];
int i = 0;
while (!resultsMaxHeap.isEmpty()) {
FaissHNSW.IdAndDistance element = resultsMaxHeap.poll();
results[i++] = new KNNQueryResult(element.id, element.distance);
int i = resultsMaxHeap.size() - 1;
while (i >= 0) {
FaissHNSW.IdAndDistance element = resultsMaxHeap.top();
results[i--] = new KNNQueryResult(element.id, element.distance);
resultsMaxHeap.pop();
}
return results;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.store.partial_loading;

import java.util.Iterator;
import java.util.NoSuchElementException;

public class DistanceMaxHeap implements Iterable<FaissHNSW.IdAndDistance> {
private int k = 0; // Pointing the next last leaf element.
private int numValidElems = 0;
private final int maxSize;
private final FaissHNSW.IdAndDistance[] heap;
private int[] invalidIndices;
private int invalidIndicesUpto;

public DistanceMaxHeap(int maxSize) {
final int heapSize;

if (maxSize == 0) {
// We allocate 1 extra to avoid if statement in top()
heapSize = 2;
} else {
// NOTE: we add +1 because all access to heap is
// 1-based not 0-based. heap[0] is unused.
heapSize = maxSize + 1;
}

// T is an unbounded type, so this unchecked cast works always.
this.heap = new FaissHNSW.IdAndDistance[heapSize];
this.maxSize = maxSize;

for (int i = 1; i < heapSize; i++) {
heap[i] = new FaissHNSW.IdAndDistance(0, Float.MAX_VALUE);
}

invalidIndices = new int[2];
invalidIndicesUpto = 0;
}

private FaissHNSW.IdAndDistance add(int id, float distance) {
// don't modify size until we know heap access didn't throw AIOOB.
int index = k + 1;
heap[index].id = id;
heap[index].distance = distance;
k = index;
upHeap(index);
return heap[1];
}

private int findLastValidIndex() {
float minDistance = Float.MAX_VALUE;
int minIdx = -1;
for (int i = k; i > 0; --i) {
if (heap[i].id != -1 && heap[i].distance < minDistance) {
minIdx = i;
minDistance = heap[i].distance;
}
}

return minIdx;
}

public final void popMin(FaissHNSW.IdAndDistance minIad) {
final int minIdx = findLastValidIndex();
if (invalidIndicesUpto >= invalidIndices.length) {
int[] newInvalidIndices = new int[2 * invalidIndices.length];
System.arraycopy(invalidIndices, 0, newInvalidIndices, 0, invalidIndices.length);
invalidIndices = newInvalidIndices;
}
invalidIndices[invalidIndicesUpto++] = minIdx;
minIad.id = heap[minIdx].id;
minIad.distance = heap[minIdx].distance;
// Mark it invalid.
heap[minIdx].id = -1;
heap[minIdx].distance = Float.MIN_VALUE;
--numValidElems;
}

public void insertWithOverflow(int id, float distance) {
if (numValidElems < maxSize) {
if (invalidIndicesUpto <= 0) {
add(id, distance);
} else {
// Find minimum invalid index.
int minIdxIdx = 0;
int minIdx = Integer.MAX_VALUE;
for (int i = 0; i < invalidIndicesUpto; ++i) {
if (invalidIndices[i] < minIdx) {
minIdx = invalidIndices[i];
minIdxIdx = i;
}
}
if (minIdxIdx != (invalidIndicesUpto - 1)) {
for (int i = minIdxIdx + 1; i < invalidIndicesUpto; ++i) {
invalidIndices[i - 1] = invalidIndices[i];
}
}
--invalidIndicesUpto;

// System.out.println("minIdx=" + minIdx + ", invalidIndicesUpto=" + invalidIndicesUpto);

heap[minIdx].id = id;
heap[minIdx].distance = distance;
upHeap(minIdx);
}
++numValidElems;
} else if (greaterThan(heap[1].distance, distance)) {
heap[1].id = id;
heap[1].distance = distance;
updateTop();
}
}

private static boolean greaterThan(float a, float b) {
return a > b;
}

public final FaissHNSW.IdAndDistance top() {
return heap[1];
}

public final FaissHNSW.IdAndDistance pop() {
if (k > 0) {
FaissHNSW.IdAndDistance result = heap[1]; // save first value
heap[1] = heap[k]; // move last to first
k--;
downHeap(1); // adjust heap
--numValidElems;
return result;
} else {
return null;
}
}

private FaissHNSW.IdAndDistance updateTop() {
downHeap(1);
return heap[1];
}

public final int size() {
return numValidElems;
}

public boolean isEmpty() {
return numValidElems <= 0;
}

private boolean upHeap(int origPos) {
int i = origPos;
FaissHNSW.IdAndDistance node = heap[i]; // save bottom node
int j = i >>> 1;
while (j > 0 && greaterThan(node.distance, heap[j].distance)) {
heap[i] = heap[j]; // shift parents down
i = j;
j = j >>> 1;
}
heap[i] = node; // install saved node
return i != origPos;
}

private void downHeap(int i) {
FaissHNSW.IdAndDistance node = heap[i]; // save top node
int j = i << 1; // find smaller child
int k = (i << 1) + 1;
if (k <= this.k && greaterThan(heap[k].distance, heap[j].distance)) {
j = k;
}
while (j <= this.k && greaterThan(heap[j].distance, node.distance)) {
heap[i] = heap[j]; // shift up child
i = j;
j = i << 1;
k = j + 1;
if (k <= this.k && greaterThan(heap[k].distance, heap[j].distance)) {
j = k;
}
}
heap[i] = node; // install saved node
}

public FaissHNSW.IdAndDistance[] getHeapArray() {
return heap;
}

@Override public Iterator<FaissHNSW.IdAndDistance> iterator() {
return new Iterator<>() {
int i = 1;

@Override public boolean hasNext() {
return i <= k;
}

@Override public FaissHNSW.IdAndDistance next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
return heap[i++];
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
package org.opensearch.knn.index.store.partial_loading;

import lombok.AllArgsConstructor;
import lombok.ToString;
import org.apache.lucene.store.IndexInput;

import java.io.IOException;
import java.util.HashSet;
import java.util.PriorityQueue;
import java.util.Set;

public class FaissHNSW {
Expand All @@ -32,12 +32,12 @@ public FaissHNSW(int M) {
// ??? set_default_probas(M, 1.0 / log(M));
}

@AllArgsConstructor public static class IdAndDistance {
@AllArgsConstructor @ToString public static class IdAndDistance {
public int id;
public float distance;
}

public PriorityQueue<IdAndDistance> hnswSearch(
public DistanceMaxHeap hnswSearch(
IndexInput indexInput, SearchParametersHNSW parametersHNSW, DistanceComputer distanceComputer
) throws IOException {
IdAndDistance nearest = new IdAndDistance(entryPoint, distanceComputer.compute(indexInput, entryPoint));
Expand All @@ -48,61 +48,36 @@ public PriorityQueue<IdAndDistance> hnswSearch(
}

final int ef = Math.max(parametersHNSW.efSearch, parametersHNSW.k);
// System.out.println(" +++++++++++++++++ hnswSearch, ef="
// + ef + ", k=" + parametersHNSW.k + ", efSearch=" + parametersHNSW.efSearch
// + ", nearest.id=" + nearest.id
// + ", nearest.distance=" + nearest.distance);
PriorityQueue<IdAndDistance> resultMaxHeap = new PriorityQueue<>((a, b) -> Float.compare(b.distance, a.distance));
PriorityQueue<IdAndDistance> candidates = new PriorityQueue<>((a, b) -> Float.compare(b.distance, a.distance));
candidates.add(nearest);
searchFromCandidates(indexInput, distanceComputer, resultMaxHeap, candidates, parametersHNSW.k, ef, 0);
// // System.out.println(" +++++++++++++++++ hnswSearch, ef="
// // + ef + ", k=" + parametersHNSW.k + ", efSearch=" + parametersHNSW.efSearch
// // + ", nearest.id=" + nearest.id
// // + ", nearest.distance=" + nearest.distance);
DistanceMaxHeap resultMaxHeap = new DistanceMaxHeap(parametersHNSW.k);
DistanceMaxHeap candidates = new DistanceMaxHeap(ef);
candidates.insertWithOverflow(nearest.id, nearest.distance);
searchFromCandidates(indexInput, distanceComputer, resultMaxHeap, candidates, 0);
return resultMaxHeap;
}

private void addToBoundedMaxHeap(IdAndDistance idAndDistance, PriorityQueue<IdAndDistance> maxHeap, int maxLength) {
maxHeap.add(idAndDistance);
while (maxHeap.size() > maxLength) {
maxHeap.poll();
}
}

private float addToResultMaxHeap(IdAndDistance idAndDistance, PriorityQueue<IdAndDistance> resultMaxHeap, int maxLength) {
if (resultMaxHeap.size() < maxLength) {
resultMaxHeap.add(idAndDistance);
} else {
final float threshold = resultMaxHeap.isEmpty() ? Float.MAX_VALUE : resultMaxHeap.peek().distance;
if (idAndDistance.distance < threshold) {
resultMaxHeap.add(idAndDistance);
while (resultMaxHeap.size() > maxLength) {
resultMaxHeap.poll();
}
}
}
return resultMaxHeap.isEmpty() ? Float.MAX_VALUE : resultMaxHeap.peek().distance;
}

private void addToMaxHeaps(
int id, float distance, PriorityQueue<IdAndDistance> resultMaxHeap, int k, PriorityQueue<IdAndDistance> candidates, int ef
private static float addToMaxHeaps(
int id, float distance, float threshold, DistanceMaxHeap resultMaxHeap, DistanceMaxHeap candidates
) {
final IdAndDistance idAndDistance = new IdAndDistance(id, distance);
addToResultMaxHeap(idAndDistance, resultMaxHeap, k);
addToBoundedMaxHeap(idAndDistance, candidates, ef);
if (distance < threshold) {
resultMaxHeap.insertWithOverflow(id, distance);
}
candidates.insertWithOverflow(id, distance);
return resultMaxHeap.isEmpty() ? Float.MAX_VALUE : resultMaxHeap.top().distance;
}

private void searchFromCandidates(
IndexInput indexInput,
DistanceComputer distanceComputer,
PriorityQueue<IdAndDistance> resultMaxHeap,
PriorityQueue<IdAndDistance> candidates,
int k,
int ef,
int level
IndexInput indexInput, DistanceComputer distanceComputer, DistanceMaxHeap resultMaxHeap, DistanceMaxHeap candidates, int level
) throws IOException {
final Set<Integer> idSet = new HashSet<>();
float threshold = resultMaxHeap.isEmpty() ? Float.MAX_VALUE : resultMaxHeap.peek().distance;
float threshold = resultMaxHeap.isEmpty() ? Float.MAX_VALUE : resultMaxHeap.top().distance;
for (final IdAndDistance candidate : candidates) {
if (candidate.distance < threshold) {
threshold = addToResultMaxHeap(candidate, resultMaxHeap, k);
resultMaxHeap.insertWithOverflow(candidate.id, candidate.distance);
threshold = resultMaxHeap.top().distance;
}
idSet.add(candidate.id);
}
Expand All @@ -111,9 +86,10 @@ private void searchFromCandidates(

int[] savedNeighborIds = new int[4];
float[] distances = new float[4];
final IdAndDistance currMin = new IdAndDistance(0, 0);

while (!candidates.isEmpty()) {
final IdAndDistance currMin = candidates.poll();
candidates.popMin(currMin);

long begin, end;
final long o = offsets.readLong(indexInput, currMin.id);
Expand Down Expand Up @@ -147,18 +123,18 @@ private void searchFromCandidates(
if (count == 4) {
distanceComputer.computeBatch4(indexInput, savedNeighborIds, distances);

addToMaxHeaps(savedNeighborIds[0], distances[0], resultMaxHeap, k, candidates, ef);
addToMaxHeaps(savedNeighborIds[1], distances[1], resultMaxHeap, k, candidates, ef);
addToMaxHeaps(savedNeighborIds[2], distances[2], resultMaxHeap, k, candidates, ef);
addToMaxHeaps(savedNeighborIds[3], distances[3], resultMaxHeap, k, candidates, ef);
threshold = addToMaxHeaps(savedNeighborIds[0], distances[0], threshold, resultMaxHeap, candidates);
threshold = addToMaxHeaps(savedNeighborIds[1], distances[1], threshold, resultMaxHeap, candidates);
threshold = addToMaxHeaps(savedNeighborIds[2], distances[2], threshold, resultMaxHeap, candidates);
threshold = addToMaxHeaps(savedNeighborIds[3], distances[3], threshold, resultMaxHeap, candidates);

count = 0;
} // End if
} // End for

for (int i = 0; i < count; ++i) {
final float distance = distanceComputer.compute(indexInput, savedNeighborIds[i]);
addToMaxHeaps(savedNeighborIds[i], distance, resultMaxHeap, k, candidates, ef);
threshold = addToMaxHeaps(savedNeighborIds[i], distance, threshold, resultMaxHeap, candidates);
}
} // End while
}
Expand Down
Loading

0 comments on commit f85120c

Please sign in to comment.