diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt index e1769d28e2269..cd9fb843a9a2d 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt @@ -31,5 +31,7 @@ static_import { double l2norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$L2Norm double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity double dotProduct(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct + java.util.DoubleSummaryStatistics documentFrequencyStatistics(org.elasticsearch.script.ScoreScript) bound_to org.elasticsearch.script.TermStatsScriptUtils$DocumentFrequencyStatistics + java.util.DoubleSummaryStatistics termFrequencyStatistics(org.elasticsearch.script.ScoreScript) bound_to org.elasticsearch.script.TermStatsScriptUtils$TermFrequencyStatistics } diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java index 7ddaa4bc681fa..82e4328cf0c52 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java @@ -9,6 +9,8 @@ package org.elasticsearch.common.lucene.search.function; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermStates; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.DocIdSetIterator; @@ -29,10 +31,15 @@ import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.ScoreScript.ExplanationHolder; import org.elasticsearch.script.Script; +import org.elasticsearch.script.TermStatsReader; import org.elasticsearch.search.lookup.SearchLookup; import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; import java.util.Objects; +import java.util.Set; /** * A query that uses a script to compute documents' scores. @@ -89,6 +96,8 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo ScoreMode subQueryScoreMode = needsScore ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES; Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, 1.0f); + TermStatsReader termStatsReader = createTermStatsReader(searcher); + return new Weight(this) { @Override public BulkScorer bulkScorer(LeafReaderContext context) throws IOException { @@ -164,7 +173,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio } private ScoreScript makeScoreScript(LeafReaderContext context) throws IOException { - final ScoreScript scoreScript = scriptBuilder.newInstance(new DocValuesDocReader(lookup, context)); + final ScoreScript scoreScript = scriptBuilder.newInstance(new DocValuesDocReader(lookup, context, termStatsReader)); scoreScript._setIndexName(indexName); scoreScript._setShard(shardId); return scoreScript; @@ -178,6 +187,25 @@ public boolean isCacheable(LeafReaderContext ctx) { }; } + private TermStatsReader createTermStatsReader(IndexSearcher searcher) throws IOException { + // We collect the different terms used in the child query. + final Set terms = new HashSet<>(); + this.visit(QueryVisitor.termCollector(terms)); + + // For each terms build the term states. + final Map termContexts = new HashMap<>(); + + for (Term term: terms) { + TermStates termStates = TermStates.build(searcher, term, true); + if (termStates != null && termStates.docFreq() > 0) { + termContexts.put(term, termStates); + searcher.termStatistics(term, termStates.docFreq(), termStates.totalTermFreq()); + } + } + + return new TermStatsReader(searcher, terms, termContexts); + } + @Override public void visit(QueryVisitor visitor) { // Highlighters must visit the child query to extract terms diff --git a/server/src/main/java/org/elasticsearch/script/DocReader.java b/server/src/main/java/org/elasticsearch/script/DocReader.java index 061c7033afa80..27bc94109f217 100644 --- a/server/src/main/java/org/elasticsearch/script/DocReader.java +++ b/server/src/main/java/org/elasticsearch/script/DocReader.java @@ -8,6 +8,9 @@ package org.elasticsearch.script; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.TermStatistics; import org.elasticsearch.index.fielddata.ScriptDocValues; import org.elasticsearch.script.field.Field; import org.elasticsearch.search.lookup.Source; @@ -41,4 +44,8 @@ public interface DocReader { /** Helper for source access */ Supplier source(); + + Map termStatistics(); + + Map postings(int flags); } diff --git a/server/src/main/java/org/elasticsearch/script/DocValuesDocReader.java b/server/src/main/java/org/elasticsearch/script/DocValuesDocReader.java index 3c83dce4b5afd..fb1298813dceb 100644 --- a/server/src/main/java/org/elasticsearch/script/DocValuesDocReader.java +++ b/server/src/main/java/org/elasticsearch/script/DocValuesDocReader.java @@ -9,6 +9,9 @@ package org.elasticsearch.script; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.TermStatistics; import org.elasticsearch.index.fielddata.ScriptDocValues; import org.elasticsearch.script.field.EmptyField; import org.elasticsearch.script.field.Field; @@ -17,6 +20,7 @@ import org.elasticsearch.search.lookup.SearchLookup; import org.elasticsearch.search.lookup.Source; +import java.util.HashMap; import java.util.Map; import java.util.function.Supplier; import java.util.stream.Stream; @@ -30,14 +34,20 @@ public class DocValuesDocReader implements DocReader, LeafReaderContextSupplier // provide access to the leaf context reader for expressions protected final LeafReaderContext leafReaderContext; + private final TermStatsReader termStatsReader; /** A leaf lookup for the bound segment this proxy will operate on. */ protected LeafSearchLookup leafSearchLookup; public DocValuesDocReader(SearchLookup searchLookup, LeafReaderContext leafContext) { + this(searchLookup, leafContext, null); + } + + public DocValuesDocReader(SearchLookup searchLookup, LeafReaderContext leafContext, TermStatsReader termStatsReader) { this.searchLookup = searchLookup; this.leafReaderContext = leafContext; this.leafSearchLookup = searchLookup.getLeafSearchLookup(leafReaderContext); + this.termStatsReader = termStatsReader; } @Override @@ -80,4 +90,22 @@ public LeafReaderContext getLeafReaderContext() { public Supplier source() { return leafSearchLookup.source(); } + + @Override + public Map termStatistics() { + Map termStatistics = new HashMap<>(); + for (Term term: termStatsReader.terms()) { + termStatistics.put(term, termStatsReader.termStatistics(term)); + } + return termStatistics; + } + + @Override + public Map postings(int flags) { + Map postings = new HashMap<>(); + for (Term term: termStatsReader.terms()) { + postings.put(term, termStatsReader.postings(leafReaderContext, term, flags)); + } + return postings; + } } diff --git a/server/src/main/java/org/elasticsearch/script/ScoreScript.java b/server/src/main/java/org/elasticsearch/script/ScoreScript.java index 503bd11fb434a..cd4da66f5a129 100644 --- a/server/src/main/java/org/elasticsearch/script/ScoreScript.java +++ b/server/src/main/java/org/elasticsearch/script/ScoreScript.java @@ -7,8 +7,11 @@ */ package org.elasticsearch.script; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.TermStatistics; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.search.lookup.SearchLookup; @@ -189,6 +192,16 @@ public void _setIndexName(String indexName) { this.indexName = indexName; } + + public Map _termStatistics() { + return this.docReader.termStatistics(); + } + + public Map _postings(int flags) { + return this.docReader.postings(flags); + } + + /** A factory to construct {@link ScoreScript} instances. */ public interface LeafFactory { diff --git a/server/src/main/java/org/elasticsearch/script/TermStatsReader.java b/server/src/main/java/org/elasticsearch/script/TermStatsReader.java new file mode 100644 index 0000000000000..fcd3bc50b8665 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/TermStatsReader.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.script; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermState; +import org.apache.lucene.index.TermStates; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.TermStatistics; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; +import java.util.Set; + +public class TermStatsReader { + private final IndexSearcher searcher; + private final Set terms; + private final Map termContexts; + + public TermStatsReader(IndexSearcher searcher, Set terms, Map termContexts) { + this.searcher = searcher; + this.terms = terms; + this.termContexts = termContexts; + } + + public Set terms() { + return terms; + } + + public TermStatistics termStatistics(Term term) { + try { + if (termContexts.containsKey(term) == false) { + return searcher.termStatistics(term, 0, 0); + } + + return searcher.termStatistics(term, termContexts.get(term).docFreq(), termContexts.get(term).totalTermFreq()); + } catch (IllegalArgumentException e) { + return null; + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public PostingsEnum postings(LeafReaderContext leafReaderContext, Term term, int flags) { + if (termContexts.containsKey(term) == false) { + return null; + } + + try { + TermStates termContext = termContexts.get(term); + TermState state = termContext.get(leafReaderContext); + if (state == null || termContext.docFreq() == 0) { + return null; + } + + TermsEnum termsEnum = leafReaderContext.reader().terms(term.field()).iterator(); + termsEnum.seekExact(term.bytes(), state); + return termsEnum.postings(null, flags); + + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/script/TermStatsScriptUtils.java b/server/src/main/java/org/elasticsearch/script/TermStatsScriptUtils.java new file mode 100644 index 0000000000000..2c344a4981131 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/TermStatsScriptUtils.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.script; + +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.search.TermStatistics; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collection; +import java.util.DoubleSummaryStatistics; +import java.util.function.Supplier; + +; + +public class TermStatsScriptUtils { + + public static final class DocumentFrequencyStatistics { + private final Collection termsStatistics; + + public DocumentFrequencyStatistics(ScoreScript scoreScript) { + this.termsStatistics = scoreScript._termStatistics().values(); + } + + public DoubleSummaryStatistics documentFrequencyStatistics() { + return termsStatistics.stream().mapToDouble(termStatistics -> termStatistics == null ? 0 : termStatistics.docFreq()).summaryStatistics(); + } + } + + public static final class TermFrequencyStatistics { + private final Collection postings; + private final Supplier docIdSupplier; + + public TermFrequencyStatistics(ScoreScript scoreScript) { + postings = scoreScript._postings(PostingsEnum.FREQS).values(); + docIdSupplier = scoreScript::_getDocId; + } + + public DoubleSummaryStatistics termFrequencyStatistics() { + return postings.stream().mapToDouble( + currentPostings -> { + try { + int docId = docIdSupplier.get(); + if (currentPostings == null || currentPostings.advance(docId) != docId) { + return 0; + } + return currentPostings.freq(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + ).summaryStatistics(); + } + } +}