From dee851dc079b3eedf4a481746f45db175dea6dc3 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 19 Sep 2022 14:47:01 +0800 Subject: [PATCH] expression: resize the result for IfXXSig (#37417) (#37431) close pingcap/tidb#37414 --- expression/builtin_control_vec_generated.go | 10 ++++++++++ expression/generator/control_vec.go | 2 ++ expression/integration_test.go | 14 ++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/expression/builtin_control_vec_generated.go b/expression/builtin_control_vec_generated.go index 8766dedbfa569..cc69b6d07e999 100644 --- a/expression/builtin_control_vec_generated.go +++ b/expression/builtin_control_vec_generated.go @@ -802,6 +802,7 @@ func (b *builtinCaseWhenJSONSig) vectorized() bool { func (b *builtinIfNullIntSig) fallbackEvalInt(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + result.ResizeInt64(n, false) x := result.Int64s() for i := 0; i < n; i++ { res, isNull, err := b.evalInt(input.GetRow(i)) @@ -856,6 +857,7 @@ func (b *builtinIfNullIntSig) vectorized() bool { func (b *builtinIfNullRealSig) fallbackEvalReal(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + result.ResizeFloat64(n, false) x := result.Float64s() for i := 0; i < n; i++ { res, isNull, err := b.evalReal(input.GetRow(i)) @@ -910,6 +912,7 @@ func (b *builtinIfNullRealSig) vectorized() bool { func (b *builtinIfNullDecimalSig) fallbackEvalDecimal(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + result.ResizeDecimal(n, false) x := result.Decimals() for i := 0; i < n; i++ { res, isNull, err := b.evalDecimal(input.GetRow(i)) @@ -1024,6 +1027,7 @@ func (b *builtinIfNullStringSig) vectorized() bool { func (b *builtinIfNullTimeSig) fallbackEvalTime(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + result.ResizeTime(n, false) x := result.Times() for i := 0; i < n; i++ { res, isNull, err := b.evalTime(input.GetRow(i)) @@ -1078,6 +1082,7 @@ func (b *builtinIfNullTimeSig) vectorized() bool { func (b *builtinIfNullDurationSig) fallbackEvalDuration(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + result.ResizeGoDuration(n, false) x := result.GoDurations() for i := 0; i < n; i++ { res, isNull, err := b.evalDuration(input.GetRow(i)) @@ -1192,6 +1197,7 @@ func (b *builtinIfNullJSONSig) vectorized() bool { func (b *builtinIfIntSig) fallbackEvalInt(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + result.ResizeInt64(n, false) x := result.Int64s() for i := 0; i < n; i++ { res, isNull, err := b.evalInt(input.GetRow(i)) @@ -1270,6 +1276,7 @@ func (b *builtinIfIntSig) vectorized() bool { func (b *builtinIfRealSig) fallbackEvalReal(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + result.ResizeFloat64(n, false) x := result.Float64s() for i := 0; i < n; i++ { res, isNull, err := b.evalReal(input.GetRow(i)) @@ -1348,6 +1355,7 @@ func (b *builtinIfRealSig) vectorized() bool { func (b *builtinIfDecimalSig) fallbackEvalDecimal(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + result.ResizeDecimal(n, false) x := result.Decimals() for i := 0; i < n; i++ { res, isNull, err := b.evalDecimal(input.GetRow(i)) @@ -1510,6 +1518,7 @@ func (b *builtinIfStringSig) vectorized() bool { func (b *builtinIfTimeSig) fallbackEvalTime(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + result.ResizeTime(n, false) x := result.Times() for i := 0; i < n; i++ { res, isNull, err := b.evalTime(input.GetRow(i)) @@ -1588,6 +1597,7 @@ func (b *builtinIfTimeSig) vectorized() bool { func (b *builtinIfDurationSig) fallbackEvalDuration(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() + result.ResizeGoDuration(n, false) x := result.GoDurations() for i := 0; i < n; i++ { res, isNull, err := b.evalDuration(input.GetRow(i)) diff --git a/expression/generator/control_vec.go b/expression/generator/control_vec.go index 628ec3b2fec1a..454931bde2ab2 100644 --- a/expression/generator/control_vec.go +++ b/expression/generator/control_vec.go @@ -230,6 +230,7 @@ var builtinIfNullVec = template.Must(template.New("builtinIfNullVec").Parse(` func (b *builtinIfNull{{ .TypeName }}Sig) fallbackEval{{ .TypeName }}(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() {{- if .Fixed }} + result.Resize{{ .TypeNameInColumn }}(n, false) x := result.{{ .TypeNameInColumn }}s() for i := 0; i < n; i++ { res, isNull, err := b.eval{{ .TypeName }}(input.GetRow(i)) @@ -345,6 +346,7 @@ var builtinIfVec = template.Must(template.New("builtinIfVec").Parse(` func (b *builtinIf{{ .TypeName }}Sig) fallbackEval{{ .TypeName }}(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() {{- if .Fixed }} + result.Resize{{ .TypeNameInColumn }}(n, false) x := result.{{ .TypeNameInColumn }}s() for i := 0; i < n; i++ { res, isNull, err := b.eval{{ .TypeName }}(input.GetRow(i)) diff --git a/expression/integration_test.go b/expression/integration_test.go index 1a45111155354..df0c1969d5dd8 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -6281,6 +6281,20 @@ func TestRedundantColumnResolve(t *testing.T) { tk.MustQuery("select t1.a, t2.a from t1 natural join t2").Check(testkit.Rows("1 1")) } +func TestIssue37414(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists foo") + tk.MustExec("drop table if exists bar") + tk.MustExec("create table foo(a decimal(65,0));") + tk.MustExec("create table bar(a decimal(65,0), b decimal(65,0));") + tk.MustExec("insert into bar values(0,0),(1,1),(2,2);") + tk.MustExec("insert into foo select if(b>0, if(a/b>1, 1, 2), null) from bar;") +} + func TestControlFunctionWithEnumOrSet(t *testing.T) { // issue 23114 store, clean := testkit.CreateMockStore(t)