From ed0c0e66b0397ba7e23544fd6a8b5a479c8fc602 Mon Sep 17 00:00:00 2001 From: Igor Motov Date: Tue, 2 Apr 2019 14:01:07 -0400 Subject: [PATCH] SQL: Convert ST_Distance into query when possible (#40595) * SQL: Convert ST_Distance into query when possible Adds additional optimization logic to convert ST_Distance function calls into geo_distance query when it is called in WHERE clauses. --- .../function/scalar/geo/GeoShape.java | 5 ++ .../function/scalar/geo/StDistance.java | 32 +++----- .../scalar/geo/StDistanceFunction.java | 27 +++++++ .../xpack/sql/planner/QueryTranslator.java | 26 +++++++ .../sql/querydsl/query/GeoDistanceQuery.java | 77 +++++++++++++++++++ .../xpack/sql/optimizer/OptimizerTests.java | 10 +++ .../sql/planner/QueryTranslatorTests.java | 38 ++++++--- 7 files changed, 186 insertions(+), 29 deletions(-) create mode 100644 x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StDistanceFunction.java create mode 100644 x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/GeoDistanceQuery.java diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/GeoShape.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/GeoShape.java index 582b84be52425..9ca0e1248e4da 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/GeoShape.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/GeoShape.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ToXContentFragment; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.geo.geometry.Geometry; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import java.io.IOException; @@ -58,6 +59,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder.value(shapeBuilder.toWKT()); } + public Geometry toGeometry() { + return shapeBuilder.buildGeometry(); + } + public static double distance(GeoShape shape1, GeoShape shape2) { if (shape1.shapeBuilder instanceof PointBuilder == false) { throw new SqlIllegalArgumentException("distance calculation is only supported for points; received [{}]", shape1); diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StDistance.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StDistance.java index 51d5e5ee02bef..fd14e90dd9d93 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StDistance.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StDistance.java @@ -9,24 +9,25 @@ import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.Expressions; import org.elasticsearch.xpack.sql.expression.FieldAttribute; -import org.elasticsearch.xpack.sql.expression.function.scalar.BinaryScalarFunction; import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; +import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator; import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.tree.Source; import org.elasticsearch.xpack.sql.type.DataType; import static org.elasticsearch.xpack.sql.expression.TypeResolutions.isGeo; -import static org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistanceProcessor.process; import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder; /** * Calculates the distance between two points */ -public class StDistance extends BinaryScalarFunction { +public class StDistance extends BinaryOperator { + + private static final StDistanceFunction FUNCTION = new StDistanceFunction(); public StDistance(Source source, Expression source1, Expression source2) { - super(source, source1, source2); + super(source, source1, source2, FUNCTION); } @Override @@ -34,20 +35,6 @@ protected StDistance replaceChildren(Expression newLeft, Expression newRight) { return new StDistance(source(), newLeft, newRight); } - @Override - protected TypeResolution resolveType() { - if (!childrenResolved()) { - return new TypeResolution("Unresolved children"); - } - - TypeResolution resolution = isGeo(left(), functionName(), Expressions.ParamOrdinal.FIRST); - if (resolution.unresolved()) { - return resolution; - } - - return isGeo(right(), functionName(), Expressions.ParamOrdinal.SECOND); - } - @Override public DataType dataType() { return DataType.DOUBLE; @@ -66,8 +53,13 @@ public ScriptTemplate scriptWithField(FieldAttribute field) { } @Override - public Object fold() { - return process(left().fold(), right().fold()); + protected TypeResolution resolveInputType(Expression e, Expressions.ParamOrdinal paramOrdinal) { + return isGeo(e, sourceText(), paramOrdinal); + } + + @Override + public StDistance swapLeftAndRight() { + return new StDistance(source(), right(), left()); } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StDistanceFunction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StDistanceFunction.java new file mode 100644 index 0000000000000..d1c15c1e2a1b2 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/geo/StDistanceFunction.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.sql.expression.function.scalar.geo; + +import org.elasticsearch.xpack.sql.expression.predicate.PredicateBiFunction; + +class StDistanceFunction implements PredicateBiFunction { + + @Override + public String name() { + return "ST_DISTANCE"; + } + + @Override + public String symbol() { + return "ST_DISTANCE"; + } + + @Override + public Double doApply(Object s1, Object s2) { + return StDistanceProcessor.process(s1, s2); + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java index 1ad5f812777b2..e34d94e187649 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.sql.planner; +import org.elasticsearch.geo.geometry.Geometry; +import org.elasticsearch.geo.geometry.Point; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.expression.Attribute; @@ -38,6 +40,8 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeFunction; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeHistogramFunction; +import org.elasticsearch.xpack.sql.expression.function.scalar.geo.GeoShape; +import org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistance; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.expression.literal.Intervals; import org.elasticsearch.xpack.sql.expression.predicate.Range; @@ -85,6 +89,7 @@ import org.elasticsearch.xpack.sql.querydsl.agg.TopHitsAgg; import org.elasticsearch.xpack.sql.querydsl.query.BoolQuery; import org.elasticsearch.xpack.sql.querydsl.query.ExistsQuery; +import org.elasticsearch.xpack.sql.querydsl.query.GeoDistanceQuery; import org.elasticsearch.xpack.sql.querydsl.query.MatchQuery; import org.elasticsearch.xpack.sql.querydsl.query.MultiMatchQuery; import org.elasticsearch.xpack.sql.querydsl.query.NestedQuery; @@ -656,6 +661,24 @@ private static Query translateQuery(BinaryComparison bc) { Object value = valueOf(bc.right()); String format = dateFormat(bc.left()); + // Possible geo optimization + if (bc.left() instanceof StDistance && value instanceof Number) { + if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { + // Special case for ST_Distance translatable into geo_distance query + StDistance stDistance = (StDistance) bc.left(); + if (stDistance.left() instanceof FieldAttribute && stDistance.right().foldable()) { + Object geoShape = valueOf(stDistance.right()); + if (geoShape instanceof GeoShape) { + Geometry geometry = ((GeoShape) geoShape).toGeometry(); + if (geometry instanceof Point) { + String field = nameOf(stDistance.left()); + return new GeoDistanceQuery(source, field, ((Number) value).doubleValue(), + ((Point) geometry).getLat(), ((Point) geometry).getLon()); + } + } + } + } + } if (bc instanceof GreaterThan) { return new RangeQuery(source, name, value, false, null, false, format); } @@ -954,6 +977,9 @@ public QueryTranslation translate(Expression exp, boolean onAggs) { protected static Query handleQuery(ScalarFunction sf, Expression field, Supplier query) { Query q = query.get(); + if (field instanceof StDistance && q instanceof GeoDistanceQuery) { + return wrapIfNested(q, ((StDistance) field).left()); + } if (field instanceof FieldAttribute) { return wrapIfNested(q, field); } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/GeoDistanceQuery.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/GeoDistanceQuery.java new file mode 100644 index 0000000000000..dd1a1171c1603 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/GeoDistanceQuery.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.sql.querydsl.query; + +import org.elasticsearch.common.unit.DistanceUnit; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.xpack.sql.tree.Source; + +import java.util.Objects; + +public class GeoDistanceQuery extends LeafQuery { + + private final String field; + private final double lat; + private final double lon; + private final double distance; + + public GeoDistanceQuery(Source source, String field, double distance, double lat, double lon) { + super(source); + this.field = field; + this.distance = distance; + this.lat = lat; + this.lon = lon; + } + + public String field() { + return field; + } + + public double lat() { + return lat; + } + + public double lon() { + return lon; + } + + public double distance() { + return distance; + } + + @Override + public QueryBuilder asBuilder() { + return QueryBuilders.geoDistanceQuery(field).distance(distance, DistanceUnit.METERS).point(lat, lon); + } + + @Override + public int hashCode() { + return Objects.hash(field, distance, lat, lon); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + GeoDistanceQuery other = (GeoDistanceQuery) obj; + return Objects.equals(field, other.field) && + Objects.equals(distance, other.distance) && + Objects.equals(lat, other.lat) && + Objects.equals(lon, other.lon); + } + + @Override + protected String innerToString() { + return field + ":" + "(" + distance + "," + "(" + lat + ", " + lon + "))"; + } +} diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java index 7c01fd8ff1560..2235cacab6385 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java @@ -31,6 +31,7 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.IsoWeekOfYear; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.MonthOfYear; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.Year; +import org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistance; import org.elasticsearch.xpack.sql.expression.function.scalar.math.ACos; import org.elasticsearch.xpack.sql.expression.function.scalar.math.ASin; import org.elasticsearch.xpack.sql.expression.function.scalar.math.ATan; @@ -622,6 +623,15 @@ public void testLiteralsOnTheRight() { assertEquals(FIVE, nullEquals.right()); } + public void testLiteralsOnTheRightInStDistance() { + Alias a = new Alias(EMPTY, "a", L(10)); + Expression result = new BooleanLiteralsOnTheRight().rule(new StDistance(EMPTY, FIVE, a)); + assertTrue(result instanceof StDistance); + StDistance sd = (StDistance) result; + assertEquals(a, sd.left()); + assertEquals(FIVE, sd.right()); + } + public void testBoolSimplifyNotIsNullAndNotIsNotNull() { BooleanSimplification simplification = new BooleanSimplification(); assertTrue(simplification.rule(new Not(EMPTY, new IsNull(EMPTY, ONE))) instanceof IsNotNull); diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java index 5ddbaa85c23e1..b83fa33ab8d56 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.sql.querydsl.agg.GroupByDateHistogram; import org.elasticsearch.xpack.sql.querydsl.query.BoolQuery; import org.elasticsearch.xpack.sql.querydsl.query.ExistsQuery; +import org.elasticsearch.xpack.sql.querydsl.query.GeoDistanceQuery; import org.elasticsearch.xpack.sql.querydsl.query.NotQuery; import org.elasticsearch.xpack.sql.querydsl.query.Query; import org.elasticsearch.xpack.sql.querydsl.query.RangeQuery; @@ -614,22 +615,41 @@ public void testTranslateStWktToSql() { assertEquals("[{v=keyword}, {v=point (10.0 20.0)}]", aggFilter.scriptTemplate().params().toString()); } - public void testTranslateStDistance() { - LogicalPlan p = plan("SELECT shape FROM test WHERE ST_Distance(shape, ST_WKTToSQL('point (10 20)')) > 20"); + public void testTranslateStDistanceToScript() { + String operator = randomFrom(">", ">="); + String operatorFunction = operator.equalsIgnoreCase(">") ? "gt" : "gte"; + LogicalPlan p = plan("SELECT shape FROM test WHERE ST_Distance(shape, ST_WKTToSQL('point (10 20)')) " + operator + " 20"); assertThat(p, instanceOf(Project.class)); assertThat(p.children().get(0), instanceOf(Filter.class)); Expression condition = ((Filter) p.children().get(0)).condition(); assertFalse(condition.foldable()); - QueryTranslation translation = QueryTranslator.toQuery(condition, true); - assertNull(translation.query); - AggFilter aggFilter = translation.aggFilter; - + QueryTranslation translation = QueryTranslator.toQuery(condition, false); + assertNull(translation.aggFilter); + assertTrue(translation.query instanceof ScriptQuery); + ScriptQuery sc = (ScriptQuery) translation.query; assertEquals("InternalSqlScriptUtils.nullSafeFilter(" + - "InternalSqlScriptUtils.gt(" + + "InternalSqlScriptUtils." + operatorFunction + "(" + "InternalSqlScriptUtils.stDistance(" + "InternalSqlScriptUtils.geoDocValue(doc,params.v0),InternalSqlScriptUtils.stWktToSql(params.v1)),params.v2))", - aggFilter.scriptTemplate().toString()); - assertEquals("[{v=shape}, {v=point (10.0 20.0)}, {v=20}]", aggFilter.scriptTemplate().params().toString()); + sc.script().toString()); + assertEquals("[{v=shape}, {v=point (10.0 20.0)}, {v=20}]", sc.script().params().toString()); + } + + public void testTranslateStDistanceToQuery() { + String operator = randomFrom("<", "<="); + LogicalPlan p = plan("SELECT shape FROM test WHERE ST_Distance(shape, ST_WKTToSQL('point (10 20)')) " + operator + " 25"); + assertThat(p, instanceOf(Project.class)); + assertThat(p.children().get(0), instanceOf(Filter.class)); + Expression condition = ((Filter) p.children().get(0)).condition(); + assertFalse(condition.foldable()); + QueryTranslation translation = QueryTranslator.toQuery(condition, false); + assertNull(translation.aggFilter); + assertTrue(translation.query instanceof GeoDistanceQuery); + GeoDistanceQuery gq = (GeoDistanceQuery) translation.query; + assertEquals("shape", gq.field()); + assertEquals(20.0, gq.lat(), 0.00001); + assertEquals(10.0, gq.lon(), 0.00001); + assertEquals(25.0, gq.distance(), 0.00001); } public void testTranslateCoalesce_GroupBy_Painless() {