Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.15] Ensure vector similarity correctly limits inner_hits returned for nested kNN (#111363) #111426

Merged
merged 4 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/111363.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 111363
summary: Ensure vector similarity correctly limits `inner_hits` returned for nested
kNN
area: Vector Search
type: bug
issues: [111093]
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,53 @@ setup:

- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "2"}
---
"nested Knn search with required similarity appropriately filters inner_hits":
- requires:
cluster_features: "gte_v8.15.0"
reason: 'bugfix for 8.15'

- do:
search:
index: test
body:
query:
nested:
path: nested
inner_hits:
size: 3
_source: false
fields:
- nested.paragraph_id
query:
knn:
field: nested.vector
query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
num_candidates: 3
similarity: 10.5

- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "2"}
- length: {hits.hits.0.inner_hits.nested.hits.hits: 1}
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}

- do:
search:
index: test
body:
knn:
field: nested.vector
query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
num_candidates: 3
k: 3
similarity: 10.5
inner_hits:
size: 3
_source: false
fields:
- nested.paragraph_id

- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "2"}
- length: {hits.hits.0.inner_hits.nested.hits.hits: 1}
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ static TransportVersion def(int id) {
public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0);
public static final TransportVersion ML_INFERENCE_AMAZON_BEDROCK_ADDED = def(8_702_00_0);
public static final TransportVersion ENTERPRISE_GEOIP_DOWNLOADER_BACKPORT_8_15 = def(8_702_00_1);
public static final TransportVersion FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15 = def(8_702_00_2);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
QueryBuilder query = new KnnScoreDocQueryBuilder(
scoreDocs.toArray(Lucene.EMPTY_SCORE_DOCS),
source.knnSearch().get(i).getField(),
source.knnSearch().get(i).getQueryVector()
source.knnSearch().get(i).getQueryVector(),
source.knnSearch().get(i).getSimilarity()
).boost(source.knnSearch().get(i).boost()).queryName(source.knnSearch().get(i).queryName());
if (nestedPath != null) {
query = new NestedQueryBuilder(nestedPath, query, ScoreMode.Max).innerHit(source.knnSearch().get(i).innerHit());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1711,17 +1711,21 @@ public Query termQuery(Object value, SearchExecutionContext context) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries");
}

public Query createExactKnnQuery(VectorData queryVector) {
public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity) {
if (isIndexed() == false) {
throw new IllegalArgumentException(
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
);
}
return switch (elementType) {
Query knnQuery = switch (elementType) {
case BYTE -> createExactKnnByteQuery(queryVector.asByteVector());
case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector());
case BIT -> createExactKnnBitQuery(queryVector.asByteVector());
};
if (vectorSimilarity != null) {
knnQuery = new VectorSimilarityQuery(knnQuery, vectorSimilarity, similarity.score(vectorSimilarity, elementType, dims));
}
return knnQuery;
}

private Query createExactKnnBitQuery(byte[] queryVector) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,18 @@ public class ExactKnnQueryBuilder extends AbstractQueryBuilder<ExactKnnQueryBuil
public static final String NAME = "exact_knn";
private final String field;
private final VectorData query;
private final Float vectorSimilarity;

/**
* Creates a query builder.
*
* @param query the query vector
* @param field the field that was used for the kNN query
*/
public ExactKnnQueryBuilder(float[] query, String field) {
this(VectorData.fromFloats(query), field);
}

/**
* Creates a query builder.
*
* @param query the query vector
* @param field the field that was used for the kNN query
*/
public ExactKnnQueryBuilder(VectorData query, String field) {
public ExactKnnQueryBuilder(VectorData query, String field, Float vectorSimilarity) {
this.query = query;
this.field = field;
this.vectorSimilarity = vectorSimilarity;
}

public ExactKnnQueryBuilder(StreamInput in) throws IOException {
Expand All @@ -62,6 +54,11 @@ public ExactKnnQueryBuilder(StreamInput in) throws IOException {
this.query = VectorData.fromFloats(in.readFloatArray());
}
this.field = in.readString();
if (in.getTransportVersion().isPatchFrom(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15)) {
this.vectorSimilarity = in.readOptionalFloat();
} else {
this.vectorSimilarity = null;
}
}

String getField() {
Expand All @@ -72,6 +69,10 @@ VectorData getQuery() {
return query;
}

Float vectorSimilarity() {
return vectorSimilarity;
}

@Override
public String getWriteableName() {
return NAME;
Expand All @@ -85,13 +86,19 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeFloatArray(query.asFloatVector());
}
out.writeString(field);
if (out.getTransportVersion().isPatchFrom(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15)) {
out.writeOptionalFloat(vectorSimilarity);
}
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.field("query", query);
builder.field("field", field);
if (vectorSimilarity != null) {
builder.field("similarity", vectorSimilarity);
}
boostAndQueryNameToXContent(builder);
builder.endObject();
}
Expand All @@ -108,17 +115,17 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
);
}
final DenseVectorFieldMapper.DenseVectorFieldType vectorFieldType = (DenseVectorFieldMapper.DenseVectorFieldType) fieldType;
return vectorFieldType.createExactKnnQuery(query);
return vectorFieldType.createExactKnnQuery(query, vectorSimilarity);
}

@Override
protected boolean doEquals(ExactKnnQueryBuilder other) {
return field.equals(other.field) && Objects.equals(query, other.query);
return field.equals(other.field) && Objects.equals(query, other.query) && Objects.equals(vectorSimilarity, other.vectorSimilarity);
}

@Override
protected int doHashCode() {
return Objects.hash(field, Objects.hashCode(query));
return Objects.hash(field, Objects.hashCode(query), vectorSimilarity);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,19 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
private final ScoreDoc[] scoreDocs;
private final String fieldName;
private final VectorData queryVector;
private final Float vectorSimilarity;

/**
* Creates a query builder.
*
* @param scoreDocs the docs and scores this query should match. The array must be
* sorted in order of ascending doc IDs.
*/
public KnnScoreDocQueryBuilder(ScoreDoc[] scoreDocs, String fieldName, float[] queryVector) {
this(scoreDocs, fieldName, VectorData.fromFloats(queryVector));
}

/**
* Creates a query builder.
*
* @param scoreDocs the docs and scores this query should match. The array must be
* sorted in order of ascending doc IDs.
*/
public KnnScoreDocQueryBuilder(ScoreDoc[] scoreDocs, String fieldName, VectorData queryVector) {
public KnnScoreDocQueryBuilder(ScoreDoc[] scoreDocs, String fieldName, VectorData queryVector, Float vectorSimilarity) {
this.scoreDocs = scoreDocs;
this.fieldName = fieldName;
this.queryVector = queryVector;
this.vectorSimilarity = vectorSimilarity;
}

public KnnScoreDocQueryBuilder(StreamInput in) throws IOException {
Expand All @@ -78,6 +70,11 @@ public KnnScoreDocQueryBuilder(StreamInput in) throws IOException {
this.fieldName = null;
this.queryVector = null;
}
if (in.getTransportVersion().isPatchFrom(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15)) {
this.vectorSimilarity = in.readOptionalFloat();
} else {
this.vectorSimilarity = null;
}
}

@Override
Expand All @@ -97,6 +94,10 @@ VectorData queryVector() {
return queryVector;
}

Float vectorSimilarity() {
return vectorSimilarity;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeArray(Lucene::writeScoreDoc, scoreDocs);
Expand All @@ -113,6 +114,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
}
if (out.getTransportVersion().isPatchFrom(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15)) {
out.writeOptionalFloat(vectorSimilarity);
}
}

@Override
Expand All @@ -129,6 +133,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
if (queryVector != null) {
builder.field("query", queryVector);
}
if (vectorSimilarity != null) {
builder.field("similarity", vectorSimilarity);
}
boostAndQueryNameToXContent(builder);
builder.endObject();
}
Expand All @@ -154,7 +161,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
return new MatchNoneQueryBuilder("The \"" + getName() + "\" query was rewritten to a \"match_none\" query.");
}
if (queryRewriteContext.convertToInnerHitsRewriteContext() != null && queryVector != null && fieldName != null) {
return new ExactKnnQueryBuilder(queryVector, fieldName);
return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity);
}
return super.doRewrite(queryRewriteContext);
}
Expand Down Expand Up @@ -193,7 +200,9 @@ protected boolean doEquals(KnnScoreDocQueryBuilder other) {
return false;
}
}
return Objects.equals(fieldName, other.fieldName) && Objects.equals(queryVector, other.queryVector);
return Objects.equals(fieldName, other.fieldName)
&& Objects.equals(queryVector, other.queryVector)
&& Objects.equals(vectorSimilarity, other.vectorSimilarity);
}

@Override
Expand All @@ -203,7 +212,7 @@ protected int doHashCode() {
int hashCode = Objects.hash(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
result = 31 * result + hashCode;
}
return Objects.hash(result, fieldName, Objects.hashCode(queryVector));
return Objects.hash(result, fieldName, vectorSimilarity, Objects.hashCode(queryVector));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ public KnnVectorQueryBuilder toQueryBuilder() {
.addFilterQueries(filterQueries);
}

public Float getSimilarity() {
return similarity;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
).queryName(queryName).addFilterQueries(filterQueries);
}
if (ctx.convertToInnerHitsRewriteContext() != null) {
return new ExactKnnQueryBuilder(queryVector, fieldName).boost(boost).queryName(queryName);
return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity).boost(boost).queryName(queryName);
}
boolean changed = false;
List<QueryBuilder> rewrittenQueries = new ArrayList<>(filterQueries.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.search.rank.TestRankBuilder;
import org.elasticsearch.search.vectors.KnnScoreDocQueryBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.elasticsearch.transport.Transport;
Expand Down Expand Up @@ -351,12 +352,14 @@ public void testRewriteShardSearchRequestWithRank() {
KnnScoreDocQueryBuilder ksdqb0 = new KnnScoreDocQueryBuilder(
new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1) },
"vector",
new float[] { 0.0f }
VectorData.fromFloats(new float[] { 0.0f }),
null
);
KnnScoreDocQueryBuilder ksdqb1 = new KnnScoreDocQueryBuilder(
new ScoreDoc[] { new ScoreDoc(1, 2.0f, 1) },
"vector2",
new float[] { 0.0f }
VectorData.fromFloats(new float[] { 0.0f }),
null
);
assertEquals(
List.of(bm25, ksdqb0, ksdqb1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public void testExactKnnQuery() {
for (int i = 0; i < dims; i++) {
queryVector[i] = randomFloat();
}
Query query = field.createExactKnnQuery(VectorData.fromFloats(queryVector));
Query query = field.createExactKnnQuery(VectorData.fromFloats(queryVector), null);
assertTrue(query instanceof DenseVectorQuery.Floats);
}
{
Expand All @@ -233,7 +233,7 @@ public void testExactKnnQuery() {
for (int i = 0; i < dims; i++) {
queryVector[i] = randomByte();
}
Query query = field.createExactKnnQuery(VectorData.fromBytes(queryVector));
Query query = field.createExactKnnQuery(VectorData.fromBytes(queryVector), null);
assertTrue(query instanceof DenseVectorQuery.Bytes);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.query.InnerHitsRewriteContext;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
Expand Down Expand Up @@ -296,6 +297,22 @@ private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery
}
}

public void testRewriteForInnerHits() throws IOException {
SearchExecutionContext context = createSearchExecutionContext();
InnerHitsRewriteContext innerHitsRewriteContext = new InnerHitsRewriteContext(context.getParserConfig(), System::currentTimeMillis);
KnnVectorQueryBuilder queryBuilder = createTestQueryBuilder();
queryBuilder.boost(randomFloat());
queryBuilder.queryName(randomAlphaOfLength(10));
QueryBuilder rewritten = queryBuilder.rewrite(innerHitsRewriteContext);
assertTrue(rewritten instanceof ExactKnnQueryBuilder);
ExactKnnQueryBuilder exactKnnQueryBuilder = (ExactKnnQueryBuilder) rewritten;
assertEquals(queryBuilder.queryVector(), exactKnnQueryBuilder.getQuery());
assertEquals(queryBuilder.getFieldName(), exactKnnQueryBuilder.getField());
assertEquals(queryBuilder.boost(), exactKnnQueryBuilder.boost(), 0.0001f);
assertEquals(queryBuilder.queryName(), exactKnnQueryBuilder.queryName());
assertEquals(queryBuilder.getVectorSimilarity(), exactKnnQueryBuilder.vectorSimilarity());
}

public void testRewriteWithQueryVectorBuilder() throws Exception {
int dims = randomInt(1024);
float[] expectedArray = new float[dims];
Expand Down
Loading