From a3ec07b20d9b30de4efc7187fab7aa67c34f0482 Mon Sep 17 00:00:00 2001 From: Rohan Yadav Date: Wed, 7 Aug 2019 14:26:28 -0400 Subject: [PATCH] exec: Add framework for vectorized casts. This PR adds in a framework for performing casts within the vectorized engine, along with a few casts implemented already. Additionally, it implements a randomized testing framework for the cast operators. Release note: None --- Makefile | 2 + pkg/sql/distsqlrun/column_exec_setup.go | 12 + pkg/sql/exec/.gitignore | 1 + pkg/sql/exec/cast_test.go | 104 ++++++++ pkg/sql/exec/cast_tmpl.go | 171 +++++++++++++ pkg/sql/exec/execgen/cmd/execgen/cast_gen.go | 52 ++++ pkg/sql/exec/execgen/cmd/execgen/overloads.go | 232 ++++++++++++++++++ .../logictest/testdata/logic_test/vectorize | 12 + pkg/sql/sem/tree/datum.go | 20 ++ 9 files changed, 606 insertions(+) create mode 100644 pkg/sql/exec/cast_test.go create mode 100644 pkg/sql/exec/cast_tmpl.go create mode 100644 pkg/sql/exec/execgen/cmd/execgen/cast_gen.go diff --git a/Makefile b/Makefile index c27b0feaa080..b64d87d96656 100644 --- a/Makefile +++ b/Makefile @@ -788,6 +788,7 @@ EXECGEN_TARGETS = \ pkg/col/coldata/vec.eg.go \ pkg/sql/exec/any_not_null_agg.eg.go \ pkg/sql/exec/avg_agg.eg.go \ + pkg/sql/exec/cast.eg.go \ pkg/sql/exec/const.eg.go \ pkg/sql/exec/distinct.eg.go \ pkg/sql/exec/hashjoiner.eg.go \ @@ -1460,6 +1461,7 @@ $(SETTINGS_DOC_PAGE): $(settings-doc-gen) pkg/col/coldata/vec.eg.go: pkg/col/coldata/vec_tmpl.go pkg/sql/exec/any_not_null_agg.eg.go: pkg/sql/exec/any_not_null_agg_tmpl.go pkg/sql/exec/avg_agg.eg.go: pkg/sql/exec/avg_agg_tmpl.go +pkg/sql/exec/cast.eg.go: pkg/sql/exec/cast_tmpl.go pkg/sql/exec/const.eg.go: pkg/sql/exec/const_tmpl.go pkg/sql/exec/distinct.eg.go: pkg/sql/exec/distinct_tmpl.go pkg/sql/exec/hashjoiner.eg.go: pkg/sql/exec/hashjoiner_tmpl.go diff --git a/pkg/sql/distsqlrun/column_exec_setup.go b/pkg/sql/distsqlrun/column_exec_setup.go index 20709ee5e7ed..a7832b5e5e00 100644 --- a/pkg/sql/distsqlrun/column_exec_setup.go +++ b/pkg/sql/distsqlrun/column_exec_setup.go @@ -768,6 +768,18 @@ func planProjectionOperators( return planProjectionExpr(ctx, t.Operator, t.ResolvedType(), t.TypedLeft(), t.TypedRight(), columnTypes, input) case *tree.BinaryExpr: return planProjectionExpr(ctx, t.Operator, t.ResolvedType(), t.TypedLeft(), t.TypedRight(), columnTypes, input) + case *tree.CastExpr: + op, resultIdx, ct, memUsed, err = planProjectionOperators(ctx, t.Expr.(tree.TypedExpr), columnTypes, input) + if err != nil { + return nil, 0, nil, 0, err + } + outputIdx := len(ct) + op, err = exec.GetCastOperator(op, resultIdx, outputIdx, t.Expr.(tree.TypedExpr).ResolvedType(), t.Type) + ct = append(ct, *t.Type) + if sMem, ok := op.(exec.StaticMemoryOperator); ok { + memUsed += sMem.EstimateStaticMemoryUsage() + } + return op, outputIdx, ct, memUsed, err case *tree.FuncExpr: var ( inputCols []int diff --git a/pkg/sql/exec/.gitignore b/pkg/sql/exec/.gitignore index c992fba1f2a2..713df1c34033 100644 --- a/pkg/sql/exec/.gitignore +++ b/pkg/sql/exec/.gitignore @@ -1,5 +1,6 @@ any_not_null_agg.eg.go avg_agg.eg.go +cast.eg.go const.eg.go distinct.eg.go hashjoiner.eg.go diff --git a/pkg/sql/exec/cast_test.go b/pkg/sql/exec/cast_test.go new file mode 100644 index 000000000000..b18a6eb3d1cd --- /dev/null +++ b/pkg/sql/exec/cast_test.go @@ -0,0 +1,104 @@ +// Copyright 2019 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package exec + +import ( + "fmt" + "testing" + + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/randutil" +) + +func TestRandomizedCast(t *testing.T) { + + datumAsBool := func(d tree.Datum) interface{} { + return bool(tree.MustBeDBool(d)) + } + datumAsInt := func(d tree.Datum) interface{} { + return int(tree.MustBeDInt(d)) + } + datumAsFloat := func(d tree.Datum) interface{} { + return float64(tree.MustBeDFloat(d)) + } + datumAsDecimal := func(d tree.Datum) interface{} { + return tree.MustBeDDecimal(d).Decimal + } + + tc := []struct { + fromTyp *types.T + fromPhysType func(tree.Datum) interface{} + toTyp *types.T + toPhysType func(tree.Datum) interface{} + // Some types casting can fail, so retry if we + // generate a datum that is unable to be casted. + retryGeneration bool + }{ + //bool -> t tests + {types.Bool, datumAsBool, types.Bool, datumAsBool, false}, + {types.Bool, datumAsBool, types.Int, datumAsInt, false}, + {types.Bool, datumAsBool, types.Float, datumAsFloat, false}, + // decimal -> t tests + {types.Decimal, datumAsDecimal, types.Bool, datumAsBool, false}, + // int -> t tests + {types.Int, datumAsInt, types.Bool, datumAsBool, false}, + {types.Int, datumAsInt, types.Float, datumAsFloat, false}, + {types.Int, datumAsInt, types.Decimal, datumAsDecimal, false}, + // float -> t tests + {types.Float, datumAsFloat, types.Bool, datumAsBool, false}, + // We can sometimes generate a float outside of the range of the integers, + // so we want to retry with generation if that occurs. + {types.Float, datumAsFloat, types.Int, datumAsInt, true}, + {types.Float, datumAsFloat, types.Decimal, datumAsDecimal, false}, + } + + evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) + rng, _ := randutil.NewPseudoRand() + + for _, c := range tc { + t.Run(fmt.Sprintf("%sTo%s", c.fromTyp.String(), c.toTyp.String()), func(t *testing.T) { + n := 100 + // Make an input vector of length n. + input := tuples{} + output := tuples{} + for i := 0; i < n; i++ { + // We don't allow any NULL datums to be generated, so disable + // this ability in the RandDatum function. + fromDatum := sqlbase.RandDatum(rng, c.fromTyp, false) + var ( + toDatum tree.Datum + err error + ) + toDatum, err = tree.PerformCast(evalCtx, fromDatum, c.toTyp) + if c.retryGeneration { + for err != nil { + // If we are allowed to retry, make a new datum and cast it on error. + fromDatum = sqlbase.RandDatum(rng, c.fromTyp, false) + toDatum, err = tree.PerformCast(evalCtx, fromDatum, c.toTyp) + } + } else { + if err != nil { + t.Fatal(err) + } + } + input = append(input, tuple{c.fromPhysType(fromDatum)}) + output = append(output, tuple{c.toPhysType(toDatum)}) + } + runTests(t, []tuples{input}, output, orderedVerifier, []int{1}, + func(input []Operator) (Operator, error) { + return GetCastOperator(input[0], 0 /* inputIdx*/, 1 /* resultIdx */, c.fromTyp, c.toTyp) + }) + }) + } +} diff --git a/pkg/sql/exec/cast_tmpl.go b/pkg/sql/exec/cast_tmpl.go new file mode 100644 index 000000000000..59d20754f40c --- /dev/null +++ b/pkg/sql/exec/cast_tmpl.go @@ -0,0 +1,171 @@ +// Copyright 2019 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// {{/* +// +build execgen_template +// +// This file is the execgen template for cast.eg.go. It's formatted in a +// special way, so it's both valid Go and a valid text/template input. This +// permits editing this file with editor support. +// +// */}} + +package exec + +import ( + "context" + "math" + + "github.com/cockroachdb/apd" + "github.com/cockroachdb/cockroach/pkg/col/coldata" + "github.com/cockroachdb/cockroach/pkg/col/coltypes" + "github.com/cockroachdb/cockroach/pkg/sql/exec/execerror" + "github.com/cockroachdb/cockroach/pkg/sql/exec/execgen" + "github.com/cockroachdb/cockroach/pkg/sql/exec/typeconv" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + semtypes "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/pkg/errors" +) + +// {{/* + +type _ALLTYPES interface{} +type _OVERLOADTYPES interface{} +type _TOTYPE interface{} +type _GOTYPE interface{} +type _FROMTYPE interface{} + +var _ apd.Decimal +var _ = math.MaxInt8 +var _ tree.Datum + +func _ASSIGN_CAST(to, from interface{}) { + execerror.VectorizedInternalPanic("") +} + +// */}} + +// Use execgen package to remove unused import warning. +var _ interface{} = execgen.GET + +func GetCastOperator( + input Operator, colIdx int, resultIdx int, fromType *semtypes.T, toType *semtypes.T, +) (Operator, error) { + switch from := typeconv.FromColumnType(fromType); from { + // {{ range $typ, $overloads := . }} + case coltypes._ALLTYPES: + switch to := typeconv.FromColumnType(toType); to { + // {{ range $overloads }} + // {{ if isCastFuncSet . }} + case coltypes._OVERLOADTYPES: + return &castOp_FROMTYPE_TOTYPE{ + OneInputNode: NewOneInputNode(input), + colIdx: colIdx, + outputIdx: resultIdx, + fromType: from, + toType: to, + }, nil + // {{end}} + // {{end}} + default: + return nil, errors.Errorf("unhandled cast FROM -> TO type: %s -> %s", from, to) + } + // {{end}} + default: + return nil, errors.Errorf("unhandled FROM type: %s", from) + } +} + +// {{ range $typ, $overloads := . }} +// {{ range $overloads }} +// {{ if isCastFuncSet . }} + +type castOp_FROMTYPE_TOTYPE struct { + OneInputNode + colIdx int + outputIdx int + fromType coltypes.T + toType coltypes.T +} + +var _ StaticMemoryOperator = &castOp_FROMTYPE_TOTYPE{} + +func (c *castOp_FROMTYPE_TOTYPE) EstimateStaticMemoryUsage() int { + return EstimateBatchSizeBytes([]coltypes.T{c.toType}, coldata.BatchSize) +} + +func (c *castOp_FROMTYPE_TOTYPE) Init() { + c.input.Init() +} + +func (c *castOp_FROMTYPE_TOTYPE) Next(ctx context.Context) coldata.Batch { + batch := c.input.Next(ctx) + n := batch.Length() + if n == 0 { + return batch + } + if c.outputIdx == batch.Width() { + batch.AppendCol(coltypes._TOTYPE) + } + vec := batch.ColVec(c.colIdx) + col := vec._FROMTYPE() + projVec := batch.ColVec(c.outputIdx) + projCol := projVec._TOTYPE() + if vec.MaybeHasNulls() { + vecNulls := vec.Nulls() + projNulls := projVec.Nulls() + if sel := batch.Selection(); sel != nil { + sel = sel[:n] + for _, i := range sel { + if vecNulls.NullAt(i) { + projNulls.SetNull(i) + } else { + v := execgen.GET(col, int(i)) + var r _GOTYPE + _ASSIGN_CAST(r, v) + execgen.SET(projCol, int(i), r) + } + } + } else { + for execgen.RANGE(i, col) { + if vecNulls.NullAt(uint16(i)) { + projNulls.SetNull(uint16(i)) + } else { + v := execgen.GET(col, i) + var r _GOTYPE + _ASSIGN_CAST(r, v) + execgen.SET(projCol, int(i), r) + } + } + } + } else { + if sel := batch.Selection(); sel != nil { + sel = sel[:n] + for _, i := range sel { + v := execgen.GET(col, int(i)) + var r _GOTYPE + _ASSIGN_CAST(r, v) + execgen.SET(projCol, int(i), r) + } + } else { + for execgen.RANGE(i, col) { + v := execgen.GET(col, i) + var r _GOTYPE + _ASSIGN_CAST(r, v) + execgen.SET(projCol, int(i), r) + } + } + } + return batch +} + +// {{end}} +// {{end}} +// {{end}} diff --git a/pkg/sql/exec/execgen/cmd/execgen/cast_gen.go b/pkg/sql/exec/execgen/cmd/execgen/cast_gen.go new file mode 100644 index 000000000000..db27fd089dd8 --- /dev/null +++ b/pkg/sql/exec/execgen/cmd/execgen/cast_gen.go @@ -0,0 +1,52 @@ +// Copyright 2019 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package main + +import ( + "io" + "io/ioutil" + "strings" + "text/template" +) + +func genCastOperators(wr io.Writer) error { + t, err := ioutil.ReadFile("pkg/sql/exec/cast_tmpl.go") + if err != nil { + return err + } + + s := string(t) + + assignCast := makeFunctionRegex("_ASSIGN_CAST", 2) + s = assignCast.ReplaceAllString(s, `{{.Assign "$1" "$2"}}`) + s = strings.Replace(s, "_ALLTYPES", "{{$typ}}", -1) + s = strings.Replace(s, "_OVERLOADTYPES", "{{.ToTyp}}", -1) + s = strings.Replace(s, "_FROMTYPE", "{{.FromTyp}}", -1) + s = strings.Replace(s, "_TOTYPE", "{{.ToTyp}}", -1) + s = strings.Replace(s, "_GOTYPE", "{{.ToGoTyp}}", -1) + + s = replaceManipulationFuncs(".FromTyp", s) + + isCastFuncSet := func(ov castOverload) bool { + return ov.AssignFunc != nil + } + + tmpl, err := template.New("cast").Funcs(template.FuncMap{"isCastFuncSet": isCastFuncSet}).Parse(s) + if err != nil { + return err + } + + return tmpl.Execute(wr, castOverloads) +} + +func init() { + registerGenerator(genCastOperators, "cast.eg.go") +} diff --git a/pkg/sql/exec/execgen/cmd/execgen/overloads.go b/pkg/sql/exec/execgen/cmd/execgen/overloads.go index 484176b26026..ef748bdb6b05 100644 --- a/pkg/sql/exec/execgen/cmd/execgen/overloads.go +++ b/pkg/sql/exec/execgen/cmd/execgen/overloads.go @@ -149,6 +149,72 @@ func (o overload) UnaryAssign(target, v string) string { return fmt.Sprintf("%s = %s(%s)", target, o.OpStr, v) } +type castOverload struct { + FromTyp coltypes.T + ToTyp coltypes.T + ToGoTyp string + AssignFunc castAssignFunc +} + +func (o castOverload) Assign(to, from string) string { + return o.AssignFunc(to, from) +} + +type castAssignFunc func(to, from string) string + +func castIdentity(to, from string) string { + return fmt.Sprintf("%s = %s", to, from) +} + +func intToDecimal(to, from string) string { + convStr := ` + %[1]s = *apd.New(int64(%[2]s), 0) + ` + return fmt.Sprintf(convStr, to, from) +} + +func intToFloat(floatSize int) func(string, string) string { + return func(to, from string) string { + convStr := ` + %[1]s = float%[3]d(%[2]s) + ` + return fmt.Sprintf(convStr, to, from, floatSize) + } +} + +func floatToInt(intSize int, floatSize int) func(string, string) string { + return func(to, from string) string { + convStr := ` + if math.IsNaN(float64(%[2]s)) || %[2]s <= float%[4]d(math.MinInt%[3]d) || %[2]s >= float%[4]d(math.MaxInt%[3]d) { + execerror.NonVectorizedPanic(tree.ErrIntOutOfRange) + } + %[1]s = int%[3]d(%[2]s) + ` + return fmt.Sprintf(convStr, to, from, intSize, floatSize) + } +} + +func numToBool(to, from string) string { + convStr := ` + %[1]s = %[2]s != 0 + ` + return fmt.Sprintf(convStr, to, from) +} + +func floatToDecimal(to, from string) string { + convStr := ` + var tmpDec apd.Decimal + _, tmpErr := tmpDec.SetFloat64(float64(%[2]s)) + if tmpErr != nil { + execerror.NonVectorizedPanic(tmpErr) + } + %[1]s = tmpDec + ` + return fmt.Sprintf(convStr, to, from) +} + +var castOverloads map[coltypes.T][]castOverload + func init() { registerTypeCustomizers() @@ -248,6 +314,171 @@ func init() { } } } + + // Build cast overloads. We omit cases of type casts that we do not support. + castOverloads = make(map[coltypes.T][]castOverload) + for _, from := range inputTypes { + switch from { + case coltypes.Bool: + for _, to := range inputTypes { + ov := castOverload{FromTyp: from, ToTyp: to, ToGoTyp: to.GoTypeName()} + switch to { + case coltypes.Bool: + ov.AssignFunc = castIdentity + case coltypes.Int8, coltypes.Int16, coltypes.Int32, + coltypes.Int64, coltypes.Float32, coltypes.Float64: + ov.AssignFunc = func(to, from string) string { + convStr := ` + %[1]s = 0 + if %[2]s { + %[1]s = 1 + } + ` + return fmt.Sprintf(convStr, to, from) + } + } + castOverloads[from] = append(castOverloads[from], ov) + } + case coltypes.Bytes: + // TODO (rohany): It's unclear what to do here in the bytes case. + // There are different conversion rules for the multiple things + // that a bytes type can implemented, but we don't know each of the + // things is contained here. Additionally, we don't really know + // what to do even if it is a bytes to bytes operation here. + for _, to := range inputTypes { + ov := castOverload{FromTyp: from, ToTyp: to, ToGoTyp: to.GoTypeName()} + switch to { + } + castOverloads[from] = append(castOverloads[from], ov) + } + case coltypes.Decimal: + for _, to := range inputTypes { + ov := castOverload{FromTyp: from, ToTyp: to, ToGoTyp: to.GoTypeName()} + switch to { + case coltypes.Bool: + ov.AssignFunc = func(to, from string) string { + convStr := ` + %[1]s = %[2]s.Sign() != 0 + ` + return fmt.Sprintf(convStr, to, from) + } + case coltypes.Decimal: + ov.AssignFunc = castIdentity + } + castOverloads[from] = append(castOverloads[from], ov) + } + case coltypes.Int8: + for _, to := range inputTypes { + ov := castOverload{FromTyp: from, ToTyp: to, ToGoTyp: to.GoTypeName()} + switch to { + case coltypes.Bool: + ov.AssignFunc = numToBool + case coltypes.Decimal: + ov.AssignFunc = intToDecimal + case coltypes.Int8: + ov.AssignFunc = castIdentity + case coltypes.Float32: + ov.AssignFunc = intToFloat(32) + case coltypes.Float64: + ov.AssignFunc = intToFloat(64) + } + castOverloads[from] = append(castOverloads[from], ov) + } + case coltypes.Int16: + for _, to := range inputTypes { + ov := castOverload{FromTyp: from, ToTyp: to, ToGoTyp: to.GoTypeName()} + switch to { + case coltypes.Bool: + ov.AssignFunc = numToBool + case coltypes.Decimal: + ov.AssignFunc = intToDecimal + case coltypes.Int16: + ov.AssignFunc = castIdentity + case coltypes.Float32: + ov.AssignFunc = intToFloat(32) + case coltypes.Float64: + ov.AssignFunc = intToFloat(64) + } + castOverloads[from] = append(castOverloads[from], ov) + } + case coltypes.Int32: + for _, to := range inputTypes { + ov := castOverload{FromTyp: from, ToTyp: to, ToGoTyp: to.GoTypeName()} + switch to { + case coltypes.Bool: + ov.AssignFunc = numToBool + case coltypes.Decimal: + ov.AssignFunc = intToDecimal + case coltypes.Int32: + ov.AssignFunc = castIdentity + case coltypes.Float32: + ov.AssignFunc = intToFloat(32) + case coltypes.Float64: + ov.AssignFunc = intToFloat(64) + } + castOverloads[from] = append(castOverloads[from], ov) + } + case coltypes.Int64: + for _, to := range inputTypes { + ov := castOverload{FromTyp: from, ToTyp: to, ToGoTyp: to.GoTypeName()} + switch to { + case coltypes.Bool: + ov.AssignFunc = numToBool + case coltypes.Decimal: + ov.AssignFunc = intToDecimal + case coltypes.Int64: + ov.AssignFunc = castIdentity + case coltypes.Float32: + ov.AssignFunc = intToFloat(32) + case coltypes.Float64: + ov.AssignFunc = intToFloat(64) + } + castOverloads[from] = append(castOverloads[from], ov) + } + case coltypes.Float32: + for _, to := range inputTypes { + ov := castOverload{FromTyp: from, ToTyp: to, ToGoTyp: to.GoTypeName()} + switch to { + case coltypes.Bool: + ov.AssignFunc = numToBool + case coltypes.Decimal: + ov.AssignFunc = floatToDecimal + case coltypes.Int8: + ov.AssignFunc = floatToInt(8, 32) + case coltypes.Int16: + ov.AssignFunc = floatToInt(16, 32) + case coltypes.Int32: + ov.AssignFunc = floatToInt(32, 32) + case coltypes.Int64: + ov.AssignFunc = floatToInt(64, 32) + case coltypes.Float32: + ov.AssignFunc = castIdentity + } + castOverloads[from] = append(castOverloads[from], ov) + } + case coltypes.Float64: + for _, to := range inputTypes { + ov := castOverload{FromTyp: from, ToTyp: to, ToGoTyp: to.GoTypeName()} + switch to { + case coltypes.Bool: + ov.AssignFunc = numToBool + case coltypes.Decimal: + ov.AssignFunc = floatToDecimal + case coltypes.Int8: + ov.AssignFunc = floatToInt(8, 64) + case coltypes.Int16: + ov.AssignFunc = floatToInt(16, 64) + case coltypes.Int32: + ov.AssignFunc = floatToInt(32, 64) + case coltypes.Int64: + ov.AssignFunc = floatToInt(64, 64) + case coltypes.Float64: + ov.AssignFunc = castIdentity + } + castOverloads[from] = append(castOverloads[from], ov) + } + } + } } // typeCustomizer is a marker interface for something that implements one or @@ -760,6 +991,7 @@ func registerTypeCustomizers() { var _ = overload{}.Assign var _ = overload{}.Compare var _ = overload{}.UnaryAssign +var _ = castOverload{}.Assign // buildDict is a template function that builds a dictionary out of its // arguments. The argument to this function should be an alternating sequence of diff --git a/pkg/sql/logictest/testdata/logic_test/vectorize b/pkg/sql/logictest/testdata/logic_test/vectorize index c74099ffba1e..37bd3620166c 100644 --- a/pkg/sql/logictest/testdata/logic_test/vectorize +++ b/pkg/sql/logictest/testdata/logic_test/vectorize @@ -682,3 +682,15 @@ CREATE TABLE t_40227 AS SELECT g FROM generate_series(0, 5) AS g statement ok SELECT '' FROM t_40227 AS t1 JOIN t_40227 AS t2 ON true + +# Tests for #39417 +statement ok +CREATE TABLE t39417 (x int8) + +statement ok +INSERT INTO t39417 VALUES (10) + +query R +select (x/1) from t39417 +---- +10 diff --git a/pkg/sql/sem/tree/datum.go b/pkg/sql/sem/tree/datum.go index 08ad4485afa1..398e8b8eb615 100644 --- a/pkg/sql/sem/tree/datum.go +++ b/pkg/sql/sem/tree/datum.go @@ -733,6 +733,16 @@ func (d *DInt) Size() uintptr { // DFloat is the float Datum. type DFloat float64 +// MustBeDFloat attempts to retrieve a DFloat from an Expr, panicking if the +// assertion fails. +func MustBeDFloat(e Expr) DFloat { + switch t := e.(type) { + case *DFloat: + return *t + } + panic(errors.AssertionFailedf("expected *DFloat, found %T", e)) +} + // NewDFloat is a helper routine to create a *DFloat initialized from its // argument. func NewDFloat(d DFloat) *DFloat { @@ -886,6 +896,16 @@ type DDecimal struct { apd.Decimal } +// MustBeDDecimal attempts to retrieve a DDecimal from an Expr, panicking if the +// assertion fails. +func MustBeDDecimal(e Expr) DDecimal { + switch t := e.(type) { + case *DDecimal: + return *t + } + panic(errors.AssertionFailedf("expected *DDecimal, found %T", e)) +} + // ParseDDecimal parses and returns the *DDecimal Datum value represented by the // provided string, or an error if parsing is unsuccessful. func ParseDDecimal(s string) (*DDecimal, error) {