Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance: simplify and optimize kth-greatest computation (96% recall at 195 qps) #616

Merged
merged 7 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/pages/performance/fashion-mnist/plot.b64

Large diffs are not rendered by default.

Binary file modified docs/pages/performance/fashion-mnist/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 8 additions & 8 deletions docs/pages/performance/fashion-mnist/results.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
|Model|Parameters|Recall|Queries per Second|
|---|---|---|---|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.378|337.457|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.446|281.828|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.634|272.814|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|232.698|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|303.686|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.846|254.121|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.922|215.233|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|190.689|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.379|353.162|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.447|295.007|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.634|286.531|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|245.690|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|312.826|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.846|265.204|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.921|221.817|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|195.653|
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.KthGreatest;

/**
* Use an array of counts to count hits. The index of the array is the doc id.
* Hopefully there's a way to do this that doesn't require O(num docs in segment) time and memory,
Expand All @@ -14,29 +12,36 @@ public class ArrayHitCounter implements HitCounter {
private int minKey;
private int maxKey;

private short maxValue;

public ArrayHitCounter(int capacity) {
counts = new short[capacity];
numHits = 0;
minKey = capacity;
maxKey = 0;
maxValue = 0;
}

@Override
public void increment(int key) {
if (counts[key]++ == 0) {
short after = ++counts[key];
if (after == 1) {
numHits++;
minKey = Math.min(key, minKey);
maxKey = Math.max(key, maxKey);
}
if (after > maxValue) maxValue = after;
}

@Override
public void increment(int key, short count) {
if ((counts[key] += count) == count) {
short after = (counts[key] += count);
if (after == count) {
numHits++;
minKey = Math.min(key, minKey);
maxKey = Math.max(key, maxKey);
}
if (after > maxValue) maxValue = after;
}

@Override
Expand Down Expand Up @@ -70,8 +75,34 @@ public int maxKey() {
}

@Override
public KthGreatest.Result kthGreatest(int k) {
return KthGreatest.kthGreatest(counts, Math.min(k, counts.length - 1));
}
public KthGreatestResult kthGreatest(int k) {
// Find the kth greatest document hit count in O(n) time and O(n) space.
// Though the space is typically negligibly small in practice.
// This implementation exploits the fact that we're specifically counting document hit counts.
// Counts are integers, and they're likely to be pretty small, since we're unlikely to match
// the same document many times.

// Start by building a histogram of all counts.
// e.g., if the counts are [0, 4, 1, 1, 2],
// then the histogram is [1, 2, 1, 0, 1],
// because 0 occurs once, 1 occurs twice, 2 occurs once, 3 occurs zero times, and 4 occurs once.
short[] hist = new short[maxValue + 1];
for (short c: counts) hist[c]++;

// Now we start at the max value and iterate backwards through the histogram,
// accumulating counts of counts until we've exceeded k.
int numGreaterEqual = 0;
short kthGreatest = maxValue;
while (kthGreatest > 0) {
numGreaterEqual += hist[kthGreatest];
if (numGreaterEqual > k) break;
else kthGreatest--;
}

// Finally we find the number that were greater than the kth greatest count.
// There's a special case if kthGreatest is zero, then the number that were greater is the number of hits.
int numGreater = numGreaterEqual - hist[kthGreatest];
if (kthGreatest == 0) numGreater = numHits;
return new KthGreatestResult(kthGreatest, numGreater, numHits);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.KthGreatest;

public final class EmptyHitCounter implements HitCounter {

@Override
Expand Down Expand Up @@ -41,7 +39,7 @@ public int maxKey() {
}

@Override
public KthGreatest.Result kthGreatest(int k) {
return new KthGreatest.Result((short) 0, 0, 0);
public KthGreatestResult kthGreatest(int k) {
return new KthGreatestResult((short) 0, 0, 0);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.klibisz.elastiknn.search;

import org.apache.lucene.search.KthGreatest;

/**
* Abstraction for counting hits for a particular query.
*/
Expand All @@ -23,6 +21,6 @@ public interface HitCounter {

int maxKey();

KthGreatest.Result kthGreatest(int k);
KthGreatestResult kthGreatest(int k);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.klibisz.elastiknn.search;

public class KthGreatestResult {
public final short kthGreatest;
public final int numGreaterThan;
public final int numNonZero;
public KthGreatestResult(short kthGreatest, int numGreaterThan, int numNonZero) {
this.kthGreatest = kthGreatest;
this.numGreaterThan = numGreaterThan;
this.numNonZero = numNonZero;
}

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
} else if (!(o instanceof KthGreatestResult other)) {
return false;
} else {
return kthGreatest == other.kthGreatest && numGreaterThan == other.numGreaterThan && numNonZero == other.numNonZero;
}
}

@Override
public String toString() {
return String.format("KthGreatestResult(kthGreatest=%d, numGreaterThan=%d, numNonZero=%d)", kthGreatest, numGreaterThan, numNonZero);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.klibisz.elastiknn.search.ArrayHitCounter;
import com.klibisz.elastiknn.search.EmptyHitCounter;
import com.klibisz.elastiknn.search.HitCounter;
import com.klibisz.elastiknn.search.KthGreatestResult;
import org.apache.lucene.index.*;
import org.apache.lucene.util.BytesRef;

Expand Down Expand Up @@ -101,7 +102,7 @@ private DocIdSetIterator buildDocIdSetIterator(HitCounter counter) {
if (counter.isEmpty()) return DocIdSetIterator.empty();
else {

KthGreatest.Result kgr = counter.kthGreatest(candidates);
KthGreatestResult kgr = counter.kthGreatest(candidates);

// Return an iterator over the doc ids >= the min candidate count.
return new DocIdSetIterator() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package com.klibisz.elastiknn.search

import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers

import scala.util.Random

final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers {

final class Reference(capacity: Int) extends HitCounter {
private val counts = scala.collection.mutable.Map[Int, Short](
(0 until capacity).map(_ -> 0.toShort): _*
)

override def increment(key: Int): Unit = counts.update(key, (counts(key) + 1).toShort)

override def increment(key: Int, count: Short): Unit = counts.update(key, (counts(key) + count).toShort)

override def isEmpty: Boolean = !counts.values.exists(_ > 0)

override def get(key: Int): Short = counts(key)

override def numHits(): Int = counts.values.count(_ > 0)

override def capacity(): Int = capacity

override def minKey(): Int = counts.filter(_._2 > 0).keys.min

override def maxKey(): Int = counts.filter(_._2 > 0).keys.max

override def kthGreatest(k: Int): KthGreatestResult = {
val values = counts.values.toArray.sorted.reverse
val numGreaterThan = values.count(_ > values(k))
val numNonZero = values.count(_ != 0)
new KthGreatestResult(values(k), numGreaterThan, numNonZero)
}
}

"reference examples" - {
"example 1" in {
val c = new Reference(10)
c.isEmpty shouldBe true
c.capacity() shouldBe 10

c.get(0) shouldBe 0
c.increment(0)
c.get(0) shouldBe 1
c.numHits() shouldBe 1
c.minKey() shouldBe 0
c.maxKey() shouldBe 0

c.get(5) shouldBe 0
c.increment(5, 5)
c.get(5) shouldBe 5
c.numHits() shouldBe 2
c.minKey() shouldBe 0
c.maxKey() shouldBe 5

c.get(9) shouldBe 0
c.increment(9)
c.get(9) shouldBe 1
c.increment(9)
c.get(9) shouldBe 2
c.numHits() shouldBe 3
c.minKey() shouldBe 0
c.maxKey() shouldBe 9

val kgr = c.kthGreatest(2)
kgr.kthGreatest shouldBe 1
kgr.numGreaterThan shouldBe 2
kgr.numNonZero shouldBe 3
}
}

"randomized comparison to reference" in {
val seed = System.currentTimeMillis()
val rng = new Random(seed)
val numDocs = 60000
val numMatches = numDocs / 2
info(s"Using seed $seed")
for (_ <- 0 until 99) {
val matches = (0 until numMatches).map(_ => rng.nextInt(numDocs))
val ref = new Reference(numDocs)
val ahc = new ArrayHitCounter(numDocs)
matches.foreach { doc =>
ref.increment(doc)
ahc.increment(doc)
ahc.get(doc) shouldBe ref.get(doc)
val count = rng.nextInt(10).toShort
ref.increment(doc, count)
ahc.increment(doc, count)
ahc.get(doc) shouldBe ref.get(doc)
}
ahc.minKey() shouldBe ref.minKey()
ahc.maxKey() shouldBe ref.maxKey()
ahc.numHits() shouldBe ref.numHits()
val k = rng.nextInt(numDocs)
val ahcKgr = ahc.kthGreatest(k)
val refKgr = ref.kthGreatest(k)
ahcKgr shouldBe refKgr
}
}
}
Loading