diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index 32f7947a86f2a..3c2c8739b43d0 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -347,7 +347,14 @@ func (sf *ScalarFunction) Equal(ctx sessionctx.Context, e Expression) bool { if sf.FuncName.L != fun.FuncName.L { return false } +<<<<<<< HEAD return sf.Function.equal(fun.Function) +======= + if !sf.RetType.Equal(fun.RetType) { + return false + } + return sf.Function.equal(ctx, fun.Function) +>>>>>>> 1a24c032126 (expression: correct the erroneous scalar function equivalence check (#54067)) } // IsCorrelated implements Expression interface. diff --git a/pkg/expression/util_test.go b/pkg/expression/util_test.go index a275871663dfb..950b9d68514b5 100644 --- a/pkg/expression/util_test.go +++ b/pkg/expression/util_test.go @@ -257,8 +257,13 @@ func TestSubstituteCorCol2Constant(t *testing.T) { plus3 := newFunction(ast.Plus, plus2, col1) ret, err = SubstituteCorCol2Constant(plus3) require.NoError(t, err) +<<<<<<< HEAD ans3 := newFunction(ast.Plus, ans1, col1) require.True(t, ret.Equal(ctx, ans3)) +======= + ans3 := newFunctionWithMockCtx(ast.Plus, ans1, col1) + require.False(t, ret.Equal(ctx, ans3)) +>>>>>>> 1a24c032126 (expression: correct the erroneous scalar function equivalence check (#54067)) } func TestPushDownNot(t *testing.T) { diff --git a/pkg/parser/types/field_type.go b/pkg/parser/types/field_type.go index 2e80bbf3c7d2b..2befced6393bb 100644 --- a/pkg/parser/types/field_type.go +++ b/pkg/parser/types/field_type.go @@ -289,7 +289,7 @@ func (ft *FieldType) Equal(other *FieldType) bool { // because flen for them is useless. // The decimal field can be ignored if the type is int or string. tpEqual := (ft.GetType() == other.GetType()) || (ft.GetType() == mysql.TypeVarchar && other.GetType() == mysql.TypeVarString) || (ft.GetType() == mysql.TypeVarString && other.GetType() == mysql.TypeVarchar) - flenEqual := ft.flen == other.flen || (ft.EvalType() == ETReal && ft.decimal == UnspecifiedLength) + flenEqual := ft.flen == other.flen || (ft.EvalType() == ETReal && ft.decimal == UnspecifiedLength) || ft.EvalType() == ETJson ignoreDecimal := ft.EvalType() == ETInt || ft.EvalType() == ETString partialEqual := tpEqual && (ignoreDecimal || ft.decimal == other.decimal) && diff --git a/pkg/planner/core/issuetest/planner_issue_test.go b/pkg/planner/core/issuetest/planner_issue_test.go index be9e5ef29d971..772a66b644388 100644 --- a/pkg/planner/core/issuetest/planner_issue_test.go +++ b/pkg/planner/core/issuetest/planner_issue_test.go @@ -57,3 +57,31 @@ func TestIssue43461(t *testing.T) { require.NotEqual(t, is.Columns, ts.Columns) } + +func Test53726(t *testing.T) { + // test for RemoveUnnecessaryFirstRow + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t7(c int); ") + tk.MustExec("insert into t7 values (575932053), (-258025139);") + tk.MustQuery("select distinct cast(c as decimal), cast(c as signed) from t7"). + Sort().Check(testkit.Rows("-258025139 -258025139", "575932053 575932053")) + tk.MustQuery("explain select distinct cast(c as decimal), cast(c as signed) from t7"). + Check(testkit.Rows( + "HashAgg_8 8000.00 root group by:Column#7, Column#8, funcs:firstrow(Column#7)->Column#3, funcs:firstrow(Column#8)->Column#4", + "└─TableReader_9 8000.00 root data:HashAgg_4", + " └─HashAgg_4 8000.00 cop[tikv] group by:cast(test.t7.c, bigint(22) BINARY), cast(test.t7.c, decimal(10,0) BINARY), ", + " └─TableFullScan_7 10000.00 cop[tikv] table:t7 keep order:false, stats:pseudo")) + + tk.MustExec("analyze table t7") + tk.MustQuery("select distinct cast(c as decimal), cast(c as signed) from t7"). + Sort(). + Check(testkit.Rows("-258025139 -258025139", "575932053 575932053")) + tk.MustQuery("explain select distinct cast(c as decimal), cast(c as signed) from t7"). + Check(testkit.Rows( + "HashAgg_6 2.00 root group by:Column#13, Column#14, funcs:firstrow(Column#11)->Column#3, funcs:firstrow(Column#12)->Column#4", + "└─Projection_12 2.00 root cast(test.t7.c, decimal(10,0) BINARY)->Column#11, cast(test.t7.c, bigint(22) BINARY)->Column#12, cast(test.t7.c, decimal(10,0) BINARY)->Column#13, cast(test.t7.c, bigint(22) BINARY)->Column#14", + " └─TableReader_11 2.00 root data:TableFullScan_10", + " └─TableFullScan_10 2.00 cop[tikv] table:t7 keep order:false")) +}