Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure we return non-negative scores when scoring scalar dot-products #108522

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ public float score(int firstOrd, int secondOrd) throws IOException {

if (firstSeg != null && secondSeg != null) {
int dotProduct = dotProduct7u(firstSeg, secondSeg, length);
assert dotProduct >= 0;
float adjustedDistance = dotProduct * scoreCorrectionConstant + firstOffset + secondOffset;
return (1 + adjustedDistance) / 2;
return Math.max((1 + adjustedDistance) / 2, 0f);
} else {
return fallbackScore(firstByteOffset, secondByteOffset);
return Math.max(fallbackScore(firstByteOffset, secondByteOffset), 0f);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static org.elasticsearch.vec.VectorSimilarityType.EUCLIDEAN;
import static org.elasticsearch.vec.VectorSimilarityType.MAXIMUM_INNER_PRODUCT;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;

// @com.carrotsearch.randomizedtesting.annotations.Repeat(iterations = 100)
public class VectorScorerFactoryTests extends AbstractVectorTestCase {
Expand Down Expand Up @@ -92,6 +93,51 @@ void testSimpleImpl(long maxChunkSize) throws IOException {
}
}

public void testNonNegativeDotProduct() throws IOException {
assumeTrue(notSupportedMsg(), supported());
var factory = AbstractVectorTestCase.factory.get();

try (Directory dir = new MMapDirectory(createTempDir(getTestName()), MMapDirectory.DEFAULT_MAX_CHUNK_SIZE)) {
// keep vecs `0` so dot product is `0`
byte[] vec1 = new byte[32];
byte[] vec2 = new byte[32];
String fileName = getTestName() + "-32";
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
var negativeOffset = floatToByteArray(-5f);
byte[] bytes = concat(vec1, negativeOffset, vec2, negativeOffset);
out.writeBytes(bytes, 0, bytes.length);
}
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
// dot product
float expected = 0f; // TODO fix in Lucene: https://github.com/apache/lucene/pull/13356 luceneScore(DOT_PRODUCT, vec1, vec2,
// 1, -5, -5);
var scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, DOT_PRODUCT, in).get();
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// max inner product
expected = luceneScore(MAXIMUM_INNER_PRODUCT, vec1, vec2, 1, -5, -5);
scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, MAXIMUM_INNER_PRODUCT, in).get();
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f));
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// cosine
expected = 0f; // TODO fix in Lucene: https://github.com/apache/lucene/pull/13356 luceneScore(COSINE, vec1, vec2, 1, -5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ha. I see this now. Thanks

// -5);
scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, COSINE, in).get();
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// euclidean
expected = luceneScore(EUCLIDEAN, vec1, vec2, 1, -5, -5);
scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, EUCLIDEAN, in).get();
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
}
}
}

public void testRandom() throws IOException {
assumeTrue(notSupportedMsg(), supported());
testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_INT7_FUNC);
Expand Down