Skip to content

Commit

Permalink
[ML] [Data Frame] add support for weighted_avg agg (#42646) (#42714)
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent authored May 30, 2019
1 parent 711de2f commit b5527b3
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,45 @@ public void testPivotWithGeoCentroidAgg() throws Exception {
assertEquals((4 + 15), Double.valueOf(latlon[1]), 0.000001);
}

public void testPivotWithWeightedAvgAgg() throws Exception {
String transformId = "weightedAvgAggTransform";
String dataFrameIndex = "weighted_avg_pivot_reviews";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);

final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

String config = "{"
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";

config += " \"pivot\": {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"weighted_avg\": {"
+ " \"value\": {\"field\": \"stars\"},"
+ " \"weight\": {\"field\": \"stars\"}"
+ "} } } }"
+ "}";

createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));

startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
assertTrue(indexExists(dataFrameIndex));

Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
assertEquals(4.47169811, actual.doubleValue(), 0.000001);
}

private void assertOnePivotValue(String query, double expected) throws IOException {
Map<String, Object> searchResult = getAsMap(query);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ enum AggregationType {
SUM("sum", SOURCE),
GEO_CENTROID("geo_centroid", "geo_point"),
SCRIPTED_METRIC("scripted_metric", DYNAMIC),
WEIGHTED_AVG("weighted_avg", DYNAMIC),
BUCKET_SCRIPT("bucket_script", DYNAMIC);

private final String aggregationType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
Expand Down Expand Up @@ -77,7 +78,7 @@ public static void deduceMappings(final Client client,
ValuesSourceAggregationBuilder<?, ?> valueSourceAggregation = (ValuesSourceAggregationBuilder<?, ?>) agg;
aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field());
aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType());
} else if(agg instanceof ScriptedMetricAggregationBuilder) {
} else if(agg instanceof ScriptedMetricAggregationBuilder || agg instanceof MultiValuesSourceAggregationBuilder) {
aggregationTypes.put(agg.getName(), agg.getType());
} else {
// execution should not reach this point
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,9 @@ public void testResolveTargetMapping() {
// bucket_script
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", "int"));

// weighted_avg
assertEquals("_dynamic", Aggregations.resolveTargetMapping("weighted_avg", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("weighted_avg", "double"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ private AggregationConfig getAggregationConfig(String agg) throws IOException {
"\"buckets_path\":{\"param_1\":\"other_bucket\"}," +
"\"script\":\"return params.param_1\"}}}");
}
if (agg.equals(AggregationType.WEIGHTED_AVG.getName())) {
return parseAggregations("{\n" +
"\"pivot_weighted_avg\": {\n" +
" \"weighted_avg\": {\n" +
" \"value\": {\"field\": \"values\"},\n" +
" \"weight\": {\"field\": \"weights\"}\n" +
" }\n" +
"}\n" +
"}");
}
return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n"
+ " }\n" + " }" + "}");
}
Expand Down

0 comments on commit b5527b3

Please sign in to comment.