Skip to content

Commit

Permalink
Code refactoring so we are using _termStatistics reserved variable.
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed May 17, 2024
1 parent 6dcb376 commit 8af06b0
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ public boolean needs_score() {
return needsScores;
}

@Override
public boolean needs_termStatistics() {
return false;
}

@Override
public ScoreScript newInstance(final DocReader reader) throws IOException {
// Use DocReader to get the leaf context while transitioning to DocReader for Painless. DocReader for expressions should follow.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@ class org.elasticsearch.script.ScoreScript @no_import {
class org.elasticsearch.script.ScoreScript$Factory @no_import {
}

class org.apache.lucene.index.Term {
String field()
String text()
}

class org.elasticsearch.script.TermStatsReader {
Set terms()
long uniqueTermsCount()
long matchedTermsCount()
DoubleSummaryStatistics docFreq()
DoubleSummaryStatistics totalTermFreq()
DoubleSummaryStatistics termFreq()
DoubleSummaryStatistics termPositions()
}

static_import {
double saturation(double, double) from_class org.elasticsearch.script.ScoreScriptUtils
double sigmoid(double, double, double) from_class org.elasticsearch.script.ScoreScriptUtils
Expand All @@ -31,7 +46,5 @@ static_import {
double l2norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$L2Norm
double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct
java.util.DoubleSummaryStatistics documentFrequencyStatistics(org.elasticsearch.script.ScoreScript) bound_to org.elasticsearch.script.TermStatsScriptUtils$DocumentFrequencyStatistics
java.util.DoubleSummaryStatistics termFrequencyStatistics(org.elasticsearch.script.ScoreScript) bound_to org.elasticsearch.script.TermStatsScriptUtils$TermFrequencyStatistics
}

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.DocIdSetIterator;
Expand All @@ -31,13 +30,10 @@
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.ScoreScript.ExplanationHolder;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.TermStatsReader;
import org.elasticsearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

Expand Down Expand Up @@ -93,10 +89,17 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
return subQuery.createWeight(searcher, scoreMode, boost);
}
boolean needsScore = scriptBuilder.needs_score();
ScoreMode subQueryScoreMode = needsScore ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
boolean needsTermStatistics = scriptBuilder.needs_termStatistics();

ScoreMode subQueryScoreMode = needsScore || needsTermStatistics ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, 1.0f);

TermStatsReader termStatsReader = createTermStatsReader(searcher);
// We collect the different terms used in the child query.
final Set<Term> terms = new HashSet<>();

if (needsTermStatistics) {
this.visit(QueryVisitor.termCollector(terms));
}

return new Weight(this) {
@Override
Expand Down Expand Up @@ -173,9 +176,12 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
}

private ScoreScript makeScoreScript(LeafReaderContext context) throws IOException {
final ScoreScript scoreScript = scriptBuilder.newInstance(new DocValuesDocReader(lookup, context, termStatsReader));
final ScoreScript scoreScript = scriptBuilder.newInstance(new DocValuesDocReader(lookup, context));
scoreScript._setIndexName(indexName);
scoreScript._setShard(shardId);
if (needsTermStatistics) {
scoreScript._setTerms(terms);
}
return scoreScript;
}

Expand All @@ -187,25 +193,6 @@ public boolean isCacheable(LeafReaderContext ctx) {
};
}

private TermStatsReader createTermStatsReader(IndexSearcher searcher) throws IOException {
// We collect the different terms used in the child query.
final Set<Term> terms = new HashSet<>();
this.visit(QueryVisitor.termCollector(terms));

// For each terms build the term states.
final Map<Term, TermStates> termContexts = new HashMap<>();

for (Term term: terms) {
TermStates termStates = TermStates.build(searcher, term, true);
if (termStates != null && termStates.docFreq() > 0) {
termContexts.put(term, termStates);
searcher.termStatistics(term, termStates.docFreq(), termStates.totalTermFreq());
}
}

return new TermStatsReader(searcher, terms, termContexts);
}

@Override
public void visit(QueryVisitor visitor) {
// Highlighters must visit the child query to extract terms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.elasticsearch.script.ScriptCompiler;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptFactory;
import org.elasticsearch.script.TermStatsReader;
import org.elasticsearch.search.NestedDocuments;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
import org.elasticsearch.search.lookup.LeafFieldLookupProvider;
Expand Down Expand Up @@ -508,7 +509,8 @@ public void setLookupProviders(
new FieldDataContext(getFullyQualifiedIndex().getName(), searchLookup, this::sourcePath, fielddataOperation)
),
sourceProvider,
fieldLookupProvider
fieldLookupProvider,
(leafReaderContext, docIdSupplier) -> new TermStatsReader(searcher, docIdSupplier, leafReaderContext)
);
}

Expand Down
32 changes: 19 additions & 13 deletions server/src/main/java/org/elasticsearch/script/DocReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@

package org.elasticsearch.script;

import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.script.field.Field;
import org.elasticsearch.search.lookup.Source;
Expand All @@ -26,26 +23,35 @@
* only reads doc-values.
*/
public interface DocReader {
/** New-style field access */
/**
* New-style field access
*/
Field<?> field(String fieldName);

/** New-style field iterator */
/**
* New-style field iterator
*/
Stream<Field<?>> fields(String fieldGlob);

/** Set the underlying docId */
/**
* Set the underlying docId
*/
void setDocument(int docID);

// Compatibility APIS
/** Old-style doc access for contexts that map some doc contents in params */

/**
* Old-style doc access for contexts that map some doc contents in params
*/
Map<String, Object> docAsMap();

/** Old-style doc['field'] access */
/**
* Old-style doc['field'] access
*/
Map<String, ScriptDocValues<?>> doc();

/** Helper for source access */
/**
* Helper for source access
*/
Supplier<Source> source();

Map<Term, TermStatistics> termStatistics();

Map<Term, PostingsEnum> postings(int flags);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
package org.elasticsearch.script;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.script.field.EmptyField;
import org.elasticsearch.script.field.Field;
Expand All @@ -20,7 +17,6 @@
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.search.lookup.Source;

import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Stream;
Expand All @@ -34,20 +30,14 @@ public class DocValuesDocReader implements DocReader, LeafReaderContextSupplier

// provide access to the leaf context reader for expressions
protected final LeafReaderContext leafReaderContext;
private final TermStatsReader termStatsReader;

/** A leaf lookup for the bound segment this proxy will operate on. */
protected LeafSearchLookup leafSearchLookup;

public DocValuesDocReader(SearchLookup searchLookup, LeafReaderContext leafContext) {
this(searchLookup, leafContext, null);
}

public DocValuesDocReader(SearchLookup searchLookup, LeafReaderContext leafContext, TermStatsReader termStatsReader) {
this.searchLookup = searchLookup;
this.leafReaderContext = leafContext;
this.leafSearchLookup = searchLookup.getLeafSearchLookup(leafReaderContext);
this.termStatsReader = termStatsReader;
}

@Override
Expand Down Expand Up @@ -90,22 +80,4 @@ public LeafReaderContext getLeafReaderContext() {
public Supplier<Source> source() {
return leafSearchLookup.source();
}

@Override
public Map<Term, TermStatistics> termStatistics() {
Map<Term, TermStatistics> termStatistics = new HashMap<>();
for (Term term: termStatsReader.terms()) {
termStatistics.put(term, termStatsReader.termStatistics(term));
}
return termStatistics;
}

@Override
public Map<Term, PostingsEnum> postings(int flags) {
Map<Term, PostingsEnum> postings = new HashMap<>();
for (Term term: termStatsReader.terms()) {
postings.put(term, termStatsReader.postings(leafReaderContext, term, flags));
}
return postings;
}
}
23 changes: 15 additions & 8 deletions server/src/main/java/org/elasticsearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
*/
package org.elasticsearch.script;

import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.search.lookup.SearchLookup;
Expand All @@ -21,6 +20,7 @@
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.DoubleSupplier;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -85,6 +85,8 @@ public Explanation get(double score, Explanation subQueryExplanation) {
private int shardId = -1;
private String indexName = null;

private final TermStatsReader termStatsReader;

public ScoreScript(Map<String, Object> params, SearchLookup searchLookup, DocReader docReader) {
// searchLookup parameter is ignored but part of the ScriptFactory contract. It is part of that contract because it's required
// for expressions. Expressions should eventually be transitioned to using DocReader.
Expand All @@ -95,11 +97,15 @@ public ScoreScript(Map<String, Object> params, SearchLookup searchLookup, DocRea
this.params = null;
;
this.docBase = 0;
this.termStatsReader = null;
} else {
params = new HashMap<>(params);
params.putAll(docReader.docAsMap());
this.params = new DynamicMap(params, PARAMS_FUNCTIONS);
this.docBase = ((DocValuesDocReader) docReader).getLeafReaderContext().docBase;
LeafReaderContext leafReaderContext = ((DocValuesDocReader) docReader).getLeafReaderContext();
this.docBase = leafReaderContext.docBase;
this.termStatsReader = searchLookup.getTermStatsReader(leafReaderContext, this::_getDocId);

}
}

Expand Down Expand Up @@ -192,13 +198,12 @@ public void _setIndexName(String indexName) {
this.indexName = indexName;
}


public Map<Term, TermStatistics> _termStatistics() {
return this.docReader.termStatistics();
public void _setTerms(Set<Term> terms) {
this.termStatsReader._setTerms(terms);
}

public Map<Term, PostingsEnum> _postings(int flags) {
return this.docReader.postings(flags);
public TermStatsReader get_termStatistics() {
return termStatsReader;
}


Expand All @@ -210,6 +215,8 @@ public interface LeafFactory {
*/
boolean needs_score();

boolean needs_termStatistics();

ScoreScript newInstance(DocReader reader) throws IOException;
}

Expand Down
Loading

0 comments on commit 8af06b0

Please sign in to comment.