Skip to content

Commit

Permalink
[Fix](Nereids) Fix problem of infer predicates not completely (apache…
Browse files Browse the repository at this point in the history
…#22145)

Problem:
When inferring predicate in nereids, new inferred predicates can not be the source of next round. For example:

create table tt1(c1 int, c2 int) distributed by hash(c1) properties('replication_num'='1');
create table tt2(c1 int, c2 int) distributed by hash(c1) properties('replication_num'='1');
create table tt3(c1 int, c2 int) distributed by hash(c1) properties('replication_num'='1');
explain select * from tt1 left join tt2 on tt1.c1 = tt2.c1 left join tt3 on tt2.c1 = tt3.c1 where tt1.c1 = 123;

we expect to get t33.c1 = 123, but we can just get t22.c1 = 123. Because when infer tt1.c1 = 123 and tt2.c1 = tt3.c1, we can
not get any relationship of these two predicates.

Solution:
We need to cache middle results of source predicates like t22.c1 = 123 in example.
  • Loading branch information
LiBinfeng-01 authored Jul 25, 2023
1 parent a0463ea commit 3c58e9b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@
*/
public class PredicatePropagation {

/**
* equal predicate with literal in one side would be chosen to be source predicates and used to infer all predicates
*/
private Set<Expression> sourcePredicates = Sets.newHashSet();

/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
predicates.addAll(sourcePredicates);
for (Expression predicate : predicates) {
if (canEquivalentInfer(predicate)) {
List<Expression> newInferred = predicates.stream()
Expand All @@ -55,6 +61,7 @@ public Set<Expression> infer(Set<Expression> predicates) {
}
}
inferred.removeAll(predicates);
sourcePredicates.addAll(inferred);
return inferred;
}

Expand All @@ -76,8 +83,10 @@ public Expression visit(Expression expr, Void context) {
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
// we need to get expression covered by cast, because we want to infer different datatype
if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) {
sourcePredicates.add(cp);
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
} else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) {
sourcePredicates.add(cp);
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
}
return super.visit(cp, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ suite("test_infer_predicate") {

sql 'drop table if exists infer_tb1;'
sql 'drop table if exists infer_tb2;'
sql 'drop table if exists infer_tb3;'

sql '''create table infer_tb1 (k1 int, k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

Expand All @@ -47,4 +48,10 @@ suite("test_infer_predicate") {
sql "select * from infer_tb1 inner join infer_tb3 where infer_tb3.k1 = infer_tb1.k2 and infer_tb3.k1 = '123';"
notContains "PREDICATES: k2[#6] = '123'"
}

explain {
sql "select * from infer_tb1 left join infer_tb2 on infer_tb1.k1 = infer_tb2.k3 left join infer_tb3 on " +
"infer_tb2.k3 = infer_tb3.k2 where infer_tb1.k1 = 1;"
contains "PREDICATES: k3[#4] = 1"
}
}

0 comments on commit 3c58e9b

Please sign in to comment.