From d7fb29cdb6e21c989581e9e41ad4167f9b2c9af6 Mon Sep 17 00:00:00 2001 From: Aurelien FOUCRET Date: Wed, 20 Mar 2024 08:04:04 +0100 Subject: [PATCH] Initial work on the term stat query --- .../xpack/ml/MachineLearning.java | 6 + .../ml/queries/TermStatsScriptQuery.java | 167 +++++++++++++ .../queries/TermStatsScriptQueryBuilder.java | 224 ++++++++++++++++++ 3 files changed, 397 insertions(+) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TermStatsScriptQuery.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TermStatsScriptQueryBuilder.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 5ef7311179e4f..faa23ecdf3601 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -376,6 +376,7 @@ import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.NativeStorageProvider; +import org.elasticsearch.xpack.ml.queries.TermStatsScriptQueryBuilder; import org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder; import org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder; import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction; @@ -1765,6 +1766,11 @@ public List> getQueries() { WeightedTokensQueryBuilder.NAME, WeightedTokensQueryBuilder::new, WeightedTokensQueryBuilder::fromXContent + ), + new QuerySpec( + TermStatsScriptQueryBuilder.NAME, + TermStatsScriptQueryBuilder::new, + TermStatsScriptQueryBuilder::fromXContent ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TermStatsScriptQuery.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TermStatsScriptQuery.java new file mode 100644 index 0000000000000..76efbb3433ffc --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TermStatsScriptQuery.java @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.queries; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; +import org.apache.lucene.search.Weight; + +import java.io.IOException;e +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +public class TermStatsScriptQuery extends Query { + private final Query filter; + + private final Set terms; + + private TermStatsCollector + + public TermStatsScriptQuery(Query filter, Set terms) { + this.filter = Objects.requireNonNull(filter, "Filter must not be null"); + this.terms = Objects.requireNonNull(terms, "Filter must not be null"); + } + + @Override + public String toString(String field) { + return "TermStatsScript(" + this.filter.toString(field) + ", terms: " + terms + ")"; + } + + @Override + public void visit(QueryVisitor visitor) { + this.filter.visit(visitor.getSubVisitor(BooleanClause.Occur.FILTER, this)); + } + + public Weight createWeight(IndexSearcher searcher, final ScoreMode scoreMode, float boost) throws IOException { + Weight innerWeight = filter.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, boost); + + if (scoreMode.needsScores() == false) { + return innerWeight; + } + + return new Weight(this) { + @Override + public boolean isCacheable(LeafReaderContext leafReaderContext) { + return false; + } + + @Override + public Explanation explain(LeafReaderContext leafReaderContext, int i) throws IOException { + return null; + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + Scorer innerFilterScorer = innerWeight.scorer(context); + if (innerFilterScorer == null) { + return null; + } + return new TermStatsScriptScorer(this, innerWeight.scorer(context), boost); + } + }; + } + + @Override + public boolean equals(Object other) { + if (this.sameClassAs(other)) { + TermStatsScriptQuery that = (TermStatsScriptQuery) other; + return this.filter.equals(that.filter) && this.terms.equals(that.terms); + } + + return false; + } + + @Override + public int hashCode() { + return Objects.hash(filter, terms); + } + + private static class TermStatsScriptScorer extends Scorer { + private final Scorer subQueryScorer; + private final float boost; + + TermStatsScriptScorer(Weight weight, Scorer subQueryScorer, float boost) { + super(weight); + this.subQueryScorer = subQueryScorer; + this.boost = boost; + } + + @Override + public float score() throws IOException { + return 12.0f * boost; + } + + @Override + public int docID() { + return subQueryScorer.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return subQueryScorer.iterator(); + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return subQueryScorer.twoPhaseIterator(); + } + + @Override + public float getMaxScore(int upTo) { + return Float.MAX_VALUE; + } + } + + private static class TermStatsCollector { + private final IndexSearcher searcher; + private final LeafReaderContext context; + private final Set terms; + + public void collect(int docId) { + for (Term term : terms) { + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + break; + + } + } + } + + private Map stats(List samples) { + double count = samples.size(); + double sum = samples.stream().collect(Collectors.summingDouble(Double::doubleValue)); + double min = samples.stream().min(Comparator.comparingDouble(Double::doubleValue)).orElse(0.0d); + double max = samples.stream().max(Comparator.comparingDouble(Double::doubleValue)).orElse(0.0d); + double mean = sum / count; + double variance = 0; + if (count > 1) { + variance = Math.sqrt(samples.stream().collect(Collectors.summingDouble((value) -> Math.pow(value - mean, 2))) / (count - 1)); + } + return Map.ofEntries( + Map.entry("count", count), + Map.entry("min", min), + Map.entry("max", max), + Map.entry("sum", sum), + Map.entry("mean", mean), + Map.entry("variance", variance) + ); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TermStatsScriptQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TermStatsScriptQueryBuilder.java new file mode 100644 index 0000000000000..8a130ff23eba8 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TermStatsScriptQueryBuilder.java @@ -0,0 +1,224 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.queries; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.Query; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + +public class TermStatsScriptQueryBuilder extends AbstractQueryBuilder { + + public static final String NAME = "term_stats_script"; + + private static final ParseField INNER_FILTER_FIELD = new ParseField("filter"); + private static final ParseField FIELD_NAME_FIELD = new ParseField("field"); + private static final ParseField QUERY_FIELD = new ParseField("query"); + private static final ParseField ANALYZER_FIELD = new ParseField("analyzer"); + + private final QueryBuilder filterBuilder; + private final String field; + private final String query; + private final String analyzer; + + public TermStatsScriptQueryBuilder(String field, String query, String analyzer, QueryBuilder filterBuilder) { + this.field = Strings.requireNonBlank(field, "field name can not be null or blank"); + this.query = Objects.requireNonNull(query, "query can not be null"); + this.analyzer = analyzer; + this.filterBuilder = Objects.requireNonNull(filterBuilder, "inner clause [filter] cannot be null."); + } + + public TermStatsScriptQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + query = in.readString(); + analyzer = in.readOptionalString(); + filterBuilder = in.readNamedWriteable(QueryBuilder.class); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + // TODO: bump a new transport version before merging. + return TransportVersion.current(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeString(query); + out.writeOptionalString(analyzer); + out.writeNamedWriteable(filterBuilder); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + QueryBuilder rewrite = filterBuilder.rewrite(queryRewriteContext); + + if (rewrite instanceof MatchNoneQueryBuilder) { + return rewrite; // we won't match anyway + } + + if (rewrite != filterBuilder) { + return new TermStatsScriptQueryBuilder(field, query, analyzer, rewrite); + } + + return this; + } + + public static TermStatsScriptQueryBuilder fromXContent(XContentParser parser) throws IOException { + String field = null; + boolean hasField = false; + + String query = null; + boolean hasQuery = false; + + String analyzer = null; + + QueryBuilder innerFilter = QueryBuilders.matchAllQuery(); + String queryName = null; + float boost = AbstractQueryBuilder.DEFAULT_BOOST; + + String currentFieldName = null; + XContentParser.Token token; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.START_OBJECT) { + if (INNER_FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + innerFilter = parseInnerQueryBuilder(parser); + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] query does not support [" + currentFieldName + "]" + ); + } + } else if (token.isValue()) { + if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + queryName = parser.text(); + } else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + boost = parser.floatValue(); + } else if (FIELD_NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + field = parser.text(); + hasField = true; + } else if (QUERY_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + query = parser.text(); + hasQuery = true; + } else if (ANALYZER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + analyzer = parser.text(); + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] query does not support [" + currentFieldName + "]" + ); + } + } else { + throw new ParsingException(parser.getTokenLocation(), "unexpected token [" + token + "]"); + } + } + + if (hasField == false) { + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] requires a 'field' element"); + } + + if (hasQuery == false) { + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] requires a 'query' element"); + } + + TermStatsScriptQueryBuilder termStatsScriptQueryBuilder = new TermStatsScriptQueryBuilder(field, query, analyzer, innerFilter); + termStatsScriptQueryBuilder.boost(boost); + termStatsScriptQueryBuilder.queryName(queryName); + + return termStatsScriptQueryBuilder; + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + { + builder.field(FIELD_NAME_FIELD.getPreferredName(), field); + builder.field(QUERY_FIELD.getPreferredName(), query); + builder.field(ANALYZER_FIELD.getPreferredName(), analyzer); + builder.field(INNER_FILTER_FIELD.getPreferredName()); + filterBuilder.toXContent(builder, params); + printBoostAndQueryName(builder); + } + builder.endObject(); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + Set terms = extractTerms(context); + Query innerFilter = filterBuilder.toQuery(context); + + return new TermStatsScriptQuery(innerFilter, terms); + } + + @Override + protected boolean doEquals(TermStatsScriptQueryBuilder other) { + return Objects.equals(field, other.field) + && Objects.equals(query, other.query) + && Objects.equals(analyzer, other.analyzer) + && Objects.equals(filterBuilder, other.filterBuilder); + } + + @Override + protected int doHashCode() { + return Objects.hash(field, query, analyzer, filterBuilder); + } + + private Set extractTerms(SearchExecutionContext context) throws IOException { + Set terms = new HashSet<>(); + Analyzer analyzer = getAnalyzer(context); + + try (TokenStream ts = analyzer.tokenStream(field, query)) { + TermToBytesRefAttribute termAttr = ts.getAttribute(TermToBytesRefAttribute.class); + ts.reset(); + while (ts.incrementToken()) { + terms.add(new Term(field, termAttr.getBytesRef())); + } + } + + return terms; + } + + private Analyzer getAnalyzer(SearchExecutionContext context) { + if (analyzer != null) { + return context.getIndexAnalyzers().get(analyzer); + } + + // TODO: additional checks on the field. + MappedFieldType fieldType = context.getFieldType(field); + return fieldType.getTextSearchInfo().searchAnalyzer(); + } +}