Skip to content

Commit

Permalink
Fix too many rewrite rounds error by injecting booleans in constructo…
Browse files Browse the repository at this point in the history
…rs for match query builder and semantic text
  • Loading branch information
kderusso committed Dec 2, 2024
1 parent 720be53 commit 6197541
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

package org.elasticsearch.index.query;

import java.util.function.BiFunction;
import org.elasticsearch.common.TriFunction;

public class InferenceQueryBuilderService {
private final BiFunction<String, String, AbstractQueryBuilder<?>> defaultInferenceQueryBuilder;
private final TriFunction<String, String, Boolean, AbstractQueryBuilder<?>> defaultInferenceQueryBuilder;

InferenceQueryBuilderService(BiFunction<String, String, AbstractQueryBuilder<?>> defaultInferenceQueryBuilder) {
InferenceQueryBuilderService(TriFunction<String, String, Boolean, AbstractQueryBuilder<?>> defaultInferenceQueryBuilder) {
this.defaultInferenceQueryBuilder = defaultInferenceQueryBuilder;
}

public AbstractQueryBuilder<?> getDefaultInferenceQueryBuilder(String fieldName, String query) {
return defaultInferenceQueryBuilder.apply(fieldName, query);
public AbstractQueryBuilder<?> getDefaultInferenceQueryBuilder(String fieldName, String query, boolean throwOnUnsupportedQueries) {
return defaultInferenceQueryBuilder.apply(fieldName, query, throwOnUnsupportedQueries);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.elasticsearch.index.query;

import org.elasticsearch.common.TriFunction;
import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.plugins.SearchPlugin;

Expand All @@ -27,7 +28,7 @@ public InferenceQueryBuilderServiceBuilder pluginsService(PluginsService plugins
public InferenceQueryBuilderService build() {
Objects.requireNonNull(pluginsService);

List<BiFunction<String, String, AbstractQueryBuilder<?>>> definedInferenceQueryBuilders = new ArrayList<>();
List<TriFunction<String, String, Boolean, AbstractQueryBuilder<?>>> definedInferenceQueryBuilders = new ArrayList<>();

List<SearchPlugin> searchPlugins = pluginsService.filterPlugins(SearchPlugin.class).toList();
for (SearchPlugin searchPlugin : searchPlugins) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ public class MatchQueryBuilder extends AbstractQueryBuilder<MatchQueryBuilder> {

private boolean autoGenerateSynonymsPhraseQuery = true;

private final boolean inferenceFieldsIdentified;

/**
* Constructs a new match query.
*/
public MatchQueryBuilder(String fieldName, Object value) {
public MatchQueryBuilder(String fieldName, Object value, boolean inferenceFieldsIdentified) {
if (fieldName == null) {
throw new IllegalArgumentException("[" + NAME + "] requires fieldName");
}
Expand All @@ -100,6 +102,11 @@ public MatchQueryBuilder(String fieldName, Object value) {
}
this.fieldName = fieldName;
this.value = value;
this.inferenceFieldsIdentified = inferenceFieldsIdentified;
}

public MatchQueryBuilder(String fieldName, Object value) {
this(fieldName, value, false);
}

/**
Expand All @@ -125,6 +132,7 @@ public MatchQueryBuilder(StreamInput in) throws IOException {
in.readOptionalFloat();
}
autoGenerateSynonymsPhraseQuery = in.readBoolean();
inferenceFieldsIdentified = false;
}

@Override
Expand Down Expand Up @@ -375,7 +383,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
QueryBuilder rewritten = super.doRewrite(queryRewriteContext);

if (rewritten == this && queryRewriteContext.getClass() == QueryRewriteContext.class) {
if (rewritten == this && inferenceFieldsIdentified == false && queryRewriteContext.getClass() == QueryRewriteContext.class) {
ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
Expand All @@ -395,36 +403,30 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
for (String inferenceIndexName : inferenceIndices) {
// Add a separate clause for each inference query, because they may be using different inference endpoints
boolQueryBuilder.should(
createInferenceSubQuery(queryRewriteContext.getQueryBuilderService(), inferenceIndexName, fieldName, value)
);
boolQueryBuilder.should(createInferenceSubQuery(queryRewriteContext.getQueryBuilderService(), inferenceIndexName));
}
boolQueryBuilder.should(createNonInferenceSubQuery(nonInferenceIndices, rewritten));
boolQueryBuilder.should(createNonInferenceSubQuery(nonInferenceIndices));
rewritten = boolQueryBuilder;
} else if (inferenceIndices.isEmpty() == false) {
rewritten = queryRewriteContext.getQueryBuilderService().getDefaultInferenceQueryBuilder(fieldName, value.toString());
rewritten = queryRewriteContext.getQueryBuilderService()
.getDefaultInferenceQueryBuilder(fieldName, value.toString(), true);
}
}
}

return rewritten;
}

private QueryBuilder createInferenceSubQuery(
InferenceQueryBuilderService inferenceQueryBuilderService,
String indexName,
String fieldName,
Object value
) {
private QueryBuilder createInferenceSubQuery(InferenceQueryBuilderService inferenceQueryBuilderService, String indexName) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(inferenceQueryBuilderService.getDefaultInferenceQueryBuilder(fieldName, value.toString()));
boolQueryBuilder.must(inferenceQueryBuilderService.getDefaultInferenceQueryBuilder(fieldName, value.toString(), false));
boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName));
return boolQueryBuilder;
}

private QueryBuilder createNonInferenceSubQuery(List<String> indices, QueryBuilder rewritten) {
private QueryBuilder createNonInferenceSubQuery(List<String> indices) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(rewritten);
boolQueryBuilder.must(new MatchQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.apache.lucene.search.Query;
import org.elasticsearch.common.CheckedBiConsumer;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.io.stream.GenericNamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -129,7 +130,7 @@ default List<QuerySpec<?>> getQueries() {
return emptyList();
}

default BiFunction<String, String, AbstractQueryBuilder<?>> getDefaultInferenceQueryBuilder() {
default TriFunction<String, String, Boolean, AbstractQueryBuilder<?>> getDefaultInferenceQueryBuilder() {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.action.support.MappedActionFilter;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
Expand Down Expand Up @@ -109,7 +110,6 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -407,7 +407,7 @@ public List<QuerySpec<?>> getQueries() {
}

@Override
public BiFunction<String, String, AbstractQueryBuilder<?>> getDefaultInferenceQueryBuilder() {
public TriFunction<String, String, Boolean, AbstractQueryBuilder<?>> getDefaultInferenceQueryBuilder() {
return SemanticQueryBuilder::new;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,13 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
private final SetOnce<InferenceServiceResults> inferenceResultsSupplier;
private final InferenceResults inferenceResults;
private final boolean noInferenceResults;
private final boolean throwOnUnsupportedField;

public SemanticQueryBuilder(String fieldName, String query) {
this(fieldName, query, true);
}

public SemanticQueryBuilder(String fieldName, String query, boolean throwOnUnsupportedField) {
if (fieldName == null) {
throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value");
}
Expand All @@ -88,6 +93,7 @@ public SemanticQueryBuilder(String fieldName, String query) {
this.inferenceResults = null;
this.inferenceResultsSupplier = null;
this.noInferenceResults = false;
this.throwOnUnsupportedField = throwOnUnsupportedField;
}

public SemanticQueryBuilder(StreamInput in) throws IOException {
Expand All @@ -97,6 +103,7 @@ public SemanticQueryBuilder(StreamInput in) throws IOException {
this.inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class);
this.noInferenceResults = in.readBoolean();
this.inferenceResultsSupplier = null;
this.throwOnUnsupportedField = true;
}

@Override
Expand All @@ -123,6 +130,7 @@ private SemanticQueryBuilder(
this.inferenceResultsSupplier = inferenceResultsSupplier;
this.inferenceResults = inferenceResults;
this.noInferenceResults = noInferenceResults;
this.throwOnUnsupportedField = other.throwOnUnsupportedField;
}

@Override
Expand Down Expand Up @@ -171,11 +179,12 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
}

return semanticTextFieldType.semanticQuery(inferenceResults, searchExecutionContext.requestSize(), boost(), queryName());
} else {
} else if (throwOnUnsupportedField) {
throw new IllegalArgumentException(
"Field [" + fieldName + "] of type [" + fieldType.typeName() + "] does not support " + NAME + " queries"
);
}
return this;
}

private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
Expand Down

0 comments on commit 6197541

Please sign in to comment.