Skip to content

Commit

Permalink
Fix threshold frequency computation in Suggesters (#34312)
Browse files Browse the repository at this point in the history
The `term` and `phrase` suggesters have different options to filter candidates
based on their frequencies. The `popular` mode for instance filters candidate
terms that occur in less docs than the original term. However when we compute this threshold
we use the total term frequency of a term instead of the document frequency. This is not inline
with the actual filtering which is always based on the document frequency. This change fixes
this discrepancy and clarifies the meaning of the different frequencies in use in the suggesters.
It also ensures that the threshold doesn't overflow the maximum allowed value (Integer.MAX_VALUE).

Closes #34282
  • Loading branch information
jimczi authored and kcm committed Oct 30, 2018
1 parent 86ae7dd commit ea62c16
Show file tree
Hide file tree
Showing 14 changed files with 260 additions and 114 deletions.
5 changes: 0 additions & 5 deletions buildSrc/src/main/resources/checkstyle_suppressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,6 @@
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]completion[/\\]context[/\\]ContextMapping.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]completion[/\\]context[/\\]GeoContextMapping.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]completion[/\\]context[/\\]GeoQueryContext.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]phrase[/\\]CandidateScorer.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]phrase[/\\]NoisyChannelSpellChecker.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]phrase[/\\]WordScorer.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]snapshots[/\\]RestoreService.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]snapshots[/\\]SnapshotShardFailure.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]main[/\\]java[/\\]org[/\\]elasticsearch[/\\]snapshots[/\\]SnapshotShardsService.java" checks="LineLength" />
Expand Down Expand Up @@ -564,7 +561,6 @@
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]index[/\\]store[/\\]CorruptedTranslogIT.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]index[/\\]store[/\\]IndexStoreTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]index[/\\]store[/\\]StoreTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]index[/\\]suggest[/\\]stats[/\\]SuggestStatsIT.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]index[/\\]translog[/\\]TranslogTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]indexing[/\\]IndexActionIT.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]indexlifecycle[/\\]IndexLifecycleActionIT.java" checks="LineLength" />
Expand Down Expand Up @@ -644,7 +640,6 @@
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]ContextCompletionSuggestSearchIT.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]completion[/\\]CategoryContextMappingTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]completion[/\\]GeoContextMappingTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]search[/\\]suggest[/\\]phrase[/\\]NoisyChannelSpellCheckerTests.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]similarity[/\\]SimilarityIT.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]snapshots[/\\]AbstractSnapshotIntegTestCase.java" checks="LineLength" />
<suppress files="server[/\\]src[/\\]test[/\\]java[/\\]org[/\\]elasticsearch[/\\]snapshots[/\\]DedicatedClusterSnapshotRestoreIT.java" checks="LineLength" />
Expand Down
7 changes: 7 additions & 0 deletions docs/reference/migration/migrate_7_0/search.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ removed.
* `levenstein` - replaced by `levenshtein`
* `jarowinkler` - replaced by `jaro_winkler`

[float]
==== `popular` mode for Suggesters

The `popular` mode for Suggesters (`term` and `phrase`) now uses the doc frequency
(instead of the sum of the doc frequency) of the input terms to compute the frequency
threshold for candidate suggestions.

[float]
==== Limiting the number of terms that can be used in a Terms Query request

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.elasticsearch.search.suggest.phrase;

import org.apache.lucene.codecs.TermStats;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.Candidate;
import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.CandidateSet;
Expand All @@ -29,22 +30,22 @@ public abstract class CandidateGenerator {

public abstract boolean isKnownWord(BytesRef term) throws IOException;

public abstract long frequency(BytesRef term) throws IOException;
public abstract TermStats termStats(BytesRef term) throws IOException;

public CandidateSet drawCandidates(BytesRef term) throws IOException {
CandidateSet set = new CandidateSet(Candidate.EMPTY, createCandidate(term, true));
return drawCandidates(set);
}

public Candidate createCandidate(BytesRef term, boolean userInput) throws IOException {
return createCandidate(term, frequency(term), 1.0, userInput);
return createCandidate(term, termStats(term), 1.0, userInput);
}
public Candidate createCandidate(BytesRef term, long frequency, double channelScore) throws IOException {
return createCandidate(term, frequency, channelScore, false);
public Candidate createCandidate(BytesRef term, TermStats termStats, double channelScore) throws IOException {
return createCandidate(term, termStats, channelScore, false);
}

public abstract Candidate createCandidate(BytesRef term, long frequency, double channelScore, boolean userInput) throws IOException;
public abstract Candidate createCandidate(BytesRef term, TermStats termStats,
double channelScore, boolean userInput) throws IOException;

public abstract CandidateSet drawCandidates(CandidateSet set) throws IOException;

}
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,24 @@ public void findCandidates(CandidateSet[] candidates, Candidate[] path, int ord,
} else {
if (numMissspellingsLeft > 0) {
path[ord] = current.originalTerm;
findCandidates(candidates, path, ord + 1, numMissspellingsLeft, corrections, cutoffScore, pathScore + scorer.score(path, candidates, ord, gramSize));
findCandidates(candidates, path, ord + 1, numMissspellingsLeft, corrections, cutoffScore,
pathScore + scorer.score(path, candidates, ord, gramSize));
for (int i = 0; i < current.candidates.length; i++) {
path[ord] = current.candidates[i];
findCandidates(candidates, path, ord + 1, numMissspellingsLeft - 1, corrections, cutoffScore, pathScore + scorer.score(path, candidates, ord, gramSize));
findCandidates(candidates, path, ord + 1, numMissspellingsLeft - 1, corrections, cutoffScore,
pathScore + scorer.score(path, candidates, ord, gramSize));
}
} else {
path[ord] = current.originalTerm;
findCandidates(candidates, path, ord + 1, 0, corrections, cutoffScore, pathScore + scorer.score(path, candidates, ord, gramSize));
findCandidates(candidates, path, ord + 1, 0, corrections, cutoffScore,
pathScore + scorer.score(path, candidates, ord, gramSize));
}
}

}

private void updateTop(CandidateSet[] candidates, Candidate[] path, PriorityQueue<Correction> corrections, double cutoffScore, double score)
throws IOException {
private void updateTop(CandidateSet[] candidates, Candidate[] path,
PriorityQueue<Correction> corrections, double cutoffScore, double score) throws IOException {
score = Math.exp(score);
assert Math.abs(score - score(path, candidates)) < 0.00001 : "cur_score=" + score + ", path_score=" + score(path,candidates);
if (score > cutoffScore) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.codecs.TermStats;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
Expand All @@ -48,6 +49,7 @@

import static java.lang.Math.log10;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.Math.round;

public final class DirectCandidateGenerator extends CandidateGenerator {
Expand All @@ -57,20 +59,20 @@ public final class DirectCandidateGenerator extends CandidateGenerator {
private final SuggestMode suggestMode;
private final TermsEnum termsEnum;
private final IndexReader reader;
private final long dictSize;
private final long sumTotalTermFreq;
private static final double LOG_BASE = 5;
private final long frequencyPlateau;
private final Analyzer preFilter;
private final Analyzer postFilter;
private final double nonErrorLikelihood;
private final boolean useTotalTermFrequency;
private final CharsRefBuilder spare = new CharsRefBuilder();
private final BytesRefBuilder byteSpare = new BytesRefBuilder();
private final int numCandidates;

public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, SuggestMode suggestMode, IndexReader reader,
double nonErrorLikelihood, int numCandidates) throws IOException {
this(spellchecker, field, suggestMode, reader, nonErrorLikelihood, numCandidates, null, null, MultiFields.getTerms(reader, field));
this(spellchecker, field, suggestMode, reader, nonErrorLikelihood,
numCandidates, null, null, MultiFields.getTerms(reader, field));
}

public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, SuggestMode suggestMode, IndexReader reader,
Expand All @@ -83,14 +85,12 @@ public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, S
this.numCandidates = numCandidates;
this.suggestMode = suggestMode;
this.reader = reader;
final long dictSize = terms.getSumTotalTermFreq();
this.useTotalTermFrequency = dictSize != -1;
this.dictSize = dictSize == -1 ? reader.maxDoc() : dictSize;
this.sumTotalTermFreq = terms.getSumTotalTermFreq() == -1 ? reader.maxDoc() : terms.getSumTotalTermFreq();
this.preFilter = preFilter;
this.postFilter = postFilter;
this.nonErrorLikelihood = nonErrorLikelihood;
float thresholdFrequency = spellchecker.getThresholdFrequency();
this.frequencyPlateau = thresholdFrequency >= 1.0f ? (int) thresholdFrequency: (int)(dictSize * thresholdFrequency);
this.frequencyPlateau = thresholdFrequency >= 1.0f ? (int) thresholdFrequency: (int) (reader.maxDoc() * thresholdFrequency);
termsEnum = terms.iterator();
}

Expand All @@ -99,24 +99,29 @@ public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, S
*/
@Override
public boolean isKnownWord(BytesRef term) throws IOException {
return frequency(term) > 0;
return termStats(term).docFreq > 0;
}

/* (non-Javadoc)
* @see org.elasticsearch.search.suggest.phrase.CandidateGenerator#frequency(org.apache.lucene.util.BytesRef)
*/
@Override
public long frequency(BytesRef term) throws IOException {
public TermStats termStats(BytesRef term) throws IOException {
term = preFilter(term, spare, byteSpare);
return internalFrequency(term);
return internalTermStats(term);
}


public long internalFrequency(BytesRef term) throws IOException {
public TermStats internalTermStats(BytesRef term) throws IOException {
if (termsEnum.seekExact(term)) {
return useTotalTermFrequency ? termsEnum.totalTermFreq() : termsEnum.docFreq();
return new TermStats(termsEnum.docFreq(),
/**
* We use the {@link TermsEnum#docFreq()} for fields that don't
* record the {@link TermsEnum#totalTermFreq()}.
*/
termsEnum.totalTermFreq() == -1 ? termsEnum.docFreq() : termsEnum.totalTermFreq());
}
return 0;
return new TermStats(0, 0);
}

public String getField() {
Expand All @@ -127,15 +132,28 @@ public String getField() {
public CandidateSet drawCandidates(CandidateSet set) throws IOException {
Candidate original = set.originalTerm;
BytesRef term = preFilter(original.term, spare, byteSpare);
final long frequency = original.frequency;
spellchecker.setThresholdFrequency(this.suggestMode == SuggestMode.SUGGEST_ALWAYS ? 0 : thresholdFrequency(frequency, dictSize));
if (suggestMode != SuggestMode.SUGGEST_ALWAYS) {
/**
* We use the {@link TermStats#docFreq} to compute the frequency threshold
* because that's what {@link DirectSpellChecker#suggestSimilar} expects
* when filtering terms.
*/
int threshold = thresholdTermFrequency(original.termStats.docFreq);
if (threshold == Integer.MAX_VALUE) {
// the threshold is the max possible frequency so we can skip the search
return set;
}
spellchecker.setThresholdFrequency(threshold);
}

SuggestWord[] suggestSimilar = spellchecker.suggestSimilar(new Term(field, term), numCandidates, reader, this.suggestMode);
List<Candidate> candidates = new ArrayList<>(suggestSimilar.length);
for (int i = 0; i < suggestSimilar.length; i++) {
SuggestWord suggestWord = suggestSimilar[i];
BytesRef candidate = new BytesRef(suggestWord.string);
postFilter(new Candidate(candidate, internalFrequency(candidate), suggestWord.score,
score(suggestWord.freq, suggestWord.score, dictSize), false), spare, byteSpare, candidates);
TermStats termStats = internalTermStats(candidate);
postFilter(new Candidate(candidate, termStats,
suggestWord.score, score(termStats, suggestWord.score, sumTotalTermFreq), false), spare, byteSpare, candidates);
}
set.addCandidates(candidates);
return set;
Expand Down Expand Up @@ -171,28 +189,30 @@ public void nextToken() throws IOException {
BytesRef term = result.toBytesRef();
// We should not use frequency(term) here because it will analyze the term again
// If preFilter and postFilter are the same analyzer it would fail.
long freq = internalFrequency(term);
candidates.add(new Candidate(result.toBytesRef(), freq, candidate.stringDistance,
score(candidate.frequency, candidate.stringDistance, dictSize), false));
TermStats termStats = internalTermStats(term);
candidates.add(new Candidate(result.toBytesRef(), termStats, candidate.stringDistance,
score(candidate.termStats, candidate.stringDistance, sumTotalTermFreq), false));
} else {
candidates.add(new Candidate(result.toBytesRef(), candidate.frequency, nonErrorLikelihood,
score(candidate.frequency, candidate.stringDistance, dictSize), false));
candidates.add(new Candidate(result.toBytesRef(), candidate.termStats, nonErrorLikelihood,
score(candidate.termStats, candidate.stringDistance, sumTotalTermFreq), false));
}
}
}, spare);
}
}

private double score(long frequency, double errorScore, long dictionarySize) {
return errorScore * (((double)frequency + 1) / ((double)dictionarySize +1));
private double score(TermStats termStats, double errorScore, long dictionarySize) {
return errorScore * (((double)termStats.totalTermFreq + 1) / ((double)dictionarySize +1));
}

protected long thresholdFrequency(long termFrequency, long dictionarySize) {
if (termFrequency > 0) {
return max(0, round(termFrequency * (log10(termFrequency - frequencyPlateau) * (1.0 / log10(LOG_BASE))) + 1));
// package protected for test
int thresholdTermFrequency(int docFreq) {
if (docFreq > 0) {
return (int) min(
max(0, round(docFreq * (log10(docFreq - frequencyPlateau) * (1.0 / log10(LOG_BASE))) + 1)), Integer.MAX_VALUE
);
}
return 0;

}

public abstract static class TokenConsumer {
Expand Down Expand Up @@ -249,12 +269,12 @@ public static class Candidate implements Comparable<Candidate> {
public static final Candidate[] EMPTY = new Candidate[0];
public final BytesRef term;
public final double stringDistance;
public final long frequency;
public final TermStats termStats;
public final double score;
public final boolean userInput;

public Candidate(BytesRef term, long frequency, double stringDistance, double score, boolean userInput) {
this.frequency = frequency;
public Candidate(BytesRef term, TermStats termStats, double stringDistance, double score, boolean userInput) {
this.termStats = termStats;
this.term = term;
this.stringDistance = stringDistance;
this.score = score;
Expand All @@ -266,7 +286,7 @@ public String toString() {
return "Candidate [term=" + term.utf8ToString()
+ ", stringDistance=" + stringDistance
+ ", score=" + score
+ ", frequency=" + frequency
+ ", termStats=" + termStats
+ (userInput ? ", userInput" : "") + "]";
}

Expand Down Expand Up @@ -305,8 +325,8 @@ public int compareTo(Candidate other) {
}

@Override
public Candidate createCandidate(BytesRef term, long frequency, double channelScore, boolean userInput) throws IOException {
return new Candidate(term, frequency, channelScore, score(frequency, channelScore, dictSize), userInput);
public Candidate createCandidate(BytesRef term, TermStats termStats, double channelScore, boolean userInput) throws IOException {
return new Candidate(term, termStats, channelScore, score(termStats, channelScore, sumTotalTermFreq), userInput);
}

public static int analyze(Analyzer analyzer, BytesRef toAnalyze, String field, TokenConsumer consumer, CharsRefBuilder spare)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ protected double scoreUnigram(Candidate word) throws IOException {
@Override
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
join(separator, spare, w_1.term, word.term);
return (alpha + frequency(spare.get())) / (w_1.frequency + alpha * numTerms);
return (alpha + frequency(spare.get())) / (w_1.termStats.totalTermFreq + alpha * numTerms);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
if (count < 1) {
return unigramLambda * scoreUnigram(word);
}
return bigramLambda * (count / (0.5d + w_1.frequency)) + unigramLambda * scoreUnigram(word);
return bigramLambda * (count / (0.5d + w_1.termStats.totalTermFreq)) + unigramLambda * scoreUnigram(word);
}

@Override
Expand Down
Loading

0 comments on commit ea62c16

Please sign in to comment.