From 1b9b48b751188574f8b8fcfe599824d716ed45e2 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 30 May 2019 09:48:29 -0500 Subject: [PATCH] [ML] [Data Frame] add support for weighted_avg agg (#42646) --- .../integration/DataFramePivotRestIT.java | 39 +++++++++++++++++++ .../transforms/pivot/Aggregations.java | 1 + .../transforms/pivot/SchemaUtil.java | 3 +- .../transforms/pivot/AggregationsTests.java | 4 ++ .../transforms/pivot/PivotTests.java | 10 +++++ 5 files changed, 56 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java b/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java index a0bec6ec13c34..3c661a0f4aca4 100644 --- a/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java +++ b/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java @@ -473,6 +473,45 @@ public void testPivotWithGeoCentroidAgg() throws Exception { assertEquals((4 + 15), Double.valueOf(latlon[1]), 0.000001); } + public void testPivotWithWeightedAvgAgg() throws Exception { + String transformId = "weightedAvgAggTransform"; + String dataFrameIndex = "weighted_avg_pivot_reviews"; + setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex); + + final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId, + BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS); + + String config = "{" + + " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"}," + + " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},"; + + config += " \"pivot\": {" + + " \"group_by\": {" + + " \"reviewer\": {" + + " \"terms\": {" + + " \"field\": \"user_id\"" + + " } } }," + + " \"aggregations\": {" + + " \"avg_rating\": {" + + " \"weighted_avg\": {" + + " \"value\": {\"field\": \"stars\"}," + + " \"weight\": {\"field\": \"stars\"}" + + "} } } }" + + "}"; + + createDataframeTransformRequest.setJsonEntity(config); + Map createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest)); + assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE)); + + startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS); + assertTrue(indexExists(dataFrameIndex)); + + Map searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4"); + assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult)); + Number actual = (Number) ((List) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0); + assertEquals(4.47169811, actual.doubleValue(), 0.000001); + } + private void assertOnePivotValue(String query, double expected) throws IOException { Map searchResult = getAsMap(query); diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java index 615c9b2e8d2e6..4e74f9085e3a6 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Aggregations.java @@ -37,6 +37,7 @@ enum AggregationType { SUM("sum", SOURCE), GEO_CENTROID("geo_centroid", "geo_point"), SCRIPTED_METRIC("scripted_metric", DYNAMIC), + WEIGHTED_AVG("weighted_avg", DYNAMIC), BUCKET_SCRIPT("bucket_script", DYNAMIC); private final String aggregationType; diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java index 304f35b8c4c8a..4ac77c38f7d5f 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/SchemaUtil.java @@ -17,6 +17,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder; import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig; @@ -77,7 +78,7 @@ public static void deduceMappings(final Client client, ValuesSourceAggregationBuilder valueSourceAggregation = (ValuesSourceAggregationBuilder) agg; aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field()); aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType()); - } else if(agg instanceof ScriptedMetricAggregationBuilder) { + } else if(agg instanceof ScriptedMetricAggregationBuilder || agg instanceof MultiValuesSourceAggregationBuilder) { aggregationTypes.put(agg.getName(), agg.getType()); } else { // execution should not reach this point diff --git a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationsTests.java b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationsTests.java index 8443699430a2a..ace42cb65fcaf 100644 --- a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationsTests.java +++ b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationsTests.java @@ -49,5 +49,9 @@ public void testResolveTargetMapping() { // bucket_script assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", null)); assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", "int")); + + // weighted_avg + assertEquals("_dynamic", Aggregations.resolveTargetMapping("weighted_avg", null)); + assertEquals("_dynamic", Aggregations.resolveTargetMapping("weighted_avg", "double")); } } diff --git a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/PivotTests.java b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/PivotTests.java index 20ea84502ed82..d54cbad97f726 100644 --- a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/PivotTests.java +++ b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/pivot/PivotTests.java @@ -215,6 +215,16 @@ private AggregationConfig getAggregationConfig(String agg) throws IOException { "\"buckets_path\":{\"param_1\":\"other_bucket\"}," + "\"script\":\"return params.param_1\"}}}"); } + if (agg.equals(AggregationType.WEIGHTED_AVG.getName())) { + return parseAggregations("{\n" + + "\"pivot_weighted_avg\": {\n" + + " \"weighted_avg\": {\n" + + " \"value\": {\"field\": \"values\"},\n" + + " \"weight\": {\"field\": \"weights\"}\n" + + " }\n" + + "}\n" + + "}"); + } return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n" + " }\n" + " }" + "}"); }