diff --git a/server/src/main/java/org/elasticsearch/index/query/InferenceQueryBuilderService.java b/server/src/main/java/org/elasticsearch/index/query/InferenceQueryBuilderService.java index fc6ebb2e7a71c..a8f0a7b28f01c 100644 --- a/server/src/main/java/org/elasticsearch/index/query/InferenceQueryBuilderService.java +++ b/server/src/main/java/org/elasticsearch/index/query/InferenceQueryBuilderService.java @@ -8,16 +8,16 @@ package org.elasticsearch.index.query; -import java.util.function.BiFunction; +import org.elasticsearch.common.TriFunction; public class InferenceQueryBuilderService { - private final BiFunction> defaultInferenceQueryBuilder; + private final TriFunction> defaultInferenceQueryBuilder; - InferenceQueryBuilderService(BiFunction> defaultInferenceQueryBuilder) { + InferenceQueryBuilderService(TriFunction> defaultInferenceQueryBuilder) { this.defaultInferenceQueryBuilder = defaultInferenceQueryBuilder; } - public AbstractQueryBuilder getDefaultInferenceQueryBuilder(String fieldName, String query) { - return defaultInferenceQueryBuilder.apply(fieldName, query); + public AbstractQueryBuilder getDefaultInferenceQueryBuilder(String fieldName, String query, boolean throwOnUnsupportedQueries) { + return defaultInferenceQueryBuilder.apply(fieldName, query, throwOnUnsupportedQueries); } } diff --git a/server/src/main/java/org/elasticsearch/index/query/InferenceQueryBuilderServiceBuilder.java b/server/src/main/java/org/elasticsearch/index/query/InferenceQueryBuilderServiceBuilder.java index 5fec65839debf..6bcd67dd51585 100644 --- a/server/src/main/java/org/elasticsearch/index/query/InferenceQueryBuilderServiceBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/InferenceQueryBuilderServiceBuilder.java @@ -8,6 +8,7 @@ package org.elasticsearch.index.query; +import org.elasticsearch.common.TriFunction; import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.plugins.SearchPlugin; @@ -27,7 +28,7 @@ public InferenceQueryBuilderServiceBuilder pluginsService(PluginsService plugins public InferenceQueryBuilderService build() { Objects.requireNonNull(pluginsService); - List>> definedInferenceQueryBuilders = new ArrayList<>(); + List>> definedInferenceQueryBuilders = new ArrayList<>(); List searchPlugins = pluginsService.filterPlugins(SearchPlugin.class).toList(); for (SearchPlugin searchPlugin : searchPlugins) { diff --git a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java index e7feafe8bc047..ded6ad59fb206 100644 --- a/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/MatchQueryBuilder.java @@ -88,10 +88,12 @@ public class MatchQueryBuilder extends AbstractQueryBuilder { private boolean autoGenerateSynonymsPhraseQuery = true; + private final boolean inferenceFieldsIdentified; + /** * Constructs a new match query. */ - public MatchQueryBuilder(String fieldName, Object value) { + public MatchQueryBuilder(String fieldName, Object value, boolean inferenceFieldsIdentified) { if (fieldName == null) { throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); } @@ -100,6 +102,11 @@ public MatchQueryBuilder(String fieldName, Object value) { } this.fieldName = fieldName; this.value = value; + this.inferenceFieldsIdentified = inferenceFieldsIdentified; + } + + public MatchQueryBuilder(String fieldName, Object value) { + this(fieldName, value, false); } /** @@ -125,6 +132,7 @@ public MatchQueryBuilder(StreamInput in) throws IOException { in.readOptionalFloat(); } autoGenerateSynonymsPhraseQuery = in.readBoolean(); + inferenceFieldsIdentified = false; } @Override @@ -375,7 +383,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { QueryBuilder rewritten = super.doRewrite(queryRewriteContext); - if (rewritten == this && queryRewriteContext.getClass() == QueryRewriteContext.class) { + if (rewritten == this && inferenceFieldsIdentified == false && queryRewriteContext.getClass() == QueryRewriteContext.class) { ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices(); if (resolvedIndices != null) { Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); @@ -395,14 +403,13 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); for (String inferenceIndexName : inferenceIndices) { // Add a separate clause for each inference query, because they may be using different inference endpoints - boolQueryBuilder.should( - createInferenceSubQuery(queryRewriteContext.getQueryBuilderService(), inferenceIndexName, fieldName, value) - ); + boolQueryBuilder.should(createInferenceSubQuery(queryRewriteContext.getQueryBuilderService(), inferenceIndexName)); } - boolQueryBuilder.should(createNonInferenceSubQuery(nonInferenceIndices, rewritten)); + boolQueryBuilder.should(createNonInferenceSubQuery(nonInferenceIndices)); rewritten = boolQueryBuilder; } else if (inferenceIndices.isEmpty() == false) { - rewritten = queryRewriteContext.getQueryBuilderService().getDefaultInferenceQueryBuilder(fieldName, value.toString()); + rewritten = queryRewriteContext.getQueryBuilderService() + .getDefaultInferenceQueryBuilder(fieldName, value.toString(), true); } } } @@ -410,21 +417,16 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return rewritten; } - private QueryBuilder createInferenceSubQuery( - InferenceQueryBuilderService inferenceQueryBuilderService, - String indexName, - String fieldName, - Object value - ) { + private QueryBuilder createInferenceSubQuery(InferenceQueryBuilderService inferenceQueryBuilderService, String indexName) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.must(inferenceQueryBuilderService.getDefaultInferenceQueryBuilder(fieldName, value.toString())); + boolQueryBuilder.must(inferenceQueryBuilderService.getDefaultInferenceQueryBuilder(fieldName, value.toString(), false)); boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName)); return boolQueryBuilder; } - private QueryBuilder createNonInferenceSubQuery(List indices, QueryBuilder rewritten) { + private QueryBuilder createNonInferenceSubQuery(List indices) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.must(rewritten); + boolQueryBuilder.must(new MatchQueryBuilder(fieldName, value, true)); boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices)); return boolQueryBuilder; } diff --git a/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java b/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java index 5f8f079c8d605..f787ec324c037 100644 --- a/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.Query; import org.elasticsearch.common.CheckedBiConsumer; +import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.io.stream.GenericNamedWriteable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.StreamInput; @@ -129,7 +130,7 @@ default List> getQueries() { return emptyList(); } - default BiFunction> getDefaultInferenceQueryBuilder() { + default TriFunction> getDefaultInferenceQueryBuilder() { return null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 17820575a4877..bfd685e224ba1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.MappedActionFilter; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.IndexScopedSettings; @@ -109,7 +110,6 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -407,7 +407,7 @@ public List> getQueries() { } @Override - public BiFunction> getDefaultInferenceQueryBuilder() { + public TriFunction> getDefaultInferenceQueryBuilder() { return SemanticQueryBuilder::new; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index d648db2fbfdbc..e20202e92c6a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -75,8 +75,13 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsSupplier; private final InferenceResults inferenceResults; private final boolean noInferenceResults; + private final boolean throwOnUnsupportedField; public SemanticQueryBuilder(String fieldName, String query) { + this(fieldName, query, true); + } + + public SemanticQueryBuilder(String fieldName, String query, boolean throwOnUnsupportedField) { if (fieldName == null) { throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value"); } @@ -88,6 +93,7 @@ public SemanticQueryBuilder(String fieldName, String query) { this.inferenceResults = null; this.inferenceResultsSupplier = null; this.noInferenceResults = false; + this.throwOnUnsupportedField = throwOnUnsupportedField; } public SemanticQueryBuilder(StreamInput in) throws IOException { @@ -97,6 +103,7 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { this.inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class); this.noInferenceResults = in.readBoolean(); this.inferenceResultsSupplier = null; + this.throwOnUnsupportedField = true; } @Override @@ -123,6 +130,7 @@ private SemanticQueryBuilder( this.inferenceResultsSupplier = inferenceResultsSupplier; this.inferenceResults = inferenceResults; this.noInferenceResults = noInferenceResults; + this.throwOnUnsupportedField = other.throwOnUnsupportedField; } @Override @@ -171,11 +179,12 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx } return semanticTextFieldType.semanticQuery(inferenceResults, searchExecutionContext.requestSize(), boost(), queryName()); - } else { + } else if (throwOnUnsupportedField) { throw new IllegalArgumentException( "Field [" + fieldName + "] of type [" + fieldType.typeName() + "] does not support " + NAME + " queries" ); } + return this; } private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {