From 51c6f69e02f07edf4ced256f9a3e880993f0d98d Mon Sep 17 00:00:00 2001 From: Igor Motov Date: Mon, 13 Apr 2020 12:28:58 -0400 Subject: [PATCH] [7.x] Add support for filters to T-Test aggregation (#54980) (#55066) Adds support for filters to T-Test aggregation. The filters can be used to select populations based on some criteria and use values from the same or different fields. Closes #53692 --- docs/build.gradle | 14 +- .../metrics/t-test-aggregation.asciidoc | 75 ++++++++- .../WeightedAvgAggregationBuilder.java | 14 +- .../MultiValuesSourceAggregationBuilder.java | 6 +- .../support/MultiValuesSourceFieldConfig.java | 58 +++++-- .../support/MultiValuesSourceParseHelper.java | 4 +- .../MultiValuesSourceFieldConfigTests.java | 22 ++- .../TopMetricsAggregationBuilder.java | 3 +- .../ttest/TTestAggregationBuilder.java | 25 ++- .../ttest/TTestAggregatorFactory.java | 54 +++++- .../ttest/UnpairedTTestAggregator.java | 33 +++- .../ttest/TTestAggregationBuilderTests.java | 59 ++++--- .../analytics/ttest/TTestAggregatorTests.java | 154 ++++++++++++++++-- .../rest-api-spec/test/analytics/t_test.yml | 72 ++++++++ 14 files changed, 510 insertions(+), 83 deletions(-) create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/analytics/t_test.yml diff --git a/docs/build.gradle b/docs/build.gradle index a60870685c6f9..6dd85ad79fd93 100644 --- a/docs/build.gradle +++ b/docs/build.gradle @@ -548,7 +548,7 @@ buildRestTests.setups['node_upgrade'] = ''' number_of_replicas: 1 mappings: properties: - name: + group: type: keyword startup_time_before: type: long @@ -560,17 +560,17 @@ buildRestTests.setups['node_upgrade'] = ''' refresh: true body: | {"index":{}} - {"name": "A", "startup_time_before": 102, "startup_time_after": 89} + {"group": "A", "startup_time_before": 102, "startup_time_after": 89} {"index":{}} - {"name": "B", "startup_time_before": 99, "startup_time_after": 93} + {"group": "A", "startup_time_before": 99, "startup_time_after": 93} {"index":{}} - {"name": "C", "startup_time_before": 111, "startup_time_after": 72} + {"group": "A", "startup_time_before": 111, "startup_time_after": 72} {"index":{}} - {"name": "D", "startup_time_before": 97, "startup_time_after": 98} + {"group": "B", "startup_time_before": 97, "startup_time_after": 98} {"index":{}} - {"name": "E", "startup_time_before": 101, "startup_time_after": 102} + {"group": "B", "startup_time_before": 101, "startup_time_after": 102} {"index":{}} - {"name": "F", "startup_time_before": 99, "startup_time_after": 98}''' + {"group": "B", "startup_time_before": 99, "startup_time_after": 98}''' // Used by iprange agg buildRestTests.setups['iprange'] = ''' diff --git a/docs/reference/aggregations/metrics/t-test-aggregation.asciidoc b/docs/reference/aggregations/metrics/t-test-aggregation.asciidoc index 342b9733b6b09..4fec228667206 100644 --- a/docs/reference/aggregations/metrics/t-test-aggregation.asciidoc +++ b/docs/reference/aggregations/metrics/t-test-aggregation.asciidoc @@ -1,7 +1,7 @@ [role="xpack"] [testenv="basic"] [[search-aggregations-metrics-ttest-aggregation]] -=== TTest Aggregation +=== T-Test Aggregation A `t_test` metrics aggregation that performs a statistical hypothesis test in which the test statistic follows a Student's t-distribution under the null hypothesis on numeric values extracted from the aggregated documents or generated by provided scripts. In practice, this @@ -43,8 +43,8 @@ GET node_upgrade/_search } -------------------------------------------------- // TEST[setup:node_upgrade] -<1> The field `startup_time_before` must be a numeric field -<2> The field `startup_time_after` must be a numeric field +<1> The field `startup_time_before` must be a numeric field. +<2> The field `startup_time_after` must be a numeric field. <3> Since we have data from the same nodes, we are using paired t-test. The response will return the p-value or probability value for the test. It is the probability of obtaining results at least as extreme as @@ -74,6 +74,69 @@ The `t_test` aggregation supports unpaired and paired two-sample t-tests. The ty `"type": "homoscedastic"`:: performs two-sample equal variance test `"type": "heteroscedastic"`:: performs two-sample unequal variance test (this is default) +==== Filters + +It is also possible to run unpaired t-test on different sets of records using filters. For example, if we want to test the difference +of startup times before upgrade between two different groups of nodes, we use the same field `startup_time_before` by separate groups of +nodes using terms filters on the group name field: + +[source,console] +-------------------------------------------------- +GET node_upgrade/_search +{ + "size" : 0, + "aggs" : { + "startup_time_ttest" : { + "t_test" : { + "a" : { + "field" : "startup_time_before", <1> + "filter" : { + "term" : { + "group" : "A" <2> + } + } + }, + "b" : { + "field" : "startup_time_before", <3> + "filter" : { + "term" : { + "group" : "B" <4> + } + } + }, + "type" : "heteroscedastic" <5> + } + } + } +} +-------------------------------------------------- +// TEST[setup:node_upgrade] +<1> The field `startup_time_before` must be a numeric field. +<2> Any query that separates two groups can be used here. +<3> We are using the same field +<4> but we are using different filters. +<5> Since we have data from different nodes, we cannot use paired t-test. + + +[source,console-result] +-------------------------------------------------- +{ + ... + + "aggregations": { + "startup_time_ttest": { + "value": 0.2981858007281437 <1> + } + } +} +-------------------------------------------------- +// TESTRESPONSE[s/\.\.\./"took": $body.took,"timed_out": false,"_shards": $body._shards,"hits": $body.hits,/] +<1> The p-value. + +In this example, we are using the same fields for both populations. However this is not a requirement and different fields and even +combination of fields and scripts can be used. Populations don't have to be in the same index either. If data sets are located in different +indices, the term filter on the <> field can be used to select populations. + ==== Script The `t_test` metric supports scripting. For example, if we need to adjust out load times for the before values, we could use @@ -108,7 +171,7 @@ GET node_upgrade/_search // TEST[setup:node_upgrade] <1> The `field` parameter is replaced with a `script` parameter, which uses the -script to generate values which percentiles are calculated on -<2> Scripting supports parameterized input just like any other script -<3> We can mix scripts and fields +script to generate values which percentiles are calculated on. +<2> Scripting supports parameterized input just like any other script. +<3> We can mix scripts and fields. diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/WeightedAvgAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/WeightedAvgAggregationBuilder.java index 457cd2fb3755f..7d2749f6df050 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/WeightedAvgAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/WeightedAvgAggregationBuilder.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.aggregations.AggregationBuilder; @@ -51,8 +52,8 @@ public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationB ObjectParser.fromBuilder(NAME, WeightedAvgAggregationBuilder::new); static { MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC); - MultiValuesSourceParseHelper.declareField(VALUE_FIELD.getPreferredName(), PARSER, true, false); - MultiValuesSourceParseHelper.declareField(WEIGHT_FIELD.getPreferredName(), PARSER, true, false); + MultiValuesSourceParseHelper.declareField(VALUE_FIELD.getPreferredName(), PARSER, true, false, false); + MultiValuesSourceParseHelper.declareField(WEIGHT_FIELD.getPreferredName(), PARSER, true, false, false); } public WeightedAvgAggregationBuilder(String name) { @@ -99,10 +100,11 @@ public BucketCardinality bucketCardinality() { @Override protected MultiValuesSourceAggregatorFactory innerBuild(QueryShardContext queryShardContext, - Map> configs, - DocValueFormat format, - AggregatorFactory parent, - Builder subFactoriesBuilder) throws IOException { + Map> configs, + Map filters, + DocValueFormat format, + AggregatorFactory parent, + Builder subFactoriesBuilder) throws IOException { return new WeightedAvgAggregatorFactory(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metadata); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java index 989400f6a8a4f..c64b83b6a9223 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java @@ -22,6 +22,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; @@ -168,13 +169,15 @@ protected final MultiValuesSourceAggregatorFactory doBuild(QueryShardContext ValueType finalValueType = this.valueType != null ? this.valueType : targetValueType; Map> configs = new HashMap<>(fields.size()); + Map filters = new HashMap<>(fields.size()); fields.forEach((key, value) -> { ValuesSourceConfig config = ValuesSourceConfig.resolve(queryShardContext, finalValueType, value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format); configs.put(key, config); + filters.put(key, value.getFilter()); }); DocValueFormat docValueFormat = resolveFormat(format, finalValueType); - return innerBuild(queryShardContext, configs, docValueFormat, parent, subFactoriesBuilder); + return innerBuild(queryShardContext, configs, filters, docValueFormat, parent, subFactoriesBuilder); } @@ -191,6 +194,7 @@ private static DocValueFormat resolveFormat(@Nullable String format, @Nullable V protected abstract MultiValuesSourceAggregatorFactory innerBuild(QueryShardContext queryShardContext, Map> configs, + Map filters, DocValueFormat format, AggregatorFactory parent, Builder subFactoriesBuilder) throws IOException; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java index 80af28e61bc66..35975463baff3 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java @@ -30,26 +30,30 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.script.Script; import java.io.IOException; import java.time.ZoneId; import java.time.ZoneOffset; import java.util.Objects; -import java.util.function.BiFunction; public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject { - private String fieldName; - private Object missing; - private Script script; - private ZoneId timeZone; + private final String fieldName; + private final Object missing; + private final Script script; + private final ZoneId timeZone; + private final QueryBuilder filter; private static final String NAME = "field_config"; - public static final BiFunction> PARSER - = (scriptable, timezoneAware) -> { + public static final ParseField FILTER = new ParseField("filter"); - ObjectParser parser + public static ObjectParser parserBuilder(boolean scriptable, boolean timezoneAware, + boolean filtered) { + + ObjectParser parser = new ObjectParser<>(MultiValuesSourceFieldConfig.NAME, MultiValuesSourceFieldConfig.Builder::new); parser.declareString(MultiValuesSourceFieldConfig.Builder::setFieldName, ParseField.CommonFields.FIELD); @@ -71,14 +75,21 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject } }, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG); } + + if (filtered) { + parser.declareField(MultiValuesSourceFieldConfig.Builder::setFilter, + (p, context) -> AbstractQueryBuilder.parseInnerQueryBuilder(p), + FILTER, ObjectParser.ValueType.OBJECT); + } return parser; }; - private MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, ZoneId timeZone) { + protected MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, ZoneId timeZone, QueryBuilder filter) { this.fieldName = fieldName; this.missing = missing; this.script = script; this.timeZone = timeZone; + this.filter = filter; } public MultiValuesSourceFieldConfig(StreamInput in) throws IOException { @@ -94,6 +105,11 @@ public MultiValuesSourceFieldConfig(StreamInput in) throws IOException { } else { this.timeZone = in.readOptionalZoneId(); } + if (in.getVersion().onOrAfter(Version.V_7_8_0)) { + this.filter = in.readOptionalNamedWriteable(QueryBuilder.class); + } else { + this.filter = null; + } } public Object getMissing() { @@ -112,6 +128,10 @@ public String getFieldName() { return fieldName; } + public QueryBuilder getFilter() { + return filter; + } + @Override public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_7_6_0)) { @@ -126,6 +146,9 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeOptionalZoneId(timeZone); } + if (out.getVersion().onOrAfter(Version.V_7_8_0)) { + out.writeOptionalNamedWriteable(filter); + } } @Override @@ -143,6 +166,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (timeZone != null) { builder.field(ParseField.CommonFields.TIME_ZONE.getPreferredName(), timeZone.getId()); } + if (filter != null) { + builder.field(FILTER.getPreferredName()); + filter.toXContent(builder, params); + } builder.endObject(); return builder; } @@ -155,12 +182,13 @@ public boolean equals(Object o) { return Objects.equals(fieldName, that.fieldName) && Objects.equals(missing, that.missing) && Objects.equals(script, that.script) - && Objects.equals(timeZone, that.timeZone); + && Objects.equals(timeZone, that.timeZone) + && Objects.equals(filter, that.filter); } @Override public int hashCode() { - return Objects.hash(fieldName, missing, script, timeZone); + return Objects.hash(fieldName, missing, script, timeZone, filter); } @Override @@ -173,6 +201,7 @@ public static class Builder { private Object missing = null; private Script script = null; private ZoneId timeZone = null; + private QueryBuilder filter = null; public String getFieldName() { return fieldName; @@ -210,6 +239,11 @@ public Builder setTimeZone(ZoneId timeZone) { return this; } + public Builder setFilter(QueryBuilder filter) { + this.filter = filter; + return this; + } + public MultiValuesSourceFieldConfig build() { if (Strings.isNullOrEmpty(fieldName) && script == null) { throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName() @@ -223,7 +257,7 @@ public MultiValuesSourceFieldConfig build() { "Please specify one or the other."); } - return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone); + return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone, filter); } } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java index 4888495f9d8da..1e30341909704 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java @@ -50,10 +50,10 @@ public static void declareCommon( public static void declareField(String fieldName, AbstractObjectParser, T> objectParser, - boolean scriptable, boolean timezoneAware) { + boolean scriptable, boolean timezoneAware, boolean filterable) { objectParser.declareField((o, fieldConfig) -> o.field(fieldName, fieldConfig.build()), - (p, c) -> MultiValuesSourceFieldConfig.PARSER.apply(scriptable, timezoneAware).parse(p, null), + (p, c) -> MultiValuesSourceFieldConfig.parserBuilder(scriptable, timezoneAware, filterable).parse(p, null), new ParseField(fieldName), ObjectParser.ValueType.OBJECT); } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java index b929f222d94fe..d80d540d2255f 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java @@ -19,13 +19,20 @@ package org.elasticsearch.search.aggregations.support; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.script.Script; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; import java.time.ZoneId; +import java.util.Collections; import static org.hamcrest.Matchers.equalTo; @@ -33,7 +40,7 @@ public class MultiValuesSourceFieldConfigTests extends AbstractSerializingTestCa @Override protected MultiValuesSourceFieldConfig doParseInstance(XContentParser parser) throws IOException { - return MultiValuesSourceFieldConfig.PARSER.apply(true, true).apply(parser, null).build(); + return MultiValuesSourceFieldConfig.parserBuilder(true, true, true).apply(parser, null).build(); } @Override @@ -41,8 +48,9 @@ protected MultiValuesSourceFieldConfig createTestInstance() { String field = randomAlphaOfLength(10); Object missing = randomBoolean() ? randomAlphaOfLength(10) : null; ZoneId timeZone = randomBoolean() ? randomZone() : null; + QueryBuilder filter = randomBoolean() ? QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)) : null; return new MultiValuesSourceFieldConfig.Builder() - .setFieldName(field).setMissing(missing).setScript(null).setTimeZone(timeZone).build(); + .setFieldName(field).setMissing(missing).setScript(null).setTimeZone(timeZone).setFilter(filter).build(); } @Override @@ -60,4 +68,14 @@ public void testBothFieldScript() { () -> new MultiValuesSourceFieldConfig.Builder().setFieldName("foo").setScript(new Script("foo")).build()); assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be configured. Please specify one or the other.")); } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + } } diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/topmetrics/TopMetricsAggregationBuilder.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/topmetrics/TopMetricsAggregationBuilder.java index 4d01b678f1121..693e38e6bf6ee 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/topmetrics/TopMetricsAggregationBuilder.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/topmetrics/TopMetricsAggregationBuilder.java @@ -55,7 +55,8 @@ public class TopMetricsAggregationBuilder extends AbstractAggregationBuilder SortBuilder.fromXContent(p), SORT_FIELD, ObjectParser.ValueType.OBJECT_ARRAY_OR_STRING); PARSER.declareInt(optionalConstructorArg(), SIZE_FIELD); - ContextParser metricParser = MultiValuesSourceFieldConfig.PARSER.apply(true, false); + ContextParser metricParser = + MultiValuesSourceFieldConfig.parserBuilder(true, false, false); PARSER.declareObjectArray(constructorArg(), (p, n) -> metricParser.parse(p, null).build(), METRIC_FIELD); } diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilder.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilder.java index a03a34d6e3f1b..1c843b2c82f8e 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilder.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilder.java @@ -12,11 +12,13 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.support.FieldContext; import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder; import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory; import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; @@ -41,11 +43,10 @@ public class TTestAggregationBuilder extends MultiValuesSourceAggregationBuilder static { MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC); - MultiValuesSourceParseHelper.declareField(A_FIELD.getPreferredName(), PARSER, true, false); - MultiValuesSourceParseHelper.declareField(B_FIELD.getPreferredName(), PARSER, true, false); + MultiValuesSourceParseHelper.declareField(A_FIELD.getPreferredName(), PARSER, true, false, true); + MultiValuesSourceParseHelper.declareField(B_FIELD.getPreferredName(), PARSER, true, false, true); PARSER.declareString(TTestAggregationBuilder::testType, TYPE_FIELD); PARSER.declareInt(TTestAggregationBuilder::tails, TAILS_FIELD); - } private TTestType testType = TTestType.HETEROSCEDASTIC; @@ -117,10 +118,26 @@ protected void innerWriteTo(StreamOutput out) throws IOException { protected MultiValuesSourceAggregatorFactory innerBuild( QueryShardContext queryShardContext, Map> configs, + Map filters, DocValueFormat format, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder) throws IOException { - return new TTestAggregatorFactory(name, configs, testType, tails, format, queryShardContext, parent, subFactoriesBuilder, metadata); + QueryBuilder filterA = filters.get(A_FIELD.getPreferredName()); + QueryBuilder filterB = filters.get(B_FIELD.getPreferredName()); + if (filterA == null && filterB == null) { + FieldContext fieldContextA = configs.get(A_FIELD.getPreferredName()).fieldContext(); + FieldContext fieldContextB = configs.get(B_FIELD.getPreferredName()).fieldContext(); + if (fieldContextA != null && fieldContextB != null) { + if (fieldContextA.field().equals(fieldContextB.field())) { + throw new IllegalArgumentException("The same field [" + fieldContextA.field() + + "] is used for both population but no filters are specified."); + } + } + } + + return new TTestAggregatorFactory(name, configs, testType, tails, + filterA, filterB, format, queryShardContext, parent, + subFactoriesBuilder, metadata); } @Override diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregatorFactory.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregatorFactory.java index 848122051dcda..2c25d2a1fb207 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregatorFactory.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregatorFactory.java @@ -6,8 +6,15 @@ package org.elasticsearch.xpack.analytics.ttest; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.AggregationInitializationException; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactory; @@ -24,14 +31,20 @@ class TTestAggregatorFactory extends MultiValuesSourceAggregatorFactory weights; TTestAggregatorFactory(String name, Map> configs, TTestType testType, int tails, + QueryBuilder filterA, QueryBuilder filterB, DocValueFormat format, QueryShardContext queryShardContext, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, Map metadata) throws IOException { super(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metadata); this.testType = testType; this.tails = tails; + this.filterA = filterA == null ? null : filterA.toQuery(queryShardContext); + this.filterB = filterB == null ? null : filterB.toQuery(queryShardContext); } @Override @@ -42,9 +55,9 @@ protected Aggregator createUnmapped(SearchContext searchContext, case PAIRED: return new PairedTTestAggregator(name, null, tails, format, searchContext, parent, metadata); case HOMOSCEDASTIC: - return new UnpairedTTestAggregator(name, null, tails, true, format, searchContext, parent, metadata); + return new UnpairedTTestAggregator(name, null, tails, true, this::getWeights, format, searchContext, parent, metadata); case HETEROSCEDASTIC: - return new UnpairedTTestAggregator(name, null, tails, false, format, searchContext, parent, metadata); + return new UnpairedTTestAggregator(name, null, tails, false, this::getWeights, format, searchContext, parent, metadata); default: throw new IllegalArgumentException("Unsupported t-test type " + testType); } @@ -64,13 +77,46 @@ protected Aggregator doCreateInternal(SearchContext searchContext, } switch (testType) { case PAIRED: + if (filterA != null || filterB != null) { + throw new IllegalArgumentException("Paired t-test doesn't support filters"); + } return new PairedTTestAggregator(name, numericMultiVS, tails, format, searchContext, parent, metadata); case HOMOSCEDASTIC: - return new UnpairedTTestAggregator(name, numericMultiVS, tails, true, format, searchContext, parent, metadata); + return new UnpairedTTestAggregator(name, numericMultiVS, tails, true, this::getWeights, format, searchContext, parent, + metadata); case HETEROSCEDASTIC: - return new UnpairedTTestAggregator(name, numericMultiVS, tails, false, format, searchContext, parent, metadata); + return new UnpairedTTestAggregator(name, numericMultiVS, tails, false, this::getWeights, format, searchContext, + parent, metadata); default: throw new IllegalArgumentException("Unsupported t-test type " + testType); } } + + /** + * Returns the {@link Weight}s for this filters, creating it if + * necessary. This is done lazily so that the {@link Weight} is only created + * if the aggregation collects documents reducing the overhead of the + * aggregation in the case where no documents are collected. + * + * Note that as aggregations are initialsed and executed in a serial manner, + * no concurrency considerations are necessary here. + */ + public Tuple getWeights() { + if (weights == null) { + weights = new Tuple<>(getWeight(filterA), getWeight(filterB)); + } + return weights; + } + + public Weight getWeight(Query filter) { + if (filter != null) { + IndexSearcher contextSearcher = queryShardContext.searcher(); + try { + return contextSearcher.createWeight(contextSearcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1f); + } catch (IOException e) { + throw new AggregationInitializationException("Failed to initialize filter", e); + } + } + return null; + } } diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/UnpairedTTestAggregator.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/UnpairedTTestAggregator.java index a03c5a0d63f9b..95bf109dc6186 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/UnpairedTTestAggregator.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/UnpairedTTestAggregator.java @@ -7,7 +7,11 @@ package org.elasticsearch.xpack.analytics.ttest; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.lease.Releasables; +import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; import org.elasticsearch.search.DocValueFormat; @@ -20,6 +24,7 @@ import java.io.IOException; import java.util.Map; +import java.util.function.Supplier; import static org.elasticsearch.xpack.analytics.ttest.TTestAggregationBuilder.A_FIELD; import static org.elasticsearch.xpack.analytics.ttest.TTestAggregationBuilder.B_FIELD; @@ -28,14 +33,16 @@ public class UnpairedTTestAggregator extends TTestAggregator private final TTestStatsBuilder a; private final TTestStatsBuilder b; private final boolean homoscedastic; + private final Supplier> weightsSupplier; UnpairedTTestAggregator(String name, MultiValuesSource.NumericMultiValuesSource valuesSources, int tails, boolean homoscedastic, - DocValueFormat format, SearchContext context, Aggregator parent, - Map metadata) throws IOException { + Supplier> weightsSupplier, DocValueFormat format, SearchContext context, + Aggregator parent, Map metadata) throws IOException { super(name, valuesSources, tails, format, context, parent, metadata); BigArrays bigArrays = context.bigArrays(); a = new TTestStatsBuilder(bigArrays); b = new TTestStatsBuilder(bigArrays); + this.weightsSupplier = weightsSupplier; this.homoscedastic = homoscedastic; } @@ -67,6 +74,9 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final CompensatedSum compSumOfSqrA = new CompensatedSum(0, 0); final CompensatedSum compSumB = new CompensatedSum(0, 0); final CompensatedSum compSumOfSqrB = new CompensatedSum(0, 0); + final Tuple weights = weightsSupplier.get(); + final Bits bitsA = getBits(ctx, weights.v1()); + final Bits bitsB = getBits(ctx, weights.v2()); return new LeafBucketCollectorBase(sub, docAValues) { @@ -82,14 +92,25 @@ private void processValues(int doc, long bucket, SortedNumericDoubleValues docVa @Override public void collect(int doc, long bucket) throws IOException { - a.grow(bigArrays, bucket + 1); - b.grow(bigArrays, bucket + 1); - processValues(doc, bucket, docAValues, compSumA, compSumOfSqrA, a); - processValues(doc, bucket, docBValues, compSumB, compSumOfSqrB, b); + if (bitsA == null || bitsA.get(doc)) { + a.grow(bigArrays, bucket + 1); + processValues(doc, bucket, docAValues, compSumA, compSumOfSqrA, a); + } + if (bitsB == null || bitsB.get(doc)) { + processValues(doc, bucket, docBValues, compSumB, compSumOfSqrB, b); + b.grow(bigArrays, bucket + 1); + } } }; } + private Bits getBits(LeafReaderContext ctx, Weight weight) throws IOException { + if (weight == null) { + return null; + } + return Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), weight.scorerSupplier(ctx)); + } + @Override public void doClose() { Releasables.close(a, b); diff --git a/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilderTests.java b/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilderTests.java index 69a228c49fe4d..32c2e3823f3eb 100644 --- a/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilderTests.java +++ b/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilderTests.java @@ -7,10 +7,14 @@ package org.elasticsearch.xpack.analytics.ttest; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.script.Script; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.BaseAggregationBuilder; import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; @@ -18,8 +22,10 @@ import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; -import static java.util.Collections.singletonList; import static org.hamcrest.Matchers.hasSize; public class TTestAggregationBuilderTests extends AbstractSerializingTestCase { @@ -30,14 +36,6 @@ public void setupName() { aggregationName = randomAlphaOfLength(10); } - @Override - protected NamedXContentRegistry xContentRegistry() { - return new NamedXContentRegistry(singletonList(new NamedXContentRegistry.Entry( - BaseAggregationBuilder.class, - new ParseField(TTestAggregationBuilder.NAME), - (p, n) -> TTestAggregationBuilder.PARSER.apply(p, (String) n)))); - } - @Override protected TTestAggregationBuilder doParseInstance(XContentParser parser) throws IOException { assertSame(XContentParser.Token.START_OBJECT, parser.nextToken()); @@ -52,26 +50,33 @@ protected TTestAggregationBuilder doParseInstance(XContentParser parser) throws @Override protected TTestAggregationBuilder createTestInstance() { - MultiValuesSourceFieldConfig aConfig; + MultiValuesSourceFieldConfig.Builder aConfig; + TTestType tTestType = randomFrom(TTestType.values()); if (randomBoolean()) { - aConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("a_field").build(); + aConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("a_field"); } else { - aConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10))).build(); + aConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10))); } - MultiValuesSourceFieldConfig bConfig; + MultiValuesSourceFieldConfig.Builder bConfig; if (randomBoolean()) { - bConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("b_field").build(); + bConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("b_field"); } else { - bConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10))).build(); + bConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10))); + } + if (tTestType != TTestType.PAIRED && randomBoolean()) { + aConfig.setFilter(QueryBuilders.queryStringQuery(randomAlphaOfLength(10))); + } + if (tTestType != TTestType.PAIRED && randomBoolean()) { + bConfig.setFilter(QueryBuilders.queryStringQuery(randomAlphaOfLength(10))); } TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder(aggregationName) - .a(aConfig) - .b(bConfig); + .a(aConfig.build()) + .b(bConfig.build()); if (randomBoolean()) { aggregationBuilder.tails(randomIntBetween(1, 2)); } - if (randomBoolean()) { - aggregationBuilder.testType(randomFrom(TTestType.values())); + if (tTestType != TTestType.HETEROSCEDASTIC || randomBoolean()) { + aggregationBuilder.testType(randomFrom(tTestType)); } return aggregationBuilder; } @@ -80,5 +85,21 @@ protected TTestAggregationBuilder createTestInstance() { protected Writeable.Reader instanceReader() { return TTestAggregationBuilder::new; } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.add(new NamedXContentRegistry.Entry( + BaseAggregationBuilder.class, + new ParseField(TTestAggregationBuilder.NAME), + (p, n) -> TTestAggregationBuilder.PARSER.apply(p, (String) n))); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } } diff --git a/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregatorTests.java b/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregatorTests.java index c21f587594a88..f7414a3b558e6 100644 --- a/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregatorTests.java +++ b/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregatorTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.analytics.ttest; +import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.index.DirectoryReader; @@ -21,6 +22,7 @@ import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.script.MockScriptEngine; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptEngine; @@ -56,11 +58,15 @@ public class TTestAggregatorTests extends AggregatorTestCase { */ public static final String ADD_HALF_SCRIPT = "add_one"; + public static final String TERM_FILTERING = "term_filtering"; + @Override protected AggregationBuilder createAggBuilderForTypeTest(MappedFieldType fieldType, String fieldName) { return new TTestAggregationBuilder("foo") - .a(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName).build()) - .b(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName).build()); + .a(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName) + .setFilter(QueryBuilders.rangeQuery(fieldName).lt(10)).build()) + .b(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName) + .setFilter(QueryBuilders.rangeQuery(fieldName).gte(10)).build()); } @Override @@ -71,11 +77,18 @@ protected ScriptService getMockScriptService() { LeafDocLookup leafDocLookup = (LeafDocLookup) vars.get("doc"); String fieldname = (String) vars.get("fieldname"); ScriptDocValues scriptDocValues = leafDocLookup.get(fieldname); - double val = ((Number) scriptDocValues.get(0)).doubleValue(); - if (val == 1) { - val += 0.0000001; + return ((Number) scriptDocValues.get(0)).doubleValue() + 0.5; + }); + + scripts.put(TERM_FILTERING, vars -> { + LeafDocLookup leafDocLookup = (LeafDocLookup) vars.get("doc"); + int term = (Integer) vars.get("term"); + ScriptDocValues termDocValues = leafDocLookup.get("term"); + int currentTerm = ((Number) termDocValues.get(0)).intValue(); + if (currentTerm == term) { + return ((Number) leafDocLookup.get("field").get(0)).doubleValue(); } - return val + 0.5; + return null; }); MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, @@ -134,6 +147,26 @@ public void testMultiplePairedValues() { ex.getMessage()); } + public void testSameFieldAndNoFilters() { + TTestType tTestType = randomFrom(TTestType.values()); + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER); + fieldType.setName("field"); + TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder("t_test") + .a(new MultiValuesSourceFieldConfig.Builder().setFieldName("field").setMissing(100).build()) + .b(new MultiValuesSourceFieldConfig.Builder().setFieldName("field").setMissing(100).build()) + .testType(tTestType); + + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> + testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + iw.addDocument(singleton(new SortedNumericDocValuesField("field", 102))); + iw.addDocument(singleton(new SortedNumericDocValuesField("field", 99))); + }, tTest -> fail("Should have thrown exception"), fieldType) + ); + assertEquals( + "The same field [field] is used for both population but no filters are specified.", + ex.getMessage()); + } + public void testMultipleUnpairedValues() throws IOException { TTestType tTestType = randomFrom(TTestType.HETEROSCEDASTIC, TTestType.HOMOSCEDASTIC); testCase(new MatchAllDocsQuery(), tTestType, iw -> { @@ -143,6 +176,15 @@ public void testMultipleUnpairedValues() throws IOException { }, tTest -> assertEquals(tTestType == TTestType.HETEROSCEDASTIC ? 0.0607303911 : 0.01718374671, tTest.getValue(), 0.000001)); } + public void testUnpairedValuesWithFilters() throws IOException { + TTestType tTestType = randomFrom(TTestType.HETEROSCEDASTIC, TTestType.HOMOSCEDASTIC); + testCase(new MatchAllDocsQuery(), tTestType, iw -> { + iw.addDocument(asList(new SortedNumericDocValuesField("a", 102), new SortedNumericDocValuesField("a", 103), + new SortedNumericDocValuesField("b", 89))); + iw.addDocument(asList(new SortedNumericDocValuesField("a", 99), new SortedNumericDocValuesField("b", 93))); + }, tTest -> assertEquals(tTestType == TTestType.HETEROSCEDASTIC ? 0.0607303911 : 0.01718374671, tTest.getValue(), 0.000001)); + } + public void testMissingValues() throws IOException { TTestType tTestType = randomFrom(TTestType.values()); testCase(new MatchAllDocsQuery(), tTestType, iw -> { @@ -426,12 +468,12 @@ public void testScript() throws IOException { a(fieldInA ? a : b).b(fieldInA ? b : a).testType(tTestType); testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { - iw.addDocument(singleton(new NumericDocValuesField("field", 1))); - iw.addDocument(singleton(new NumericDocValuesField("field", 2))); - iw.addDocument(singleton(new NumericDocValuesField("field", 3))); - }, (Consumer) tTest -> { - assertEquals(tTestType == TTestType.PAIRED ? 0 : 0.5733922538, tTest.getValue(), 0.000001); - }, fieldType); + iw.addDocument(singleton(new NumericDocValuesField("field", 1))); + iw.addDocument(singleton(new NumericDocValuesField("field", 2))); + iw.addDocument(singleton(new NumericDocValuesField("field", 3))); + }, (Consumer) tTest -> { + assertEquals(tTestType == TTestType.PAIRED ? 0 : 0.5733922538, tTest.getValue(), 0.000001); + }, fieldType); } public void testPaired() throws IOException { @@ -484,7 +526,6 @@ public void testHomoscedastic() throws IOException { }, fieldType1, fieldType2); } - public void testHeteroscedastic() throws IOException { MappedFieldType fieldType1 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER); fieldType1.setName("a"); @@ -512,6 +553,93 @@ public void testHeteroscedastic() throws IOException { }, fieldType1, fieldType2); } + public void testFiltered() throws IOException { + TTestType tTestType = randomFrom(TTestType.values()); + MappedFieldType fieldType1 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER); + fieldType1.setName("a"); + MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER); + fieldType2.setName("b"); + TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder("t_test") + .a(new MultiValuesSourceFieldConfig.Builder().setFieldName("a").setFilter(QueryBuilders.termQuery("b", 1)).build()) + .b(new MultiValuesSourceFieldConfig.Builder().setFieldName("a").setFilter(QueryBuilders.termQuery("b", 2)).build()) + .testType(tTestType); + int tails = randomIntBetween(1, 2); + if (tails == 1 || randomBoolean()) { + aggregationBuilder.tails(tails); + } + CheckedConsumer buildIndex = iw -> { + iw.addDocument(asList(new NumericDocValuesField("a", 102), new IntPoint("b", 1))); + iw.addDocument(asList(new NumericDocValuesField("a", 99), new IntPoint("b", 1))); + iw.addDocument(asList(new NumericDocValuesField("a", 111), new IntPoint("b", 1))); + iw.addDocument(asList(new NumericDocValuesField("a", 97), new IntPoint("b", 1))); + iw.addDocument(asList(new NumericDocValuesField("a", 101), new IntPoint("b", 1))); + iw.addDocument(asList(new NumericDocValuesField("a", 99), new IntPoint("b", 1))); + + iw.addDocument(asList(new NumericDocValuesField("a", 89), new IntPoint("b", 2))); + iw.addDocument(asList(new NumericDocValuesField("a", 93), new IntPoint("b", 2))); + iw.addDocument(asList(new NumericDocValuesField("a", 72), new IntPoint("b", 2))); + iw.addDocument(asList(new NumericDocValuesField("a", 98), new IntPoint("b", 2))); + iw.addDocument(asList(new NumericDocValuesField("a", 102), new IntPoint("b", 2))); + iw.addDocument(asList(new NumericDocValuesField("a", 98), new IntPoint("b", 2))); + + iw.addDocument(asList(new NumericDocValuesField("a", 189), new IntPoint("b", 3))); + iw.addDocument(asList(new NumericDocValuesField("a", 193), new IntPoint("b", 3))); + iw.addDocument(asList(new NumericDocValuesField("a", 172), new IntPoint("b", 3))); + iw.addDocument(asList(new NumericDocValuesField("a", 198), new IntPoint("b", 3))); + iw.addDocument(asList(new NumericDocValuesField("a", 1102), new IntPoint("b", 3))); + iw.addDocument(asList(new NumericDocValuesField("a", 198), new IntPoint("b", 3))); + }; + if (tTestType == TTestType.PAIRED) { + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> + testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, tTest -> fail("Should have thrown exception"), + fieldType1, fieldType2) + ); + assertEquals("Paired t-test doesn't support filters", ex.getMessage()); + } else { + testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, (Consumer) ttest -> { + if (tTestType == TTestType.HOMOSCEDASTIC) { + assertEquals(0.03928288693 * tails, ttest.getValue(), 0.00001); + } else { + assertEquals(0.04538666214 * tails, ttest.getValue(), 0.00001); + } + }, fieldType1, fieldType2); + } + } + + public void testFilterByFilterOrScript() throws IOException { + boolean fieldInA = randomBoolean(); + TTestType tTestType = randomFrom(TTestType.HOMOSCEDASTIC, TTestType.HETEROSCEDASTIC); + + MappedFieldType fieldType1 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER); + fieldType1.setName("field"); + MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER); + fieldType2.setName("term"); + + boolean filterTermOne = randomBoolean(); + + MultiValuesSourceFieldConfig.Builder a = new MultiValuesSourceFieldConfig.Builder().setFieldName("field").setFilter( + QueryBuilders.termQuery("term", filterTermOne? 1 : 2) + ); + MultiValuesSourceFieldConfig.Builder b = new MultiValuesSourceFieldConfig.Builder().setScript( + new Script(ScriptType.INLINE, MockScriptEngine.NAME, TERM_FILTERING, Collections.singletonMap("term", filterTermOne? 2 : 1)) + ); + + TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder("t_test"). + a(fieldInA ? a.build() : b.build()).b(fieldInA ? b.build() : a.build()).testType(tTestType); + + testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + iw.addDocument(asList(new NumericDocValuesField("field", 1), new IntPoint("term", 1), new NumericDocValuesField("term", 1))); + iw.addDocument(asList(new NumericDocValuesField("field", 2), new IntPoint("term", 1), new NumericDocValuesField("term", 1))); + iw.addDocument(asList(new NumericDocValuesField("field", 3), new IntPoint("term", 1), new NumericDocValuesField("term", 1))); + + iw.addDocument(asList(new NumericDocValuesField("field", 4), new IntPoint("term", 2), new NumericDocValuesField("term", 2))); + iw.addDocument(asList(new NumericDocValuesField("field", 5), new IntPoint("term", 2), new NumericDocValuesField("term", 2))); + iw.addDocument(asList(new NumericDocValuesField("field", 6), new IntPoint("term", 2), new NumericDocValuesField("term", 2))); + }, (Consumer) tTest -> { + assertEquals(0.02131164113, tTest.getValue(), 0.000001); + }, fieldType1, fieldType2); + } + private void testCase(Query query, TTestType type, CheckedConsumer buildIndex, Consumer verify) throws IOException { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/analytics/t_test.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/analytics/t_test.yml new file mode 100644 index 0000000000000..081c6e8977c04 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/analytics/t_test.yml @@ -0,0 +1,72 @@ +--- +setup: + - do: + bulk: + index: test + refresh: true + body: + - '{"index": {}}' + - '{"v1": 15.2, "v2": 15.9, "str": "a"}' + - '{"index": {}}' + - '{"v1": 15.3, "v2": 15.9, "str": "a"}' + - '{"index": {}}' + - '{"v1": 16.0, "v2": 15.2, "str": "b"}' + - '{"index": {}}' + - '{"v1": 15.1, "v2": 15.5, "str": "b"}' +--- +"heteroscedastic t-test": + - do: + search: + size: 0 + index: "test" + body: + aggs: + ttest: + t_test: + a: + field: v1 + b: + field: v2 + - match: { aggregations.ttest.value: 0.43066659210472646 } + +--- +"paired t-test": + - do: + search: + size: 0 + index: "test" + body: + aggs: + ttest: + t_test: + a: + field: v1 + b: + field: v2 + type: paired + + - match: { aggregations.ttest.value: 0.5632529432617406 } + +--- +"homoscedastic t-test with filters": + - do: + search: + size: 0 + index: "test" + body: + aggs: + ttest: + t_test: + a: + field: v1 + filter: + term: + str.keyword: a + b: + field: v1 + filter: + term: + str.keyword: b + type: homoscedastic + + - match: { aggregations.ttest.value: 0.5757355806262943 }