Skip to content

Commit

Permalink
Update the POC to rewrite to a bool query when combined inference and…
Browse files Browse the repository at this point in the history
… non-inference fields
  • Loading branch information
kderusso committed Nov 25, 2024
1 parent 6956ffb commit 3235130
Showing 1 changed file with 46 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.analysis.core.KeywordAnalyzer;
import org.apache.lucene.search.FuzzyQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ResolvedIndices;
Expand All @@ -24,16 +25,21 @@
import org.elasticsearch.common.unit.Fuzziness;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.index.analysis.NamedAnalyzer;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.support.QueryParsers;
import org.elasticsearch.index.search.MatchQueryParser;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

/**
* Match query is a query that analyzes the text and constructs a query as the
Expand Down Expand Up @@ -376,10 +382,11 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();

boolean foundNonInferenceField = false;
List<String> inferenceIndices = new ArrayList<>();
List<String> nonInferenceIndices = new ArrayList<>();
String inferenceFieldQueryName = null;
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName);
if (inferenceFieldMetadata != null) {
if (inferenceFieldQueryName != null
Expand All @@ -388,25 +395,56 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
}

inferenceFieldQueryName = inferenceFieldMetadata.getQueryName();
inferenceIndices.add(indexName);
} else {
foundNonInferenceField = true;
nonInferenceIndices.add(indexName);
}
}

if (foundNonInferenceField && inferenceFieldQueryName != null) {
throw new IllegalArgumentException("Cannot query inference fields and non-inference fields at the same time");
}

if (inferenceFieldQueryName != null) {
if (inferenceIndices.isEmpty() == false && nonInferenceIndices.isEmpty() == false) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
createInferenceSubQuery(
queryRewriteContext.getQueryBuilderService(),
inferenceFieldQueryName,
inferenceIndices,
fieldName,
value
)
);
boolQueryBuilder.should(createNonInferenceSubQuery(nonInferenceIndices, rewritten));
rewritten = boolQueryBuilder;
} else if (inferenceFieldQueryName != null) {
rewritten = queryRewriteContext.getQueryBuilderService()
.getQueryBuilder(inferenceFieldQueryName, fieldName, value.toString());
}
}
}

// LogManager.getLogger(MatchQueryBuilder.class).info("rewritten: " + rewritten);
return rewritten;
}

private QueryBuilder createInferenceSubQuery(
QueryBuilderService queryBuilderService,
String inferenceFieldQueryName,
List<String> indices,
String fieldName,
Object value
) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(queryBuilderService.getQueryBuilder(inferenceFieldQueryName, fieldName, value.toString()));
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}

private QueryBuilder createNonInferenceSubQuery(List<String> indices, QueryBuilder rewritten) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(rewritten);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}

@Override
protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException {
if (fuzziness != null || lenient) {
Expand Down

0 comments on commit 3235130

Please sign in to comment.