Skip to content

Commit

Permalink
SQL: Convert ST_Distance into query when possible (#40595)
Browse files Browse the repository at this point in the history
* SQL: Convert ST_Distance into query when possible

Adds additional optimization logic to convert ST_Distance function
calls into geo_distance query when it is called in WHERE clauses.
  • Loading branch information
imotov authored Apr 2, 2019
1 parent b4231c9 commit ed0c0e6
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.geo.geometry.Geometry;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;

import java.io.IOException;
Expand Down Expand Up @@ -58,6 +59,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder.value(shapeBuilder.toWKT());
}

public Geometry toGeometry() {
return shapeBuilder.buildGeometry();
}

public static double distance(GeoShape shape1, GeoShape shape2) {
if (shape1.shapeBuilder instanceof PointBuilder == false) {
throw new SqlIllegalArgumentException("distance calculation is only supported for points; received [{}]", shape1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,32 @@
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.FieldAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataType;

import static org.elasticsearch.xpack.sql.expression.TypeResolutions.isGeo;
import static org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistanceProcessor.process;
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;

/**
* Calculates the distance between two points
*/
public class StDistance extends BinaryScalarFunction {
public class StDistance extends BinaryOperator<Object, Object, Double, StDistanceFunction> {

private static final StDistanceFunction FUNCTION = new StDistanceFunction();

public StDistance(Source source, Expression source1, Expression source2) {
super(source, source1, source2);
super(source, source1, source2, FUNCTION);
}

@Override
protected StDistance replaceChildren(Expression newLeft, Expression newRight) {
return new StDistance(source(), newLeft, newRight);
}

@Override
protected TypeResolution resolveType() {
if (!childrenResolved()) {
return new TypeResolution("Unresolved children");
}

TypeResolution resolution = isGeo(left(), functionName(), Expressions.ParamOrdinal.FIRST);
if (resolution.unresolved()) {
return resolution;
}

return isGeo(right(), functionName(), Expressions.ParamOrdinal.SECOND);
}

@Override
public DataType dataType() {
return DataType.DOUBLE;
Expand All @@ -66,8 +53,13 @@ public ScriptTemplate scriptWithField(FieldAttribute field) {
}

@Override
public Object fold() {
return process(left().fold(), right().fold());
protected TypeResolution resolveInputType(Expression e, Expressions.ParamOrdinal paramOrdinal) {
return isGeo(e, sourceText(), paramOrdinal);
}

@Override
public StDistance swapLeftAndRight() {
return new StDistance(source(), right(), left());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.sql.expression.function.scalar.geo;

import org.elasticsearch.xpack.sql.expression.predicate.PredicateBiFunction;

class StDistanceFunction implements PredicateBiFunction<Object, Object, Double> {

@Override
public String name() {
return "ST_DISTANCE";
}

@Override
public String symbol() {
return "ST_DISTANCE";
}

@Override
public Double doApply(Object s1, Object s2) {
return StDistanceProcessor.process(s1, s2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
*/
package org.elasticsearch.xpack.sql.planner;

import org.elasticsearch.geo.geometry.Geometry;
import org.elasticsearch.geo.geometry.Point;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.expression.Attribute;
Expand Down Expand Up @@ -38,6 +40,8 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeHistogramFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.geo.GeoShape;
import org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistance;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.literal.Intervals;
import org.elasticsearch.xpack.sql.expression.predicate.Range;
Expand Down Expand Up @@ -85,6 +89,7 @@
import org.elasticsearch.xpack.sql.querydsl.agg.TopHitsAgg;
import org.elasticsearch.xpack.sql.querydsl.query.BoolQuery;
import org.elasticsearch.xpack.sql.querydsl.query.ExistsQuery;
import org.elasticsearch.xpack.sql.querydsl.query.GeoDistanceQuery;
import org.elasticsearch.xpack.sql.querydsl.query.MatchQuery;
import org.elasticsearch.xpack.sql.querydsl.query.MultiMatchQuery;
import org.elasticsearch.xpack.sql.querydsl.query.NestedQuery;
Expand Down Expand Up @@ -656,6 +661,24 @@ private static Query translateQuery(BinaryComparison bc) {
Object value = valueOf(bc.right());
String format = dateFormat(bc.left());

// Possible geo optimization
if (bc.left() instanceof StDistance && value instanceof Number) {
if (bc instanceof LessThan || bc instanceof LessThanOrEqual) {
// Special case for ST_Distance translatable into geo_distance query
StDistance stDistance = (StDistance) bc.left();
if (stDistance.left() instanceof FieldAttribute && stDistance.right().foldable()) {
Object geoShape = valueOf(stDistance.right());
if (geoShape instanceof GeoShape) {
Geometry geometry = ((GeoShape) geoShape).toGeometry();
if (geometry instanceof Point) {
String field = nameOf(stDistance.left());
return new GeoDistanceQuery(source, field, ((Number) value).doubleValue(),
((Point) geometry).getLat(), ((Point) geometry).getLon());
}
}
}
}
}
if (bc instanceof GreaterThan) {
return new RangeQuery(source, name, value, false, null, false, format);
}
Expand Down Expand Up @@ -954,6 +977,9 @@ public QueryTranslation translate(Expression exp, boolean onAggs) {

protected static Query handleQuery(ScalarFunction sf, Expression field, Supplier<Query> query) {
Query q = query.get();
if (field instanceof StDistance && q instanceof GeoDistanceQuery) {
return wrapIfNested(q, ((StDistance) field).left());
}
if (field instanceof FieldAttribute) {
return wrapIfNested(q, field);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.querydsl.query;

import org.elasticsearch.common.unit.DistanceUnit;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.sql.tree.Source;

import java.util.Objects;

public class GeoDistanceQuery extends LeafQuery {

private final String field;
private final double lat;
private final double lon;
private final double distance;

public GeoDistanceQuery(Source source, String field, double distance, double lat, double lon) {
super(source);
this.field = field;
this.distance = distance;
this.lat = lat;
this.lon = lon;
}

public String field() {
return field;
}

public double lat() {
return lat;
}

public double lon() {
return lon;
}

public double distance() {
return distance;
}

@Override
public QueryBuilder asBuilder() {
return QueryBuilders.geoDistanceQuery(field).distance(distance, DistanceUnit.METERS).point(lat, lon);
}

@Override
public int hashCode() {
return Objects.hash(field, distance, lat, lon);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}

if (obj == null || getClass() != obj.getClass()) {
return false;
}

GeoDistanceQuery other = (GeoDistanceQuery) obj;
return Objects.equals(field, other.field) &&
Objects.equals(distance, other.distance) &&
Objects.equals(lat, other.lat) &&
Objects.equals(lon, other.lon);
}

@Override
protected String innerToString() {
return field + ":" + "(" + distance + "," + "(" + lat + ", " + lon + "))";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.IsoWeekOfYear;
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.MonthOfYear;
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.Year;
import org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistance;
import org.elasticsearch.xpack.sql.expression.function.scalar.math.ACos;
import org.elasticsearch.xpack.sql.expression.function.scalar.math.ASin;
import org.elasticsearch.xpack.sql.expression.function.scalar.math.ATan;
Expand Down Expand Up @@ -622,6 +623,15 @@ public void testLiteralsOnTheRight() {
assertEquals(FIVE, nullEquals.right());
}

public void testLiteralsOnTheRightInStDistance() {
Alias a = new Alias(EMPTY, "a", L(10));
Expression result = new BooleanLiteralsOnTheRight().rule(new StDistance(EMPTY, FIVE, a));
assertTrue(result instanceof StDistance);
StDistance sd = (StDistance) result;
assertEquals(a, sd.left());
assertEquals(FIVE, sd.right());
}

public void testBoolSimplifyNotIsNullAndNotIsNotNull() {
BooleanSimplification simplification = new BooleanSimplification();
assertTrue(simplification.rule(new Not(EMPTY, new IsNull(EMPTY, ONE))) instanceof IsNotNull);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.xpack.sql.querydsl.agg.GroupByDateHistogram;
import org.elasticsearch.xpack.sql.querydsl.query.BoolQuery;
import org.elasticsearch.xpack.sql.querydsl.query.ExistsQuery;
import org.elasticsearch.xpack.sql.querydsl.query.GeoDistanceQuery;
import org.elasticsearch.xpack.sql.querydsl.query.NotQuery;
import org.elasticsearch.xpack.sql.querydsl.query.Query;
import org.elasticsearch.xpack.sql.querydsl.query.RangeQuery;
Expand Down Expand Up @@ -614,22 +615,41 @@ public void testTranslateStWktToSql() {
assertEquals("[{v=keyword}, {v=point (10.0 20.0)}]", aggFilter.scriptTemplate().params().toString());
}

public void testTranslateStDistance() {
LogicalPlan p = plan("SELECT shape FROM test WHERE ST_Distance(shape, ST_WKTToSQL('point (10 20)')) > 20");
public void testTranslateStDistanceToScript() {
String operator = randomFrom(">", ">=");
String operatorFunction = operator.equalsIgnoreCase(">") ? "gt" : "gte";
LogicalPlan p = plan("SELECT shape FROM test WHERE ST_Distance(shape, ST_WKTToSQL('point (10 20)')) " + operator + " 20");
assertThat(p, instanceOf(Project.class));
assertThat(p.children().get(0), instanceOf(Filter.class));
Expression condition = ((Filter) p.children().get(0)).condition();
assertFalse(condition.foldable());
QueryTranslation translation = QueryTranslator.toQuery(condition, true);
assertNull(translation.query);
AggFilter aggFilter = translation.aggFilter;

QueryTranslation translation = QueryTranslator.toQuery(condition, false);
assertNull(translation.aggFilter);
assertTrue(translation.query instanceof ScriptQuery);
ScriptQuery sc = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(" +
"InternalSqlScriptUtils.gt(" +
"InternalSqlScriptUtils." + operatorFunction + "(" +
"InternalSqlScriptUtils.stDistance(" +
"InternalSqlScriptUtils.geoDocValue(doc,params.v0),InternalSqlScriptUtils.stWktToSql(params.v1)),params.v2))",
aggFilter.scriptTemplate().toString());
assertEquals("[{v=shape}, {v=point (10.0 20.0)}, {v=20}]", aggFilter.scriptTemplate().params().toString());
sc.script().toString());
assertEquals("[{v=shape}, {v=point (10.0 20.0)}, {v=20}]", sc.script().params().toString());
}

public void testTranslateStDistanceToQuery() {
String operator = randomFrom("<", "<=");
LogicalPlan p = plan("SELECT shape FROM test WHERE ST_Distance(shape, ST_WKTToSQL('point (10 20)')) " + operator + " 25");
assertThat(p, instanceOf(Project.class));
assertThat(p.children().get(0), instanceOf(Filter.class));
Expression condition = ((Filter) p.children().get(0)).condition();
assertFalse(condition.foldable());
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
assertNull(translation.aggFilter);
assertTrue(translation.query instanceof GeoDistanceQuery);
GeoDistanceQuery gq = (GeoDistanceQuery) translation.query;
assertEquals("shape", gq.field());
assertEquals(20.0, gq.lat(), 0.00001);
assertEquals(10.0, gq.lon(), 0.00001);
assertEquals(25.0, gq.distance(), 0.00001);
}

public void testTranslateCoalesce_GroupBy_Painless() {
Expand Down

0 comments on commit ed0c0e6

Please sign in to comment.