Skip to content

Commit

Permalink
Ensure we return non-negative scores when scoring scalar dot-products (
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent authored May 13, 2024
1 parent 364a6f2 commit e352345
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/108522.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 108522
summary: Ensure we return non-negative scores when scoring scalar dot-products
area: Vector Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ public float score(int firstOrd, int secondOrd) throws IOException {

if (firstSeg != null && secondSeg != null) {
int dotProduct = dotProduct7u(firstSeg, secondSeg, length);
assert dotProduct >= 0;
float adjustedDistance = dotProduct * scoreCorrectionConstant + firstOffset + secondOffset;
return (1 + adjustedDistance) / 2;
return Math.max((1 + adjustedDistance) / 2, 0f);
} else {
return fallbackScore(firstByteOffset, secondByteOffset);
return Math.max(fallbackScore(firstByteOffset, secondByteOffset), 0f);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static org.elasticsearch.vec.VectorSimilarityType.EUCLIDEAN;
import static org.elasticsearch.vec.VectorSimilarityType.MAXIMUM_INNER_PRODUCT;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;

// @com.carrotsearch.randomizedtesting.annotations.Repeat(iterations = 100)
public class VectorScorerFactoryTests extends AbstractVectorTestCase {
Expand Down Expand Up @@ -96,6 +97,51 @@ void testSimpleImpl(long maxChunkSize) throws IOException {
}
}

public void testNonNegativeDotProduct() throws IOException {
assumeTrue(notSupportedMsg(), supported());
var factory = AbstractVectorTestCase.factory.get();

try (Directory dir = new MMapDirectory(createTempDir(getTestName()), MMapDirectory.DEFAULT_MAX_CHUNK_SIZE)) {
// keep vecs `0` so dot product is `0`
byte[] vec1 = new byte[32];
byte[] vec2 = new byte[32];
String fileName = getTestName() + "-32";
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
var negativeOffset = floatToByteArray(-5f);
byte[] bytes = concat(vec1, negativeOffset, vec2, negativeOffset);
out.writeBytes(bytes, 0, bytes.length);
}
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
// dot product
float expected = 0f; // TODO fix in Lucene: https://github.com/apache/lucene/pull/13356 luceneScore(DOT_PRODUCT, vec1, vec2,
// 1, -5, -5);
var scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, DOT_PRODUCT, in).get();
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// max inner product
expected = luceneScore(MAXIMUM_INNER_PRODUCT, vec1, vec2, 1, -5, -5);
scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, MAXIMUM_INNER_PRODUCT, in).get();
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f));
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// cosine
expected = 0f; // TODO fix in Lucene: https://github.com/apache/lucene/pull/13356 luceneScore(COSINE, vec1, vec2, 1, -5,
// -5);
scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, COSINE, in).get();
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// euclidean
expected = luceneScore(EUCLIDEAN, vec1, vec2, 1, -5, -5);
scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, EUCLIDEAN, in).get();
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
}
}
}

public void testRandom() throws IOException {
assumeTrue(notSupportedMsg(), supported());
testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_INT7_FUNC);
Expand Down

0 comments on commit e352345

Please sign in to comment.