diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 6dfa72702243b..876ff01812064 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -10,7 +10,6 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 0f26f6577860f..faeac9dc1853f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -56,11 +56,11 @@ public record SemanticTextField(String fieldName, List originalValues, I ToXContentObject { static final String TEXT_FIELD = "text"; - static final String INFERENCE_FIELD = "inference"; + public static final String INFERENCE_FIELD = "inference"; static final String INFERENCE_ID_FIELD = "inference_id"; static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id"; - static final String CHUNKS_FIELD = "chunks"; - static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; + public static final String CHUNKS_FIELD = "chunks"; + public static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; public static final String CHUNKED_TEXT_FIELD = "text"; static final String MODEL_SETTINGS_FIELD = "model_settings"; static final String TASK_TYPE_FIELD = "task_type"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java index a4a8123935c3e..705ff12ac7c7e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java @@ -7,23 +7,19 @@ package org.elasticsearch.xpack.inference.queries; -import org.elasticsearch.action.ResolvedIndices; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.IndexFieldMapper; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; -import java.util.ArrayList; -import java.util.Collection; import java.util.List; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.InferenceIndexInformationForField; + public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor { public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature( @@ -37,41 +33,37 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde assert (queryBuilder instanceof MatchQueryBuilder); MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; QueryBuilder rewritten = queryBuilder; - ResolvedIndices resolvedIndices = context.getResolvedIndices(); - if (resolvedIndices != null) { - Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); - List inferenceIndices = new ArrayList<>(); - List nonInferenceIndices = new ArrayList<>(); - for (IndexMetadata indexMetadata : indexMetadataCollection) { - String indexName = indexMetadata.getIndex().getName(); - InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName()); - if (inferenceFieldMetadata != null) { - inferenceIndices.add(indexName); - } else { - nonInferenceIndices.add(indexName); - } - } + InferenceIndexInformationForField inferenceIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField( + matchQueryBuilder.fieldName(), + context.getResolvedIndices() + ); - if (inferenceIndices.isEmpty()) { - return rewritten; - } else if (nonInferenceIndices.isEmpty() == false) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - for (String inferenceIndexName : inferenceIndices) { - // Add a separate clause for each semantic query, because they may be using different inference endpoints - // TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints - boolQueryBuilder.should( - createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value()) - ); - } - boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder)); - rewritten = boolQueryBuilder; - } else { - rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false); - } + if (inferenceIndexInformationForField == null || inferenceIndexInformationForField.inferenceIndices().isEmpty()) { + // No inference fields, return original query + return rewritten; + } else if (inferenceIndexInformationForField.nonInferenceIndices().isEmpty() == false) { + // Combined inference and non inference fields + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should( + createSemanticSubQuery( + inferenceIndexInformationForField.inferenceIndices(), + matchQueryBuilder.fieldName(), + (String) matchQueryBuilder.value() + ) + ); + boolQueryBuilder.should( + SemanticQueryInterceptionUtils.createSubQueryForIndices( + inferenceIndexInformationForField.nonInferenceIndices(), + matchQueryBuilder + ) + ); + rewritten = boolQueryBuilder; + } else { + // Only inference fields + rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false); } return rewritten; - } @Override @@ -79,16 +71,9 @@ public String getQueryName() { return MatchQueryBuilder.NAME; } - private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) { + private QueryBuilder createSemanticSubQuery(List indices, String fieldName, String value) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true)); - boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName)); - return boolQueryBuilder; - } - - private QueryBuilder createMatchSubQuery(List indices, MatchQueryBuilder matchQueryBuilder) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.must(matchQueryBuilder); boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices)); return boolQueryBuilder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java index dadbb58335145..8922c7bdf5ce4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java @@ -21,10 +21,9 @@ public class SemanticQueryInterceptionUtils { - private SemanticQueryInterceptionUtils() {} - public static SemanticTextIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) { + public static InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) { if (resolvedIndices != null) { Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); List inferenceIndices = new ArrayList<>(); @@ -39,18 +38,11 @@ public static SemanticTextIndexInformationForField resolveIndicesForField(String } } - return new SemanticTextIndexInformationForField(inferenceIndices, nonInferenceIndices); + return new InferenceIndexInformationForField(inferenceIndices, nonInferenceIndices); } return null; } - public static QueryBuilder createSemanticSubQueryForIndices(List indices, String fieldName, String value) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true)); - boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices)); - return boolQueryBuilder; - } - public static QueryBuilder createSubQueryForIndices(List indices, QueryBuilder queryBuilder) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); boolQueryBuilder.must(queryBuilder); @@ -58,6 +50,5 @@ public static QueryBuilder createSubQueryForIndices(List indices, QueryB return boolQueryBuilder; } - public record SemanticTextIndexInformationForField(List semanticMappedIndices, List otherIndices) {} - + public record InferenceIndexInformationForField(List inferenceIndices, List nonInferenceIndices) {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java index 6c3d4809aa814..6ca71decc0e6a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java @@ -15,6 +15,9 @@ import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; + +import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.InferenceIndexInformationForField; public class SemanticSparseVectorQueryRewriteInterceptor implements QueryRewriteInterceptor { @@ -22,9 +25,6 @@ public class SemanticSparseVectorQueryRewriteInterceptor implements QueryRewrite "search.semantic_sparse_vector_query_rewrite_interception_supported" ); - private static final String NESTED_FIELD_PATH = ".inference.chunks"; - private static final String NESTED_EMBEDDINGS_FIELD = NESTED_FIELD_PATH + ".embeddings"; - public SemanticSparseVectorQueryRewriteInterceptor() {} @Override @@ -32,34 +32,34 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde assert (queryBuilder instanceof SparseVectorQueryBuilder); SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; QueryBuilder rewritten = queryBuilder; - SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField semanticTextIndexInformationForField = - SemanticQueryInterceptionUtils.resolveIndicesForField(sparseVectorQueryBuilder.getFieldName(), context.getResolvedIndices()); + InferenceIndexInformationForField inferenceIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField( + sparseVectorQueryBuilder.getFieldName(), + context.getResolvedIndices() + ); - if (semanticTextIndexInformationForField == null || semanticTextIndexInformationForField.semanticMappedIndices().isEmpty()) { - // No semantic text fields, return original query + if (inferenceIndexInformationForField == null || inferenceIndexInformationForField.inferenceIndices().isEmpty()) { + // No inference fields, return original query return rewritten; - } else if (semanticTextIndexInformationForField.otherIndices().isEmpty() == false) { - // Combined semantic and sparse vector fields + } else if (inferenceIndexInformationForField.nonInferenceIndices().isEmpty() == false) { + // Combined inference and non inference fields BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - // sparse_vector fields should be passed in as their own clause boolQueryBuilder.should( SemanticQueryInterceptionUtils.createSubQueryForIndices( - semanticTextIndexInformationForField.otherIndices(), + inferenceIndexInformationForField.nonInferenceIndices(), SemanticQueryInterceptionUtils.createSubQueryForIndices( - semanticTextIndexInformationForField.otherIndices(), + inferenceIndexInformationForField.nonInferenceIndices(), sparseVectorQueryBuilder ) ) ); - // semantic text fields should be passed in as nested sub queries + // We always perform nested subqueries on semantic_text fields, to support + // sparse_vector queries using query vectors boolQueryBuilder.should( SemanticQueryInterceptionUtils.createSubQueryForIndices( - semanticTextIndexInformationForField.semanticMappedIndices(), + inferenceIndexInformationForField.inferenceIndices(), buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder) ) - ); - rewritten = boolQueryBuilder; } else { // Only semantic text fields @@ -85,11 +85,11 @@ private QueryBuilder buildNestedQueryFromSparseVectorQuery(SparseVectorQueryBuil } private static String getNestedFieldPath(String fieldName) { - return fieldName + NESTED_FIELD_PATH; + return fieldName + SemanticTextField.INFERENCE_FIELD + SemanticTextField.CHUNKS_FIELD; } private static String getNestedEmbeddingsField(String fieldName) { - return fieldName + NESTED_EMBEDDINGS_FIELD; + return getNestedFieldPath(fieldName) + SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; } @Override