diff --git a/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java b/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java new file mode 100644 index 0000000000000..415a1ca2eb083 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/query/InterceptedQueryBuilderWrapper.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.search.Query; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; + +/** + * Wrapper for queries that have been intercepted using the {@link QueryRewriteInterceptor} that may need to + * break out of the rewrite phase. + * @param + */ +public class InterceptedQueryBuilderWrapper> extends AbstractQueryBuilder { + + protected final T queryBuilder; + + public InterceptedQueryBuilderWrapper(T queryBuilder) { + super(); + this.queryBuilder = queryBuilder; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + queryBuilder.writeTo(out); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + queryBuilder.toXContent(builder, params); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + return queryBuilder.doToQuery(context); + } + + @Override + protected boolean doEquals(T other) { + // Handle the edge case where we need to unwrap the incoming query builder + if (other instanceof InterceptedQueryBuilderWrapper) { + @SuppressWarnings("unchecked") + InterceptedQueryBuilderWrapper wrapper = (InterceptedQueryBuilderWrapper) other; + return queryBuilder.doEquals(wrapper.queryBuilder); + } else { + return queryBuilder.doEquals(other); + } + } + + @Override + protected int doHashCode() { + return queryBuilder.doHashCode(); + } + + @Override + public String getWriteableName() { + return queryBuilder.getWriteableName(); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return queryBuilder.getMinimalSupportedVersion(); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java index 9f0fa36bcbd5e..fd704d39ca384 100644 --- a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java @@ -81,16 +81,10 @@ public class MatchQueryBuilder extends AbstractQueryBuilder { private boolean autoGenerateSynonymsPhraseQuery = true; - /** - * Indicates that this MatchQueryBuilder has already been intercepted and rewritten, - * so subsequent rewrite rounds can short-circuit interception. - */ - private final boolean interceptedAndRewritten; - /** * Constructs a new match query. */ - public MatchQueryBuilder(String fieldName, Object value, boolean interceptedAndRewritten) { + public MatchQueryBuilder(String fieldName, Object value) { if (fieldName == null) { throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); } @@ -99,11 +93,6 @@ public MatchQueryBuilder(String fieldName, Object value, boolean interceptedAndR } this.fieldName = fieldName; this.value = value; - this.interceptedAndRewritten = interceptedAndRewritten; - } - - public MatchQueryBuilder(String fieldName, Object value) { - this(fieldName, value, false); } /** @@ -129,7 +118,6 @@ public MatchQueryBuilder(StreamInput in) throws IOException { in.readOptionalFloat(); } autoGenerateSynonymsPhraseQuery = in.readBoolean(); - interceptedAndRewritten = false; } @Override @@ -203,10 +191,6 @@ public Fuzziness fuzziness() { return this.fuzziness; } - public boolean getInterceptedAndRewritten() { - return interceptedAndRewritten; - } - /** * Sets the length of a length of common (non-fuzzy) prefix for fuzzy match queries * @param prefixLength non-negative length of prefix diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java index d1d5dab761f2a..45cf4dd6f698d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java @@ -13,6 +13,7 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.IndexFieldMapper; import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.InterceptedQueryBuilderWrapper; import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; @@ -41,45 +42,40 @@ public QueryBuilder rewrite(QueryRewriteContext context, QueryBuilder queryBuild MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; QueryBuilder rewritten = queryBuilder; - if (matchQueryBuilder.getInterceptedAndRewritten() == false) { - ResolvedIndices resolvedIndices = context.getResolvedIndices(); - if (resolvedIndices != null) { - Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); - List inferenceIndices = new ArrayList<>(); - List nonInferenceIndices = new ArrayList<>(); - for (IndexMetadata indexMetadata : indexMetadataCollection) { - String indexName = indexMetadata.getIndex().getName(); - InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName()); - if (inferenceFieldMetadata != null) { - inferenceIndices.add(indexName); - } else { - nonInferenceIndices.add(indexName); - } + ResolvedIndices resolvedIndices = context.getResolvedIndices(); + if (resolvedIndices != null) { + Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); + List inferenceIndices = new ArrayList<>(); + List nonInferenceIndices = new ArrayList<>(); + for (IndexMetadata indexMetadata : indexMetadataCollection) { + String indexName = indexMetadata.getIndex().getName(); + InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName()); + if (inferenceFieldMetadata != null) { + inferenceIndices.add(indexName); + } else { + nonInferenceIndices.add(indexName); } + } - if (inferenceIndices.isEmpty()) { - return rewritten; - } else if (nonInferenceIndices.isEmpty() == false) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - for (String inferenceIndexName : inferenceIndices) { - // Add a separate clause for each semantic query, because they may be using different inference endpoints - boolQueryBuilder.should( - createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value()) - ); - } + if (inferenceIndices.isEmpty()) { + return rewritten; + } else if (nonInferenceIndices.isEmpty() == false) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + for (String inferenceIndexName : inferenceIndices) { + // Add a separate clause for each semantic query, because they may be using different inference endpoints boolQueryBuilder.should( - createMatchSubQuery(nonInferenceIndices, matchQueryBuilder.fieldName(), matchQueryBuilder.value()) + createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value()) ); - rewritten = boolQueryBuilder; - } else { - rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), true); } + boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder.fieldName(), matchQueryBuilder.value())); + rewritten = boolQueryBuilder; + } else { + rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), true); } - - return rewritten; } - return queryBuilder; + return rewritten; + } @Override @@ -96,8 +92,14 @@ private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, private QueryBuilder createMatchSubQuery(List indices, String fieldName, Object value) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.must(new MatchQueryBuilder(fieldName, value, true)); + boolQueryBuilder.must(new InterceptedSemanticMatchQueryWrapper(fieldName, value)); boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices)); return boolQueryBuilder; } + + static class InterceptedSemanticMatchQueryWrapper extends InterceptedQueryBuilderWrapper { + InterceptedSemanticMatchQueryWrapper(String fieldName, Object value) { + super(new MatchQueryBuilder(fieldName, value)); + } + } }