Skip to content

Commit

Permalink
[ML-DataFrame] Add support for (date) histogram pivots (#38725)
Browse files Browse the repository at this point in the history
* [FEATURE][DATA_FRAME] Adding (date) histogram group_by support for pivot

* adjusting format for merge

* Update DataFramePivotRestIT.java
  • Loading branch information
benwtrent authored Feb 11, 2019
1 parent cd7292c commit cedd78c
Show file tree
Hide file tree
Showing 8 changed files with 469 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,44 @@ public void testSimplePivotWithQuery() throws Exception {
assertOnePivotValue(dataFrameIndex + "/_search?q=reviewer:user_26", 3.918918918);
}

public void testHistogramPivot() throws Exception {
String transformId = "simpleHistogramPivot";
String dataFrameIndex = "pivot_reviews_via_histogram";

final Request createDataframeTransformRequest = new Request("PUT", DATAFRAME_ENDPOINT + transformId);

String config = "{"
+ " \"source\": \"reviews\","
+ " \"dest\": \"" + dataFrameIndex + "\",";


config += " \"pivot\": {"
+ " \"group_by\": [ {"
+ " \"every_2\": {"
+ " \"histogram\": {"
+ " \"interval\": 2,\"field\":\"stars\""
+ " } } } ],"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } } } }"
+ "}";


createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
assertTrue(indexExists(dataFrameIndex));

startAndWaitForTransform(transformId, dataFrameIndex);

// we expect 3 documents as there shall be 5 unique star values and we are bucketing every 2 starting at 0
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
assertEquals(3, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
assertOnePivotValue(dataFrameIndex + "/_search?q=every_2:0.0", 1.0);
}

public void testBiggerPivot() throws Exception {
String transformId = "biggerPivot";
String dataFrameIndex = "bigger_pivot_reviews";
Expand Down Expand Up @@ -149,6 +187,43 @@ public void testBiggerPivot() throws Exception {
assertEquals(41, actual.longValue());
}

public void testDateHistogramPivot() throws Exception {
String transformId = "simpleDateHistogramPivot";
String dataFrameIndex = "pivot_reviews_via_date_histogram";

final Request createDataframeTransformRequest = new Request("PUT", DATAFRAME_ENDPOINT + transformId);

String config = "{"
+ " \"source\": \"reviews\","
+ " \"dest\": \"" + dataFrameIndex + "\",";


config += " \"pivot\": {"
+ " \"group_by\": [ {"
+ " \"by_day\": {"
+ " \"date_histogram\": {"
+ " \"interval\": \"1d\",\"field\":\"timestamp\",\"format\":\"yyyy-MM-DD\""
+ " } } } ],"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
+ " \"field\": \"stars\""
+ " } } } }"
+ "}";

createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
assertTrue(indexExists(dataFrameIndex));

startAndWaitForTransform(transformId, dataFrameIndex);

// we expect 21 documents as there shall be 21 days worth of docs
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
assertEquals(21, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
assertOnePivotValue(dataFrameIndex + "/_search?q=by_day:2017-01-15", 3.82);
}

private void startAndWaitForTransform(String transformId, String dataFrameIndex) throws IOException, Exception {
// start the transform
final Request startTransformRequest = new Request("POST", DATAFRAME_ENDPOINT + transformId + "/_start");
Expand All @@ -160,8 +235,6 @@ private void startAndWaitForTransform(String transformId, String dataFrameIndex)
refreshIndex(dataFrameIndex);
}



private void waitForDataFrameGeneration(String transformId) throws Exception {
assertBusy(() -> {
long generation = getDataFrameGeneration(transformId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ protected void createReviewsIndex() throws IOException {
{
builder.startObject("mappings")
.startObject("properties")
.startObject("timestamp")
.field("type", "date")
.endObject()
.startObject("user_id")
.field("type", "keyword")
.endObject()
Expand All @@ -66,11 +69,17 @@ protected void createReviewsIndex() throws IOException {

// create index
final StringBuilder bulk = new StringBuilder();
int day = 10;
for (int i = 0; i < numDocs; i++) {
bulk.append("{\"index\":{\"_index\":\"reviews\"}}\n");
long user = Math.round(Math.pow(i * 31 % 1000, distributionTable[i % distributionTable.length]) % 27);
int stars = distributionTable[(i * 33) % distributionTable.length];
long business = Math.round(Math.pow(user * stars, distributionTable[i % distributionTable.length]) % 13);
int hour = randomIntBetween(10, 20);
int min = randomIntBetween(30, 59);
int sec = randomIntBetween(30, 59);

String date_string = "2017-01-" + day + "T" + hour + ":" + min + ":" + sec + "Z";
bulk.append("{\"user_id\":\"")
.append("user_")
.append(user)
Expand All @@ -79,7 +88,9 @@ protected void createReviewsIndex() throws IOException {
.append(business)
.append("\",\"stars\":")
.append(stars)
.append("}\n");
.append(",\"timestamp\":\"")
.append(date_string)
.append("\"}\n");

if (i % 50 == 0) {
bulk.append("\r\n");
Expand All @@ -89,6 +100,7 @@ protected void createReviewsIndex() throws IOException {
client().performRequest(bulkRequest);
// clear the builder
bulk.setLength(0);
day += 1;
}
}
bulk.append("\r\n");
Expand Down Expand Up @@ -209,4 +221,4 @@ protected static void wipeIndices() throws IOException {
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.dataframe.transforms.pivot;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval;

import java.io.IOException;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.util.Objects;


public class DateHistogramGroupSource extends SingleGroupSource<DateHistogramGroupSource> {

private static final String NAME = "data_frame_date_histogram_group";
private static final ParseField TIME_ZONE = new ParseField("time_zone");
private static final ParseField FORMAT = new ParseField("format");

private static final ConstructingObjectParser<DateHistogramGroupSource, Void> STRICT_PARSER = createParser(false);
private static final ConstructingObjectParser<DateHistogramGroupSource, Void> LENIENT_PARSER = createParser(true);
private long interval = 0;
private DateHistogramInterval dateHistogramInterval;
private String format;
private ZoneId timeZone;

public DateHistogramGroupSource(String field) {
super(field);
}

public DateHistogramGroupSource(StreamInput in) throws IOException {
super(in);
this.interval = in.readLong();
this.dateHistogramInterval = in.readOptionalWriteable(DateHistogramInterval::new);
this.timeZone = in.readOptionalZoneId();
this.format = in.readOptionalString();
}

private static ConstructingObjectParser<DateHistogramGroupSource, Void> createParser(boolean lenient) {
ConstructingObjectParser<DateHistogramGroupSource, Void> parser = new ConstructingObjectParser<>(NAME, lenient, (args) -> {
String field = (String) args[0];
return new DateHistogramGroupSource(field);
});

SingleGroupSource.declareValuesSourceFields(parser, null);

parser.declareField((histogram, interval) -> {
if (interval instanceof Long) {
histogram.setInterval((long) interval);
} else {
histogram.setDateHistogramInterval((DateHistogramInterval) interval);
}
}, p -> {
if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.longValue();
} else {
return new DateHistogramInterval(p.text());
}
}, HistogramGroupSource.INTERVAL, ObjectParser.ValueType.LONG);

parser.declareField(DateHistogramGroupSource::setTimeZone, p -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return ZoneId.of(p.text());
} else {
return ZoneOffset.ofHours(p.intValue());
}
}, TIME_ZONE, ObjectParser.ValueType.LONG);

parser.declareString(DateHistogramGroupSource::setFormat, FORMAT);
return parser;
}

public static DateHistogramGroupSource fromXContent(final XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}

public long getInterval() {
return interval;
}

public void setInterval(long interval) {
if (interval < 1) {
throw new IllegalArgumentException("[interval] must be greater than or equal to 1.");
}
this.interval = interval;
}

public DateHistogramInterval getDateHistogramInterval() {
return dateHistogramInterval;
}

public void setDateHistogramInterval(DateHistogramInterval dateHistogramInterval) {
if (dateHistogramInterval == null) {
throw new IllegalArgumentException("[dateHistogramInterval] must not be null");
}
this.dateHistogramInterval = dateHistogramInterval;
}

public String getFormat() {
return format;
}

public void setFormat(String format) {
this.format = format;
}

public ZoneId getTimeZone() {
return timeZone;
}

public void setTimeZone(ZoneId timeZone) {
this.timeZone = timeZone;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(field);
out.writeLong(interval);
out.writeOptionalWriteable(dateHistogramInterval);
out.writeOptionalZoneId(timeZone);
out.writeOptionalString(format);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (field != null) {
builder.field(FIELD.getPreferredName(), field);
}
if (dateHistogramInterval == null) {
builder.field(HistogramGroupSource.INTERVAL.getPreferredName(), interval);
} else {
builder.field(HistogramGroupSource.INTERVAL.getPreferredName(), dateHistogramInterval.toString());
}
if (timeZone != null) {
builder.field(TIME_ZONE.getPreferredName(), timeZone.toString());
}
if (format != null) {
builder.field(FORMAT.getPreferredName(), format);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}

if (other == null || getClass() != other.getClass()) {
return false;
}

final DateHistogramGroupSource that = (DateHistogramGroupSource) other;

return Objects.equals(this.field, that.field) &&
Objects.equals(interval, that.interval) &&
Objects.equals(dateHistogramInterval, that.dateHistogramInterval) &&
Objects.equals(timeZone, that.timeZone) &&
Objects.equals(format, that.format);
}

@Override
public int hashCode() {
return Objects.hash(field, interval, dateHistogramInterval, timeZone, format);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ public GroupConfig(StreamInput in) throws IOException {
case TERMS:
groupSource = in.readOptionalWriteable(TermsGroupSource::new);
break;
case HISTOGRAM:
groupSource = in.readOptionalWriteable(HistogramGroupSource::new);
break;
case DATE_HISTOGRAM:
groupSource = in.readOptionalWriteable(DateHistogramGroupSource::new);
break;
default:
throw new IOException("Unknown group type");
}
Expand Down Expand Up @@ -126,6 +132,12 @@ public static GroupConfig fromXContent(final XContentParser parser, boolean leni
case TERMS:
groupSource = TermsGroupSource.fromXContent(parser, lenient);
break;
case HISTOGRAM:
groupSource = HistogramGroupSource.fromXContent(parser, lenient);
break;
case DATE_HISTOGRAM:
groupSource = DateHistogramGroupSource.fromXContent(parser, lenient);
break;
default:
throw new ParsingException(parser.getTokenLocation(), "invalid grouping type: " + groupType);
}
Expand Down
Loading

0 comments on commit cedd78c

Please sign in to comment.