diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java index e9e4e90421adc..697c13da5c6ea 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java @@ -143,6 +143,14 @@ public List getQueryVectors() { return queryVectors; } + public String getInferenceId() { + return inferenceId; + } + + public String getQuery() { + return query; + } + public boolean shouldPruneTokens() { return shouldPruneTokens; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 93743a5485c2c..169c8f87043e8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -80,6 +80,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; +import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; @@ -440,7 +441,7 @@ public List> getQueries() { @Override public List getQueryRewriteInterceptors() { - return List.of(new SemanticMatchQueryRewriteInterceptor()); + return List.of(new SemanticMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor()); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java new file mode 100644 index 0000000000000..dadbb58335145 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.index.mapper.IndexFieldMapper; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.TermsQueryBuilder; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +public class SemanticQueryInterceptionUtils { + + + private SemanticQueryInterceptionUtils() {} + + public static SemanticTextIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) { + if (resolvedIndices != null) { + Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); + List inferenceIndices = new ArrayList<>(); + List nonInferenceIndices = new ArrayList<>(); + for (IndexMetadata indexMetadata : indexMetadataCollection) { + String indexName = indexMetadata.getIndex().getName(); + InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName); + if (inferenceFieldMetadata != null) { + inferenceIndices.add(indexName); + } else { + nonInferenceIndices.add(indexName); + } + } + + return new SemanticTextIndexInformationForField(inferenceIndices, nonInferenceIndices); + } + return null; + } + + public static QueryBuilder createSemanticSubQueryForIndices(List indices, String fieldName, String value) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true)); + boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices)); + return boolQueryBuilder; + } + + public static QueryBuilder createSubQueryForIndices(List indices, QueryBuilder queryBuilder) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must(queryBuilder); + boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices)); + return boolQueryBuilder; + } + + public record SemanticTextIndexInformationForField(List semanticMappedIndices, List otherIndices) {} + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java new file mode 100644 index 0000000000000..6c3d4809aa814 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; + +public class SemanticSparseVectorQueryRewriteInterceptor implements QueryRewriteInterceptor { + + public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature( + "search.semantic_sparse_vector_query_rewrite_interception_supported" + ); + + private static final String NESTED_FIELD_PATH = ".inference.chunks"; + private static final String NESTED_EMBEDDINGS_FIELD = NESTED_FIELD_PATH + ".embeddings"; + + public SemanticSparseVectorQueryRewriteInterceptor() {} + + @Override + public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) { + assert (queryBuilder instanceof SparseVectorQueryBuilder); + SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; + QueryBuilder rewritten = queryBuilder; + SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField semanticTextIndexInformationForField = + SemanticQueryInterceptionUtils.resolveIndicesForField(sparseVectorQueryBuilder.getFieldName(), context.getResolvedIndices()); + + if (semanticTextIndexInformationForField == null || semanticTextIndexInformationForField.semanticMappedIndices().isEmpty()) { + // No semantic text fields, return original query + return rewritten; + } else if (semanticTextIndexInformationForField.otherIndices().isEmpty() == false) { + // Combined semantic and sparse vector fields + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + // sparse_vector fields should be passed in as their own clause + boolQueryBuilder.should( + SemanticQueryInterceptionUtils.createSubQueryForIndices( + semanticTextIndexInformationForField.otherIndices(), + SemanticQueryInterceptionUtils.createSubQueryForIndices( + semanticTextIndexInformationForField.otherIndices(), + sparseVectorQueryBuilder + ) + ) + ); + // semantic text fields should be passed in as nested sub queries + boolQueryBuilder.should( + SemanticQueryInterceptionUtils.createSubQueryForIndices( + semanticTextIndexInformationForField.semanticMappedIndices(), + buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder) + ) + + ); + + rewritten = boolQueryBuilder; + } else { + // Only semantic text fields + rewritten = buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder); + } + + return rewritten; + } + + private QueryBuilder buildNestedQueryFromSparseVectorQuery(SparseVectorQueryBuilder sparseVectorQueryBuilder) { + return QueryBuilders.nestedQuery( + getNestedFieldPath(sparseVectorQueryBuilder.getFieldName()), + new SparseVectorQueryBuilder( + getNestedEmbeddingsField(sparseVectorQueryBuilder.getFieldName()), + sparseVectorQueryBuilder.getQueryVectors(), + sparseVectorQueryBuilder.getInferenceId(), + sparseVectorQueryBuilder.getQuery(), + sparseVectorQueryBuilder.shouldPruneTokens(), + sparseVectorQueryBuilder.getTokenPruningConfig() + ), + ScoreMode.Max + ); + } + + private static String getNestedFieldPath(String fieldName) { + return fieldName + NESTED_FIELD_PATH; + } + + private static String getNestedEmbeddingsField(String fieldName) { + return fieldName + NESTED_EMBEDDINGS_FIELD; + } + + @Override + public String getQueryName() { + return SparseVectorQueryBuilder.NAME; + } +}