Skip to content

Commit

Permalink
[fix](nereids)subquery unnesting get wrong result if correlated conju…
Browse files Browse the repository at this point in the history
…ncts is not slot_a = slot_b (apache#37644) (apache#37684)

pick from master apache#37644
  • Loading branch information
starocean999 authored Jul 17, 2024
1 parent 3bb5291 commit 971a7c8
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -74,6 +75,11 @@ public Expression visitNot(Not not, CascadesContext context) {

@Override
public Expression visitExistsSubquery(Exists exists, CascadesContext context) {
LogicalPlan queryPlan = exists.getQueryPlan();
// distinct is useless, remove it
if (queryPlan instanceof LogicalProject && ((LogicalProject) queryPlan).isDistinct()) {
exists = exists.withSubquery(((LogicalProject) queryPlan).withDistinct(false));
}
AnalyzedResult analyzedResult = analyzeSubquery(exists);
if (analyzedResult.rootIsLimitZero()) {
return BooleanLiteral.of(exists.isNot());
Expand All @@ -88,6 +94,11 @@ public Expression visitExistsSubquery(Exists exists, CascadesContext context) {

@Override
public Expression visitInSubquery(InSubquery expr, CascadesContext context) {
LogicalPlan queryPlan = expr.getQueryPlan();
// distinct is useless, remove it
if (queryPlan instanceof LogicalProject && ((LogicalProject) queryPlan).isDistinct()) {
expr = expr.withSubquery(((LogicalProject) queryPlan).withDistinct(false));
}
AnalyzedResult analyzedResult = analyzeSubquery(expr);

checkOutputColumn(analyzedResult.getLogicalPlan());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
Expand All @@ -29,8 +31,8 @@
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -75,12 +77,19 @@ public Rule build() {
}

List<NamedExpression> newAggOutput = new ArrayList<>(agg.getOutputExpressions());
List<Expression> newGroupby = Utils.getCorrelatedSlots(correlatedPredicate,
apply.getCorrelationSlot());
List<Expression> newGroupby = Utils.getUnCorrelatedExprs(correlatedPredicate, apply.getCorrelationSlot());
newGroupby.addAll(agg.getGroupByExpressions());
newAggOutput.addAll(newGroupby.stream()
.map(NamedExpression.class::cast)
.collect(ImmutableList.toImmutableList()));
Map<Expression, Slot> unCorrelatedExprToSlot = Maps.newHashMap();
for (Expression expression : newGroupby) {
if (expression instanceof Slot) {
newAggOutput.add((NamedExpression) expression);
} else {
Alias alias = new Alias(expression);
unCorrelatedExprToSlot.put(expression, alias.toSlot());
newAggOutput.add(alias);
}
}
correlatedPredicate = ExpressionUtils.replace(correlatedPredicate, unCorrelatedExprToSlot);
LogicalAggregate newAgg = new LogicalAggregate<>(
newGroupby, newAggOutput,
PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), filter.child()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ public LogicalProject<Plan> withProjectsAndChild(List<NamedExpression> projects,
return new LogicalProject<>(projects, excepts, isDistinct, canEliminate, child);
}

public LogicalProject<Plan> withDistinct(boolean isDistinct) {
return new LogicalProject<>(projects, excepts, isDistinct, canEliminate, child());
}

public boolean isDistinct() {
return isDistinct;
}
Expand Down
56 changes: 46 additions & 10 deletions fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.doris.nereids.util;

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;

Expand Down Expand Up @@ -147,18 +150,51 @@ public static String toSqlString(String planName, Object... variables) {
}

/**
* Get the correlated columns that belong to the subquery,
* that is, the correlated columns that can be resolved within the subquery.
* Get the unCorrelated exprs that belong to the subquery,
* that is, the unCorrelated exprs that can be resolved within the subquery.
* eg:
* select * from t1 where t1.a = (select sum(t2.b) from t2 where t1.c = t2.d));
* correlatedPredicates : t1.c = t2.d
* correlatedSlots : t1.c
* return t2.d
* select * from t1 where t1.a = (select sum(t2.b) from t2 where t1.c = abs(t2.d));
* correlatedPredicates : t1.c = abs(t2.d)
* unCorrelatedExprs : abs(t2.d)
* return abs(t2.d)
*/
public static List<Expression> getCorrelatedSlots(List<Expression> correlatedPredicates,
List<Expression> correlatedSlots) {
return ExpressionUtils.getInputSlotSet(correlatedPredicates).stream()
.filter(slot -> !correlatedSlots.contains(slot)).collect(Collectors.toList());
public static List<Expression> getUnCorrelatedExprs(List<Expression> correlatedPredicates,
List<Expression> correlatedSlots) {
List<Expression> unCorrelatedExprs = new ArrayList<>();
correlatedPredicates.forEach(predicate -> {
if (!(predicate instanceof BinaryExpression) && (!(predicate instanceof Not)
|| !(predicate.child(0) instanceof BinaryExpression))) {
throw new AnalysisException(
"Unsupported correlated subquery with correlated predicate "
+ predicate.toString());
}

BinaryExpression binaryExpression;
if (predicate instanceof Not) {
binaryExpression = (BinaryExpression) ((Not) predicate).child();
} else {
binaryExpression = (BinaryExpression) predicate;
}
Expression left = binaryExpression.left();
Expression right = binaryExpression.right();
Set<Slot> leftInputSlots = left.getInputSlots();
Set<Slot> rightInputSlots = right.getInputSlots();
boolean correlatedToLeft = !leftInputSlots.isEmpty()
&& leftInputSlots.stream().allMatch(correlatedSlots::contains)
&& rightInputSlots.stream().noneMatch(correlatedSlots::contains);
boolean correlatedToRight = !rightInputSlots.isEmpty()
&& rightInputSlots.stream().allMatch(correlatedSlots::contains)
&& leftInputSlots.stream().noneMatch(correlatedSlots::contains);
if (!correlatedToLeft && !correlatedToRight) {
throw new AnalysisException(
"Unsupported correlated subquery with correlated predicate " + predicate);
} else if (correlatedToLeft && !rightInputSlots.isEmpty()) {
unCorrelatedExprs.add(right);
} else if (correlatedToRight && !leftInputSlots.isEmpty()) {
unCorrelatedExprs.add(left);
}
});
return unCorrelatedExprs;
}

private static List<Expression> collectCorrelatedSlotsFromChildren(
Expand Down
54 changes: 54 additions & 0 deletions regression-test/data/nereids_syntax_p0/test_subquery_conjunct.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select_simple_scalar --
-2 -2
2 2
3 2

-- !select_complex_scalar --
2 2
3 2

-- !select_simple_in --
1 1
2 1

-- !select_complex_in --
1 1
2 1

-- !select_simple_not_in --
-2 -2
-1 -1
1 1
2 1
2 2
3 2

-- !select_complex_not_in --
-2 -2
-1 -1
1 1
2 1
2 2
3 2

-- !select_simple_exists --
-2 -2
2 2
3 2

-- !select_complex_exists --
2 2
3 2

-- !select_simple_not_exists --
-1 -1
1 1
2 1

-- !select_complex_not_exists --
-2 -2
-1 -1
1 1
2 1

Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ suite ("sub_query_diff_old_optimize") {
sql """
SELECT DISTINCT k1 FROM sub_query_diff_old_optimize_subquery1 i1 WHERE ((SELECT count(*) FROM sub_query_diff_old_optimize_subquery1 WHERE ((k1 = i1.k1) AND (k2 = 2)) or ((k2 = i1.k1) AND (k2 = 1)) ) > 0);
"""
exception "java.sql.SQLException: errCode = 2, detailMessage = Unexpected exception: scalar subquery's correlatedPredicates's operator must be EQ"
exception "Unsupported correlated subquery with correlated predicate"

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("test_subquery_conjunct") {
sql "set enable_nereids_planner=true"
sql "set enable_fallback_to_original_planner=false"
sql """drop table if exists subquery_conjunct_table;"""
sql """CREATE TABLE `subquery_conjunct_table` (
`id` INT NOT NULL,
`c1` INT NOT NULL
) ENGINE=OLAP
DUPLICATE KEY(`id`, `c1`)
DISTRIBUTED BY RANDOM BUCKETS AUTO
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);"""

sql """insert into subquery_conjunct_table values(1, 1),(2,2),(-1,-1),(-2,-2),(2,1),(3,2);"""
qt_select_simple_scalar """select * from subquery_conjunct_table t1 where abs(t1.c1) != (select sum(c1) from subquery_conjunct_table t2 where t2.c1 + t2.id = t1.c1) order by t1.id, t1.c1;"""
qt_select_complex_scalar """select * from subquery_conjunct_table t1 where abs(t1.c1) != (select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1 + t2.id) = t1.c1) order by t1.id, t1.c1;"""
qt_select_simple_in """select * from subquery_conjunct_table t1 where abs(t1.c1) in (select c1 from subquery_conjunct_table t2 where t2.c1 + t2.id -1 = t1.c1) order by t1.id, t1.c1;"""
qt_select_complex_in """select * from subquery_conjunct_table t1 where abs(t1.c1) in (select c1 from subquery_conjunct_table t2 where abs(t2.c1+ t2.id -1) = t1.c1) order by t1.id, t1.c1;"""
qt_select_simple_not_in """select * from subquery_conjunct_table t1 where abs(t1.c1) not in (select c1 from subquery_conjunct_table t2 where t2.c1 + t2.id = t1.c1) order by t1.id, t1.c1;"""
qt_select_complex_not_in """select * from subquery_conjunct_table t1 where abs(t1.c1) not in (select c1 from subquery_conjunct_table t2 where abs(t2.c1 + t2.id) = t1.c1) order by t1.id, t1.c1;"""
qt_select_simple_exists """select * from subquery_conjunct_table t1 where exists (select c1 from subquery_conjunct_table t2 where t2.c1 + t2.id = t1.c1) order by t1.id, t1.c1;"""
qt_select_complex_exists """select * from subquery_conjunct_table t1 where exists (select c1 from subquery_conjunct_table t2 where abs(t2.c1 + t2.id) = t1.c1) order by t1.id, t1.c1;"""
qt_select_simple_not_exists """select * from subquery_conjunct_table t1 where not exists (select c1 from subquery_conjunct_table t2 where t2.c1 + t2.id = t1.c1) order by t1.id, t1.c1;"""
qt_select_complex_not_exists """select * from subquery_conjunct_table t1 where not exists (select c1 from subquery_conjunct_table t2 where abs(t2.c1 + t2.id) = t1.c1) order by t1.id, t1.c1;"""
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) != (select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1) - t1.c1 = 0) order by t1.id; """
exception "Unsupported correlated subquery with correlated predicate"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) != ( select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1 -1) + t1.id = t1.c1) order by t1.id, t1.c1; """
exception "Unsupported correlated subquery with correlated predicate"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) != (select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1) > t1.c1) order by t1.id; """
exception "scalar subquery's correlatedPredicates's operator must be EQ"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) in (select sum(c1) from subquery_conjunct_table t2 where t2.c1 + 1 = t1.c1) order by t1.id, t1.c1; """
exception "Unsupported correlated subquery with grouping and/or aggregation"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) in (select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1) = t1.c1) order by t1.id, t1.c1; """
exception "Unsupported correlated subquery with grouping and/or aggregation"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) not in (select sum(c1) from subquery_conjunct_table t2 where t2.c1 + 1= t1.c1) order by t1.id, t1.c1; """
exception "Unsupported correlated subquery with grouping and/or aggregation"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) not in (select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1 -1) = t1.c1) order by t1.id, t1.c1; """
exception "Unsupported correlated subquery with grouping and/or aggregation"
}
}

0 comments on commit 971a7c8

Please sign in to comment.