Skip to content

Commit

Permalink
New version of the script_score term stats helpers.
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed May 14, 2024
1 parent 0b1d71e commit 6dcb376
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@ static_import {
double l2norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$L2Norm
double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct
java.util.DoubleSummaryStatistics documentFrequencyStatistics(org.elasticsearch.script.ScoreScript) bound_to org.elasticsearch.script.TermStatsScriptUtils$DocumentFrequencyStatistics
java.util.DoubleSummaryStatistics termFrequencyStatistics(org.elasticsearch.script.ScoreScript) bound_to org.elasticsearch.script.TermStatsScriptUtils$TermFrequencyStatistics
}

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package org.elasticsearch.common.lucene.search.function;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.DocIdSetIterator;
Expand All @@ -29,10 +31,15 @@
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.ScoreScript.ExplanationHolder;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.TermStatsReader;
import org.elasticsearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/**
* A query that uses a script to compute documents' scores.
Expand Down Expand Up @@ -89,6 +96,8 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
ScoreMode subQueryScoreMode = needsScore ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, 1.0f);

TermStatsReader termStatsReader = createTermStatsReader(searcher);

return new Weight(this) {
@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
Expand Down Expand Up @@ -164,7 +173,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
}

private ScoreScript makeScoreScript(LeafReaderContext context) throws IOException {
final ScoreScript scoreScript = scriptBuilder.newInstance(new DocValuesDocReader(lookup, context));
final ScoreScript scoreScript = scriptBuilder.newInstance(new DocValuesDocReader(lookup, context, termStatsReader));
scoreScript._setIndexName(indexName);
scoreScript._setShard(shardId);
return scoreScript;
Expand All @@ -178,6 +187,25 @@ public boolean isCacheable(LeafReaderContext ctx) {
};
}

private TermStatsReader createTermStatsReader(IndexSearcher searcher) throws IOException {
// We collect the different terms used in the child query.
final Set<Term> terms = new HashSet<>();
this.visit(QueryVisitor.termCollector(terms));

// For each terms build the term states.
final Map<Term, TermStates> termContexts = new HashMap<>();

for (Term term: terms) {
TermStates termStates = TermStates.build(searcher, term, true);
if (termStates != null && termStates.docFreq() > 0) {
termContexts.put(term, termStates);
searcher.termStatistics(term, termStates.docFreq(), termStates.totalTermFreq());
}
}

return new TermStatsReader(searcher, terms, termContexts);
}

@Override
public void visit(QueryVisitor visitor) {
// Highlighters must visit the child query to extract terms
Expand Down
7 changes: 7 additions & 0 deletions server/src/main/java/org/elasticsearch/script/DocReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

package org.elasticsearch.script;

import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.script.field.Field;
import org.elasticsearch.search.lookup.Source;
Expand Down Expand Up @@ -41,4 +44,8 @@ public interface DocReader {

/** Helper for source access */
Supplier<Source> source();

Map<Term, TermStatistics> termStatistics();

Map<Term, PostingsEnum> postings(int flags);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
package org.elasticsearch.script;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.script.field.EmptyField;
import org.elasticsearch.script.field.Field;
Expand All @@ -17,6 +20,7 @@
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.search.lookup.Source;

import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Stream;
Expand All @@ -30,14 +34,20 @@ public class DocValuesDocReader implements DocReader, LeafReaderContextSupplier

// provide access to the leaf context reader for expressions
protected final LeafReaderContext leafReaderContext;
private final TermStatsReader termStatsReader;

/** A leaf lookup for the bound segment this proxy will operate on. */
protected LeafSearchLookup leafSearchLookup;

public DocValuesDocReader(SearchLookup searchLookup, LeafReaderContext leafContext) {
this(searchLookup, leafContext, null);
}

public DocValuesDocReader(SearchLookup searchLookup, LeafReaderContext leafContext, TermStatsReader termStatsReader) {
this.searchLookup = searchLookup;
this.leafReaderContext = leafContext;
this.leafSearchLookup = searchLookup.getLeafSearchLookup(leafReaderContext);
this.termStatsReader = termStatsReader;
}

@Override
Expand Down Expand Up @@ -80,4 +90,22 @@ public LeafReaderContext getLeafReaderContext() {
public Supplier<Source> source() {
return leafSearchLookup.source();
}

@Override
public Map<Term, TermStatistics> termStatistics() {
Map<Term, TermStatistics> termStatistics = new HashMap<>();
for (Term term: termStatsReader.terms()) {
termStatistics.put(term, termStatsReader.termStatistics(term));
}
return termStatistics;
}

@Override
public Map<Term, PostingsEnum> postings(int flags) {
Map<Term, PostingsEnum> postings = new HashMap<>();
for (Term term: termStatsReader.terms()) {
postings.put(term, termStatsReader.postings(leafReaderContext, term, flags));
}
return postings;
}
}
13 changes: 13 additions & 0 deletions server/src/main/java/org/elasticsearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
*/
package org.elasticsearch.script;

import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.search.lookup.SearchLookup;
Expand Down Expand Up @@ -189,6 +192,16 @@ public void _setIndexName(String indexName) {
this.indexName = indexName;
}


public Map<Term, TermStatistics> _termStatistics() {
return this.docReader.termStatistics();
}

public Map<Term, PostingsEnum> _postings(int flags) {
return this.docReader.postings(flags);
}


/** A factory to construct {@link ScoreScript} instances. */
public interface LeafFactory {

Expand Down
74 changes: 74 additions & 0 deletions server/src/main/java/org/elasticsearch/script/TermStatsReader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.script;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TermStatistics;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.Set;

public class TermStatsReader {
private final IndexSearcher searcher;
private final Set<Term> terms;
private final Map<Term, TermStates> termContexts;

public TermStatsReader(IndexSearcher searcher, Set<Term> terms, Map<Term, TermStates> termContexts) {
this.searcher = searcher;
this.terms = terms;
this.termContexts = termContexts;
}

public Set<Term> terms() {
return terms;
}

public TermStatistics termStatistics(Term term) {
try {
if (termContexts.containsKey(term) == false) {
return searcher.termStatistics(term, 0, 0);
}

return searcher.termStatistics(term, termContexts.get(term).docFreq(), termContexts.get(term).totalTermFreq());
} catch (IllegalArgumentException e) {
return null;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

public PostingsEnum postings(LeafReaderContext leafReaderContext, Term term, int flags) {
if (termContexts.containsKey(term) == false) {
return null;
}

try {
TermStates termContext = termContexts.get(term);
TermState state = termContext.get(leafReaderContext);
if (state == null || termContext.docFreq() == 0) {
return null;
}

TermsEnum termsEnum = leafReaderContext.reader().terms(term.field()).iterator();
termsEnum.seekExact(term.bytes(), state);
return termsEnum.postings(null, flags);

} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.script;

import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.search.TermStatistics;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collection;
import java.util.DoubleSummaryStatistics;
import java.util.function.Supplier;

;

public class TermStatsScriptUtils {

public static final class DocumentFrequencyStatistics {
private final Collection<TermStatistics> termsStatistics;

public DocumentFrequencyStatistics(ScoreScript scoreScript) {
this.termsStatistics = scoreScript._termStatistics().values();
}

public DoubleSummaryStatistics documentFrequencyStatistics() {
return termsStatistics.stream().mapToDouble(termStatistics -> termStatistics == null ? 0 : termStatistics.docFreq()).summaryStatistics();
}
}

public static final class TermFrequencyStatistics {
private final Collection<PostingsEnum> postings;
private final Supplier<Integer> docIdSupplier;

public TermFrequencyStatistics(ScoreScript scoreScript) {
postings = scoreScript._postings(PostingsEnum.FREQS).values();
docIdSupplier = scoreScript::_getDocId;
}

public DoubleSummaryStatistics termFrequencyStatistics() {
return postings.stream().mapToDouble(
currentPostings -> {
try {
int docId = docIdSupplier.get();
if (currentPostings == null || currentPostings.advance(docId) != docId) {
return 0;
}
return currentPostings.freq();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
).summaryStatistics();
}
}
}

0 comments on commit 6dcb376

Please sign in to comment.