diff --git a/server/src/main/java/org/elasticsearch/search/suggest/phrase/DirectCandidateGenerator.java b/server/src/main/java/org/elasticsearch/search/suggest/phrase/DirectCandidateGenerator.java index 678b00aa13dca..d0a8cbf5e2dbb 100644 --- a/server/src/main/java/org/elasticsearch/search/suggest/phrase/DirectCandidateGenerator.java +++ b/server/src/main/java/org/elasticsearch/search/suggest/phrase/DirectCandidateGenerator.java @@ -46,9 +46,10 @@ import java.util.List; import java.util.Set; -import static java.lang.Math.log10; +import static java.lang.Math.min; import static java.lang.Math.max; import static java.lang.Math.round; +import static java.lang.Math.log10; public final class DirectCandidateGenerator extends CandidateGenerator { @@ -187,12 +188,14 @@ private double score(long frequency, double errorScore, long dictionarySize) { return errorScore * (((double)frequency + 1) / ((double)dictionarySize +1)); } - protected long thresholdFrequency(long termFrequency, long dictionarySize) { + // package protected for tests + long thresholdFrequency(long termFrequency, long dictionarySize) { if (termFrequency > 0) { - return max(0, round(termFrequency * (log10(termFrequency - frequencyPlateau) * (1.0 / log10(LOG_BASE))) + 1)); + return min( + max(0, round(termFrequency * (log10(termFrequency - frequencyPlateau) * (1.0 / log10(LOG_BASE))) + 1)), Integer.MAX_VALUE + ); } return 0; - } public abstract static class TokenConsumer { diff --git a/server/src/test/java/org/elasticsearch/search/suggest/phrase/DirectCandidateGeneratorTests.java b/server/src/test/java/org/elasticsearch/search/suggest/phrase/DirectCandidateGeneratorTests.java index 4a07537b123c5..8ee3af043bb84 100644 --- a/server/src/test/java/org/elasticsearch/search/suggest/phrase/DirectCandidateGeneratorTests.java +++ b/server/src/test/java/org/elasticsearch/search/suggest/phrase/DirectCandidateGeneratorTests.java @@ -19,11 +19,20 @@ package org.elasticsearch.search.suggest.phrase; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; import org.apache.lucene.search.spell.DirectSpellChecker; import org.apache.lucene.search.spell.JaroWinklerDistance; import org.apache.lucene.search.spell.LevensteinDistance; import org.apache.lucene.search.spell.LuceneLevenshteinDistance; import org.apache.lucene.search.spell.NGramDistance; +import org.apache.lucene.search.spell.SuggestMode; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContent; @@ -33,7 +42,6 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.search.suggest.phrase.PhraseSuggestionContext.DirectCandidateGenerator; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -143,7 +151,8 @@ public void testFromXContent() throws IOException { } } - public static void assertEqualGenerators(DirectCandidateGenerator first, DirectCandidateGenerator second) { + public static void assertEqualGenerators(PhraseSuggestionContext.DirectCandidateGenerator first, + PhraseSuggestionContext.DirectCandidateGenerator second) { assertEquals(first.field(), second.field()); assertEquals(first.accuracy(), second.accuracy(), Float.MIN_VALUE); assertEquals(first.maxTermFreq(), second.maxTermFreq(), Float.MIN_VALUE); @@ -195,6 +204,56 @@ public void testIllegalXContent() throws IOException { "[direct_generator] size doesn't support values of type: START_ARRAY"); } + public void testFrequencyThreshold() throws Exception { + try (Directory dir = newDirectory()) { + IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); + long numDocs = randomIntBetween(10, 20); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (i == 0) { + doc.add(new TextField("field", "fooz", Field.Store.NO)); + } else { + doc.add(new TextField("field", "foo", Field.Store.NO)); + } + writer.addDocument(doc); + } + try (IndexReader reader = DirectoryReader.open(writer)) { + writer.close(); + DirectSpellChecker spellchecker = new DirectSpellChecker(); + DirectCandidateGenerator generator = new DirectCandidateGenerator(spellchecker, "field", SuggestMode.SUGGEST_MORE_POPULAR, + reader, 0f, 10); + DirectCandidateGenerator.CandidateSet candidateSet = + generator.drawCandidates(new DirectCandidateGenerator.CandidateSet(DirectCandidateGenerator.Candidate.EMPTY, + generator.createCandidate(new BytesRef("fooz"), false))); + assertThat(candidateSet.candidates.length, equalTo(1)); + assertThat(candidateSet.candidates[0].frequency, equalTo(numDocs - 1)); + // test that it doesn't overflow + assertThat(generator.thresholdFrequency(Integer.MAX_VALUE, -1), equalTo((long) Integer.MAX_VALUE)); + spellchecker = new DirectSpellChecker(); + spellchecker.setThresholdFrequency(0.5f); + generator = new DirectCandidateGenerator(spellchecker, "field", SuggestMode.SUGGEST_MORE_POPULAR, + reader, 0f, 10); + candidateSet = + generator.drawCandidates(new DirectCandidateGenerator.CandidateSet(DirectCandidateGenerator.Candidate.EMPTY, + generator.createCandidate(new BytesRef("fooz"), false))); + assertThat(candidateSet.candidates.length, equalTo(1)); + assertThat(candidateSet.candidates[0].frequency, equalTo(numDocs - 1)); + // test that it doesn't overflow + assertThat(generator.thresholdFrequency(Integer.MAX_VALUE, -1), equalTo((long) Integer.MAX_VALUE)); + spellchecker = new DirectSpellChecker(); + spellchecker.setThresholdFrequency(0.5f); + generator = new DirectCandidateGenerator(spellchecker, "field", SuggestMode.SUGGEST_ALWAYS, + reader, 0f, 10); + candidateSet = + generator.drawCandidates(new DirectCandidateGenerator.CandidateSet(DirectCandidateGenerator.Candidate.EMPTY, + generator.createCandidate(new BytesRef("fooz"), false))); + assertThat(candidateSet.candidates.length, equalTo(1)); + // test that it doesn't overflow + assertThat(generator.thresholdFrequency(Integer.MAX_VALUE, -1), equalTo((long) Integer.MAX_VALUE)); + } + } + } + private void assertIllegalXContent(String directGenerator, Class exceptionClass, String exceptionMsg) throws IOException { XContentParser parser = createParser(JsonXContent.jsonXContent, directGenerator);