Skip to content

Commit

Permalink
Refactor significance heuristic tests for easier extensability (#75264)
Browse files Browse the repository at this point in the history
The significant terms heuristic tests do not lend themselves well for new heuristics being added.

This commit extracts common code and builds an abstract significant heuristic test class.

This way new heuristics get the common suite of tests by extending a test class.
  • Loading branch information
benwtrent authored Jul 13, 2021
1 parent 4410d76 commit 379fad1
Show file tree
Hide file tree
Showing 7 changed files with 334 additions and 186 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,35 @@
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.terms.SignificantTermsAggregatorFactory.ExecutionMode;
import org.elasticsearch.search.aggregations.bucket.terms.heuristic.ChiSquare;
import org.elasticsearch.search.aggregations.bucket.terms.heuristic.GND;
import org.elasticsearch.search.aggregations.bucket.terms.heuristic.JLHScore;
import org.elasticsearch.search.aggregations.bucket.terms.heuristic.MutualInformation;
import org.elasticsearch.search.aggregations.bucket.terms.heuristic.PercentageScore;
import org.elasticsearch.search.aggregations.bucket.terms.heuristic.SignificanceHeuristic;
import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
import org.elasticsearch.search.aggregations.support.ValuesSourceType;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.elasticsearch.search.aggregations.AggregationBuilders.significantTerms;
import static org.hamcrest.Matchers.equalTo;

public class SignificantTermsAggregatorTests extends AggregatorTestCase {

static SignificanceHeuristic getRandomSignificanceheuristic() {
List<SignificanceHeuristic> heuristics = new ArrayList<>();
heuristics.add(new JLHScore());
heuristics.add(new MutualInformation(randomBoolean(), randomBoolean()));
heuristics.add(new GND(randomBoolean()));
heuristics.add(new ChiSquare(randomBoolean(), randomBoolean()));
heuristics.add(new PercentageScore());
return heuristics.get(randomInt(4));
}

@Override
protected AggregationBuilder createAggBuilderForTypeTest(MappedFieldType fieldType, String fieldName) {
return new SignificantTermsAggregationBuilder("foo").field(fieldName);
Expand All @@ -75,10 +93,7 @@ protected List<String> unsupportedMappedFieldTypes() {
);
}

/**
* Uses the significant terms aggregation to find the keywords in text fields
*/
public void testSignificance() throws IOException {
public void testSignificance(SignificanceHeuristic heuristic) throws IOException {
TextFieldType textFieldType = new TextFieldType("text");
textFieldType.setFielddata(true);

Expand Down Expand Up @@ -135,7 +150,7 @@ public void testSignificance() throws IOException {
String evenStrings[] = new String[] {"even", "regular"};

sigAgg.includeExclude(new IncludeExclude(oddStrings, evenStrings));
sigAgg.significanceHeuristic(SignificanceHeuristicTests.getRandomSignificanceheuristic());
sigAgg.significanceHeuristic(heuristic);
terms = searchAndReduce(searcher, new TermQuery(new Term("text", "odd")), sigAgg, textFieldType);
assertThat(terms.getSubsetSize(), equalTo(5L));
assertEquals(1, terms.getBuckets().size());
Expand All @@ -159,6 +174,13 @@ public void testSignificance() throws IOException {
}
}

/**
* Uses the significant terms aggregation to find the keywords in text fields
*/
public void testSignificance() throws IOException {
testSignificance(getRandomSignificanceheuristic());
}

/**
* Uses the significant terms aggregation to find the keywords in numeric
* fields
Expand All @@ -167,8 +189,6 @@ public void testNumericSignificance() throws IOException {
NumberFieldType longFieldType
= new NumberFieldMapper.NumberFieldType("long_field", NumberFieldMapper.NumberType.LONG);

TextFieldType textFieldType = new TextFieldType("text");

IndexWriterConfig indexWriterConfig = newIndexWriterConfig();
indexWriterConfig.setMaxBufferedDocs(100);
indexWriterConfig.setRAMBufferSizeMB(100); // flush on open to have a single segment
Expand Down Expand Up @@ -257,8 +277,6 @@ public void testUnmapped() throws IOException {
*/
public void testRangeField() throws IOException {
RangeType rangeType = RangeType.DOUBLE;
final RangeFieldMapper.Range range1 = new RangeFieldMapper.Range(rangeType, 1.0D, 5.0D, true, true);
final RangeFieldMapper.Range range2 = new RangeFieldMapper.Range(rangeType, 6.0D, 10.0D, true, true);
final String fieldName = "rangeField";
MappedFieldType fieldType = new RangeFieldMapper.RangeFieldType(fieldName, rangeType);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.aggregations.bucket.terms.heuristic;

import org.elasticsearch.search.aggregations.bucket.terms.AbstractSignificanceHeuristicTests;

public class ChiSquareTests extends AbstractSignificanceHeuristicTests {
@Override
protected SignificanceHeuristic getHeuristic() {
return new ChiSquare(randomBoolean(), randomBoolean());
}

@Override
protected boolean testZeroScore() {
return false;
}

@Override
public void testAssertions() {
testBackgroundAssertions(new ChiSquare(true, true), new ChiSquare(true, false));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.aggregations.bucket.terms.heuristic;

import org.elasticsearch.search.aggregations.bucket.terms.AbstractSignificanceHeuristicTests;

import static org.hamcrest.Matchers.equalTo;

public class GNDTests extends AbstractSignificanceHeuristicTests {
@Override
protected SignificanceHeuristic getHeuristic() {
return new GND(randomBoolean());
}

@Override
protected boolean testZeroScore() {
return true;
}

@Override
public void testAssertions() {
testBackgroundAssertions(new GND(true), new GND(false));
}

/**
* term is only in the subset, not at all in the other set but that is because the other set is empty.
* this should actually not happen because only terms that are in the subset are considered now,
* however, in this case the score should be 0 because a term that does not exist cannot be relevant...
*/
public void testGNDCornerCases() {
GND gnd = new GND(true);
assertThat(gnd.getScore(0, randomIntBetween(1, 2), 0, randomIntBetween(2,3)), equalTo(0.0));
// the terms do not co-occur at all - should be 0
assertThat(gnd.getScore(0, randomIntBetween(1, 2), randomIntBetween(2, 3), randomIntBetween(5,6)), equalTo(0.0));
// comparison between two terms that do not exist - probably not relevant
assertThat(gnd.getScore(0, 0, 0, randomIntBetween(1,2)), equalTo(0.0));
// terms co-occur perfectly - should be 1
assertThat(gnd.getScore(1, 1, 1, 1), equalTo(1.0));
gnd = new GND(false);
assertThat(gnd.getScore(0, 0, 0, 0), equalTo(0.0));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.aggregations.bucket.terms.heuristic;

import org.elasticsearch.search.aggregations.bucket.terms.AbstractSignificanceHeuristicTests;

public class JLHScoreTests extends AbstractSignificanceHeuristicTests {
@Override
protected SignificanceHeuristic getHeuristic() {
return new JLHScore();
}

@Override
protected boolean testZeroScore() {
return true;
}

@Override
public void testAssertions() {
testAssertions(new JLHScore());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.aggregations.bucket.terms.heuristic;

import org.elasticsearch.search.aggregations.bucket.terms.AbstractSignificanceHeuristicTests;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

public class MutualInformationTests extends AbstractSignificanceHeuristicTests {
@Override
protected SignificanceHeuristic getHeuristic() {
return new MutualInformation(randomBoolean(), randomBoolean());
}

@Override
protected boolean testZeroScore() {
return false;
}

@Override
public void testAssertions() {
testBackgroundAssertions(new MutualInformation(true, true), new MutualInformation(true, false));
}

public void testScoreMutual() {
SignificanceHeuristic heuristic = new MutualInformation(true, true);
assertThat(heuristic.getScore(1, 1, 1, 3), greaterThan(0.0));
assertThat(heuristic.getScore(1, 1, 2, 3), lessThan(heuristic.getScore(1, 1, 1, 3)));
assertThat(heuristic.getScore(2, 2, 2, 4), equalTo(1.0));
assertThat(heuristic.getScore(0, 2, 2, 4), equalTo(1.0));
assertThat(heuristic.getScore(2, 2, 4, 4), equalTo(0.0));
assertThat(heuristic.getScore(1, 2, 2, 4), equalTo(0.0));
assertThat(heuristic.getScore(3, 6, 9, 18), equalTo(0.0));

double score = 0.0;
try {
long a = randomLong();
long b = randomLong();
long c = randomLong();
long d = randomLong();
score = heuristic.getScore(a, b, c, d);
} catch (IllegalArgumentException e) {
}
assertThat(score, lessThanOrEqualTo(1.0));
assertThat(score, greaterThanOrEqualTo(0.0));
heuristic = new MutualInformation(false, true);
assertThat(heuristic.getScore(0, 1, 2, 3), equalTo(Double.NEGATIVE_INFINITY));

heuristic = new MutualInformation(true, false);
score = heuristic.getScore(2, 3, 1, 4);
assertThat(score, greaterThanOrEqualTo(0.0));
assertThat(score, lessThanOrEqualTo(1.0));
score = heuristic.getScore(1, 4, 2, 3);
assertThat(score, greaterThanOrEqualTo(0.0));
assertThat(score, lessThanOrEqualTo(1.0));
score = heuristic.getScore(1, 3, 4, 4);
assertThat(score, greaterThanOrEqualTo(0.0));
assertThat(score, lessThanOrEqualTo(1.0));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.aggregations.bucket.terms.heuristic;

import org.elasticsearch.search.aggregations.bucket.terms.AbstractSignificanceHeuristicTests;

public class PercentageScoreTests extends AbstractSignificanceHeuristicTests {
@Override
protected SignificanceHeuristic getHeuristic() {
return new PercentageScore();
}

@Override
protected boolean testZeroScore() {
return true;
}

@Override
public void testAssertions() {
testAssertions(new PercentageScore());
}
}

0 comments on commit 379fad1

Please sign in to comment.