Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Changes default destination index field mapping and adds scripted_metric agg #40750

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,59 @@ public void testPivotWithMaxOnDateField() throws Exception {
assertThat(actual, containsString("2017-01-15T"));
}

public void testPivotWithScriptedMetricAgg() throws Exception {
String transformId = "scriptedMetricPivot";
String dataFrameIndex = "scripted_metric_pivot_reviews";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);

final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

String config = "{"
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";

config += " \"pivot\": {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } },"
+ " \"squared_sum\": {"
+ " \"scripted_metric\": {"
+ " \"init_script\": \"state.reviews_sqrd = []\","
+ " \"map_script\": \"state.reviews_sqrd.add(doc.stars.value * doc.stars.value)\","
+ " \"combine_script\": \"state.reviews_sqrd\","
+ " \"reduce_script\": \"def sum = 0.0; for(l in states){ for(a in l) { sum += a}} return sum\""
+ " } }"
+ " } }"
+ "}";

createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
assertTrue(indexExists(dataFrameIndex));

startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

// we expect 27 documents as there shall be 27 user_id's
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));

// get and check some users
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.squared_sum", searchResult)).get(0);
assertEquals(711.0, actual.doubleValue(), 0.000001);
}

private void assertOnePivotValue(String query, double expected) throws IOException {
Map<String, Object> searchResult = getAsMap(query);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,5 @@ private void getPreview(Pivot pivot, ActionListener<List<Map<String, Object>>> l
},
listener::onFailure
));

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
Expand Down Expand Up @@ -73,6 +74,8 @@ public static Stream<Map<String, Object>> extractCompositeAggregationResults(Com
} else {
document.put(aggName, aggResultSingleValue.getValueAsString());
}
} else if (aggResult instanceof ScriptedMetric) {
document.put(aggName, ((ScriptedMetric) aggResult).aggregation());
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ enum AggregationType {
VALUE_COUNT("value_count", "long"),
MAX("max", null),
MIN("min", null),
SUM("sum", null);
SUM("sum", null),
SCRIPTED_METRIC("scripted_metric", null);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 1st argument is the name known by aggregations, the 2nd the output type, which is either the concrete type or null if the source type should be used. So null is a magic variable and actually it should not be null for scripted_magic, but that's due to the lack of another magic.

null is a bad magic anyway and the types are well defined and limited. So I suggest to change this to:

enum AggregationType {
        AVG("avg", "double"),
        CARDINALITY("cardinality", "long"),
        VALUE_COUNT("value_count", "long"),
        MAX("max", "source"),
        MIN("min", "source"),
        SUM("sum", "source"),
        SCRIPTED_METRIC("scripted_metric", "dynamic");

There should never be a type source and dynamic, alternatively we could use _source, _dynamic to make the special meaning more visible.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method resolveTargetMapping needs to be changed, it should return the concrete source type or dynamic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nice, for sure! Let me crank that out today. It would be better than implicitly having different behavior for null between max and scripted_metric.


private final String aggregationType;
private final String targetMapping;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
Expand Down Expand Up @@ -75,6 +76,8 @@ public static void deduceMappings(final Client client,
ValuesSourceAggregationBuilder<?, ?> valueSourceAggregation = (ValuesSourceAggregationBuilder<?, ?>) agg;
aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field());
aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType());
} else if(agg instanceof ScriptedMetricAggregationBuilder) {
// do nothing

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to put the fieldname/mapping (dynamic) but only aggregationTypes and not aggregationSourceFieldNames

(I am not sure what happens further down, might require additional checks)

} else {
// execution should not reach this point
listener.onFailure(new RuntimeException("Unsupported aggregation type [" + agg.getType() + "]"));
Expand Down Expand Up @@ -134,8 +137,7 @@ private static Map<String, String> resolveMappings(Map<String, String> aggregati
if (destinationMapping != null) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs an additional else if to check for dynamic, would be good to do a logger.info for this case

targetMapping.put(targetFieldName, destinationMapping);
} else {
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to double.");
targetMapping.put(targetFieldName, "double");
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to dynamic mapping.");
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
import org.elasticsearch.search.aggregations.metrics.ParsedExtendedStats;
import org.elasticsearch.search.aggregations.metrics.ParsedMax;
import org.elasticsearch.search.aggregations.metrics.ParsedMin;
import org.elasticsearch.search.aggregations.metrics.ParsedScriptedMetric;
import org.elasticsearch.search.aggregations.metrics.ParsedStats;
import org.elasticsearch.search.aggregations.metrics.ParsedSum;
import org.elasticsearch.search.aggregations.metrics.ParsedValueCount;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ValueCountAggregationBuilder;
Expand Down Expand Up @@ -76,6 +78,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
map.put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c));
map.put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c));
map.put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c));
map.put(ScriptedMetricAggregationBuilder.NAME, (p, c) -> ParsedScriptedMetric.fromXContent(p, (String) c));
map.put(ValueCountAggregationBuilder.NAME, (p, c) -> ParsedValueCount.fromXContent(p, (String) c));
map.put(StatsAggregationBuilder.NAME, (p, c) -> ParsedStats.fromXContent(p, (String) c));
map.put(StatsBucketPipelineAggregationBuilder.NAME, (p, c) -> ParsedStatsBucket.fromXContent(p, (String) c));
Expand Down Expand Up @@ -409,6 +412,92 @@ aggTypedName2, asMap(
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
}

public void testExtractCompositeAggregationResultsWithDynamicType() throws IOException {
String targetField = randomAlphaOfLengthBetween(5, 10);
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";

GroupConfig groupBy = parseGroupConfig("{"
+ "\"" + targetField + "\" : {"
+ " \"terms\" : {"
+ " \"field\" : \"doesn't_matter_for_this_test\""
+ " } },"
+ "\"" + targetField2 + "\" : {"
+ " \"terms\" : {"
+ " \"field\" : \"doesn't_matter_for_this_test\""
+ " } }"
+ "}");

String aggName = randomAlphaOfLengthBetween(5, 10);
String aggTypedName = "scripted_metric#" + aggName;

Collection<AggregationBuilder> aggregationBuilders = asList(AggregationBuilders.scriptedMetric(aggName));

Map<String, Object> input = asMap(
"buckets",
asList(
asMap(
KEY, asMap(
targetField, "ID1",
targetField2, "ID1_2"
),
aggTypedName, asMap(
"value", asMap("field", 123.0)),
DOC_COUNT, 1),
asMap(
KEY, asMap(
targetField, "ID1",
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", asMap("field", 1.0)),
DOC_COUNT, 2),
asMap(
KEY, asMap(
targetField, "ID2",
targetField2, "ID1_2"
),
aggTypedName, asMap(
"value", asMap("field", 2.13)),
DOC_COUNT, 3),
asMap(
KEY, asMap(
targetField, "ID3",
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", asMap("field", 12.0)),
DOC_COUNT, 4)
));

List<Map<String, Object>> expected = asList(
asMap(
targetField, "ID1",
targetField2, "ID1_2",
aggName, asMap("field", 123.0)
),
asMap(
targetField, "ID1",
targetField2, "ID2_2",
aggName, asMap("field", 1.0)
),
asMap(
targetField, "ID2",
targetField2, "ID1_2",
aggName, asMap("field", 2.13)
),
asMap(
targetField, "ID3",
targetField2, "ID2_2",
aggName, asMap("field", 12.0)
)
);
Map<String, String> fieldTypeMap = asStringMap(
targetField, "keyword",
targetField2, "keyword"
);
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
}

public void testExtractCompositeAggregationResultsDocIDs() throws IOException {
String targetField = randomAlphaOfLengthBetween(5, 10);
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -176,14 +174,20 @@ private AggregationConfig getValidAggregationConfig() throws IOException {
}

private AggregationConfig getAggregationConfig(String agg) throws IOException {
if (agg.equals(AggregationType.SCRIPTED_METRIC.getName())) {
return parseAggregations("{\"pivot_scripted_metric\": {\n" +
"\"scripted_metric\": {\n" +
" \"init_script\" : \"state.transactions = []\",\n" +
" \"map_script\" : \"state.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)\", \n" +
" \"combine_script\" : \"double profit = 0; for (t in state.transactions) { profit += t } return profit\",\n" +
" \"reduce_script\" : \"double profit = 0; for (a in states) { profit += a } return profit\"\n" +
" }\n" +
"}}");
}
return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n"
+ " }\n" + " }" + "}");
}

private Map<String, String> getFieldMappings() {
return Collections.singletonMap("values", "double");
}

private AggregationConfig parseAggregations(String json) throws IOException {
final XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION, json);
Expand Down