Skip to content

Commit

Permalink
[ML] Changes default destination index field mapping and adds scripte…
Browse files Browse the repository at this point in the history
…d_metric agg (elastic#40750)

* [ML] Allowing destination index mappings to have dynamic types, adds script_metric agg

* Making dynamic|source mapping explicit
  • Loading branch information
benwtrent authored and Gurkan Kaymak committed May 27, 2019
1 parent 39e173b commit 0f85c54
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,59 @@ public void testPivotWithMaxOnDateField() throws Exception {
assertThat(actual, containsString("2017-01-15T"));
}

public void testPivotWithScriptedMetricAgg() throws Exception {
String transformId = "scriptedMetricPivot";
String dataFrameIndex = "scripted_metric_pivot_reviews";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);

final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

String config = "{"
+ " \"source\": {\"index\":\"" + REVIEWS_INDEX_NAME + "\"},"
+ " \"dest\": {\"index\":\"" + dataFrameIndex + "\"},";

config += " \"pivot\": {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } },"
+ " \"squared_sum\": {"
+ " \"scripted_metric\": {"
+ " \"init_script\": \"state.reviews_sqrd = []\","
+ " \"map_script\": \"state.reviews_sqrd.add(doc.stars.value * doc.stars.value)\","
+ " \"combine_script\": \"state.reviews_sqrd\","
+ " \"reduce_script\": \"def sum = 0.0; for(l in states){ for(a in l) { sum += a}} return sum\""
+ " } }"
+ " } }"
+ "}";

createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
assertTrue(indexExists(dataFrameIndex));

startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

// we expect 27 documents as there shall be 27 user_id's
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
assertEquals(27, XContentMapValues.extractValue("_all.total.docs.count", indexStats));

// get and check some users
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=reviewer:user_4");
assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.avg_rating", searchResult)).get(0);
assertEquals(3.878048780, actual.doubleValue(), 0.000001);
actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.squared_sum", searchResult)).get(0);
assertEquals(711.0, actual.doubleValue(), 0.000001);
}

private void assertOnePivotValue(String query, double expected) throws IOException {
Map<String, Object> searchResult = getAsMap(query);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,5 @@ private void getPreview(Pivot pivot, ActionListener<List<Map<String, Object>>> l
},
listener::onFailure
));

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetric;
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig;
Expand Down Expand Up @@ -73,6 +74,8 @@ public static Stream<Map<String, Object>> extractCompositeAggregationResults(Com
} else {
document.put(aggName, aggResultSingleValue.getValueAsString());
}
} else if (aggResult instanceof ScriptedMetric) {
document.put(aggName, ((ScriptedMetric) aggResult).aggregation());
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import java.util.stream.Stream;

public final class Aggregations {

// the field mapping should not explicitly be set and allow ES to dynamically determine mapping via the data.
private static final String DYNAMIC = "_dynamic";
// the field mapping should be determined explicitly from the source field mapping if possible.
private static final String SOURCE = "_source";
private Aggregations() {}

/**
Expand All @@ -27,9 +32,10 @@ enum AggregationType {
AVG("avg", "double"),
CARDINALITY("cardinality", "long"),
VALUE_COUNT("value_count", "long"),
MAX("max", null),
MIN("min", null),
SUM("sum", null);
MAX("max", SOURCE),
MIN("min", SOURCE),
SUM("sum", SOURCE),
SCRIPTED_METRIC("scripted_metric", DYNAMIC);

private final String aggregationType;
private final String targetMapping;
Expand All @@ -55,8 +61,12 @@ public static boolean isSupportedByDataframe(String aggregationType) {
return aggregationSupported.contains(aggregationType.toUpperCase(Locale.ROOT));
}

public static boolean isDynamicMapping(String targetMapping) {
return DYNAMIC.equals(targetMapping);
}

public static String resolveTargetMapping(String aggregationType, String sourceType) {
AggregationType agg = AggregationType.valueOf(aggregationType.toUpperCase(Locale.ROOT));
return agg.getTargetMapping() == null ? sourceType : agg.getTargetMapping();
return agg.getTargetMapping().equals(SOURCE) ? sourceType : agg.getTargetMapping();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig;
Expand Down Expand Up @@ -75,6 +76,8 @@ public static void deduceMappings(final Client client,
ValuesSourceAggregationBuilder<?, ?> valueSourceAggregation = (ValuesSourceAggregationBuilder<?, ?>) agg;
aggregationSourceFieldNames.put(valueSourceAggregation.getName(), valueSourceAggregation.field());
aggregationTypes.put(valueSourceAggregation.getName(), valueSourceAggregation.getType());
} else if(agg instanceof ScriptedMetricAggregationBuilder) {
aggregationTypes.put(agg.getName(), agg.getType());
} else {
// execution should not reach this point
listener.onFailure(new RuntimeException("Unsupported aggregation type [" + agg.getType() + "]"));
Expand Down Expand Up @@ -127,15 +130,17 @@ private static Map<String, String> resolveMappings(Map<String, String> aggregati

aggregationTypes.forEach((targetFieldName, aggregationName) -> {
String sourceFieldName = aggregationSourceFieldNames.get(targetFieldName);
String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMappings.get(sourceFieldName));
String sourceMapping = sourceFieldName == null ? null : sourceMappings.get(sourceFieldName);
String destinationMapping = Aggregations.resolveTargetMapping(aggregationName, sourceMapping);

logger.debug(
"Deduced mapping for: [" + targetFieldName + "], agg type [" + aggregationName + "] to [" + destinationMapping + "]");
if (destinationMapping != null) {
if (Aggregations.isDynamicMapping(destinationMapping)) {
logger.info("Dynamic target mapping set for field ["+ targetFieldName +"] and aggregation [" + aggregationName +"]");
} else if (destinationMapping != null) {
targetMapping.put(targetFieldName, destinationMapping);
} else {
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to double.");
targetMapping.put(targetFieldName, "double");
logger.warn("Failed to deduce mapping for [" + targetFieldName + "], fall back to dynamic mapping.");
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
import org.elasticsearch.search.aggregations.metrics.ParsedExtendedStats;
import org.elasticsearch.search.aggregations.metrics.ParsedMax;
import org.elasticsearch.search.aggregations.metrics.ParsedMin;
import org.elasticsearch.search.aggregations.metrics.ParsedScriptedMetric;
import org.elasticsearch.search.aggregations.metrics.ParsedStats;
import org.elasticsearch.search.aggregations.metrics.ParsedSum;
import org.elasticsearch.search.aggregations.metrics.ParsedValueCount;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.StatsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ValueCountAggregationBuilder;
Expand Down Expand Up @@ -76,6 +78,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
map.put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c));
map.put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c));
map.put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c));
map.put(ScriptedMetricAggregationBuilder.NAME, (p, c) -> ParsedScriptedMetric.fromXContent(p, (String) c));
map.put(ValueCountAggregationBuilder.NAME, (p, c) -> ParsedValueCount.fromXContent(p, (String) c));
map.put(StatsAggregationBuilder.NAME, (p, c) -> ParsedStats.fromXContent(p, (String) c));
map.put(StatsBucketPipelineAggregationBuilder.NAME, (p, c) -> ParsedStatsBucket.fromXContent(p, (String) c));
Expand Down Expand Up @@ -409,6 +412,92 @@ aggTypedName2, asMap(
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
}

public void testExtractCompositeAggregationResultsWithDynamicType() throws IOException {
String targetField = randomAlphaOfLengthBetween(5, 10);
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";

GroupConfig groupBy = parseGroupConfig("{"
+ "\"" + targetField + "\" : {"
+ " \"terms\" : {"
+ " \"field\" : \"doesn't_matter_for_this_test\""
+ " } },"
+ "\"" + targetField2 + "\" : {"
+ " \"terms\" : {"
+ " \"field\" : \"doesn't_matter_for_this_test\""
+ " } }"
+ "}");

String aggName = randomAlphaOfLengthBetween(5, 10);
String aggTypedName = "scripted_metric#" + aggName;

Collection<AggregationBuilder> aggregationBuilders = asList(AggregationBuilders.scriptedMetric(aggName));

Map<String, Object> input = asMap(
"buckets",
asList(
asMap(
KEY, asMap(
targetField, "ID1",
targetField2, "ID1_2"
),
aggTypedName, asMap(
"value", asMap("field", 123.0)),
DOC_COUNT, 1),
asMap(
KEY, asMap(
targetField, "ID1",
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", asMap("field", 1.0)),
DOC_COUNT, 2),
asMap(
KEY, asMap(
targetField, "ID2",
targetField2, "ID1_2"
),
aggTypedName, asMap(
"value", asMap("field", 2.13)),
DOC_COUNT, 3),
asMap(
KEY, asMap(
targetField, "ID3",
targetField2, "ID2_2"
),
aggTypedName, asMap(
"value", asMap("field", 12.0)),
DOC_COUNT, 4)
));

List<Map<String, Object>> expected = asList(
asMap(
targetField, "ID1",
targetField2, "ID1_2",
aggName, asMap("field", 123.0)
),
asMap(
targetField, "ID1",
targetField2, "ID2_2",
aggName, asMap("field", 1.0)
),
asMap(
targetField, "ID2",
targetField2, "ID1_2",
aggName, asMap("field", 2.13)
),
asMap(
targetField, "ID3",
targetField2, "ID2_2",
aggName, asMap("field", 12.0)
)
);
Map<String, String> fieldTypeMap = asStringMap(
targetField, "keyword",
targetField2, "keyword"
);
executeTest(groupBy, aggregationBuilders, input, fieldTypeMap, expected, 10);
}

public void testExtractCompositeAggregationResultsDocIDs() throws IOException {
String targetField = randomAlphaOfLengthBetween(5, 10);
String targetField2 = randomAlphaOfLengthBetween(5, 10) + "_2";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,31 @@ public void testResolveTargetMapping() {
assertEquals("double", Aggregations.resolveTargetMapping("avg", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("avg", "double"));

// cardinality
assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "int"));
assertEquals("long", Aggregations.resolveTargetMapping("cardinality", "double"));

// value_count
assertEquals("long", Aggregations.resolveTargetMapping("value_count", "int"));
assertEquals("long", Aggregations.resolveTargetMapping("value_count", "double"));

// max
assertEquals("int", Aggregations.resolveTargetMapping("max", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("max", "double"));
assertEquals("half_float", Aggregations.resolveTargetMapping("max", "half_float"));

// min
assertEquals("int", Aggregations.resolveTargetMapping("min", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("min", "double"));
assertEquals("half_float", Aggregations.resolveTargetMapping("min", "half_float"));

// sum
assertEquals("int", Aggregations.resolveTargetMapping("sum", "int"));
assertEquals("double", Aggregations.resolveTargetMapping("sum", "double"));
assertEquals("half_float", Aggregations.resolveTargetMapping("sum", "half_float"));

// scripted_metric
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", null));
assertEquals("_dynamic", Aggregations.resolveTargetMapping("scripted_metric", "int"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -176,14 +174,20 @@ private AggregationConfig getValidAggregationConfig() throws IOException {
}

private AggregationConfig getAggregationConfig(String agg) throws IOException {
if (agg.equals(AggregationType.SCRIPTED_METRIC.getName())) {
return parseAggregations("{\"pivot_scripted_metric\": {\n" +
"\"scripted_metric\": {\n" +
" \"init_script\" : \"state.transactions = []\",\n" +
" \"map_script\" : \"state.transactions.add(doc.type.value == 'sale' ? doc.amount.value : -1 * doc.amount.value)\", \n" +
" \"combine_script\" : \"double profit = 0; for (t in state.transactions) { profit += t } return profit\",\n" +
" \"reduce_script\" : \"double profit = 0; for (a in states) { profit += a } return profit\"\n" +
" }\n" +
"}}");
}
return parseAggregations("{\n" + " \"pivot_" + agg + "\": {\n" + " \"" + agg + "\": {\n" + " \"field\": \"values\"\n"
+ " }\n" + " }" + "}");
}

private Map<String, String> getFieldMappings() {
return Collections.singletonMap("values", "double");
}

private AggregationConfig parseAggregations(String json) throws IOException {
final XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION, json);
Expand Down

0 comments on commit 0f85c54

Please sign in to comment.