Skip to content

Commit

Permalink
SQL: Enable adding missing equals to bool fields as filters (#66252)
Browse files Browse the repository at this point in the history
* Enable the AddMissingEqualsToBoolField rule in SQL

This will enable QL's AnalyzerRules.AddMissingEqualsToBoolField to SQL's
analyzer as well.
  • Loading branch information
bpintea authored Dec 15, 2020
1 parent 7c69823 commit ea137f9
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
import org.elasticsearch.xpack.ql.rule.Rule;

import static java.util.Arrays.asList;
import static org.elasticsearch.xpack.ql.type.DataTypes.BOOLEAN;

public final class AnalyzerRules {

public static class AddMissingEqualsToBoolField extends AnalyzerRule<Filter> {

@Override
protected LogicalPlan rule(Filter filter) {
if (filter.resolved() == false) {
return filter;
}
// check the condition itself
Expression condition = replaceRawBoolFieldWithEquals(filter.condition());
// otherwise look for binary logic
Expand All @@ -39,7 +43,7 @@ protected LogicalPlan rule(Filter filter) {
}

private Expression replaceRawBoolFieldWithEquals(Expression e) {
if (e instanceof FieldAttribute) {
if (e instanceof FieldAttribute && e.dataType() == BOOLEAN) {
e = new Equals(e.source(), e, Literal.of(e, Boolean.TRUE));
}
return e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.xpack.ql.analyzer.AnalyzerRules.AddMissingEqualsToBoolField;
import org.elasticsearch.xpack.ql.capabilities.Resolvables;
import org.elasticsearch.xpack.ql.common.Failure;
import org.elasticsearch.xpack.ql.expression.Alias;
Expand Down Expand Up @@ -119,6 +120,7 @@ protected Iterable<RuleExecutor<LogicalPlan>.Batch> batches() {
);
Batch finish = new Batch("Finish Analysis",
new PruneSubqueryAliases(),
new AddMissingEqualsToBoolField(),
CleanAliases.INSTANCE
);
return Arrays.asList(substitution, resolution, finish);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.ql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.ql.index.EsIndex;
import org.elasticsearch.xpack.ql.index.IndexResolution;
Expand Down Expand Up @@ -88,13 +89,15 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static java.util.Arrays.asList;
import static org.elasticsearch.xpack.ql.expression.Literal.TRUE;
import static org.elasticsearch.xpack.ql.type.DataTypes.BOOLEAN;
import static org.elasticsearch.xpack.ql.type.DataTypes.DATETIME;
import static org.elasticsearch.xpack.ql.type.DataTypes.DOUBLE;
Expand Down Expand Up @@ -2459,4 +2462,61 @@ public void testReplaceSumWithStats() {
assertThat(eqe.queryContainer().toString().replaceAll("\\s+", ""), containsString("{\"stats\":{\"field\":\"int\"}}"));
}
}

public void testAddMissingEqualsToBoolField() {
LogicalPlan p = plan("SELECT bool FROM test WHERE bool");
assertTrue(p instanceof Project);

p = ((Project) p).child();
assertTrue(p instanceof Filter);

Expression condition = ((Filter) p).condition();
assertTrue(condition instanceof Equals);
Equals eq = (Equals) condition;

assertTrue(eq.left() instanceof FieldAttribute);
assertEquals("bool", ((FieldAttribute) eq.left()).name());

assertTrue(eq.right() instanceof Literal);
assertEquals(TRUE, eq.right());
}

public void testAddMissingEqualsToNestedBoolField() {
LogicalPlan p = plan("SELECT bool FROM test " +
"WHERE int > 1 and (bool or int < 2) or (int = 3 and bool) or (int = 4 and bool = false) or bool");
LogicalPlan expectedPlan = plan("SELECT bool FROM test " +
"WHERE int > 1 and (bool = true or int < 2) or (int = 3 and bool = true) or (int = 4 and bool = false) or bool = true");

assertTrue(p instanceof Project);
p = ((Project) p).child();
assertTrue(p instanceof Filter);
Expression condition = ((Filter) p).condition();

Expression expectedCondition = ((Filter) ((Project) expectedPlan).child()).condition();

List<Expression> expectedFields = expectedCondition.collect(x -> x instanceof FieldAttribute);
Set<Expression> expectedBools = expectedFields.stream()
.filter(x -> ((FieldAttribute) x).name().equals("bool")).collect(Collectors.toSet());
assertEquals(1, expectedBools.size());
Set<Expression> expectedInts = expectedFields.stream()
.filter(x -> ((FieldAttribute) x).name().equals("int")).collect(Collectors.toSet());
assertEquals(1, expectedInts.size());

condition = condition
.transformDown(x -> x.name().equals("bool") ? (FieldAttribute) expectedBools.toArray()[0] : x, FieldAttribute.class)
.transformDown(x -> x.name().equals("int") ? (FieldAttribute) expectedInts.toArray()[0] : x , FieldAttribute.class);

assertEquals(expectedCondition, condition);
}

public void testNotAddMissingEqualsToNonBoolField() {
LogicalPlan p = plan("SELECT bool FROM test WHERE " + randomFrom("int", "text", "keyword", "date"));
assertTrue(p instanceof Project);

p = ((Project) p).child();
assertTrue(p instanceof Filter);

Expression condition = ((Filter) p).condition();
assertTrue(condition instanceof FieldAttribute);
}
}

0 comments on commit ea137f9

Please sign in to comment.