diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java index aebed5642383..c23f56bcdc62 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.FixedBitSet; @@ -56,7 +57,7 @@ public HnswConcurrentMergeBuilder( this.taskExecutor = taskExecutor; AtomicInteger workProgress = new AtomicInteger(0); workers = new ConcurrentMergeWorker[numWorker]; - hnswLock = new HnswLock(hnsw); + hnswLock = new HnswLock(); for (int i = 0; i < numWorker; i++) { workers[i] = new ConcurrentMergeWorker( @@ -221,13 +222,16 @@ private MergeSearcher(NeighborQueue candidates, HnswLock hnswLock, BitSet visite @Override void graphSeek(HnswGraph graph, int level, int targetNode) { - try (HnswLock.LockedRow rowLock = hnswLock.read(level, targetNode)) { - NeighborArray neighborArray = rowLock.row(); + Lock lock = hnswLock.read(level, targetNode); + try { + NeighborArray neighborArray = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode); if (nodeBuffer == null || nodeBuffer.length < neighborArray.size()) { nodeBuffer = new int[neighborArray.size()]; } size = neighborArray.size(); - if (size >= 0) System.arraycopy(neighborArray.nodes(), 0, nodeBuffer, 0, size); + System.arraycopy(neighborArray.nodes(), 0, nodeBuffer, 0, size); + } finally { + lock.unlock(); } upto = -1; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 1f5253ef7f85..c5a55f2d5cd2 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -27,6 +27,7 @@ import java.util.Objects; import java.util.SplittableRandom; import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Lock; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.FixedBitSet; @@ -338,9 +339,12 @@ private void addDiverseNeighbors(int level, int node, NeighborArray candidates) } int nbr = candidates.nodes()[i]; if (hnswLock != null) { - try (HnswLock.LockedRow rowLock = hnswLock.write(level, nbr)) { - NeighborArray nbrsOfNbr = rowLock.row(); + Lock lock = hnswLock.write(level, nbr); + try { + NeighborArray nbrsOfNbr = getGraph().getNeighbors(level, nbr); nbrsOfNbr.addAndEnsureDiversity(node, candidates.scores()[i], nbr, scorerSupplier); + } finally { + lock.unlock(); } } else { NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswLock.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswLock.java index d8b12f67f4a6..15ec34c1e712 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswLock.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswLock.java @@ -17,49 +17,39 @@ package org.apache.lucene.util.hnsw; -import java.io.Closeable; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantReadWriteLock; /** - * Provide (read-and-write) locked access to rows of an OnHeapHnswGraph. For use by - * HnswConcurrentMerger and its HnswGraphBuilders. + * Provide (read-and-write) striped locks for access to nodes of an {@link OnHeapHnswGraph}. For use + * by {@link HnswConcurrentMergeBuilder} and its HnswGraphBuilders. */ final class HnswLock { private static final int NUM_LOCKS = 512; private final ReentrantReadWriteLock[] locks; - private final OnHeapHnswGraph graph; - HnswLock(OnHeapHnswGraph graph) { - this.graph = graph; + HnswLock() { locks = new ReentrantReadWriteLock[NUM_LOCKS]; for (int i = 0; i < NUM_LOCKS; i++) { locks[i] = new ReentrantReadWriteLock(); } } - LockedRow read(int level, int node) { + Lock read(int level, int node) { int lockid = hash(level, node) % NUM_LOCKS; Lock lock = locks[lockid].readLock(); lock.lock(); - return new LockedRow(graph.getNeighbors(level, node), lock); + return lock; } - LockedRow write(int level, int node) { + Lock write(int level, int node) { int lockid = hash(level, node) % NUM_LOCKS; Lock lock = locks[lockid].writeLock(); lock.lock(); - return new LockedRow(graph.getNeighbors(level, node), lock); + return lock; } - record LockedRow(NeighborArray row, Lock lock) implements Closeable { - @Override - public void close() { - lock.unlock(); - } - } - - static int hash(int v1, int v2) { + private static int hash(int v1, int v2) { return v1 * 31 + v2; } }