Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WeightedAvg metric aggregation #31037

Merged
merged 15 commits into from
Jul 23, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ As a formula, a weighted average is the `∑(value * weight) / ∑(weight)`

A regular average can be thought of as a weighted average where every value has an implicit weight of `1`.

[NOTE]
======
While multiple values-per-field are allowed, only one weight is allowed. If the aggregation encounters
a document that has more than one weight (e.g. the weight field is a multi-valued field) it will throw an exception.
If you have this situation, you will need to specify a `script` for the weight field, and use the script
to combine the multiple values into a single value to be used.

This single weight will be applied independently to each value extracted from the `value` field.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if its worth having an example of a single weight being applied to each value independantly to help solidify what we mean?

======


.`weighted_avg` Parameters
|===
Expand All @@ -29,7 +39,6 @@ The `value` and `weight` objects have per-field specific configuration:
|Parameter Name |Description |Required |Default Value
|`field` | The field that values should be extracted from |Required |
|`missing` | A value to use if the field is missing entirely |Optional |
|`multi` | If a document has multiple values for the field, how should the values be combined |Optional | `avg`
|`script` | A script which provides the values for the document. This is mutually exclusive with `field` |Optional
|===

Expand All @@ -38,7 +47,6 @@ The `value` and `weight` objects have per-field specific configuration:
|Parameter Name |Description |Required |Default Value
|`field` | The field that weights should be extracted from |Required |
|`missing` | A weight to use if the field is missing entirely |Optional |
|`multi` | If a document has multiple weights for the field, how should the weights be combined |Optional | `avg`
|`script` | A script which provides the weights for the document. This is mutually exclusive with `field` |Optional
|===

Expand Down Expand Up @@ -148,38 +156,3 @@ POST /exams/_search
// CONSOLE
// TEST[setup:exams]

==== Multi-value mode

If a document has multiple values, you can configure the `multi` mode of both `value` and `weight`. This controls
how the multiple values should be combined when calculating the average. Acceptable values are:

- `avg`: average the multiple values together
- `min`: use the minimum value
- `max`: use the maximum value
- `sum`: sum all the values together

The default if unspecified is `avg`.

[source,js]
--------------------------------------------------
POST /exams/_search
{
"size": 0,
"aggs" : {
"weighted_grade": {
"weighted_avg": {
"value": {
"field": "grade",
"multi": "avg"
},
"weight": {
"field": "weight",
"multi": "min"
}
}
}
}
}
--------------------------------------------------
// CONSOLE
// TEST[setup:exams]
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceConfig;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceParseHelper;
import org.elasticsearch.search.aggregations.support.ValueType;
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
Expand Down Expand Up @@ -99,10 +99,10 @@ protected void innerWriteTo(StreamOutput out) {

@Override
protected MultiValuesSourceAggregatorFactory<Numeric, ?> innerBuild(SearchContext context,
MultiValuesSourceConfig<Numeric> configs,
DocValueFormat format,
AggregatorFactory<?> parent,
Builder subFactoriesBuilder) throws IOException {
Map<String, ValuesSourceConfig<Numeric>> configs,
DocValueFormat format,
AggregatorFactory<?> parent,
Builder subFactoriesBuilder) throws IOException {
return new WeightedAvgAggregatorFactory(name, configs, format, context, parent, subFactoriesBuilder, metaData);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.index.fielddata.NumericDoubleValues;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
Expand Down Expand Up @@ -77,8 +78,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
return LeafBucketCollector.NO_OP_COLLECTOR;
}
final BigArrays bigArrays = context.bigArrays();
final NumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx);
final NumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), 1.0, ctx);
final SortedNumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx);
final SortedNumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), ctx);

return new LeafBucketCollectorBase(sub, docValues) {
@Override
Expand All @@ -88,13 +89,23 @@ public void collect(int doc, long bucket) throws IOException {
sumCompensations = bigArrays.grow(sumCompensations, bucket + 1);
weightCompensations = bigArrays.grow(weightCompensations, bucket + 1);

if (docValues.advanceExact(doc)) {
boolean advanced = docWeights.advanceExact(doc);
assert advanced;
final double weight = docWeights.doubleValue();

kahanSum(docValues.doubleValue() * weight, sums, sumCompensations, bucket);
kahanSum(weight, weights, weightCompensations, bucket);
if (docValues.advanceExact(doc) && docWeights.advanceExact(doc)) {
if (docWeights.docValueCount() > 1) {
throw new AggregationExecutionException("Encountered more than one weight for a " +
"single document. Use a script to combine multiple weights-per-doc into a single value.");
}
// There should always be one weight if advanceExact lands us here, either
// a real weight or a `missing` value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: missing value -> missing weight

assert docWeights.docValueCount() == 1;
final double weight = docWeights.nextValue();

final int numValues = docValues.docValueCount();
assert numValues > 0;

for (int i = 0; i < numValues; i++) {
kahanSum(docValues.nextValue() * weight, sums, sumCompensations, bucket);
kahanSum(weight, weights, weightCompensations, bucket);
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.MultiValuesSource;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceConfig;
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
Expand All @@ -36,9 +36,9 @@

public class WeightedAvgAggregatorFactory extends MultiValuesSourceAggregatorFactory<Numeric, WeightedAvgAggregatorFactory> {

public WeightedAvgAggregatorFactory(String name, MultiValuesSourceConfig<Numeric> configs,
DocValueFormat format,
SearchContext context, AggregatorFactory<?> parent, AggregatorFactories.Builder subFactoriesBuilder,
public WeightedAvgAggregatorFactory(String name, Map<String, ValuesSourceConfig<Numeric>> configs,
DocValueFormat format, SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
super(name, configs, format, context, parent, subFactoriesBuilder, metaData);
}
Expand All @@ -50,8 +50,9 @@ protected Aggregator createUnmapped(Aggregator parent, List<PipelineAggregator>
}

@Override
protected Aggregator doCreateInternal(MultiValuesSourceConfig<Numeric> configs, DocValueFormat format, Aggregator parent,
boolean collectsFromSingleBucket, List<PipelineAggregator> pipelineAggregators,
protected Aggregator doCreateInternal(Map<String, ValuesSourceConfig<Numeric>> configs, DocValueFormat format,
Aggregator parent, boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
MultiValuesSource.NumericMultiValuesSource numericMultiVS
= new MultiValuesSource.NumericMultiValuesSource(configs, context.getQueryShardContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,107 +19,75 @@
package org.elasticsearch.search.aggregations.support;

import org.apache.lucene.index.LeafReaderContext;
import org.elasticsearch.index.fielddata.FieldData;
import org.elasticsearch.index.fielddata.NumericDoubleValues;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.search.MultiValueMode;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/**
* Class to encapsulate a set of ValuesSource objects labeled by field name
*/
public abstract class MultiValuesSource <VS extends ValuesSource> {

public static class Wrapper<VS> {
private MultiValueMode multiValueMode;
private VS valueSource;

public Wrapper(MultiValueMode multiValueMode, VS value) {
this.multiValueMode = multiValueMode;
this.valueSource = value;
}

public MultiValueMode getMultiValueMode() {
return multiValueMode;
}

public VS getValueSource() {
return valueSource;
}
}

protected Map<String, Wrapper<VS>> values;
protected Map<String, VS> values;

public static class NumericMultiValuesSource extends MultiValuesSource<ValuesSource.Numeric> {
public NumericMultiValuesSource(MultiValuesSourceConfig<ValuesSource.Numeric> valuesSourceConfigs,
public NumericMultiValuesSource(Map<String, ValuesSourceConfig<ValuesSource.Numeric>> valuesSourceConfigs,
QueryShardContext context) throws IOException {
values = new HashMap<>(valuesSourceConfigs.getMap().size());
for (Map.Entry<String, MultiValuesSourceConfig.Wrapper<ValuesSource.Numeric>> entry : valuesSourceConfigs.getMap().entrySet()) {
values.put(entry.getKey(), new Wrapper<>(entry.getValue().getMulti(),
entry.getValue().getConfig().toValuesSource(context)));
values = new HashMap<>(valuesSourceConfigs.size());
for (Map.Entry<String, ValuesSourceConfig<ValuesSource.Numeric>> entry : valuesSourceConfigs.entrySet()) {
values.put(entry.getKey(), entry.getValue().toValuesSource(context));
}
}

public NumericDoubleValues getField(String fieldName, LeafReaderContext ctx) throws IOException {
Wrapper<ValuesSource.Numeric> wrapper = values.get(fieldName);
if (wrapper == null) {
public SortedNumericDoubleValues getField(String fieldName, LeafReaderContext ctx) throws IOException {
ValuesSource.Numeric value = values.get(fieldName);
if (value == null) {
throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource");
}
return wrapper.getMultiValueMode().select(wrapper.getValueSource().doubleValues(ctx));
}

public NumericDoubleValues getField(String fieldName, double defaultValue, LeafReaderContext ctx) throws IOException {
Wrapper<ValuesSource.Numeric> wrapper = values.get(fieldName);
if (wrapper == null) {
throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource");
}
return FieldData.replaceMissing(wrapper.getMultiValueMode().select(wrapper.getValueSource().doubleValues(ctx)), defaultValue);
return value.doubleValues(ctx);
}
}

public static class BytesMultiValuesSource extends MultiValuesSource<ValuesSource.Bytes> {
public BytesMultiValuesSource(MultiValuesSourceConfig<ValuesSource.Bytes> valuesSourceConfigs,
public BytesMultiValuesSource(Map<String, ValuesSourceConfig<ValuesSource.Bytes>> valuesSourceConfigs,
QueryShardContext context) throws IOException {
values = new HashMap<>(valuesSourceConfigs.getMap().size());
for (Map.Entry<String, MultiValuesSourceConfig.Wrapper<ValuesSource.Bytes>> entry : valuesSourceConfigs.getMap().entrySet()) {
values.put(entry.getKey(), new Wrapper<>(entry.getValue().getMulti(),
entry.getValue().getConfig().toValuesSource(context)));
values = new HashMap<>(valuesSourceConfigs.size());
for (Map.Entry<String, ValuesSourceConfig<ValuesSource.Bytes>> entry : valuesSourceConfigs.entrySet()) {
values.put(entry.getKey(), entry.getValue().toValuesSource(context));
}
}

public Object getField(String fieldName, LeafReaderContext ctx) throws IOException {
Wrapper<ValuesSource.Bytes> wrapper = values.get(fieldName);
if (wrapper == null) {
ValuesSource.Bytes value = values.get(fieldName);
if (value == null) {
throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource");
}
return wrapper.getValueSource().bytesValues(ctx);
return value.bytesValues(ctx);
}
}

public static class GeoPointValuesSource extends MultiValuesSource<ValuesSource.GeoPoint> {
public GeoPointValuesSource(MultiValuesSourceConfig<ValuesSource.GeoPoint> valuesSourceConfigs,
public GeoPointValuesSource(Map<String, ValuesSourceConfig<ValuesSource.GeoPoint>> valuesSourceConfigs,
QueryShardContext context) throws IOException {
values = new HashMap<>(valuesSourceConfigs.getMap().size());
for (Map.Entry<String, MultiValuesSourceConfig.Wrapper<ValuesSource.GeoPoint>> entry : valuesSourceConfigs.getMap().entrySet()){
values.put(entry.getKey(), new Wrapper<>(entry.getValue().getMulti(),
entry.getValue().getConfig().toValuesSource(context)));
values = new HashMap<>(valuesSourceConfigs.size());
for (Map.Entry<String, ValuesSourceConfig<ValuesSource.GeoPoint>> entry : valuesSourceConfigs.entrySet()) {
values.put(entry.getKey(), entry.getValue().toValuesSource(context));
}
}
}


public boolean needsScores() {
return values.values().stream().anyMatch(vsWrapper -> vsWrapper.getValueSource().needsScores());
return values.values().stream().anyMatch(ValuesSource::needsScores);
}

public String[] fieldNames() {
return values.keySet().toArray(new String[0]);
}

public boolean areValuesSourcesEmpty() {
return values.values().stream().allMatch(vsWrapper -> vsWrapper.getValueSource() == null);
return values.values().stream().allMatch(Objects::isNull);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ public String format() {
AggregatorFactories.Builder subFactoriesBuilder) throws IOException {
ValueType finalValueType = this.valueType != null ? this.valueType : targetValueType;

MultiValuesSourceConfig<VS> configs = new MultiValuesSourceConfig<>();
Map<String, ValuesSourceConfig<VS>> configs = new HashMap<>(fields.size());
fields.forEach((key, value) -> {
ValuesSourceConfig<VS> config = ValuesSourceConfig.resolve(context.getQueryShardContext(), finalValueType,
value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format);
configs.addField(key, config, value.getMulti());
configs.put(key, config);
});
DocValueFormat docValueFormat = resolveFormat(format, finalValueType);
return innerBuild(context, configs, docValueFormat, parent, subFactoriesBuilder);
Expand All @@ -209,7 +209,7 @@ private static DocValueFormat resolveFormat(@Nullable String format, @Nullable V
}

protected abstract MultiValuesSourceAggregatorFactory<VS, ?> innerBuild(SearchContext context,
MultiValuesSourceConfig<VS> configs, DocValueFormat format, AggregatorFactory<?> parent,
Map<String, ValuesSourceConfig<VS>> configs, DocValueFormat format, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder) throws IOException;


Expand Down
Loading