Skip to content

Commit

Permalink
Don't throw on multiple inference IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Dec 16, 2024
1 parent 7bc962b commit d28cd8b
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Intercepts and adapts a query to be rewritten to work seamlessly on a semantic_text field.
Expand Down Expand Up @@ -132,22 +133,29 @@ public Collection<String> getInferenceIndices() {
return inferenceIndicesMetadata.keySet();
}

public String getSearchInferenceIdForIndex(String index) {
return inferenceIndicesMetadata.get(index).getSearchInferenceId();
}

public String getSearchInferenceId() {
List<String> searchInferenceIds = inferenceIndicesMetadata.values()
public Map<String, List<String>> getInferenceIdsIndices() {
return inferenceIndicesMetadata.entrySet()
.stream()
.map(InferenceFieldMetadata::getSearchInferenceId)
.distinct()
.toList();
if (searchInferenceIds.size() > 1) {
throw new IllegalStateException(
"Conflicting searchInferenceIds for field [" + fieldName + "]: Found [" + searchInferenceIds + "]"
.collect(
Collectors.groupingBy(
entry -> entry.getValue().getSearchInferenceId(),
Collectors.mapping(Map.Entry::getKey, Collectors.toList())
)
);
}
return searchInferenceIds.getFirst();
}

// public String getSearchInferenceId() {
// List<String> searchInferenceIds = inferenceIndicesMetadata.values()
// .stream()
// .map(InferenceFieldMetadata::getSearchInferenceId)
// .distinct()
// .toList();
// if (searchInferenceIds.size() > 1) {
// throw new IllegalStateException(
// "Conflicting searchInferenceIds for field [" + fieldName + "]: Found [" + searchInferenceIds + "]"
// );
// }
// return searchInferenceIds.getFirst();
// }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;

import java.util.List;
import java.util.Map;

public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
Expand All @@ -39,8 +42,31 @@ protected String getQuery(QueryBuilder queryBuilder) {

@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
String searchInferenceId = indexInformation.getSearchInferenceId();
return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
if (inferenceIdsIndices.size() == 1) {
// Simple case, everything uses the same inference ID
String searchInferenceId = inferenceIdsIndices.keySet().iterator().next();
return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
} else {
// Multiple inference IDs, construct a boolean query
return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
}
}

private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
QueryBuilder queryBuilder,
Map<String, List<String>> inferenceIdsIndices
) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
for (String inferenceId : inferenceIdsIndices.keySet()) {
boolQueryBuilder.should(
createSubQueryForIndices(
inferenceIdsIndices.get(inferenceId),
buildNestedQueryFromSparseVectorQuery(queryBuilder, inferenceId)
)
);
}
return boolQueryBuilder;
}

@Override
Expand All @@ -50,7 +76,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
) {
assert (queryBuilder instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
String searchInferenceId = indexInformation.getSearchInferenceId();
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();

BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
Expand All @@ -61,12 +87,14 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
);
// We always perform nested subqueries on semantic_text fields, to support
// sparse_vector queries using query vectors.
boolQueryBuilder.should(
createSubQueryForIndices(
indexInformation.getInferenceIndices(),
buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder, searchInferenceId)
)
);
for (String inferenceId : inferenceIdsIndices.keySet()) {
boolQueryBuilder.should(
createSubQueryForIndices(
inferenceIdsIndices.get(inferenceId),
buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder, inferenceId)
)
);
}
return boolQueryBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@ setup:
}
}
- do:
inference.put:
task_type: sparse_embedding
inference_id: sparse-inference-id-2
body: >
{
"service": "test_service",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
},
"task_settings": {
}
}
- do:
indices.create:
index: test-semantic-text-index
Expand All @@ -28,6 +43,16 @@ setup:
type: semantic_text
inference_id: sparse-inference-id

- do:
indices.create:
index: test-semantic-text-index-2
body:
mappings:
properties:
inference_field:
type: semantic_text
inference_id: sparse-inference-id-2

- do:
indices.create:
index: test-sparse-vector-index
Expand All @@ -45,6 +70,13 @@ setup:
inference_field: [ "inference test", "another inference test" ]
refresh: true

- do:
index:
index: test-semantic-text-index-2
id: doc_3
body:
inference_field: [ "inference test", "another inference test" ]
refresh: true

- do:
index:
Expand Down Expand Up @@ -202,3 +234,22 @@ setup:
- match: { hits.hits.0._id: "doc_1" }
- match: { hits.hits.1._id: "doc_2" }


---
"sparse_vector query against multiple semantic_text fields with multiple inference IDs specified in semantic_text fields":

- do:
search:
index:
- test-semantic-text-index
- test-semantic-text-index-2
body:
query:
sparse_vector:
field: inference_field
query: "inference test"

- match: { hits.total.value: 2 }
- match: { hits.hits.0._id: "doc_1" }
- match: { hits.hits.1._id: "doc_3" }

0 comments on commit d28cd8b

Please sign in to comment.