From d28cd8b0e08fe9a25072739d6a0ef6d4727592ea Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 16 Dec 2024 15:20:52 -0500 Subject: [PATCH] Don't throw on multiple inference IDs --- .../SemanticQueryRewriteInterceptor.java | 36 ++++++++----- ...icSparseVectorQueryRewriteInterceptor.java | 46 +++++++++++++---- .../46_semantic_text_sparse_vector.yml | 51 +++++++++++++++++++ 3 files changed, 110 insertions(+), 23 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java index 1ba491a2617ac..a98db6eea1aa8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; /** * Intercepts and adapts a query to be rewritten to work seamlessly on a semantic_text field. @@ -132,22 +133,29 @@ public Collection getInferenceIndices() { return inferenceIndicesMetadata.keySet(); } - public String getSearchInferenceIdForIndex(String index) { - return inferenceIndicesMetadata.get(index).getSearchInferenceId(); - } - - public String getSearchInferenceId() { - List searchInferenceIds = inferenceIndicesMetadata.values() + public Map> getInferenceIdsIndices() { + return inferenceIndicesMetadata.entrySet() .stream() - .map(InferenceFieldMetadata::getSearchInferenceId) - .distinct() - .toList(); - if (searchInferenceIds.size() > 1) { - throw new IllegalStateException( - "Conflicting searchInferenceIds for field [" + fieldName + "]: Found [" + searchInferenceIds + "]" + .collect( + Collectors.groupingBy( + entry -> entry.getValue().getSearchInferenceId(), + Collectors.mapping(Map.Entry::getKey, Collectors.toList()) + ) ); - } - return searchInferenceIds.getFirst(); } + + // public String getSearchInferenceId() { + // List searchInferenceIds = inferenceIndicesMetadata.values() + // .stream() + // .map(InferenceFieldMetadata::getSearchInferenceId) + // .distinct() + // .toList(); + // if (searchInferenceIds.size() > 1) { + // throw new IllegalStateException( + // "Conflicting searchInferenceIds for field [" + fieldName + "]: Found [" + searchInferenceIds + "]" + // ); + // } + // return searchInferenceIds.getFirst(); + // } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java index 6f73a6acf935f..a35e83450c55a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java @@ -15,6 +15,9 @@ import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import java.util.List; +import java.util.Map; + public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor { public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature( @@ -39,8 +42,31 @@ protected String getQuery(QueryBuilder queryBuilder) { @Override protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { - String searchInferenceId = indexInformation.getSearchInferenceId(); - return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId); + Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); + if (inferenceIdsIndices.size() == 1) { + // Simple case, everything uses the same inference ID + String searchInferenceId = inferenceIdsIndices.keySet().iterator().next(); + return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId); + } else { + // Multiple inference IDs, construct a boolean query + return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices); + } + } + + private QueryBuilder buildInferenceQueryWithMultipleInferenceIds( + QueryBuilder queryBuilder, + Map> inferenceIdsIndices + ) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + for (String inferenceId : inferenceIdsIndices.keySet()) { + boolQueryBuilder.should( + createSubQueryForIndices( + inferenceIdsIndices.get(inferenceId), + buildNestedQueryFromSparseVectorQuery(queryBuilder, inferenceId) + ) + ); + } + return boolQueryBuilder; } @Override @@ -50,7 +76,7 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ) { assert (queryBuilder instanceof SparseVectorQueryBuilder); SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; - String searchInferenceId = indexInformation.getSearchInferenceId(); + Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); boolQueryBuilder.should( @@ -61,12 +87,14 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( ); // We always perform nested subqueries on semantic_text fields, to support // sparse_vector queries using query vectors. - boolQueryBuilder.should( - createSubQueryForIndices( - indexInformation.getInferenceIndices(), - buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder, searchInferenceId) - ) - ); + for (String inferenceId : inferenceIdsIndices.keySet()) { + boolQueryBuilder.should( + createSubQueryForIndices( + inferenceIdsIndices.get(inferenceId), + buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder, inferenceId) + ) + ); + } return boolQueryBuilder; } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml index 05cc2e7125e26..71b147cb99f66 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml @@ -18,6 +18,21 @@ setup: } } + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id-2 + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: indices.create: index: test-semantic-text-index @@ -28,6 +43,16 @@ setup: type: semantic_text inference_id: sparse-inference-id + - do: + indices.create: + index: test-semantic-text-index-2 + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id-2 + - do: indices.create: index: test-sparse-vector-index @@ -45,6 +70,13 @@ setup: inference_field: [ "inference test", "another inference test" ] refresh: true + - do: + index: + index: test-semantic-text-index-2 + id: doc_3 + body: + inference_field: [ "inference test", "another inference test" ] + refresh: true - do: index: @@ -202,3 +234,22 @@ setup: - match: { hits.hits.0._id: "doc_1" } - match: { hits.hits.1._id: "doc_2" } + +--- +"sparse_vector query against multiple semantic_text fields with multiple inference IDs specified in semantic_text fields": + + - do: + search: + index: + - test-semantic-text-index + - test-semantic-text-index-2 + body: + query: + sparse_vector: + field: inference_field + query: "inference test" + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_3" } +