Skip to content

Commit

Permalink
Showing 14 changed files with 510 additions and 83 deletions.
14 changes: 7 additions & 7 deletions docs/build.gradle
Original file line number Diff line number Diff line change
@@ -548,7 +548,7 @@ buildRestTests.setups['node_upgrade'] = '''
number_of_replicas: 1
mappings:
properties:
name:
group:
type: keyword
startup_time_before:
type: long
@@ -560,17 +560,17 @@ buildRestTests.setups['node_upgrade'] = '''
refresh: true
body: |
{"index":{}}
{"name": "A", "startup_time_before": 102, "startup_time_after": 89}
{"group": "A", "startup_time_before": 102, "startup_time_after": 89}
{"index":{}}
{"name": "B", "startup_time_before": 99, "startup_time_after": 93}
{"group": "A", "startup_time_before": 99, "startup_time_after": 93}
{"index":{}}
{"name": "C", "startup_time_before": 111, "startup_time_after": 72}
{"group": "A", "startup_time_before": 111, "startup_time_after": 72}
{"index":{}}
{"name": "D", "startup_time_before": 97, "startup_time_after": 98}
{"group": "B", "startup_time_before": 97, "startup_time_after": 98}
{"index":{}}
{"name": "E", "startup_time_before": 101, "startup_time_after": 102}
{"group": "B", "startup_time_before": 101, "startup_time_after": 102}
{"index":{}}
{"name": "F", "startup_time_before": 99, "startup_time_after": 98}'''
{"group": "B", "startup_time_before": 99, "startup_time_after": 98}'''

// Used by iprange agg
buildRestTests.setups['iprange'] = '''
75 changes: 69 additions & 6 deletions docs/reference/aggregations/metrics/t-test-aggregation.asciidoc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[role="xpack"]
[testenv="basic"]
[[search-aggregations-metrics-ttest-aggregation]]
=== TTest Aggregation
=== T-Test Aggregation

A `t_test` metrics aggregation that performs a statistical hypothesis test in which the test statistic follows a Student's t-distribution
under the null hypothesis on numeric values extracted from the aggregated documents or generated by provided scripts. In practice, this
@@ -43,8 +43,8 @@ GET node_upgrade/_search
}
--------------------------------------------------
// TEST[setup:node_upgrade]
<1> The field `startup_time_before` must be a numeric field
<2> The field `startup_time_after` must be a numeric field
<1> The field `startup_time_before` must be a numeric field.
<2> The field `startup_time_after` must be a numeric field.
<3> Since we have data from the same nodes, we are using paired t-test.

The response will return the p-value or probability value for the test. It is the probability of obtaining results at least as extreme as
@@ -74,6 +74,69 @@ The `t_test` aggregation supports unpaired and paired two-sample t-tests. The ty
`"type": "homoscedastic"`:: performs two-sample equal variance test
`"type": "heteroscedastic"`:: performs two-sample unequal variance test (this is default)

==== Filters

It is also possible to run unpaired t-test on different sets of records using filters. For example, if we want to test the difference
of startup times before upgrade between two different groups of nodes, we use the same field `startup_time_before` by separate groups of
nodes using terms filters on the group name field:

[source,console]
--------------------------------------------------
GET node_upgrade/_search
{
"size" : 0,
"aggs" : {
"startup_time_ttest" : {
"t_test" : {
"a" : {
"field" : "startup_time_before", <1>
"filter" : {
"term" : {
"group" : "A" <2>
}
}
},
"b" : {
"field" : "startup_time_before", <3>
"filter" : {
"term" : {
"group" : "B" <4>
}
}
},
"type" : "heteroscedastic" <5>
}
}
}
}
--------------------------------------------------
// TEST[setup:node_upgrade]
<1> The field `startup_time_before` must be a numeric field.
<2> Any query that separates two groups can be used here.
<3> We are using the same field
<4> but we are using different filters.
<5> Since we have data from different nodes, we cannot use paired t-test.


[source,console-result]
--------------------------------------------------
{
...
"aggregations": {
"startup_time_ttest": {
"value": 0.2981858007281437 <1>
}
}
}
--------------------------------------------------
// TESTRESPONSE[s/\.\.\./"took": $body.took,"timed_out": false,"_shards": $body._shards,"hits": $body.hits,/]
<1> The p-value.

In this example, we are using the same fields for both populations. However this is not a requirement and different fields and even
combination of fields and scripts can be used. Populations don't have to be in the same index either. If data sets are located in different
indices, the term filter on the <<mapping-index-field,`_index`>> field can be used to select populations.

==== Script

The `t_test` metric supports scripting. For example, if we need to adjust out load times for the before values, we could use
@@ -108,7 +171,7 @@ GET node_upgrade/_search
// TEST[setup:node_upgrade]

<1> The `field` parameter is replaced with a `script` parameter, which uses the
script to generate values which percentiles are calculated on
<2> Scripting supports parameterized input just like any other script
<3> We can mix scripts and fields
script to generate values which percentiles are calculated on.
<2> Scripting supports parameterized input just like any other script.
<3> We can mix scripts and fields.

Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationBuilder;
@@ -51,8 +52,8 @@ public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationB
ObjectParser.fromBuilder(NAME, WeightedAvgAggregationBuilder::new);
static {
MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC);
MultiValuesSourceParseHelper.declareField(VALUE_FIELD.getPreferredName(), PARSER, true, false);
MultiValuesSourceParseHelper.declareField(WEIGHT_FIELD.getPreferredName(), PARSER, true, false);
MultiValuesSourceParseHelper.declareField(VALUE_FIELD.getPreferredName(), PARSER, true, false, false);
MultiValuesSourceParseHelper.declareField(WEIGHT_FIELD.getPreferredName(), PARSER, true, false, false);
}

public WeightedAvgAggregationBuilder(String name) {
@@ -99,10 +100,11 @@ public BucketCardinality bucketCardinality() {

@Override
protected MultiValuesSourceAggregatorFactory<Numeric> innerBuild(QueryShardContext queryShardContext,
Map<String, ValuesSourceConfig<Numeric>> configs,
DocValueFormat format,
AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException {
Map<String, ValuesSourceConfig<Numeric>> configs,
Map<String, QueryBuilder> filters,
DocValueFormat format,
AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException {
return new WeightedAvgAggregatorFactory(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metadata);
}

Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
@@ -168,13 +169,15 @@ protected final MultiValuesSourceAggregatorFactory<VS> doBuild(QueryShardContext
ValueType finalValueType = this.valueType != null ? this.valueType : targetValueType;

Map<String, ValuesSourceConfig<VS>> configs = new HashMap<>(fields.size());
Map<String, QueryBuilder> filters = new HashMap<>(fields.size());
fields.forEach((key, value) -> {
ValuesSourceConfig<VS> config = ValuesSourceConfig.resolve(queryShardContext, finalValueType,
value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format);
configs.put(key, config);
filters.put(key, value.getFilter());
});
DocValueFormat docValueFormat = resolveFormat(format, finalValueType);
return innerBuild(queryShardContext, configs, docValueFormat, parent, subFactoriesBuilder);
return innerBuild(queryShardContext, configs, filters, docValueFormat, parent, subFactoriesBuilder);
}


@@ -191,6 +194,7 @@ private static DocValueFormat resolveFormat(@Nullable String format, @Nullable V

protected abstract MultiValuesSourceAggregatorFactory<VS> innerBuild(QueryShardContext queryShardContext,
Map<String, ValuesSourceConfig<VS>> configs,
Map<String, QueryBuilder> filters,
DocValueFormat format, AggregatorFactory parent,
Builder subFactoriesBuilder) throws IOException;

Original file line number Diff line number Diff line change
@@ -30,26 +30,30 @@
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.script.Script;

import java.io.IOException;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.util.Objects;
import java.util.function.BiFunction;

public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject {
private String fieldName;
private Object missing;
private Script script;
private ZoneId timeZone;
private final String fieldName;
private final Object missing;
private final Script script;
private final ZoneId timeZone;
private final QueryBuilder filter;

private static final String NAME = "field_config";

public static final BiFunction<Boolean, Boolean, ObjectParser<MultiValuesSourceFieldConfig.Builder, Void>> PARSER
= (scriptable, timezoneAware) -> {
public static final ParseField FILTER = new ParseField("filter");

ObjectParser<MultiValuesSourceFieldConfig.Builder, Void> parser
public static <C> ObjectParser<MultiValuesSourceFieldConfig.Builder, C> parserBuilder(boolean scriptable, boolean timezoneAware,
boolean filtered) {

ObjectParser<MultiValuesSourceFieldConfig.Builder, C> parser
= new ObjectParser<>(MultiValuesSourceFieldConfig.NAME, MultiValuesSourceFieldConfig.Builder::new);

parser.declareString(MultiValuesSourceFieldConfig.Builder::setFieldName, ParseField.CommonFields.FIELD);
@@ -71,14 +75,21 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentObject
}
}, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG);
}

if (filtered) {
parser.declareField(MultiValuesSourceFieldConfig.Builder::setFilter,
(p, context) -> AbstractQueryBuilder.parseInnerQueryBuilder(p),
FILTER, ObjectParser.ValueType.OBJECT);
}
return parser;
};

private MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, ZoneId timeZone) {
protected MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, ZoneId timeZone, QueryBuilder filter) {
this.fieldName = fieldName;
this.missing = missing;
this.script = script;
this.timeZone = timeZone;
this.filter = filter;
}

public MultiValuesSourceFieldConfig(StreamInput in) throws IOException {
@@ -94,6 +105,11 @@ public MultiValuesSourceFieldConfig(StreamInput in) throws IOException {
} else {
this.timeZone = in.readOptionalZoneId();
}
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
} else {
this.filter = null;
}
}

public Object getMissing() {
@@ -112,6 +128,10 @@ public String getFieldName() {
return fieldName;
}

public QueryBuilder getFilter() {
return filter;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
@@ -126,6 +146,9 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeOptionalZoneId(timeZone);
}
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeOptionalNamedWriteable(filter);
}
}

@Override
@@ -143,6 +166,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (timeZone != null) {
builder.field(ParseField.CommonFields.TIME_ZONE.getPreferredName(), timeZone.getId());
}
if (filter != null) {
builder.field(FILTER.getPreferredName());
filter.toXContent(builder, params);
}
builder.endObject();
return builder;
}
@@ -155,12 +182,13 @@ public boolean equals(Object o) {
return Objects.equals(fieldName, that.fieldName)
&& Objects.equals(missing, that.missing)
&& Objects.equals(script, that.script)
&& Objects.equals(timeZone, that.timeZone);
&& Objects.equals(timeZone, that.timeZone)
&& Objects.equals(filter, that.filter);
}

@Override
public int hashCode() {
return Objects.hash(fieldName, missing, script, timeZone);
return Objects.hash(fieldName, missing, script, timeZone, filter);
}

@Override
@@ -173,6 +201,7 @@ public static class Builder {
private Object missing = null;
private Script script = null;
private ZoneId timeZone = null;
private QueryBuilder filter = null;

public String getFieldName() {
return fieldName;
@@ -210,6 +239,11 @@ public Builder setTimeZone(ZoneId timeZone) {
return this;
}

public Builder setFilter(QueryBuilder filter) {
this.filter = filter;
return this;
}

public MultiValuesSourceFieldConfig build() {
if (Strings.isNullOrEmpty(fieldName) && script == null) {
throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName()
@@ -223,7 +257,7 @@ public MultiValuesSourceFieldConfig build() {
"Please specify one or the other.");
}

return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone);
return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone, filter);
}
}
}
Original file line number Diff line number Diff line change
@@ -50,10 +50,10 @@ public static <VS extends ValuesSource, T> void declareCommon(

public static <VS extends ValuesSource, T> void declareField(String fieldName,
AbstractObjectParser<? extends MultiValuesSourceAggregationBuilder<VS, ?>, T> objectParser,
boolean scriptable, boolean timezoneAware) {
boolean scriptable, boolean timezoneAware, boolean filterable) {

objectParser.declareField((o, fieldConfig) -> o.field(fieldName, fieldConfig.build()),
(p, c) -> MultiValuesSourceFieldConfig.PARSER.apply(scriptable, timezoneAware).parse(p, null),
(p, c) -> MultiValuesSourceFieldConfig.parserBuilder(scriptable, timezoneAware, filterable).parse(p, null),
new ParseField(fieldName), ObjectParser.ValueType.OBJECT);
}
}
Original file line number Diff line number Diff line change
@@ -19,30 +19,38 @@

package org.elasticsearch.search.aggregations.support;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;

import java.io.IOException;
import java.time.ZoneId;
import java.util.Collections;

import static org.hamcrest.Matchers.equalTo;

public class MultiValuesSourceFieldConfigTests extends AbstractSerializingTestCase<MultiValuesSourceFieldConfig> {

@Override
protected MultiValuesSourceFieldConfig doParseInstance(XContentParser parser) throws IOException {
return MultiValuesSourceFieldConfig.PARSER.apply(true, true).apply(parser, null).build();
return MultiValuesSourceFieldConfig.parserBuilder(true, true, true).apply(parser, null).build();
}

@Override
protected MultiValuesSourceFieldConfig createTestInstance() {
String field = randomAlphaOfLength(10);
Object missing = randomBoolean() ? randomAlphaOfLength(10) : null;
ZoneId timeZone = randomBoolean() ? randomZone() : null;
QueryBuilder filter = randomBoolean() ? QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)) : null;
return new MultiValuesSourceFieldConfig.Builder()
.setFieldName(field).setMissing(missing).setScript(null).setTimeZone(timeZone).build();
.setFieldName(field).setMissing(missing).setScript(null).setTimeZone(timeZone).setFilter(filter).build();
}

@Override
@@ -60,4 +68,14 @@ public void testBothFieldScript() {
() -> new MultiValuesSourceFieldConfig.Builder().setFieldName("foo").setScript(new Script("foo")).build());
assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be configured. Please specify one or the other."));
}

@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
}

@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
}
}
Original file line number Diff line number Diff line change
@@ -55,7 +55,8 @@ public class TopMetricsAggregationBuilder extends AbstractAggregationBuilder<Top
PARSER.declareField(constructorArg(), (p, n) -> SortBuilder.fromXContent(p), SORT_FIELD,
ObjectParser.ValueType.OBJECT_ARRAY_OR_STRING);
PARSER.declareInt(optionalConstructorArg(), SIZE_FIELD);
ContextParser<Void, MultiValuesSourceFieldConfig.Builder> metricParser = MultiValuesSourceFieldConfig.PARSER.apply(true, false);
ContextParser<Void, MultiValuesSourceFieldConfig.Builder> metricParser =
MultiValuesSourceFieldConfig.parserBuilder(true, false, false);
PARSER.declareObjectArray(constructorArg(), (p, n) -> metricParser.parse(p, null).build(), METRIC_FIELD);
}

Original file line number Diff line number Diff line change
@@ -12,11 +12,13 @@
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.support.FieldContext;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
@@ -41,11 +43,10 @@ public class TTestAggregationBuilder extends MultiValuesSourceAggregationBuilder

static {
MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC);
MultiValuesSourceParseHelper.declareField(A_FIELD.getPreferredName(), PARSER, true, false);
MultiValuesSourceParseHelper.declareField(B_FIELD.getPreferredName(), PARSER, true, false);
MultiValuesSourceParseHelper.declareField(A_FIELD.getPreferredName(), PARSER, true, false, true);
MultiValuesSourceParseHelper.declareField(B_FIELD.getPreferredName(), PARSER, true, false, true);
PARSER.declareString(TTestAggregationBuilder::testType, TYPE_FIELD);
PARSER.declareInt(TTestAggregationBuilder::tails, TAILS_FIELD);

}

private TTestType testType = TTestType.HETEROSCEDASTIC;
@@ -117,10 +118,26 @@ protected void innerWriteTo(StreamOutput out) throws IOException {
protected MultiValuesSourceAggregatorFactory<ValuesSource.Numeric> innerBuild(
QueryShardContext queryShardContext,
Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs,
Map<String, QueryBuilder> filters,
DocValueFormat format,
AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException {
return new TTestAggregatorFactory(name, configs, testType, tails, format, queryShardContext, parent, subFactoriesBuilder, metadata);
QueryBuilder filterA = filters.get(A_FIELD.getPreferredName());
QueryBuilder filterB = filters.get(B_FIELD.getPreferredName());
if (filterA == null && filterB == null) {
FieldContext fieldContextA = configs.get(A_FIELD.getPreferredName()).fieldContext();
FieldContext fieldContextB = configs.get(B_FIELD.getPreferredName()).fieldContext();
if (fieldContextA != null && fieldContextB != null) {
if (fieldContextA.field().equals(fieldContextB.field())) {
throw new IllegalArgumentException("The same field [" + fieldContextA.field() +
"] is used for both population but no filters are specified.");
}
}
}

return new TTestAggregatorFactory(name, configs, testType, tails,
filterA, filterB, format, queryShardContext, parent,
subFactoriesBuilder, metadata);
}

@Override
Original file line number Diff line number Diff line change
@@ -6,8 +6,15 @@

package org.elasticsearch.xpack.analytics.ttest;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationInitializationException;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
@@ -24,14 +31,20 @@ class TTestAggregatorFactory extends MultiValuesSourceAggregatorFactory<ValuesSo

private final TTestType testType;
private final int tails;
private final Query filterA;
private final Query filterB;
private Tuple<Weight, Weight> weights;

TTestAggregatorFactory(String name, Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs, TTestType testType, int tails,
QueryBuilder filterA, QueryBuilder filterB,
DocValueFormat format, QueryShardContext queryShardContext, AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metadata) throws IOException {
super(name, configs, format, queryShardContext, parent, subFactoriesBuilder, metadata);
this.testType = testType;
this.tails = tails;
this.filterA = filterA == null ? null : filterA.toQuery(queryShardContext);
this.filterB = filterB == null ? null : filterB.toQuery(queryShardContext);
}

@Override
@@ -42,9 +55,9 @@ protected Aggregator createUnmapped(SearchContext searchContext,
case PAIRED:
return new PairedTTestAggregator(name, null, tails, format, searchContext, parent, metadata);
case HOMOSCEDASTIC:
return new UnpairedTTestAggregator(name, null, tails, true, format, searchContext, parent, metadata);
return new UnpairedTTestAggregator(name, null, tails, true, this::getWeights, format, searchContext, parent, metadata);
case HETEROSCEDASTIC:
return new UnpairedTTestAggregator(name, null, tails, false, format, searchContext, parent, metadata);
return new UnpairedTTestAggregator(name, null, tails, false, this::getWeights, format, searchContext, parent, metadata);
default:
throw new IllegalArgumentException("Unsupported t-test type " + testType);
}
@@ -64,13 +77,46 @@ protected Aggregator doCreateInternal(SearchContext searchContext,
}
switch (testType) {
case PAIRED:
if (filterA != null || filterB != null) {
throw new IllegalArgumentException("Paired t-test doesn't support filters");
}
return new PairedTTestAggregator(name, numericMultiVS, tails, format, searchContext, parent, metadata);
case HOMOSCEDASTIC:
return new UnpairedTTestAggregator(name, numericMultiVS, tails, true, format, searchContext, parent, metadata);
return new UnpairedTTestAggregator(name, numericMultiVS, tails, true, this::getWeights, format, searchContext, parent,
metadata);
case HETEROSCEDASTIC:
return new UnpairedTTestAggregator(name, numericMultiVS, tails, false, format, searchContext, parent, metadata);
return new UnpairedTTestAggregator(name, numericMultiVS, tails, false, this::getWeights, format, searchContext,
parent, metadata);
default:
throw new IllegalArgumentException("Unsupported t-test type " + testType);
}
}

/**
* Returns the {@link Weight}s for this filters, creating it if
* necessary. This is done lazily so that the {@link Weight} is only created
* if the aggregation collects documents reducing the overhead of the
* aggregation in the case where no documents are collected.
*
* Note that as aggregations are initialsed and executed in a serial manner,
* no concurrency considerations are necessary here.
*/
public Tuple<Weight, Weight> getWeights() {
if (weights == null) {
weights = new Tuple<>(getWeight(filterA), getWeight(filterB));
}
return weights;
}

public Weight getWeight(Query filter) {
if (filter != null) {
IndexSearcher contextSearcher = queryShardContext.searcher();
try {
return contextSearcher.createWeight(contextSearcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1f);
} catch (IOException e) {
throw new AggregationInitializationException("Failed to initialize filter", e);
}
}
return null;
}
}
Original file line number Diff line number Diff line change
@@ -7,7 +7,11 @@
package org.elasticsearch.xpack.analytics.ttest;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.DocValueFormat;
@@ -20,6 +24,7 @@

import java.io.IOException;
import java.util.Map;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.analytics.ttest.TTestAggregationBuilder.A_FIELD;
import static org.elasticsearch.xpack.analytics.ttest.TTestAggregationBuilder.B_FIELD;
@@ -28,14 +33,16 @@ public class UnpairedTTestAggregator extends TTestAggregator<UnpairedTTestState>
private final TTestStatsBuilder a;
private final TTestStatsBuilder b;
private final boolean homoscedastic;
private final Supplier<Tuple<Weight, Weight>> weightsSupplier;

UnpairedTTestAggregator(String name, MultiValuesSource.NumericMultiValuesSource valuesSources, int tails, boolean homoscedastic,
DocValueFormat format, SearchContext context, Aggregator parent,
Map<String, Object> metadata) throws IOException {
Supplier<Tuple<Weight, Weight>> weightsSupplier, DocValueFormat format, SearchContext context,
Aggregator parent, Map<String, Object> metadata) throws IOException {
super(name, valuesSources, tails, format, context, parent, metadata);
BigArrays bigArrays = context.bigArrays();
a = new TTestStatsBuilder(bigArrays);
b = new TTestStatsBuilder(bigArrays);
this.weightsSupplier = weightsSupplier;
this.homoscedastic = homoscedastic;
}

@@ -67,6 +74,9 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
final CompensatedSum compSumOfSqrA = new CompensatedSum(0, 0);
final CompensatedSum compSumB = new CompensatedSum(0, 0);
final CompensatedSum compSumOfSqrB = new CompensatedSum(0, 0);
final Tuple<Weight, Weight> weights = weightsSupplier.get();
final Bits bitsA = getBits(ctx, weights.v1());
final Bits bitsB = getBits(ctx, weights.v2());

return new LeafBucketCollectorBase(sub, docAValues) {

@@ -82,14 +92,25 @@ private void processValues(int doc, long bucket, SortedNumericDoubleValues docVa

@Override
public void collect(int doc, long bucket) throws IOException {
a.grow(bigArrays, bucket + 1);
b.grow(bigArrays, bucket + 1);
processValues(doc, bucket, docAValues, compSumA, compSumOfSqrA, a);
processValues(doc, bucket, docBValues, compSumB, compSumOfSqrB, b);
if (bitsA == null || bitsA.get(doc)) {
a.grow(bigArrays, bucket + 1);
processValues(doc, bucket, docAValues, compSumA, compSumOfSqrA, a);
}
if (bitsB == null || bitsB.get(doc)) {
processValues(doc, bucket, docBValues, compSumB, compSumOfSqrB, b);
b.grow(bigArrays, bucket + 1);
}
}
};
}

private Bits getBits(LeafReaderContext ctx, Weight weight) throws IOException {
if (weight == null) {
return null;
}
return Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), weight.scorerSupplier(ctx));
}

@Override
public void doClose() {
Releasables.close(a, b);
Original file line number Diff line number Diff line change
@@ -7,19 +7,25 @@
package org.elasticsearch.xpack.analytics.ttest;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.BaseAggregationBuilder;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static java.util.Collections.singletonList;
import static org.hamcrest.Matchers.hasSize;

public class TTestAggregationBuilderTests extends AbstractSerializingTestCase<TTestAggregationBuilder> {
@@ -30,14 +36,6 @@ public void setupName() {
aggregationName = randomAlphaOfLength(10);
}

@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(singletonList(new NamedXContentRegistry.Entry(
BaseAggregationBuilder.class,
new ParseField(TTestAggregationBuilder.NAME),
(p, n) -> TTestAggregationBuilder.PARSER.apply(p, (String) n))));
}

@Override
protected TTestAggregationBuilder doParseInstance(XContentParser parser) throws IOException {
assertSame(XContentParser.Token.START_OBJECT, parser.nextToken());
@@ -52,26 +50,33 @@ protected TTestAggregationBuilder doParseInstance(XContentParser parser) throws

@Override
protected TTestAggregationBuilder createTestInstance() {
MultiValuesSourceFieldConfig aConfig;
MultiValuesSourceFieldConfig.Builder aConfig;
TTestType tTestType = randomFrom(TTestType.values());
if (randomBoolean()) {
aConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("a_field").build();
aConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("a_field");
} else {
aConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10))).build();
aConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10)));
}
MultiValuesSourceFieldConfig bConfig;
MultiValuesSourceFieldConfig.Builder bConfig;
if (randomBoolean()) {
bConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("b_field").build();
bConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("b_field");
} else {
bConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10))).build();
bConfig = new MultiValuesSourceFieldConfig.Builder().setScript(new Script(randomAlphaOfLength(10)));
}
if (tTestType != TTestType.PAIRED && randomBoolean()) {
aConfig.setFilter(QueryBuilders.queryStringQuery(randomAlphaOfLength(10)));
}
if (tTestType != TTestType.PAIRED && randomBoolean()) {
bConfig.setFilter(QueryBuilders.queryStringQuery(randomAlphaOfLength(10)));
}
TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder(aggregationName)
.a(aConfig)
.b(bConfig);
.a(aConfig.build())
.b(bConfig.build());
if (randomBoolean()) {
aggregationBuilder.tails(randomIntBetween(1, 2));
}
if (randomBoolean()) {
aggregationBuilder.testType(randomFrom(TTestType.values()));
if (tTestType != TTestType.HETEROSCEDASTIC || randomBoolean()) {
aggregationBuilder.testType(randomFrom(tTestType));
}
return aggregationBuilder;
}
@@ -80,5 +85,21 @@ protected TTestAggregationBuilder createTestInstance() {
protected Writeable.Reader<TTestAggregationBuilder> instanceReader() {
return TTestAggregationBuilder::new;
}

@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
}

@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.add(new NamedXContentRegistry.Entry(
BaseAggregationBuilder.class,
new ParseField(TTestAggregationBuilder.NAME),
(p, n) -> TTestAggregationBuilder.PARSER.apply(p, (String) n)));
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
}

Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@

package org.elasticsearch.xpack.analytics.ttest;

import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.DirectoryReader;
@@ -21,6 +22,7 @@
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.script.MockScriptEngine;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptEngine;
@@ -56,11 +58,15 @@ public class TTestAggregatorTests extends AggregatorTestCase {
*/
public static final String ADD_HALF_SCRIPT = "add_one";

public static final String TERM_FILTERING = "term_filtering";

@Override
protected AggregationBuilder createAggBuilderForTypeTest(MappedFieldType fieldType, String fieldName) {
return new TTestAggregationBuilder("foo")
.a(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName).build())
.b(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName).build());
.a(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName)
.setFilter(QueryBuilders.rangeQuery(fieldName).lt(10)).build())
.b(new MultiValuesSourceFieldConfig.Builder().setFieldName(fieldName)
.setFilter(QueryBuilders.rangeQuery(fieldName).gte(10)).build());
}

@Override
@@ -71,11 +77,18 @@ protected ScriptService getMockScriptService() {
LeafDocLookup leafDocLookup = (LeafDocLookup) vars.get("doc");
String fieldname = (String) vars.get("fieldname");
ScriptDocValues<?> scriptDocValues = leafDocLookup.get(fieldname);
double val = ((Number) scriptDocValues.get(0)).doubleValue();
if (val == 1) {
val += 0.0000001;
return ((Number) scriptDocValues.get(0)).doubleValue() + 0.5;
});

scripts.put(TERM_FILTERING, vars -> {
LeafDocLookup leafDocLookup = (LeafDocLookup) vars.get("doc");
int term = (Integer) vars.get("term");
ScriptDocValues<?> termDocValues = leafDocLookup.get("term");
int currentTerm = ((Number) termDocValues.get(0)).intValue();
if (currentTerm == term) {
return ((Number) leafDocLookup.get("field").get(0)).doubleValue();
}
return val + 0.5;
return null;
});

MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
@@ -134,6 +147,26 @@ public void testMultiplePairedValues() {
ex.getMessage());
}

public void testSameFieldAndNoFilters() {
TTestType tTestType = randomFrom(TTestType.values());
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType.setName("field");
TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder("t_test")
.a(new MultiValuesSourceFieldConfig.Builder().setFieldName("field").setMissing(100).build())
.b(new MultiValuesSourceFieldConfig.Builder().setFieldName("field").setMissing(100).build())
.testType(tTestType);

IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> {
iw.addDocument(singleton(new SortedNumericDocValuesField("field", 102)));
iw.addDocument(singleton(new SortedNumericDocValuesField("field", 99)));
}, tTest -> fail("Should have thrown exception"), fieldType)
);
assertEquals(
"The same field [field] is used for both population but no filters are specified.",
ex.getMessage());
}

public void testMultipleUnpairedValues() throws IOException {
TTestType tTestType = randomFrom(TTestType.HETEROSCEDASTIC, TTestType.HOMOSCEDASTIC);
testCase(new MatchAllDocsQuery(), tTestType, iw -> {
@@ -143,6 +176,15 @@ public void testMultipleUnpairedValues() throws IOException {
}, tTest -> assertEquals(tTestType == TTestType.HETEROSCEDASTIC ? 0.0607303911 : 0.01718374671, tTest.getValue(), 0.000001));
}

public void testUnpairedValuesWithFilters() throws IOException {
TTestType tTestType = randomFrom(TTestType.HETEROSCEDASTIC, TTestType.HOMOSCEDASTIC);
testCase(new MatchAllDocsQuery(), tTestType, iw -> {
iw.addDocument(asList(new SortedNumericDocValuesField("a", 102), new SortedNumericDocValuesField("a", 103),
new SortedNumericDocValuesField("b", 89)));
iw.addDocument(asList(new SortedNumericDocValuesField("a", 99), new SortedNumericDocValuesField("b", 93)));
}, tTest -> assertEquals(tTestType == TTestType.HETEROSCEDASTIC ? 0.0607303911 : 0.01718374671, tTest.getValue(), 0.000001));
}

public void testMissingValues() throws IOException {
TTestType tTestType = randomFrom(TTestType.values());
testCase(new MatchAllDocsQuery(), tTestType, iw -> {
@@ -426,12 +468,12 @@ public void testScript() throws IOException {
a(fieldInA ? a : b).b(fieldInA ? b : a).testType(tTestType);

testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> {
iw.addDocument(singleton(new NumericDocValuesField("field", 1)));
iw.addDocument(singleton(new NumericDocValuesField("field", 2)));
iw.addDocument(singleton(new NumericDocValuesField("field", 3)));
}, (Consumer<InternalTTest>) tTest -> {
assertEquals(tTestType == TTestType.PAIRED ? 0 : 0.5733922538, tTest.getValue(), 0.000001);
}, fieldType);
iw.addDocument(singleton(new NumericDocValuesField("field", 1)));
iw.addDocument(singleton(new NumericDocValuesField("field", 2)));
iw.addDocument(singleton(new NumericDocValuesField("field", 3)));
}, (Consumer<InternalTTest>) tTest -> {
assertEquals(tTestType == TTestType.PAIRED ? 0 : 0.5733922538, tTest.getValue(), 0.000001);
}, fieldType);
}

public void testPaired() throws IOException {
@@ -484,7 +526,6 @@ public void testHomoscedastic() throws IOException {
}, fieldType1, fieldType2);
}


public void testHeteroscedastic() throws IOException {
MappedFieldType fieldType1 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType1.setName("a");
@@ -512,6 +553,93 @@ public void testHeteroscedastic() throws IOException {
}, fieldType1, fieldType2);
}

public void testFiltered() throws IOException {
TTestType tTestType = randomFrom(TTestType.values());
MappedFieldType fieldType1 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType1.setName("a");
MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType2.setName("b");
TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder("t_test")
.a(new MultiValuesSourceFieldConfig.Builder().setFieldName("a").setFilter(QueryBuilders.termQuery("b", 1)).build())
.b(new MultiValuesSourceFieldConfig.Builder().setFieldName("a").setFilter(QueryBuilders.termQuery("b", 2)).build())
.testType(tTestType);
int tails = randomIntBetween(1, 2);
if (tails == 1 || randomBoolean()) {
aggregationBuilder.tails(tails);
}
CheckedConsumer<RandomIndexWriter, IOException> buildIndex = iw -> {
iw.addDocument(asList(new NumericDocValuesField("a", 102), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 99), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 111), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 97), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 101), new IntPoint("b", 1)));
iw.addDocument(asList(new NumericDocValuesField("a", 99), new IntPoint("b", 1)));

iw.addDocument(asList(new NumericDocValuesField("a", 89), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 93), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 72), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 98), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 102), new IntPoint("b", 2)));
iw.addDocument(asList(new NumericDocValuesField("a", 98), new IntPoint("b", 2)));

iw.addDocument(asList(new NumericDocValuesField("a", 189), new IntPoint("b", 3)));
iw.addDocument(asList(new NumericDocValuesField("a", 193), new IntPoint("b", 3)));
iw.addDocument(asList(new NumericDocValuesField("a", 172), new IntPoint("b", 3)));
iw.addDocument(asList(new NumericDocValuesField("a", 198), new IntPoint("b", 3)));
iw.addDocument(asList(new NumericDocValuesField("a", 1102), new IntPoint("b", 3)));
iw.addDocument(asList(new NumericDocValuesField("a", 198), new IntPoint("b", 3)));
};
if (tTestType == TTestType.PAIRED) {
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, tTest -> fail("Should have thrown exception"),
fieldType1, fieldType2)
);
assertEquals("Paired t-test doesn't support filters", ex.getMessage());
} else {
testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, (Consumer<InternalTTest>) ttest -> {
if (tTestType == TTestType.HOMOSCEDASTIC) {
assertEquals(0.03928288693 * tails, ttest.getValue(), 0.00001);
} else {
assertEquals(0.04538666214 * tails, ttest.getValue(), 0.00001);
}
}, fieldType1, fieldType2);
}
}

public void testFilterByFilterOrScript() throws IOException {
boolean fieldInA = randomBoolean();
TTestType tTestType = randomFrom(TTestType.HOMOSCEDASTIC, TTestType.HETEROSCEDASTIC);

MappedFieldType fieldType1 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType1.setName("field");
MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType2.setName("term");

boolean filterTermOne = randomBoolean();

MultiValuesSourceFieldConfig.Builder a = new MultiValuesSourceFieldConfig.Builder().setFieldName("field").setFilter(
QueryBuilders.termQuery("term", filterTermOne? 1 : 2)
);
MultiValuesSourceFieldConfig.Builder b = new MultiValuesSourceFieldConfig.Builder().setScript(
new Script(ScriptType.INLINE, MockScriptEngine.NAME, TERM_FILTERING, Collections.singletonMap("term", filterTermOne? 2 : 1))
);

TTestAggregationBuilder aggregationBuilder = new TTestAggregationBuilder("t_test").
a(fieldInA ? a.build() : b.build()).b(fieldInA ? b.build() : a.build()).testType(tTestType);

testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> {
iw.addDocument(asList(new NumericDocValuesField("field", 1), new IntPoint("term", 1), new NumericDocValuesField("term", 1)));
iw.addDocument(asList(new NumericDocValuesField("field", 2), new IntPoint("term", 1), new NumericDocValuesField("term", 1)));
iw.addDocument(asList(new NumericDocValuesField("field", 3), new IntPoint("term", 1), new NumericDocValuesField("term", 1)));

iw.addDocument(asList(new NumericDocValuesField("field", 4), new IntPoint("term", 2), new NumericDocValuesField("term", 2)));
iw.addDocument(asList(new NumericDocValuesField("field", 5), new IntPoint("term", 2), new NumericDocValuesField("term", 2)));
iw.addDocument(asList(new NumericDocValuesField("field", 6), new IntPoint("term", 2), new NumericDocValuesField("term", 2)));
}, (Consumer<InternalTTest>) tTest -> {
assertEquals(0.02131164113, tTest.getValue(), 0.000001);
}, fieldType1, fieldType2);
}

private void testCase(Query query, TTestType type,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalTTest> verify) throws IOException {
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
---
setup:
- do:
bulk:
index: test
refresh: true
body:
- '{"index": {}}'
- '{"v1": 15.2, "v2": 15.9, "str": "a"}'
- '{"index": {}}'
- '{"v1": 15.3, "v2": 15.9, "str": "a"}'
- '{"index": {}}'
- '{"v1": 16.0, "v2": 15.2, "str": "b"}'
- '{"index": {}}'
- '{"v1": 15.1, "v2": 15.5, "str": "b"}'
---
"heteroscedastic t-test":
- do:
search:
size: 0
index: "test"
body:
aggs:
ttest:
t_test:
a:
field: v1
b:
field: v2
- match: { aggregations.ttest.value: 0.43066659210472646 }

---
"paired t-test":
- do:
search:
size: 0
index: "test"
body:
aggs:
ttest:
t_test:
a:
field: v1
b:
field: v2
type: paired

- match: { aggregations.ttest.value: 0.5632529432617406 }

---
"homoscedastic t-test with filters":
- do:
search:
size: 0
index: "test"
body:
aggs:
ttest:
t_test:
a:
field: v1
filter:
term:
str.keyword: a
b:
field: v1
filter:
term:
str.keyword: b
type: homoscedastic

- match: { aggregations.ttest.value: 0.5757355806262943 }

0 comments on commit 51c6f69

Please sign in to comment.