Skip to content

Commit

Permalink
Cap threshold frequency computation in Suggesters
Browse files Browse the repository at this point in the history
This change ensures that the term frequency threshold computed by the term/phrase
suggesters doesn't overflow the maximum allowed value (Integer.MAX_VALUE).

Closes #34282
Relates #34312
  • Loading branch information
jimczi committed Oct 19, 2018
1 parent 1a9c100 commit 147d80f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@
import java.util.List;
import java.util.Set;

import static java.lang.Math.log10;
import static java.lang.Math.min;
import static java.lang.Math.max;
import static java.lang.Math.round;
import static java.lang.Math.log10;

public final class DirectCandidateGenerator extends CandidateGenerator {

Expand Down Expand Up @@ -187,12 +188,14 @@ private double score(long frequency, double errorScore, long dictionarySize) {
return errorScore * (((double)frequency + 1) / ((double)dictionarySize +1));
}

protected long thresholdFrequency(long termFrequency, long dictionarySize) {
// package protected for tests
long thresholdFrequency(long termFrequency, long dictionarySize) {
if (termFrequency > 0) {
return max(0, round(termFrequency * (log10(termFrequency - frequencyPlateau) * (1.0 / log10(LOG_BASE))) + 1));
return min(
max(0, round(termFrequency * (log10(termFrequency - frequencyPlateau) * (1.0 / log10(LOG_BASE))) + 1)), Integer.MAX_VALUE
);
}
return 0;

}

public abstract static class TokenConsumer {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,20 @@

package org.elasticsearch.search.suggest.phrase;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.search.spell.DirectSpellChecker;
import org.apache.lucene.search.spell.JaroWinklerDistance;
import org.apache.lucene.search.spell.LevensteinDistance;
import org.apache.lucene.search.spell.LuceneLevenshteinDistance;
import org.apache.lucene.search.spell.NGramDistance;
import org.apache.lucene.search.spell.SuggestMode;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContent;
Expand All @@ -33,7 +42,6 @@
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.search.suggest.phrase.PhraseSuggestionContext.DirectCandidateGenerator;
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
Expand Down Expand Up @@ -143,7 +151,8 @@ public void testFromXContent() throws IOException {
}
}

public static void assertEqualGenerators(DirectCandidateGenerator first, DirectCandidateGenerator second) {
public static void assertEqualGenerators(PhraseSuggestionContext.DirectCandidateGenerator first,
PhraseSuggestionContext.DirectCandidateGenerator second) {
assertEquals(first.field(), second.field());
assertEquals(first.accuracy(), second.accuracy(), Float.MIN_VALUE);
assertEquals(first.maxTermFreq(), second.maxTermFreq(), Float.MIN_VALUE);
Expand Down Expand Up @@ -195,6 +204,56 @@ public void testIllegalXContent() throws IOException {
"[direct_generator] size doesn't support values of type: START_ARRAY");
}

public void testFrequencyThreshold() throws Exception {
try (Directory dir = newDirectory()) {
IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig());
long numDocs = randomIntBetween(10, 20);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
if (i == 0) {
doc.add(new TextField("field", "fooz", Field.Store.NO));
} else {
doc.add(new TextField("field", "foo", Field.Store.NO));
}
writer.addDocument(doc);
}
try (IndexReader reader = DirectoryReader.open(writer)) {
writer.close();
DirectSpellChecker spellchecker = new DirectSpellChecker();
DirectCandidateGenerator generator = new DirectCandidateGenerator(spellchecker, "field", SuggestMode.SUGGEST_MORE_POPULAR,
reader, 0f, 10);
DirectCandidateGenerator.CandidateSet candidateSet =
generator.drawCandidates(new DirectCandidateGenerator.CandidateSet(DirectCandidateGenerator.Candidate.EMPTY,
generator.createCandidate(new BytesRef("fooz"), false)));
assertThat(candidateSet.candidates.length, equalTo(1));
assertThat(candidateSet.candidates[0].frequency, equalTo(numDocs - 1));
// test that it doesn't overflow
assertThat(generator.thresholdFrequency(Integer.MAX_VALUE, -1), equalTo((long) Integer.MAX_VALUE));
spellchecker = new DirectSpellChecker();
spellchecker.setThresholdFrequency(0.5f);
generator = new DirectCandidateGenerator(spellchecker, "field", SuggestMode.SUGGEST_MORE_POPULAR,
reader, 0f, 10);
candidateSet =
generator.drawCandidates(new DirectCandidateGenerator.CandidateSet(DirectCandidateGenerator.Candidate.EMPTY,
generator.createCandidate(new BytesRef("fooz"), false)));
assertThat(candidateSet.candidates.length, equalTo(1));
assertThat(candidateSet.candidates[0].frequency, equalTo(numDocs - 1));
// test that it doesn't overflow
assertThat(generator.thresholdFrequency(Integer.MAX_VALUE, -1), equalTo((long) Integer.MAX_VALUE));
spellchecker = new DirectSpellChecker();
spellchecker.setThresholdFrequency(0.5f);
generator = new DirectCandidateGenerator(spellchecker, "field", SuggestMode.SUGGEST_ALWAYS,
reader, 0f, 10);
candidateSet =
generator.drawCandidates(new DirectCandidateGenerator.CandidateSet(DirectCandidateGenerator.Candidate.EMPTY,
generator.createCandidate(new BytesRef("fooz"), false)));
assertThat(candidateSet.candidates.length, equalTo(1));
// test that it doesn't overflow
assertThat(generator.thresholdFrequency(Integer.MAX_VALUE, -1), equalTo((long) Integer.MAX_VALUE));
}
}
}

private void assertIllegalXContent(String directGenerator, Class<? extends Exception> exceptionClass, String exceptionMsg)
throws IOException {
XContentParser parser = createParser(JsonXContent.jsonXContent, directGenerator);
Expand Down

0 comments on commit 147d80f

Please sign in to comment.