Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] [Data Frame] add support for weighted_avg agg #42646

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,45 @@ public void testPivotWithGeoCentroidAgg() throws Exception {
assertEquals((4 + 15), Double.valueOf(latlon[1]), 0.000001);
}

public void testPivotWithWeightedAvgAgg() throws Exception {
String transformId = "weightedAvgAggTransform";
String dataFrameIndex = "weighted_avg_pivot_reviews";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);

final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

String config = "{"
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";

config += " \"pivot\": {"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to do concatenation, can be done in declaration of "String config".

+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"weighted_avg\": {"
+ " \"value\": {\"field\": \"stars\"},"
+ " \"weight\": {\"field\": \"stars\"}"
+ "} } } }"
+ "}";

createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));

startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);
assertTrue(indexExists(dataFrameIndex));

Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
assertEquals(4.47169811, actual.doubleValue(), 0.000001);
}

private void assertOnePivotValue(String query, double expected) throws IOException {
Map<String, Object> searchResult = getAsMap(query);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ enum AggregationType {
SUM("sum", SOURCE),
GEO_CENTROID("geo_centroid", "geo_point"),
SCRIPTED_METRIC("scripted_metric", DYNAMIC),
WEIGHTED_AVG("weighted_avg", DYNAMIC),
BUCKET_SCRIPT("bucket_script", DYNAMIC);

private final String aggregationType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
Expand Down Expand Up @@ -77,7 +78,7 @@ public static void deduceMappings(final Client client,
ValuesSourceAggregationBuilder<?, ?> valueSourceAggregation = (ValuesSourceAggregationBuilder<?, ?>) agg;
aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field());
aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType());
} else if(agg instanceof ScriptedMetricAggregationBuilder) {
} else if(agg instanceof ScriptedMetricAggregationBuilder || agg instanceof MultiValuesSourceAggregationBuilder) {
aggregationTypes.put(agg.getName(), agg.getType());
} else {
// execution should not reach this point
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,9 @@ public void testResolveTargetMapping() {
// bucket_script
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("bucket_script", "int"));

// weighted_avg
assertEquals("_dynamic", Aggregations.resolveTargetMapping("weighted_avg", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("weighted_avg", "double"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ private AggregationConfig getAggregationConfig(String agg) throws IOException {
"\"buckets_path\":{\"param_1\":\"other_bucket\"}," +
"\"script\":\"return params.param_1\"}}}");
}
if (agg.equals(AggregationType.WEIGHTED_AVG.getName())) {
return parseAggregations("{\n" +
"\"pivot_weighted_avg\": {\n" +
" \"weighted_avg\": {\n" +
" \"value\": {\"field\": \"values\"},\n" +
" \"weight\": {\"field\": \"weights\"}\n" +
" }\n" +
"}\n" +
"}");
}
return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n"
+ " }\n" + " }" + "}");
}
Expand Down