Skip to content

Commit

Permalink
Fix threshold frequency computation in Suggesters
Browse files Browse the repository at this point in the history
The `term` and `phrase` suggesters have different options to filter candidates
based on their frequencies. The `popular` mode for instance filters candidate
terms that occur in less docs than the original term. However when we compute this threshold
we use the total term frequency of a term instead of the document frequency. This is not inline
with the actual filtering which is always based on the document frequency. This change fixes
this discrepancy and clarifies the meaning of the different frequencies in use in the suggesters.
It also ensures that the threshold doesn't overflow the maximum allowed value (Integer.MAX_VALUE).

Closes elastic#34282
  • Loading branch information
jimczi committed Oct 4, 2018
1 parent 09aaed4 commit f59c1dd
Show file tree
Hide file tree
Showing 14 changed files with 248 additions and 120 deletions.
5 changes: 0 additions & 5 deletions buildSrc/src/main/resources/checkstyle_suppressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,6 @@
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]completion[/\\]context[/\\]ContextMapping.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]completion[/\\]context[/\\]GeoContextMapping.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]completion[/\\]context[/\\]GeoQueryContext.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]phrase[/\\]CandidateScorer.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]phrase[/\\]NoisyChannelSpellChecker.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]phrase[/\\]WordScorer.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]snapshots[/\\]RestoreService.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]snapshots[/\\]SnapshotShardFailure.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]snapshots[/\\]SnapshotShardsService.java" checks="LineLength" />
Expand Down Expand Up @@ -601,7 +598,6 @@
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]index[/\\]store[/\\]DirectoryUtilsTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]index[/\\]store[/\\]IndexStoreTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]index[/\\]store[/\\]StoreTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]index[/\\]suggest[/\\]stats[/\\]SuggestStatsIT.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]index[/\\]translog[/\\]TranslogTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]indexing[/\\]IndexActionIT.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]indexlifecycle[/\\]IndexLifecycleActionIT.java" checks="LineLength" />
Expand Down Expand Up @@ -688,7 +684,6 @@
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]CustomSuggester.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]completion[/\\]CategoryContextMappingTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]completion[/\\]GeoContextMappingTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]phrase[/\\]NoisyChannelSpellCheckerTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]similarity[/\\]SimilarityIT.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]snapshots[/\\]AbstractSnapshotIntegTestCase.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]snapshots[/\\]DedicatedClusterSnapshotRestoreIT.java" checks="LineLength" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.elasticsearch.search.suggest.phrase;

import org.apache.lucene.codecs.TermStats;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.Candidate;
import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.CandidateSet;
Expand All @@ -29,22 +30,21 @@ public abstract class CandidateGenerator {

public abstract boolean isKnownWord(BytesRef term) throws IOException;

public abstract long frequency(BytesRef term) throws IOException;
public abstract TermStats termStats(BytesRef term) throws IOException;

public CandidateSet drawCandidates(BytesRef term) throws IOException {
CandidateSet set = new CandidateSet(Candidate.EMPTY, createCandidate(term, true));
return drawCandidates(set);
}

public Candidate createCandidate(BytesRef term, boolean userInput) throws IOException {
return createCandidate(term, frequency(term), 1.0, userInput);
return createCandidate(term, termStats(term), 1.0, userInput);
}
public Candidate createCandidate(BytesRef term, long frequency, double channelScore) throws IOException {
return createCandidate(term, frequency, channelScore, false);
public Candidate createCandidate(BytesRef term, TermStats termStats, double channelScore) throws IOException {
return createCandidate(term, termStats, channelScore, false);
}

public abstract Candidate createCandidate(BytesRef term, long frequency, double channelScore, boolean userInput) throws IOException;
public abstract Candidate createCandidate(BytesRef term, TermStats termStats, double channelScore, boolean userInput) throws IOException;

public abstract CandidateSet drawCandidates(CandidateSet set) throws IOException;

}
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,24 @@ public void findCandidates(CandidateSet[] candidates, Candidate[] path, int ord,
} else {
if (numMissspellingsLeft > 0) {
path[ord] = current.originalTerm;
findCandidates(candidates, path, ord + 1, numMissspellingsLeft, corrections, cutoffScore, pathScore + scorer.score(path, candidates, ord, gramSize));
findCandidates(candidates, path, ord + 1, numMissspellingsLeft, corrections, cutoffScore,
pathScore + scorer.score(path, candidates, ord, gramSize));
for (int i = 0; i < current.candidates.length; i++) {
path[ord] = current.candidates[i];
findCandidates(candidates, path, ord + 1, numMissspellingsLeft - 1, corrections, cutoffScore, pathScore + scorer.score(path, candidates, ord, gramSize));
findCandidates(candidates, path, ord + 1, numMissspellingsLeft - 1, corrections, cutoffScore,
pathScore + scorer.score(path, candidates, ord, gramSize));
}
} else {
path[ord] = current.originalTerm;
findCandidates(candidates, path, ord + 1, 0, corrections, cutoffScore, pathScore + scorer.score(path, candidates, ord, gramSize));
findCandidates(candidates, path, ord + 1, 0, corrections, cutoffScore,
pathScore + scorer.score(path, candidates, ord, gramSize));
}
}

}

private void updateTop(CandidateSet[] candidates, Candidate[] path, PriorityQueue<Correction> corrections, double cutoffScore, double score)
throws IOException {
private void updateTop(CandidateSet[] candidates, Candidate[] path,
PriorityQueue<Correction> corrections, double cutoffScore, double score) throws IOException {
score = Math.exp(score);
assert Math.abs(score - score(path, candidates)) < 0.00001 : "cur_score=" + score + ", path_score=" + score(path,candidates);
if (score > cutoffScore) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.codecs.TermStats;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
Expand All @@ -48,6 +49,7 @@

import static java.lang.Math.log10;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.Math.round;

public final class DirectCandidateGenerator extends CandidateGenerator {
Expand All @@ -57,20 +59,20 @@ public final class DirectCandidateGenerator extends CandidateGenerator {
private final SuggestMode suggestMode;
private final TermsEnum termsEnum;
private final IndexReader reader;
private final long dictSize;
private final long sumTotalTermFreq;
private static final double LOG_BASE = 5;
private final long frequencyPlateau;
private final Analyzer preFilter;
private final Analyzer postFilter;
private final double nonErrorLikelihood;
private final boolean useTotalTermFrequency;
private final CharsRefBuilder spare = new CharsRefBuilder();
private final BytesRefBuilder byteSpare = new BytesRefBuilder();
private final int numCandidates;

public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, SuggestMode suggestMode, IndexReader reader,
double nonErrorLikelihood, int numCandidates) throws IOException {
this(spellchecker, field, suggestMode, reader, nonErrorLikelihood, numCandidates, null, null, MultiFields.getTerms(reader, field));
this(spellchecker, field, suggestMode, reader, nonErrorLikelihood,
numCandidates, null, null, MultiFields.getTerms(reader, field));
}

public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, SuggestMode suggestMode, IndexReader reader,
Expand All @@ -83,14 +85,12 @@ public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, S
this.numCandidates = numCandidates;
this.suggestMode = suggestMode;
this.reader = reader;
final long dictSize = terms.getSumTotalTermFreq();
this.useTotalTermFrequency = dictSize != -1;
this.dictSize = dictSize == -1 ? reader.maxDoc() : dictSize;
this.sumTotalTermFreq = terms.getSumTotalTermFreq() == -1 ? reader.maxDoc() : terms.getSumTotalTermFreq();
this.preFilter = preFilter;
this.postFilter = postFilter;
this.nonErrorLikelihood = nonErrorLikelihood;
float thresholdFrequency = spellchecker.getThresholdFrequency();
this.frequencyPlateau = thresholdFrequency >= 1.0f ? (int) thresholdFrequency: (int)(dictSize * thresholdFrequency);
this.frequencyPlateau = thresholdFrequency >= 1.0f ? (int) thresholdFrequency: (int) (reader.maxDoc() * thresholdFrequency);
termsEnum = terms.iterator();
}

Expand All @@ -99,24 +99,29 @@ public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, S
*/
@Override
public boolean isKnownWord(BytesRef term) throws IOException {
return frequency(term) > 0;
return termStats(term).docFreq > 0;
}

/* (non-Javadoc)
* @see org.elasticsearch.search.suggest.phrase.CandidateGenerator#frequency(org.apache.lucene.util.BytesRef)
*/
@Override
public long frequency(BytesRef term) throws IOException {
public TermStats termStats(BytesRef term) throws IOException {
term = preFilter(term, spare, byteSpare);
return internalFrequency(term);
return internalTermStats(term);
}


public long internalFrequency(BytesRef term) throws IOException {
public TermStats internalTermStats(BytesRef term) throws IOException {
if (termsEnum.seekExact(term)) {
return useTotalTermFrequency ? termsEnum.totalTermFreq() : termsEnum.docFreq();
return new TermStats(termsEnum.docFreq(),
/**
* We use the {@link TermsEnum#docFreq()} for fields that don't
* record the {@link TermsEnum#totalTermFreq()}.
*/
termsEnum.totalTermFreq() == -1 ? termsEnum.docFreq() : termsEnum.totalTermFreq());
}
return 0;
return new TermStats(0, 0);
}

public String getField() {
Expand All @@ -127,15 +132,28 @@ public String getField() {
public CandidateSet drawCandidates(CandidateSet set) throws IOException {
Candidate original = set.originalTerm;
BytesRef term = preFilter(original.term, spare, byteSpare);
final long frequency = original.frequency;
spellchecker.setThresholdFrequency(this.suggestMode == SuggestMode.SUGGEST_ALWAYS ? 0 : thresholdFrequency(frequency, dictSize));
if (suggestMode != SuggestMode.SUGGEST_ALWAYS) {
/**
* We use the {@link TermStats#docFreq} to compute the frequency threshold
* because that's what {@link DirectSpellChecker#suggestSimilar} expects
* when filtering terms.
*/
int threshold = thresholdTermFrequency(original.termStats.docFreq);
if (threshold == Integer.MAX_VALUE) {
// the threshold is the max possible frequency so we can skip the search
return set;
}
spellchecker.setThresholdFrequency(threshold);
}

SuggestWord[] suggestSimilar = spellchecker.suggestSimilar(new Term(field, term), numCandidates, reader, this.suggestMode);
List<Candidate> candidates = new ArrayList<>(suggestSimilar.length);
for (int i = 0; i < suggestSimilar.length; i++) {
SuggestWord suggestWord = suggestSimilar[i];
BytesRef candidate = new BytesRef(suggestWord.string);
postFilter(new Candidate(candidate, internalFrequency(candidate), suggestWord.score,
score(suggestWord.freq, suggestWord.score, dictSize), false), spare, byteSpare, candidates);
TermStats termStats = internalTermStats(candidate);
postFilter(new Candidate(candidate, termStats,
suggestWord.score, score(termStats, suggestWord.score, sumTotalTermFreq), false), spare, byteSpare, candidates);
}
set.addCandidates(candidates);
return set;
Expand Down Expand Up @@ -171,28 +189,29 @@ public void nextToken() throws IOException {
BytesRef term = result.toBytesRef();
// We should not use frequency(term) here because it will analyze the term again
// If preFilter and postFilter are the same analyzer it would fail.
long freq = internalFrequency(term);
candidates.add(new Candidate(result.toBytesRef(), freq, candidate.stringDistance,
score(candidate.frequency, candidate.stringDistance, dictSize), false));
TermStats termStats = internalTermStats(term);
candidates.add(new Candidate(result.toBytesRef(), termStats, candidate.stringDistance,
score(candidate.termStats, candidate.stringDistance, sumTotalTermFreq), false));
} else {
candidates.add(new Candidate(result.toBytesRef(), candidate.frequency, nonErrorLikelihood,
score(candidate.frequency, candidate.stringDistance, dictSize), false));
candidates.add(new Candidate(result.toBytesRef(), candidate.termStats, nonErrorLikelihood,
score(candidate.termStats, candidate.stringDistance, sumTotalTermFreq), false));
}
}
}, spare);
}
}

private double score(long frequency, double errorScore, long dictionarySize) {
return errorScore * (((double)frequency + 1) / ((double)dictionarySize +1));
private double score(TermStats termStats, double errorScore, long dictionarySize) {
return errorScore * (((double)termStats.totalTermFreq + 1) / ((double)dictionarySize +1));
}

protected long thresholdFrequency(long termFrequency, long dictionarySize) {
if (termFrequency > 0) {
return max(0, round(termFrequency * (log10(termFrequency - frequencyPlateau) * (1.0 / log10(LOG_BASE))) + 1));
protected int thresholdTermFrequency(int docFreq) {
if (docFreq > 0) {
return (int) min(
max(0, round(docFreq * (log10(docFreq - frequencyPlateau) * (1.0 / log10(LOG_BASE))) + 1)), Integer.MAX_VALUE
);
}
return 0;

}

public abstract static class TokenConsumer {
Expand Down Expand Up @@ -249,12 +268,12 @@ public static class Candidate implements Comparable<Candidate> {
public static final Candidate[] EMPTY = new Candidate[0];
public final BytesRef term;
public final double stringDistance;
public final long frequency;
public final TermStats termStats;
public final double score;
public final boolean userInput;

public Candidate(BytesRef term, long frequency, double stringDistance, double score, boolean userInput) {
this.frequency = frequency;
public Candidate(BytesRef term, TermStats termStats, double stringDistance, double score, boolean userInput) {
this.termStats = termStats;
this.term = term;
this.stringDistance = stringDistance;
this.score = score;
Expand All @@ -266,7 +285,7 @@ public String toString() {
return "Candidate [term=" + term.utf8ToString()
+ ", stringDistance=" + stringDistance
+ ", score=" + score
+ ", frequency=" + frequency
+ ", termStats=" + termStats
+ (userInput ? ", userInput" : "") + "]";
}

Expand Down Expand Up @@ -305,8 +324,8 @@ public int compareTo(Candidate other) {
}

@Override
public Candidate createCandidate(BytesRef term, long frequency, double channelScore, boolean userInput) throws IOException {
return new Candidate(term, frequency, channelScore, score(frequency, channelScore, dictSize), userInput);
public Candidate createCandidate(BytesRef term, TermStats termStats, double channelScore, boolean userInput) throws IOException {
return new Candidate(term, termStats, channelScore, score(termStats, channelScore, sumTotalTermFreq), userInput);
}

public static int analyze(Analyzer analyzer, BytesRef toAnalyze, String field, TokenConsumer consumer, CharsRefBuilder spare)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ Integer size() {
* possible values:
* <ol>
* <li><code>score</code> - Sort should first be based on score, then
* document frequency and then the term itself.
* <li><code>frequency</code> - Sort should first be based on document
* frequency, then score and then the term itself.
* document totalTermFrequency and then the term itself.
* <li><code>totalTermFrequency</code> - Sort should first be based on document
* totalTermFrequency, then score and then the term itself.
* </ol>
* <p>
* What the score is depends on the suggester being used.
Expand Down Expand Up @@ -268,8 +268,8 @@ Integer maxInspections() {
* frequencies. If an value higher than 1 is specified then fractional
* can not be specified. Defaults to {@code 0.01}.
* <p>
* This can be used to exclude high frequency terms from being
* suggested. High frequency terms are usually spelled correctly on top
* This can be used to exclude high totalTermFrequency terms from being
* suggested. High totalTermFrequency terms are usually spelled correctly on top
* of this this also improves the suggest performance.
*/
public DirectCandidateGeneratorBuilder maxTermFreq(float maxTermFreq) {
Expand Down Expand Up @@ -313,7 +313,7 @@ Integer minWordLength() {
* Sets a minimal threshold in number of documents a suggested term
* should appear in. This can be specified as an absolute number or as a
* relative percentage of number of documents. This can improve quality
* by only suggesting high frequency terms. Defaults to 0f and is not
* by only suggesting high totalTermFrequency terms. Defaults to 0f and is not
* enabled. If a value higher than 1 is specified then the number cannot
* be fractional.
*/
Expand Down
Loading

0 comments on commit f59c1dd

Please sign in to comment.