Skip to content

Commit

Permalink
Bugfix for mixed version cluster queries using text expansion (elasti…
Browse files Browse the repository at this point in the history
…c#105912)

* Bugfix for CCR queries using text expansion

* Fix test

* PR feedback

* Fix test

* Minor cleanup

* Edit comment

* One more comment clarification

---------

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
kderusso and elasticmachine authored Mar 5, 2024
1 parent 335afe5 commit fe13a04
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -67,12 +68,7 @@ public String getTypeName() {
}

public static boolean isFieldTypeAllowed(String typeName) {
for (AllowedFieldType fieldType : values()) {
if (fieldType.getTypeName().equals(typeName)) {
return true;
}
}
return false;
return Arrays.stream(values()).anyMatch(value -> value.typeName.equals(typeName));
}

public static String getAllowedFieldTypesAsString() {
Expand Down Expand Up @@ -168,8 +164,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {

protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
if (weightedTokensSupplier != null) {
if (weightedTokensSupplier.get() == null) {
return this;
Expand All @@ -188,8 +183,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);

SetOnce<TextExpansionResults> textExpansionResultsSupplier = new SetOnce<>();
queryRewriteContext.registerAsyncAction((client, listener) -> {
executeAsyncWithOrigin(
queryRewriteContext.registerAsyncAction(
(client, listener) -> executeAsyncWithOrigin(
client,
ML_ORIGIN,
CoordinatedInferenceAction.INSTANCE,
Expand Down Expand Up @@ -220,21 +215,34 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
);
}
}, listener::onFailure)
);
});
)
);

return new TextExpansionQueryBuilder(this, textExpansionResultsSupplier);
}

private QueryBuilder weightedTokensToQuery(String fieldName, TextExpansionResults textExpansionResults) {
WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder(
fieldName,
textExpansionResults.getWeightedTokens(),
tokenPruningConfig
);
weightedTokensQueryBuilder.queryName(queryName);
weightedTokensQueryBuilder.boost(boost);
return weightedTokensQueryBuilder;
if (tokenPruningConfig != null) {
WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder(
fieldName,
textExpansionResults.getWeightedTokens(),
tokenPruningConfig
);
weightedTokensQueryBuilder.queryName(queryName);
weightedTokensQueryBuilder.boost(boost);
return weightedTokensQueryBuilder;
}
// Note: Weighted tokens queries were introduced in 8.13.0. To support mixed version clusters prior to 8.13.0,
// if no token pruning configuration is specified we fall back to a boolean query.
// TODO this should be updated to always use a WeightedTokensQueryBuilder once it's in all supported versions.
var boolQuery = QueryBuilders.boolQuery();
for (var weightedToken : textExpansionResults.getWeightedTokens()) {
boolQuery.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight()));
}
boolQuery.minimumShouldMatch(1);
boolQuery.boost(boost);
boolQuery.queryName(queryName);
return boolQuery;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.plugins.Plugin;
Expand Down Expand Up @@ -259,6 +260,10 @@ public void testThatTokensAreCorrectlyPruned() {
SearchExecutionContext searchExecutionContext = createSearchExecutionContext();
TextExpansionQueryBuilder queryBuilder = createTestQueryBuilder();
QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext);
assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder);
if (queryBuilder.getTokenPruningConfig() == null) {
assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder);
} else {
assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,4 @@ setup:
source_text:
model_id: text_expansion_model
model_text: "octopus comforter smells"
pruning_config: {}

0 comments on commit fe13a04

Please sign in to comment.