Skip to content

Commit

Permalink
SQL: Implement IN(value1, value2, ...) expression. (#34581)
Browse files Browse the repository at this point in the history
Implement the functionality to translate the
`field IN (value1, value2,...)` expressions to proper Lucene queries
or painless script or local processors depending on the use case.

The `IN` expression can be used in SELECT, WHERE and HAVING clauses.

Closes: #32955
  • Loading branch information
Marios Trivyzas authored Oct 23, 2018
1 parent 123f784 commit 4a8386f
Show file tree
Hide file tree
Showing 20 changed files with 727 additions and 80 deletions.
9 changes: 8 additions & 1 deletion docs/reference/sql/functions/operators.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[[sql-operators]]
=== Comparison Operators

Boolean operator for comparing one or two expressions.
Boolean operator for comparing against one or multiple expressions.

* Equality (`=`)

Expand Down Expand Up @@ -40,6 +40,13 @@ include-tagged::{sql-specs}/filter.sql-spec[whereBetween]
include-tagged::{sql-specs}/filter.sql-spec[whereIsNotNullAndIsNull]
--------------------------------------------------

* `IN (<value1>, <value2>, ...)`

["source","sql",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{sql-specs}/filter.sql-spec[whereWithInAndMultipleValues]
--------------------------------------------------

[[sql-operators-logical]]
=== Logical Operators

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,16 @@ public static DataType fromODBCType(String odbcType) {
public static DataType fromEsType(String esType) {
return DataType.valueOf(esType.toUpperCase(Locale.ROOT));
}

public boolean isCompatibleWith(DataType other) {
if (this == other) {
return true;
} else if (isString() && other.isString()) {
return true;
} else if (isNumeric() && other.isNumeric()) {
return true;
} else {
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.xpack.sql.expression.function.Functions;
import org.elasticsearch.xpack.sql.expression.function.Score;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
import org.elasticsearch.xpack.sql.plan.logical.Distinct;
import org.elasticsearch.xpack.sql.plan.logical.Filter;
Expand All @@ -40,7 +41,9 @@

import static java.lang.String.format;

abstract class Verifier {
final class Verifier {

private Verifier() {}

static class Failure {
private final Node<?> source;
Expand Down Expand Up @@ -188,6 +191,8 @@ static Collection<Failure> verify(LogicalPlan plan) {

Set<Failure> localFailures = new LinkedHashSet<>();

validateInExpression(p, localFailures);

if (!groupingFailures.contains(p)) {
checkGroupBy(p, localFailures, resolvedFunctions, groupingFailures);
}
Expand Down Expand Up @@ -488,4 +493,19 @@ private static void checkNestedUsedInGroupByOrHaving(LogicalPlan p, Set<Failure>
fail(nested.get(0), "HAVING isn't (yet) compatible with nested fields " + new AttributeSet(nested).names()));
}
}
}

private static void validateInExpression(LogicalPlan p, Set<Failure> localFailures) {
p.forEachExpressions(e ->
e.forEachUp((In in) -> {
DataType dt = in.value().dataType();
for (Expression value : in.list()) {
if (!in.value().dataType().isCompatibleWith(value.dataType())) {
localFailures.add(fail(value, "expected data type [%s], value provided is of type [%s]",
dt, value.dataType()));
return;
}
}
},
In.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ public static boolean nullable(List<? extends Expression> exps) {
return true;
}

public static boolean foldable(List<? extends Expression> exps) {
for (Expression exp : exps) {
if (!exp.foldable()) {
return false;
}
}
return true;
}

public static AttributeSet references(List<? extends Expression> exps) {
if (exps.isEmpty()) {
return AttributeSet.EMPTY;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.sql.expression.gen.pipeline;

import org.elasticsearch.xpack.sql.capabilities.Resolvable;
import org.elasticsearch.xpack.sql.execution.search.FieldExtraction;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.Expression;
Expand All @@ -24,7 +25,7 @@
* Is an {@code Add} operator with left {@code ABS} over an aggregate (MAX), and
* right being a {@code CAST} function.
*/
public abstract class Pipe extends Node<Pipe> implements FieldExtraction {
public abstract class Pipe extends Node<Pipe> implements FieldExtraction, Resolvable {

private final Expression expression;

Expand All @@ -37,8 +38,6 @@ public Expression expression() {
return expression;
}

public abstract boolean resolved();

public abstract Processor asProcessor();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,55 @@
*/
package org.elasticsearch.xpack.sql.expression.predicate;

import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.script.Params;
import org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptWeaver;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Comparisons;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InPipe;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.tree.NodeInfo;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.util.CollectionUtils;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.stream.Collectors;

public class In extends NamedExpression {
import static java.lang.String.format;
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;

public class In extends NamedExpression implements ScriptWeaver {

private final Expression value;
private final List<Expression> list;
private final boolean nullable, foldable;
private Attribute lazyAttribute;

public In(Location location, Expression value, List<Expression> list) {
super(location, null, CollectionUtils.combine(list, value), null);
this.value = value;
this.list = list;

this.nullable = children().stream().anyMatch(Expression::nullable);
this.foldable = children().stream().allMatch(Expression::foldable);
this.list = list.stream().distinct().collect(Collectors.toList());
}

@Override
protected NodeInfo<In> info() {
return NodeInfo.create(this, In::new, value(), list());
return NodeInfo.create(this, In::new, value, list);
}

@Override
public Expression replaceChildren(List<Expression> newChildren) {
if (newChildren.size() < 1) {
throw new IllegalArgumentException("expected one or more children but received [" + newChildren.size() + "]");
if (newChildren.size() < 2) {
throw new IllegalArgumentException("expected at least [2] children but received [" + newChildren.size() + "]");
}
return new In(location(), newChildren.get(newChildren.size() - 1), newChildren.subList(0, newChildren.size() - 1));
}
Expand All @@ -61,22 +73,75 @@ public DataType dataType() {

@Override
public boolean nullable() {
return nullable;
return Expressions.nullable(children());
}

@Override
public boolean foldable() {
return foldable;
return Expressions.foldable(children());
}

@Override
public Object fold() {
Object foldedLeftValue = value.fold();

for (Expression rightValue : list) {
Boolean compResult = Comparisons.eq(foldedLeftValue, rightValue.fold());
if (compResult != null && compResult) {
return true;
}
}
return false;
}

@Override
public String name() {
StringJoiner sj = new StringJoiner(", ", " IN(", ")");
list.forEach(e -> sj.add(Expressions.name(e)));
return Expressions.name(value) + sj.toString();
}

@Override
public Attribute toAttribute() {
throw new SqlIllegalArgumentException("not implemented yet");
if (lazyAttribute == null) {
lazyAttribute = new ScalarFunctionAttribute(location(), name(), dataType(), null,
false, id(), false, "IN", asScript(), null, asPipe());
}
return lazyAttribute;
}

@Override
public ScriptTemplate asScript() {
throw new SqlIllegalArgumentException("not implemented yet");
StringJoiner sj = new StringJoiner(" || ");
ScriptTemplate leftScript = asScript(value);
List<Params> rightParams = new ArrayList<>();
String scriptPrefix = leftScript + "==";
LinkedHashSet<Object> values = list.stream().map(Expression::fold).collect(Collectors.toCollection(LinkedHashSet::new));
for (Object valueFromList : values) {
if (valueFromList instanceof Expression) {
ScriptTemplate rightScript = asScript((Expression) valueFromList);
sj.add(scriptPrefix + rightScript.template());
rightParams.add(rightScript.params());
} else {
if (valueFromList instanceof String) {
sj.add(scriptPrefix + '"' + valueFromList + '"');
} else {
sj.add(scriptPrefix + valueFromList.toString());
}
}
}

ParamsBuilder paramsBuilder = paramsBuilder().script(leftScript.params());
for (Params p : rightParams) {
paramsBuilder = paramsBuilder.script(p);
}

return new ScriptTemplate(format(Locale.ROOT, "%s", sj.toString()), paramsBuilder.build(), dataType());
}

@Override
protected Pipe makePipe() {
return new InPipe(location(), this, children().stream().map(Expressions::pipe).collect(Collectors.toList()));
}

@Override
Expand All @@ -97,4 +162,4 @@ public boolean equals(Object obj) {
return Objects.equals(value, other.value)
&& Objects.equals(list, other.list);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
*/
package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison;

import java.util.Set;

/**
* Comparison utilities.
*/
abstract class Comparisons {
public final class Comparisons {

private Comparisons() {}

static Boolean eq(Object l, Object r) {
public static Boolean eq(Object l, Object r) {
Integer i = compare(l, r);
return i == null ? null : i.intValue() == 0;
}
Expand All @@ -35,6 +39,10 @@ static Boolean gte(Object l, Object r) {
return i == null ? null : i.intValue() >= 0;
}

static Boolean in(Object l, Set<Object> r) {
return r.contains(l);
}

/**
* Compares two expression arguments (typically Numbers), if possible.
* Otherwise returns null (the arguments are not comparable or at least
Expand Down Expand Up @@ -73,4 +81,4 @@ private static Integer compare(Number l, Number r) {

return Integer.valueOf(Integer.compare(l.intValue(), r.intValue()));
}
}
}
Loading

0 comments on commit 4a8386f

Please sign in to comment.