From 3c58e9bac919ffa3b7131a5c9b3ceec879b50c0e Mon Sep 17 00:00:00 2001 From: LiBinfeng <46676950+LiBinfeng-01@users.noreply.github.com> Date: Tue, 25 Jul 2023 10:05:00 +0800 Subject: [PATCH] [Fix](Nereids) Fix problem of infer predicates not completely (#22145) Problem: When inferring predicate in nereids, new inferred predicates can not be the source of next round. For example: create table tt1(c1 int, c2 int) distributed by hash(c1) properties('replication_num'='1'); create table tt2(c1 int, c2 int) distributed by hash(c1) properties('replication_num'='1'); create table tt3(c1 int, c2 int) distributed by hash(c1) properties('replication_num'='1'); explain select * from tt1 left join tt2 on tt1.c1 = tt2.c1 left join tt3 on tt2.c1 = tt3.c1 where tt1.c1 = 123; we expect to get t33.c1 = 123, but we can just get t22.c1 = 123. Because when infer tt1.c1 = 123 and tt2.c1 = tt3.c1, we can not get any relationship of these two predicates. Solution: We need to cache middle results of source predicates like t22.c1 = 123 in example. --- .../nereids/rules/rewrite/PredicatePropagation.java | 9 +++++++++ .../nereids_p0/infer_predicate/infer_predicate.groovy | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index cc45952817a845..f6f04e899bc2f3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -40,11 +40,17 @@ */ public class PredicatePropagation { + /** + * equal predicate with literal in one side would be chosen to be source predicates and used to infer all predicates + */ + private Set sourcePredicates = Sets.newHashSet(); + /** * infer additional predicates. */ public Set infer(Set predicates) { Set inferred = Sets.newHashSet(); + predicates.addAll(sourcePredicates); for (Expression predicate : predicates) { if (canEquivalentInfer(predicate)) { List newInferred = predicates.stream() @@ -55,6 +61,7 @@ public Set infer(Set predicates) { } } inferred.removeAll(predicates); + sourcePredicates.addAll(inferred); return inferred; } @@ -76,8 +83,10 @@ public Expression visit(Expression expr, Void context) { public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) { // we need to get expression covered by cast, because we want to infer different datatype if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) { + sourcePredicates.add(cp); return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left())); } else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) { + sourcePredicates.add(cp); return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right())); } return super.visit(cp, context); diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy index f93de7273b9f45..2ad1c250cabc09 100644 --- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy +++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy @@ -21,6 +21,7 @@ suite("test_infer_predicate") { sql 'drop table if exists infer_tb1;' sql 'drop table if exists infer_tb2;' + sql 'drop table if exists infer_tb3;' sql '''create table infer_tb1 (k1 int, k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');''' @@ -47,4 +48,10 @@ suite("test_infer_predicate") { sql "select * from infer_tb1 inner join infer_tb3 where infer_tb3.k1 = infer_tb1.k2 and infer_tb3.k1 = '123';" notContains "PREDICATES: k2[#6] = '123'" } + + explain { + sql "select * from infer_tb1 left join infer_tb2 on infer_tb1.k1 = infer_tb2.k3 left join infer_tb3 on " + + "infer_tb2.k3 = infer_tb3.k2 where infer_tb1.k1 = 1;" + contains "PREDICATES: k3[#4] = 1" + } }