Skip to content

Commit

Permalink
[ML] Adding support for geo_shape, geo_centroid, geo_point in datafee…
Browse files Browse the repository at this point in the history
…ds (elastic#42969) (elastic#43069)

* [ML] Adding support for geo_shape, geo_centroid, geo_point in datafeeds

* only supporting doc_values for geo_point fields

* moving validation into GeoPointField ctor
  • Loading branch information
benwtrent authored Jun 11, 2019
1 parent eadfe05 commit 7905205
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 34 deletions.
38 changes: 7 additions & 31 deletions docs/reference/ml/functions/geo.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ detects anomalies where the geographic location of a credit card transaction is
unusual for a particular customer’s credit card. An anomaly might indicate fraud.

IMPORTANT: The `field_name` that you supply must be a single string that contains
two comma-separated numbers of the form `latitude,longitude`. The `latitude` and
`longitude` must be in the range -180 to 180 and represent a point on the
two comma-separated numbers of the form `latitude,longitude`, a `geo_point` field,
a `geo_shape` field that contains point values, or a `geo_centroid` aggregation.
The `latitude` and `longitude` must be in the range -180 to 180 and represent a point on the
surface of the Earth.

For example, JSON data might contain the following transaction coordinates:
Expand All @@ -71,34 +72,9 @@ For example, JSON data might contain the following transaction coordinates:
// NOTCONSOLE

In {es}, location data is likely to be stored in `geo_point` fields. For more
information, see {ref}/geo-point.html[Geo-point datatype]. This data type is not
supported natively in {ml-features}. You can, however, use Painless scripts
in `script_fields` in your {dfeed} to transform the data into an appropriate
format. For example, the following Painless script transforms
`"coords": {"lat" : 41.44, "lon":90.5}` into `"lat-lon": "41.44,90.5"`:

[source,js]
--------------------------------------------------
PUT _ml/datafeeds/datafeed-test2
{
"job_id": "farequote",
"indices": ["farequote"],
"query": {
"match_all": {
"boost": 1
}
},
"script_fields": {
"lat-lon": {
"script": {
"source": "doc['coords'].lat + ',' + doc['coords'].lon",
"lang": "painless"
}
}
}
}
--------------------------------------------------
// CONSOLE
// TEST[skip:setup:farequote_job]
information, see {ref}/geo-point.html[Geo-point datatype]. This data type is
supported natively in {ml-features}. Specifically, {dfeed} when pulling data from
a `geo_point` field, will transform the data into the appropriate `lat,lon` string
format before sending to the {ml} job.

For more information, see <<ml-configuring-transform>>.
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,71 @@ public void testLookbackOnlyWithNestedFields() throws Exception {
assertThat(jobStatsResponseAsString, containsString("\"missing_field_count\":0"));
}

public void testLookbackWithGeo() throws Exception {
String jobId = "test-lookback-only-with-geo";
Request createJobRequest = new Request("PUT", MachineLearning.BASE_PATH + "anomaly_detectors/" + jobId);
createJobRequest.setJsonEntity("{\n"
+ " \"description\": \"lat_long with geo_point\",\n"
+ " \"analysis_config\": {\n"
+ " \"bucket_span\": \"15m\",\n"
+ " \"detectors\": [\n"
+ " {\n"
+ " \"function\": \"lat_long\",\n"
+ " \"field_name\": \"location\"\n"
+ " }\n"
+ " ]\n"
+ " },"
+ " \"data_description\": {\"time_field\": \"time\"}\n"
+ "}");
client().performRequest(createJobRequest);
String datafeedId = jobId + "-datafeed";
new DatafeedBuilder(datafeedId, jobId, "geo-data").build();

StringBuilder bulk = new StringBuilder();

Request createGeoData = new Request("PUT", "/geo-data");
createGeoData.setJsonEntity("{"
+ " \"mappings\": {"
+ " \"properties\": {"
+ " \"time\": { \"type\":\"date\"},"
+ " \"location\": { \"type\":\"geo_point\"}"
+ " }"
+ " }"
+ "}");
client().performRequest(createGeoData);

bulk.append("{\"index\": {\"_index\": \"geo-data\", \"_id\": 1}}\n");
bulk.append("{\"time\":\"2016-06-01T00:00:00Z\",\"location\":{\"lat\":38.897676,\"lon\":-77.03653}}\n");
bulk.append("{\"index\": {\"_index\": \"geo-data\", \"_id\": 2}}\n");
bulk.append("{\"time\":\"2016-06-01T00:05:00Z\",\"location\":{\"lat\":38.897676,\"lon\":-77.03653}}\n");
bulk.append("{\"index\": {\"_index\": \"geo-data\", \"_id\": 3}}\n");
bulk.append("{\"time\":\"2016-06-01T00:10:00Z\",\"location\":{\"lat\":38.897676,\"lon\":-77.03653}}\n");
bulk.append("{\"index\": {\"_index\": \"geo-data\", \"_id\": 4}}\n");
bulk.append("{\"time\":\"2016-06-01T00:15:00Z\",\"location\":{\"lat\":38.897676,\"lon\":-77.03653}}\n");
bulk.append("{\"index\": {\"_index\": \"geo-data\", \"_id\": 5}}\n");
bulk.append("{\"time\":\"2016-06-01T00:20:00Z\",\"location\":{\"lat\":38.897676,\"lon\":-77.03653}}\n");
bulk.append("{\"index\": {\"_index\": \"geo-data\", \"_id\": 6}}\n");
bulk.append("{\"time\":\"2016-06-01T00:25:00Z\",\"location\":{\"lat\":38.897676,\"lon\":-77.03653}}\n");
bulk.append("{\"index\": {\"_index\": \"geo-data\", \"_id\": 7}}\n");
bulk.append("{\"time\":\"2016-06-01T00:30:00Z\",\"location\":{\"lat\":38.897676,\"lon\":-77.03653}}\n");
bulk.append("{\"index\": {\"_index\": \"geo-data\", \"_id\": 8}}\n");
bulk.append("{\"time\":\"2016-06-01T00:40:00Z\",\"location\":{\"lat\":90.0,\"lon\":-77.03653}}\n");
bulk.append("{\"index\": {\"_index\": \"geo-data\", \"_id\": 9}}\n");
bulk.append("{\"time\":\"2016-06-01T00:41:00Z\",\"location\":{\"lat\":38.897676,\"lon\":-77.03653}}\n");
bulkIndex(bulk.toString());

openJob(client(), jobId);

startDatafeedAndWaitUntilStopped(datafeedId);
waitUntilJobIsClosed(jobId);
Response jobStatsResponse = client().performRequest(
new Request("GET", MachineLearning.BASE_PATH + "anomaly_detectors/" + jobId + "/_stats"));
String jobStatsResponseAsString = EntityUtils.toString(jobStatsResponse.getEntity());
assertThat(jobStatsResponseAsString, containsString("\"input_record_count\":9"));
assertThat(jobStatsResponseAsString, containsString("\"processed_record_count\":9"));
assertThat(jobStatsResponseAsString, containsString("\"missing_field_count\":0"));
}

public void testLookbackOnlyGivenEmptyIndex() throws Exception {
new LookbackOnlyTestHelper("test-lookback-only-given-empty-index", "airline-data-empty")
.setShouldSucceedInput(false).setShouldSucceedProcessing(false).execute();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.histogram.Histogram;
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
import org.elasticsearch.search.aggregations.metrics.Max;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.Percentile;
Expand Down Expand Up @@ -275,14 +276,16 @@ private void processBucket(MultiBucketsAggregation bucketAgg, boolean addField)
}

/**
* Adds a leaf key-value. It returns the name of the key added or {@code null} when nothing was added.
* Adds a leaf key-value. It returns {@code true} if the key added or {@code false} when nothing was added.
* Non-finite metric values are not added.
*/
private boolean processLeaf(Aggregation agg) throws IOException {
if (agg instanceof NumericMetricsAggregation.SingleValue) {
return processSingleValue((NumericMetricsAggregation.SingleValue) agg);
} else if (agg instanceof Percentiles) {
return processPercentiles((Percentiles) agg);
} else if (agg instanceof GeoCentroid){
return processGeoCentroid((GeoCentroid) agg);
} else {
throw new IllegalArgumentException("Unsupported aggregation type [" + agg.getName() + "]");
}
Expand All @@ -300,6 +303,14 @@ private boolean addMetricIfFinite(String key, double value) {
return false;
}

private boolean processGeoCentroid(GeoCentroid agg) {
if (agg.count() > 0) {
keyValuePairs.put(agg.getName(), agg.centroid().getLat() + "," + agg.centroid().getLon());
return true;
}
return false;
}

private boolean processPercentiles(Percentiles percentiles) throws IOException {
Iterator<Percentile> percentileIterator = percentiles.iterator();
boolean aggregationAdded = addMetricIfFinite(percentiles.getName(), percentileIterator.next().getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@
package org.elasticsearch.xpack.ml.datafeed.extractor.fields;

import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.geo.geometry.Geometry;
import org.elasticsearch.geo.geometry.Point;
import org.elasticsearch.geo.geometry.ShapeType;
import org.elasticsearch.geo.utils.WellKnownText;
import org.elasticsearch.search.SearchHit;

import java.io.IOException;
import java.text.ParseException;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -61,6 +69,14 @@ public static ExtractedField newTimeField(String name, ExtractionMethod extracti
return new TimeField(name, extractionMethod);
}

public static ExtractedField newGeoShapeField(String alias, String name, ExtractionMethod extractionMethod) {
return new GeoShapeField(alias, name, extractionMethod);
}

public static ExtractedField newGeoPointField(String alias, String name, ExtractionMethod extractionMethod) {
return new GeoPointField(alias, name, extractionMethod);
}

public static ExtractedField newField(String name, ExtractionMethod extractionMethod) {
return newField(name, name, extractionMethod);
}
Expand Down Expand Up @@ -88,12 +104,102 @@ public Object[] value(SearchHit hit) {
DocumentField keyValue = hit.field(name);
if (keyValue != null) {
List<Object> values = keyValue.getValues();
return values.toArray(new Object[values.size()]);
return values.toArray(new Object[0]);
}
return new Object[0];
}
}

private static class GeoShapeField extends FromSource {
private static final WellKnownText wkt = new WellKnownText();

GeoShapeField(String alias, String name, ExtractionMethod extractionMethod) {
super(alias, name, extractionMethod);
}

@Override
public Object[] value(SearchHit hit) {
Object[] value = super.value(hit);
if (value.length != 1) {
throw new IllegalStateException("Unexpected values for a geo_shape field: " + Arrays.toString(value));
}
if (value[0] instanceof String) {
value[0] = handleString((String) value[0]);
} else if (value[0] instanceof Map<?, ?>) {
@SuppressWarnings("unchecked")
Map<String, Object> geoObject = (Map<String, Object>) value[0];
value[0] = handleObject(geoObject);
} else {
throw new IllegalStateException("Unexpected value type for a geo_shape field: " + value[0].getClass());
}
return value;
}

private String handleString(String geoString) {
try {
if (geoString.startsWith("POINT")) { // Entry is of the form "POINT (-77.03653 38.897676)"
Geometry geometry = wkt.fromWKT(geoString);
if (geometry.type() != ShapeType.POINT) {
throw new IllegalArgumentException("Unexpected non-point geo_shape type: " + geometry.type().name());
}
Point pt = ((Point)geometry);
return pt.getLat() + "," + pt.getLon();
} else {
throw new IllegalArgumentException("Unexpected value for a geo_shape field: " + geoString);
}
} catch (IOException | ParseException ex) {
throw new IllegalArgumentException("Unexpected value for a geo_shape field: " + geoString);
}
}

private String handleObject(Map<String, Object> geoObject) {
String geoType = (String) geoObject.get("type");
if (geoType != null && "point".equals(geoType.toLowerCase(Locale.ROOT))) {
@SuppressWarnings("unchecked")
List<Double> coordinates = (List<Double>) geoObject.get("coordinates");
if (coordinates == null || coordinates.size() != 2) {
throw new IllegalArgumentException("Invalid coordinates for geo_shape point: " + geoObject);
}
return coordinates.get(1) + "," + coordinates.get(0);
} else {
throw new IllegalArgumentException("Unexpected value for a geo_shape field: " + geoObject);
}
}

}

private static class GeoPointField extends FromFields {

GeoPointField(String alias, String name, ExtractionMethod extractionMethod) {
super(alias, name, extractionMethod);
if (extractionMethod != ExtractionMethod.DOC_VALUE) {
throw new IllegalArgumentException("cannot use [geo_point] field with disabled doc values");
}
}

@Override
public Object[] value(SearchHit hit) {
Object[] value = super.value(hit);
if (value.length != 1) {
throw new IllegalStateException("Unexpected values for a geo_point field: " + Arrays.toString(value));
}
if (value[0] instanceof String) {
value[0] = handleString((String) value[0]);
} else {
throw new IllegalStateException("Unexpected value type for a geo_point field: " + value[0].getClass());
}
return value;
}

private String handleString(String geoString) {
if (geoString.contains(",")) { // Entry is of the form "38.897676, -77.03653"
return geoString.replace(" ", "");
} else {
throw new IllegalArgumentException("Unexpected value for a geo_point field: " + geoString);
}
}
}

private static class TimeField extends FromFields {

private static final String EPOCH_MILLIS_FORMAT = "epoch_millis";
Expand Down Expand Up @@ -145,7 +251,7 @@ public Object[] value(SearchHit hit) {
if (values instanceof List<?>) {
@SuppressWarnings("unchecked")
List<Object> asList = (List<Object>) values;
return asList.toArray(new Object[asList.size()]);
return asList.toArray(new Object[0]);
} else {
return new Object[]{values};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ protected ExtractedField detect(String field) {
: ExtractedField.ExtractionMethod.SOURCE;
}
}
if (isFieldOfType(field, "geo_point")) {
return ExtractedField.newGeoPointField(field, internalField, method);
}
if (isFieldOfType(field, "geo_shape")) {
return ExtractedField.newGeoShapeField(field, internalField, method);
}
return ExtractedField.newField(field, internalField, method);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
*/
package org.elasticsearch.xpack.ml.datafeed.extractor.aggregation;

import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.histogram.Histogram;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.search.aggregations.metrics.Max;
import org.elasticsearch.search.aggregations.metrics.Percentile;
Expand Down Expand Up @@ -70,6 +72,15 @@ static Max createMax(String name, double value) {
return max;
}

static GeoCentroid createGeoCentroid(String name, long count, double lat, double lon) {
GeoCentroid centroid = mock(GeoCentroid.class);
when(centroid.count()).thenReturn(count);
when(centroid.getName()).thenReturn(name);
GeoPoint point = count > 0 ? new GeoPoint(lat, lon) : null;
when(centroid.centroid()).thenReturn(point);
return centroid;
}

static NumericMetricsAggregation.SingleValue createSingleValue(String name, double value) {
NumericMetricsAggregation.SingleValue singleValue = mock(NumericMetricsAggregation.SingleValue.class);
when(singleValue.getName()).thenReturn(name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import static org.elasticsearch.xpack.ml.datafeed.extractor.aggregation.AggregationTestUtils.Term;
import static org.elasticsearch.xpack.ml.datafeed.extractor.aggregation.AggregationTestUtils.createAggs;
import static org.elasticsearch.xpack.ml.datafeed.extractor.aggregation.AggregationTestUtils.createGeoCentroid;
import static org.elasticsearch.xpack.ml.datafeed.extractor.aggregation.AggregationTestUtils.createHistogramAggregation;
import static org.elasticsearch.xpack.ml.datafeed.extractor.aggregation.AggregationTestUtils.createHistogramBucket;
import static org.elasticsearch.xpack.ml.datafeed.extractor.aggregation.AggregationTestUtils.createMax;
Expand Down Expand Up @@ -472,6 +473,20 @@ public void testSingleBucketAgg_failureWithSubMultiBucket() throws IOException {
() -> aggToString(Sets.newHashSet("my_field"), histogramBuckets));
}

public void testGeoCentroidAgg() throws IOException {
List<Histogram.Bucket> histogramBuckets = Arrays.asList(
createHistogramBucket(1000L, 4, Arrays.asList(
createMax("time", 1000),
createGeoCentroid("geo_field", 4, 92.1, 93.1))),
createHistogramBucket(2000L, 7, Arrays.asList(
createMax("time", 2000),
createGeoCentroid("geo_field", 0, -1, -1))));
String json = aggToString(Sets.newHashSet("geo_field"), histogramBuckets);

assertThat(json, equalTo("{\"time\":1000,\"geo_field\":\"92.1,93.1\",\"doc_count\":4}" +
" {\"time\":2000,\"doc_count\":7}"));
}

private String aggToString(Set<String> fields, Histogram.Bucket bucket) throws IOException {
return aggToString(fields, Collections.singletonList(bucket));
}
Expand Down
Loading

0 comments on commit 7905205

Please sign in to comment.