Skip to content

Commit

Permalink
SQL: Improve painless script generated from IN (#35055)
Browse files Browse the repository at this point in the history
Replace standard `||` and `==` painless operators with
new `in` method introduced in `InternalSqlScriptUtils`.
This allows the list of values to become a script variable
which is replaced each time with the list of  values provided
by the user.

Move In to the same package as InPipe & InProcessor

Follow up to #34750

Co-authored-by: Costin Leau <[email protected]>
  • Loading branch information
2 people authored and matriv committed Nov 1, 2018
1 parent d1038d6 commit dab6665
Show file tree
Hide file tree
Showing 19 changed files with 79 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.elasticsearch.xpack.sql.expression.function.Functions;
import org.elasticsearch.xpack.sql.expression.function.Score;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
import org.elasticsearch.xpack.sql.plan.logical.Distinct;
import org.elasticsearch.xpack.sql.plan.logical.Filter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,12 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
entries.add(new Entry(Processor.class, SubstringFunctionProcessor.NAME, SubstringFunctionProcessor::new));
return entries;
}
}

public static List<Object> process(List<Processor> processors, Object input) {
List<Object> values = new ArrayList<>(processors.size());
for (Processor p : processors) {
values.add(p.process(input));
}
return values;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.BinaryArithmeticProcessor.BinaryArithmeticOperation;
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.UnaryArithmeticProcessor.UnaryArithmeticOperation;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.regex.RegexProcessor.RegexOperation;
import org.elasticsearch.xpack.sql.util.StringUtils;

import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;

/**
Expand Down Expand Up @@ -113,6 +115,10 @@ public static Boolean notNull(Object expression) {
return IsNotNullProcessor.apply(expression);
}

public static Boolean in(Object value, List<Object> values) {
return InProcessor.apply(value, values);
}

//
// Regex
//
Expand Down Expand Up @@ -375,4 +381,4 @@ public static String substring(String s, Number start, Number length) {
public static String ucase(String s) {
return (String) StringOperation.UCASE.apply(s);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@ public static ScriptTemplate binaryMethod(String methodName, ScriptTemplate left
.build(),
dataType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.predicate;
package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison;

import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.script.Params;
import org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptWeaver;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Comparisons;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InPipe;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.type.DataType;
Expand All @@ -30,7 +27,6 @@
import java.util.StringJoiner;
import java.util.stream.Collectors;

import static java.lang.String.format;
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;

public class In extends NamedExpression implements ScriptWeaver {
Expand Down Expand Up @@ -84,24 +80,12 @@ public boolean foldable() {

@Override
public Boolean fold() {
if (value.dataType() == DataType.NULL) {
// Optimization for early return and Query folding to LocalExec
if (value.dataType() == DataType.NULL ||
list.size() == 1 && list.get(0).dataType() == DataType.NULL) {
return null;
}
if (list.size() == 1 && list.get(0).dataType() == DataType.NULL) {
return false;
}

Object foldedLeftValue = value.fold();
Boolean result = false;
for (Expression rightValue : list) {
Boolean compResult = Comparisons.eq(foldedLeftValue, rightValue.fold());
if (compResult == null) {
result = null;
} else if (compResult) {
return true;
}
}
return result;
return InProcessor.apply(value.fold(), Foldables.valuesOf(list, value.dataType()));
}

@Override
Expand All @@ -122,34 +106,18 @@ public Attribute toAttribute() {

@Override
public ScriptTemplate asScript() {
StringJoiner sj = new StringJoiner(" || ");
ScriptTemplate leftScript = asScript(value);
List<Params> rightParams = new ArrayList<>();
String scriptPrefix = leftScript + "==";
LinkedHashSet<Object> values = list.stream().map(Expression::fold).collect(Collectors.toCollection(LinkedHashSet::new));
for (Object valueFromList : values) {
// if checked against null => false
if (valueFromList != null) {
if (valueFromList instanceof Expression) {
ScriptTemplate rightScript = asScript((Expression) valueFromList);
sj.add(scriptPrefix + rightScript.template());
rightParams.add(rightScript.params());
} else {
if (valueFromList instanceof String) {
sj.add(scriptPrefix + '"' + valueFromList + '"');
} else {
sj.add(scriptPrefix + valueFromList.toString());
}
}
}
}

ParamsBuilder paramsBuilder = paramsBuilder().script(leftScript.params());
for (Params p : rightParams) {
paramsBuilder = paramsBuilder.script(p);
}

return new ScriptTemplate(format(Locale.ROOT, "%s", sj.toString()), paramsBuilder.build(), dataType());
// remove duplicates
List<Object> values = new ArrayList<>(new LinkedHashSet<>(Foldables.valuesOf(list, value.dataType())));
values.removeIf(Objects::isNull);

return new ScriptTemplate(
formatTemplate(String.format(Locale.ROOT, "{sql}.in(%s, {})", leftScript.template())),
paramsBuilder()
.script(leftScript.params())
.variable(values)
.build(),
dataType());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.sql.expression.function.scalar.Processors;
import org.elasticsearch.xpack.sql.expression.gen.processor.Processor;

import java.io.IOException;
Expand All @@ -19,7 +20,7 @@ public class InProcessor implements Processor {

private final List<Processor> processsors;

public InProcessor(List<Processor> processors) {
InProcessor(List<Processor> processors) {
this.processsors = processors;
}

Expand All @@ -40,14 +41,17 @@ public final void writeTo(StreamOutput out) throws IOException {
@Override
public Object process(Object input) {
Object leftValue = processsors.get(processsors.size() - 1).process(input);
Boolean result = false;
return apply(leftValue, Processors.process(processsors.subList(0, processsors.size() - 1), leftValue));
}

for (int i = 0; i < processsors.size() - 1; i++) {
Boolean compResult = Comparisons.eq(leftValue, processsors.get(i).process(input));
public static Boolean apply(Object input, List<Object> values) {
Boolean result = Boolean.FALSE;
for (Object v : values) {
Boolean compResult = Comparisons.eq(input, v);
if (compResult == null) {
result = null;
} else if (compResult) {
return true;
} else if (compResult == Boolean.TRUE) {
return Boolean.TRUE;
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator;
import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator.Negateable;
import org.elasticsearch.xpack.sql.expression.predicate.BinaryPredicate;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.IsNotNull;
import org.elasticsearch.xpack.sql.expression.predicate.Predicates;
import org.elasticsearch.xpack.sql.expression.predicate.Range;
Expand All @@ -50,6 +49,7 @@
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.elasticsearch.xpack.sql.expression.function.Function;
import org.elasticsearch.xpack.sql.expression.function.UnresolvedFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.expression.predicate.IsNotNull;
import org.elasticsearch.xpack.sql.expression.predicate.Range;
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.UnaryPipe;
import org.elasticsearch.xpack.sql.expression.gen.processor.Processor;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.sql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.sql.plan.physical.FilterExec;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeHistogramFunction;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.expression.predicate.IsNotNull;
import org.elasticsearch.xpack.sql.expression.predicate.Range;
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;

import java.util.Collections;
import java.util.LinkedHashSet;
Expand All @@ -27,7 +27,7 @@ public class TermsQuery extends LeafQuery {
public TermsQuery(Location location, String term, List<Expression> values) {
super(location);
this.term = term;
values.removeIf(e -> e.dataType() == DataType.NULL);
values.removeIf(e -> DataTypes.isNull(e.dataType()));
if (values.isEmpty()) {
this.values = Collections.emptySet();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ public static Conversion conversionFor(DataType from, DataType to) {
if (to == DataType.NULL) {
return Conversion.NULL;
}
if (from == DataType.NULL) {
return Conversion.NULL;
}

Conversion conversion = conversion(from, to);
if (conversion == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
Boolean lte(Object, Object)
Boolean gt(Object, Object)
Boolean gte(Object, Object)
Boolean in(Object, java.util.List)

#
# Logical
Expand Down Expand Up @@ -107,4 +108,4 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
String space(Number)
String substring(String, Number, Number)
String ucase(String)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.predicate;
package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.sql.expression.Literal;
import org.elasticsearch.xpack.sql.expression.function.scalar.Processors;
import org.elasticsearch.xpack.sql.expression.gen.processor.ConstantProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InProcessor;

import java.util.Arrays;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.predicate;
package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.sql.expression.Literal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.string.Ascii;
import org.elasticsearch.xpack.sql.expression.function.scalar.string.Repeat;
import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.expression.predicate.IsNotNull;
import org.elasticsearch.xpack.sql.expression.predicate.Range;
import org.elasticsearch.xpack.sql.expression.predicate.logical.And;
Expand Down Expand Up @@ -345,7 +345,7 @@ public void testConstantFoldingIn_LeftValueNotFoldable() {
public void testConstantFoldingIn_RightValueIsNull() {
In in = new In(EMPTY, getFieldAttribute(), Arrays.asList(NULL, NULL));
Literal result= (Literal) new ConstantFolding().rule(in);
assertEquals(false, result.value());
assertNull(result.value());
}

public void testConstantFoldingIn_LeftValueIsNull() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Map;
import java.util.TimeZone;

import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.core.StringStartsWith.startsWith;

public class QueryTranslatorTests extends ESTestCase {
Expand Down Expand Up @@ -208,12 +209,11 @@ public void testTranslateInExpression_WhereClause_Painless() {
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
assertNull(translation.aggFilter);
assertTrue(translation.query instanceof ScriptQuery);
ScriptQuery sq = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(" +
"InternalSqlScriptUtils.power(InternalSqlScriptUtils.docValue(doc,params.v0),params.v1)==10 || " +
"InternalSqlScriptUtils.power(InternalSqlScriptUtils.docValue(doc,params.v0),params.v1)==20)",
sq.script().toString());
assertEquals("[{v=int}, {v=2}]", sq.script().params().toString());
ScriptQuery sc = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(" +
"InternalSqlScriptUtils.power(InternalSqlScriptUtils.docValue(doc,params.v0),params.v1), params.v2))",
sc.script().toString());
assertEquals("[{v=int}, {v=2}, {v=[10.0, 20.0]}]", sc.script().params().toString());
}

public void testTranslateInExpression_HavingClause_Painless() {
Expand All @@ -225,9 +225,10 @@ public void testTranslateInExpression_HavingClause_Painless() {
QueryTranslation translation = QueryTranslator.toQuery(condition, true);
assertNull(translation.query);
AggFilter aggFilter = translation.aggFilter;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20)",
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))",
aggFilter.scriptTemplate().toString());
assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10, 20]}]"));
}

public void testTranslateInExpression_HavingClause_PainlessOneArg() {
Expand All @@ -239,9 +240,10 @@ public void testTranslateInExpression_HavingClause_PainlessOneArg() {
QueryTranslation translation = QueryTranslator.toQuery(condition, true);
assertNull(translation.query);
AggFilter aggFilter = translation.aggFilter;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10)",
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))",
aggFilter.scriptTemplate().toString());
assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10]}]"));

}

Expand All @@ -254,8 +256,9 @@ public void testTranslateInExpression_HavingClause_PainlessAndNullHandling() {
QueryTranslation translation = QueryTranslator.toQuery(condition, true);
assertNull(translation.query);
AggFilter aggFilter = translation.aggFilter;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20 || params.a0==30)",
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))",
aggFilter.scriptTemplate().toString());
assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10, 20, 30]}]"));
}
}
Loading

0 comments on commit dab6665

Please sign in to comment.