From 95030f212501449a18cd55d7aef2c27f9d2427a1 Mon Sep 17 00:00:00 2001 From: Andrew Kimball Date: Fri, 22 Mar 2019 11:02:10 -0700 Subject: [PATCH] opt: Round decimal values before check constraints The PG spec requires the following DECIMAL type behavior: If the scale of a value to be stored is greater than the declared scale of the column, the system will round the value to the specified number of fractional digits. This is currently happening at the very end of the mutation operator, after any check constraints. This commit moves the rounding to occur before check constraints. Rounding must be performed on inserted and updated values before computed columns are evaluated as well, since computed columns should run on the final values to be inserted/updated. Fixes #35364 Release note (sql change): Computed columns are now evaluated after rounding any decimal values in input columns. Previously, computed columns used input values before rounding. --- docs/generated/sql/functions.md | 4 + pkg/sql/logictest/testdata/logic_test/insert | 19 ++ pkg/sql/logictest/testdata/logic_test/update | 19 ++ pkg/sql/logictest/testdata/logic_test/upsert | 66 +++++ pkg/sql/opt/cat/column.go | 20 ++ pkg/sql/opt/exec/execbuilder/testdata/insert | 18 ++ pkg/sql/opt/exec/execbuilder/testdata/update | 23 ++ pkg/sql/opt/exec/execbuilder/testdata/upsert | 31 ++ pkg/sql/opt/memo/testdata/logprops/upsert | 12 +- pkg/sql/opt/memo/testdata/stats/upsert | 4 +- pkg/sql/opt/norm/testdata/rules/prune_cols | 6 +- pkg/sql/opt/optbuilder/insert.go | 50 ++-- pkg/sql/opt/optbuilder/mutation_builder.go | 102 +++++++ pkg/sql/opt/optbuilder/testdata/insert | 65 ++++- pkg/sql/opt/optbuilder/testdata/update | 64 ++++ pkg/sql/opt/optbuilder/testdata/upsert | 274 +++++++++++++++++- pkg/sql/opt/optbuilder/update.go | 14 +- pkg/sql/opt/testutils/testcat/create_table.go | 11 + pkg/sql/opt/testutils/testcat/test_catalog.go | 19 +- pkg/sql/sem/builtins/builtins.go | 76 ++++- pkg/sql/sem/tree/testdata/eval/func | 94 ++++++ pkg/sql/sqlbase/structured.go | 10 + 22 files changed, 958 insertions(+), 43 deletions(-) diff --git a/docs/generated/sql/functions.md b/docs/generated/sql/functions.md index 35451a6df5c2..e69464cdbb33 100644 --- a/docs/generated/sql/functions.md +++ b/docs/generated/sql/functions.md @@ -958,6 +958,10 @@ may increase either contention or retry errors, or both.

crdb_internal.pretty_key(raw_key: bytes, skip_fields: int) → string

This function is used only by CockroachDB’s developers for testing purposes.

+crdb_internal.round_decimal_values(val: decimal, scale: int) → decimal

This function is used internally to round decimal values during mutations.

+
+crdb_internal.round_decimal_values(val: decimal[], scale: int) → decimal[]

This function is used internally to round decimal array values during mutations.

+
crdb_internal.set_vmodule(vmodule_string: string) → int

Set the equivalent of the --vmodule flag on the gateway node processing this request; it affords control over the logging verbosity of different files. Example syntax: crdb_internal.set_vmodule('recordio=2,file=1,gfs*=3'). Reset with: crdb_internal.set_vmodule(''). Raising the verbosity can severely affect performance.

current_database() → string

Returns the current database.

diff --git a/pkg/sql/logictest/testdata/logic_test/insert b/pkg/sql/logictest/testdata/logic_test/insert index af262116c4c0..b225fd4a157e 100644 --- a/pkg/sql/logictest/testdata/logic_test/insert +++ b/pkg/sql/logictest/testdata/logic_test/insert @@ -686,3 +686,22 @@ INSERT INTO t35611 (a) VALUES (1) statement ok COMMIT + +# ------------------------------------------------------------------------------ +# Regression for #35364. +# ------------------------------------------------------------------------------ +subtest regression_35364 + +statement ok +CREATE TABLE t35364(x DECIMAL(1,0) CHECK (x = 0)) + +statement ok +INSERT INTO t35364(x) VALUES (0.1) + +query T +SELECT x FROM t35364 +---- +0 + +statement ok +DROP TABLE t35364 diff --git a/pkg/sql/logictest/testdata/logic_test/update b/pkg/sql/logictest/testdata/logic_test/update index ada5aba89fc3..691f58e7440e 100644 --- a/pkg/sql/logictest/testdata/logic_test/update +++ b/pkg/sql/logictest/testdata/logic_test/update @@ -527,3 +527,22 @@ query II SELECT * FROM t32054 ---- NULL NULL + +# ------------------------------------------------------------------------------ +# Regression for #35364. +# ------------------------------------------------------------------------------ +subtest regression_35364 + +statement ok +CREATE TABLE t35364(x DECIMAL(1,0) CHECK (x >= 1)) + +statement ok +INSERT INTO t35364 VALUES (1) + +statement ok +UPDATE t35364 SET x=0.5 + +query T +SELECT x FROM t35364 +---- +1 diff --git a/pkg/sql/logictest/testdata/logic_test/upsert b/pkg/sql/logictest/testdata/logic_test/upsert index 9ccc04a99347..67df4603413a 100644 --- a/pkg/sql/logictest/testdata/logic_test/upsert +++ b/pkg/sql/logictest/testdata/logic_test/upsert @@ -868,3 +868,69 @@ INSERT INTO test35040(a,b) VALUES (0,1) ON CONFLICT(a) DO UPDATE SET c = 1111111 statement ok DROP TABLE test35040 + +# ------------------------------------------------------------------------------ +# Regression for #35364. +# ------------------------------------------------------------------------------ +subtest regression_35364 + +statement ok +CREATE TABLE t35364(x INT PRIMARY KEY, y DECIMAL(10,1) CHECK(y >= 8.0), UNIQUE INDEX (y)) + +statement ok +INSERT INTO t35364(x, y) VALUES (1, 10.2) + +# 10.18 should be mapped to 10.2 before the left outer join so that the conflict +# can be detected, and 7.95 should be mapped to 8.0 so that check constraint +# will pass. +statement ok +INSERT INTO t35364(x, y) VALUES (2, 10.18) ON CONFLICT (y) DO UPDATE SET y=7.95 + +query IT +SELECT * FROM t35364 +---- +1 8.0 + +statement ok +DROP TABLE t35364 + +# Check UPSERT syntax. +statement ok +CREATE TABLE t35364( + x DECIMAL(10,0) CHECK (x >= 0) PRIMARY KEY, + y DECIMAL(10,0) CHECK (y >= 0) +) + +statement ok +UPSERT INTO t35364 (x) VALUES (-0.1) + +query TT +SELECT * FROM t35364 +---- +-0 NULL + +statement ok +UPSERT INTO t35364 (x, y) VALUES (-0.2, -0.3) + +query TT +SELECT * FROM t35364 +---- +-0 -0 + +statement ok +UPSERT INTO t35364 (x, y) VALUES (1.5, 2.5) + +query TT rowsort +SELECT * FROM t35364 +---- +-0 -0 +2 3 + +statement ok +INSERT INTO t35364 (x) VALUES (1.5) ON CONFLICT (x) DO UPDATE SET x=2.5, y=3.5 + +query TT rowsort +SELECT * FROM t35364 +---- +-0 -0 +3 4 diff --git a/pkg/sql/opt/cat/column.go b/pkg/sql/opt/cat/column.go index db00bfab962f..e4de3c2337a8 100644 --- a/pkg/sql/opt/cat/column.go +++ b/pkg/sql/opt/cat/column.go @@ -35,6 +35,26 @@ type Column interface { // DatumType returns the data type of the column. DatumType() types.T + // ColTypePrecision returns the precision of the column's SQL data type. This + // is only defined for the Decimal data type and represents the max number of + // decimal digits in the decimal (including fractional digits). If precision + // is 0, then the decimal has no max precision. + ColTypePrecision() int + + // ColTypeWidth returns the width of the column's SQL data type. This has + // different meanings depending on the data type: + // + // Decimal : scale + // Int : # bits (16, 32, 64, etc) + // Bit Array: # bits + // String : rune count + // + // TODO(andyk): It'd be better to expose the attributes of the column type + // using a different type or interface. However, currently that's hard to do, + // since using sqlbase.ColumnType creates an import cycle, and there's no good + // way to create a coltypes.T from sqlbase.ColumnType. + ColTypeWidth() int + // ColTypeStr returns the SQL data type of the column, as a string. Note that // this is sometimes different than DatumType().String(), since datum types // are a subset of column types. diff --git a/pkg/sql/opt/exec/execbuilder/testdata/insert b/pkg/sql/opt/exec/execbuilder/testdata/insert index f455d2e2cc95..16539a03e2b1 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/insert +++ b/pkg/sql/opt/exec/execbuilder/testdata/insert @@ -494,3 +494,21 @@ render · · (z) └── scan · · (a, b, c) +c · table abc@abc_c_idx · · · spans ALL · · + +# ------------------------------------------------------------------------------ +# Regression for #35364. This tests behavior that is different between the CBO +# and the HP. The CBO will (deliberately) round any input columns *before* +# evaluating any computed columns, as well as rounding the output. +# ------------------------------------------------------------------------------ + +statement ok +CREATE TABLE t35364( + x DECIMAL(10,0) CHECK(round(x) = x) PRIMARY KEY, + y DECIMAL(10,0) DEFAULT (1.5), + z DECIMAL(10,0) AS (x+y+2.5) STORED CHECK(z >= 7) +) + +query TTT +INSERT INTO t35364 (x) VALUES (1.5) RETURNING * +---- +2 2 7 diff --git a/pkg/sql/opt/exec/execbuilder/testdata/update b/pkg/sql/opt/exec/execbuilder/testdata/update index 02ee30a15baf..a55117f4709a 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/update +++ b/pkg/sql/opt/exec/execbuilder/testdata/update @@ -322,3 +322,26 @@ render · · (a) + └── scan · · (a, b, c, rowid[hidden]) +c · table abc@abc_c_idx · · · spans ALL · · + +# ------------------------------------------------------------------------------ +# Regression for #35364. This tests behavior that is different between the CBO +# and the HP. The CBO will (deliberately) round any input columns *before* +# evaluating any computed columns, as well as rounding the output. +# ------------------------------------------------------------------------------ + +statement ok +CREATE TABLE t35364( + x DECIMAL(10,0) CHECK(round(x) = x) PRIMARY KEY, + y DECIMAL(10,0) DEFAULT (1.5), + z DECIMAL(10,0) AS (x+y+2.5) STORED CHECK(z >= 7) +) + +query TTT +INSERT INTO t35364 (x) VALUES (1.5) RETURNING * +---- +2 2 7 + +query TTT +UPDATE t35364 SET x=2.5 RETURNING * +---- +3 2 8 diff --git a/pkg/sql/opt/exec/execbuilder/testdata/upsert b/pkg/sql/opt/exec/execbuilder/testdata/upsert index f1ca91140d3c..20725184206f 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/upsert +++ b/pkg/sql/opt/exec/execbuilder/testdata/upsert @@ -342,3 +342,34 @@ render · · (z) └── scan · · (a, b, c) +c · table abc@abc_c_idx · · · spans ALL · · + +# ------------------------------------------------------------------------------ +# Regression for #35364. This tests behavior that is different between the CBO +# and the HP. The CBO will (deliberately) round any input columns *before* +# evaluating any computed columns, as well as rounding the output. +# ------------------------------------------------------------------------------ + +statement ok +CREATE TABLE t35364( + x DECIMAL(10,0) CHECK(round(x) = x) PRIMARY KEY, + y DECIMAL(10,0) DEFAULT (1.5), + z DECIMAL(10,0) AS (x+y+2.5) STORED CHECK(z >= 7) +) + +query TTT +UPSERT INTO t35364 (x) VALUES (1.5) RETURNING * +---- +2 2 7 + +query TTT +UPSERT INTO t35364 (x, y) VALUES (1.5, 2.5) RETURNING * +---- +2 3 8 + +query TTT +INSERT INTO t35364 (x) VALUES (1.5) ON CONFLICT (x) DO UPDATE SET x=2.5 RETURNING * +---- +3 3 9 + +statement error pq: failed to satisfy CHECK constraint \(z >= 7\) +UPSERT INTO t35364 (x) VALUES (0) diff --git a/pkg/sql/opt/memo/testdata/logprops/upsert b/pkg/sql/opt/memo/testdata/logprops/upsert index 7467db754803..4a705cfd0b9b 100644 --- a/pkg/sql/opt/memo/testdata/logprops/upsert +++ b/pkg/sql/opt/memo/testdata/logprops/upsert @@ -78,11 +78,11 @@ project │ └── upsert_rowid:20 => rowid:4 ├── side-effects, mutations └── project - ├── columns: upsert_a:17(int) upsert_b:18(int) upsert_c:19(int) upsert_rowid:20(int) x:5(int!null) y:6(int!null) column8:8(int) column9:9(int) a:10(int) b:11(int) c:12(int) rowid:13(int) + ├── columns: upsert_a:17(int) upsert_b:18(int) upsert_c:19(int) upsert_rowid:20(int) x:5(int!null) y:6(int!null) column8:8(int) column9:9(int) a:10(int) b:11(int) c:12(int) rowid:13(int) column14:14(int!null) column15:15(int) column16:16(int) ├── side-effects ├── key: (5,13) - ├── fd: ()-->(6,9), (5)-->(8), (13)-->(10-12), (10)-->(11-13), (11,12)~~>(10,13), (5,13)-->(17-19), (8,13)-->(20) - ├── prune: (5,6,8-13,17-20) + ├── fd: ()-->(6,9,14), (5)-->(8), (13)-->(10-12), (10)-->(11-13), (11,12)~~>(10,13), (12)-->(15), (15)-->(16), (5,13)-->(17), (13,15)-->(18), (13,16)-->(19), (8,13)-->(20) + ├── prune: (5,6,8-20) ├── interesting orderings: (+5) (+6) (+13) (+10) (+11,+12,+13) ├── project │ ├── columns: column16:16(int) x:5(int!null) y:6(int!null) column8:8(int) column9:9(int) a:10(int) b:11(int) c:12(int) rowid:13(int) column14:14(int!null) column15:15(int) @@ -400,11 +400,11 @@ project │ ├── cardinality: [2 - ] │ ├── side-effects, mutations │ └── project - │ ├── columns: upsert_b:14(int) upsert_c:15(int) upsert_rowid:16(int) column1:5(int) column6:6(int!null) column7:7(int) column8:8(int) a:9(int) b:10(int) c:11(int) rowid:12(int) + │ ├── columns: upsert_b:14(int) upsert_c:15(int) upsert_rowid:16(int) column1:5(int) column6:6(int!null) column7:7(int) column8:8(int) a:9(int) b:10(int) c:11(int) rowid:12(int) column13:13(int) │ ├── cardinality: [2 - ] │ ├── side-effects - │ ├── fd: ()-->(6,8), (12)-->(9-11), (9)-->(10-12), (10,11)~~>(9,12), (10,12)-->(14), (7,12)-->(16) - │ ├── prune: (5-12,14-16) + │ ├── fd: ()-->(6,8), (12)-->(9-11), (9)-->(10-12), (10,11)~~>(9,12), (10)-->(13), (10,12)-->(14), (12,13)-->(15), (7,12)-->(16) + │ ├── prune: (5-16) │ ├── interesting orderings: (+12) (+9) (+10,+11,+12) │ ├── project │ │ ├── columns: column13:13(int) column1:5(int) column6:6(int!null) column7:7(int) column8:8(int) a:9(int) b:10(int) c:11(int) rowid:12(int) diff --git a/pkg/sql/opt/memo/testdata/stats/upsert b/pkg/sql/opt/memo/testdata/stats/upsert index 1933d15e1ccb..6458e275326e 100644 --- a/pkg/sql/opt/memo/testdata/stats/upsert +++ b/pkg/sql/opt/memo/testdata/stats/upsert @@ -79,9 +79,9 @@ select │ ├── side-effects, mutations │ ├── stats: [rows=200, distinct(1)=181.351171, null(1)=0, distinct(2)=200, null(2)=0] │ └── project - │ ├── columns: upsert_x:13(string) upsert_y:14(int) upsert_z:15(float) a:4(int!null) b:5(string!null) column8:8(float) x:9(string) y:10(int) z:11(float) + │ ├── columns: upsert_x:13(string) upsert_y:14(int) upsert_z:15(float) a:4(int!null) b:5(string!null) column8:8(float) x:9(string) y:10(int) z:11(float) column12:12(int!null) │ ├── stats: [rows=200, distinct(13)=181.351171, null(13)=0, distinct(14)=200, null(14)=0] - │ ├── fd: ()-->(5,8), (9)-->(10,11), (9)-->(13), (4,9)-->(14) + │ ├── fd: ()-->(5,8,12), (9)-->(10,11), (9)-->(13), (4,9)-->(14) │ ├── project │ │ ├── columns: column12:12(int!null) a:4(int!null) b:5(string!null) column8:8(float) x:9(string) y:10(int) z:11(float) │ │ ├── stats: [rows=200, distinct(5,9)=181.351171, null(5,9)=0, distinct(4,9,12)=200, null(4,9,12)=0] diff --git a/pkg/sql/opt/norm/testdata/rules/prune_cols b/pkg/sql/opt/norm/testdata/rules/prune_cols index 05fd605d0790..e86623e9ac43 100644 --- a/pkg/sql/opt/norm/testdata/rules/prune_cols +++ b/pkg/sql/opt/norm/testdata/rules/prune_cols @@ -1845,9 +1845,9 @@ upsert a └── projections └── CASE WHEN k IS NULL THEN column7 ELSE i + 1 END [type=int, outer=(7,9,10)] -# No pruning when RETURNING clause is present. -# TODO(andyk): Need to prune output columns. -opt expect-not=(PruneMutationFetchCols,PruneMutationInputCols) +# Prune update columns replaced by upsert columns. +# TODO(andyk): Need to also prune output columns. +opt expect=PruneMutationInputCols expect-not=PruneMutationFetchCols INSERT INTO a (k, s) VALUES (1, 'foo') ON CONFLICT (k) DO UPDATE SET i=a.i+1 RETURNING * ---- upsert a diff --git a/pkg/sql/opt/optbuilder/insert.go b/pkg/sql/opt/optbuilder/insert.go index 6ca5ffbb8ed1..9a01ef8e3bb5 100644 --- a/pkg/sql/opt/optbuilder/insert.go +++ b/pkg/sql/opt/optbuilder/insert.go @@ -223,11 +223,19 @@ func (b *Builder) buildInsert(ins *tree.Insert, inScope *scope) (outScope *scope mb.buildInputForInsert(inScope, nil /* rows */) } - // Add default and computed columns that were not explicitly specified by - // name or implicitly targeted by input columns. This includes any columns - // undergoing write mutations, as they must always have a default or computed - // value. - mb.addDefaultAndComputedColsForInsert() + // Add default columns that were not explicitly specified by name or + // implicitly targeted by input columns. This includes columns undergoing + // write mutations, if they have a default value. + mb.addDefaultColsForInsert() + + // Possibly round DECIMAL-related columns containing insertion values. Do + // this before evaluating computed expressions, since those may depend on + // the inserted columns. + mb.roundDecimalValues(mb.insertOrds, false /* roundComputedCols */) + + // Add any computed columns. This includes columns undergoing write mutations, + // if they have a computed value. + mb.addComputedColsForInsert() var returning tree.ReturningExprs if resultsNeeded(ins.Returning) { @@ -284,10 +292,6 @@ func (b *Builder) buildInsert(ins *tree.Insert, inScope *scope) (outScope *scope // Build each of the SET expressions. mb.addUpdateCols(ins.OnConflict.Exprs) - // Add additional columns for computed expressions that may depend on any - // updated columns. - mb.addComputedColsForUpdate() - // Build the final upsert statement, including any returned expressions. mb.buildUpsert(returning) } @@ -565,27 +569,31 @@ func (mb *mutationBuilder) buildInputForInsert(inScope *scope, inputRows *tree.S } } -// addDefaultAndComputedColsForInsert wraps an Insert input expression with -// Project operator(s) containing any default (or nullable) and computed columns -// that are not yet part of the target column list. This includes mutation -// columns, since they must always have default or computed values. -// -// After this call, the input expression will provide values for every one of -// the target table columns, whether it was explicitly specified or implicitly -// added. -func (mb *mutationBuilder) addDefaultAndComputedColsForInsert() { - // Add any missing default and nullable columns. +// addDefaultColsForInsert wraps an Insert input expression with a Project +// operator containing any default (or nullable) columns that are not yet part +// of the target column list. This includes mutation columns, since they must +// always have default or computed values. +func (mb *mutationBuilder) addDefaultColsForInsert() { mb.addSynthesizedCols( mb.insertOrds, func(tabCol cat.Column) bool { return !tabCol.IsComputed() }, ) +} - // Add any missing computed columns. This must be done after adding default - // columns above, because computed columns can depend on default columns. +// addComputedColsForInsert wraps an Insert input expression with a Project +// operator containing computed columns that are not yet part of the target +// column list. This includes mutation columns, since they must always have +// default or computed values. This must be done after calling +// addDefaultColsForInsert, because computed columns can depend on default +// columns. +func (mb *mutationBuilder) addComputedColsForInsert() { mb.addSynthesizedCols( mb.insertOrds, func(tabCol cat.Column) bool { return tabCol.IsComputed() }, ) + + // Possibly round DECIMAL-related computed columns. + mb.roundDecimalValues(mb.insertOrds, true /* roundComputedCols */) } // buildInsert constructs an Insert operator, possibly wrapped by a Project diff --git a/pkg/sql/opt/optbuilder/mutation_builder.go b/pkg/sql/opt/optbuilder/mutation_builder.go index ed65b21ab8cc..46c6522b2c92 100644 --- a/pkg/sql/opt/optbuilder/mutation_builder.go +++ b/pkg/sql/opt/optbuilder/mutation_builder.go @@ -22,6 +22,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sem/types" "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" @@ -403,6 +404,107 @@ func (mb *mutationBuilder) addSynthesizedCols( } } +// roundDecimalValues wraps each DECIMAL-related column (including arrays of +// decimals) with a call to the crdb_internal.round_decimal_values function, if +// column values may need to be rounded. This is necessary when mutating table +// columns that have a limited scale (e.g. DECIMAL(10, 1)). Here is the PG docs +// description: +// +// http://www.postgresql.org/docs/9.5/static/datatype-numeric.html +// "If the scale of a value to be stored is greater than +// the declared scale of the column, the system will round the +// value to the specified number of fractional digits. Then, +// if the number of digits to the left of the decimal point +// exceeds the declared precision minus the declared scale, an +// error is raised." +// +// Note that this function only handles the rounding portion of that. The +// precision check is done by the execution engine. The rounding cannot be done +// there, since it needs to happen before check constraints are computed, and +// before UPSERT joins. +// +// if roundComputedCols is false, then don't wrap computed columns. If true, +// then only wrap computed columns. This is necessary because computed columns +// can depend on other columns mutated by the operation; it is necessary to +// first round those values, then evaluated the computed expression, and then +// round the result of the computation. +func (mb *mutationBuilder) roundDecimalValues(scopeOrds []scopeOrdinal, roundComputedCols bool) { + var projectionsScope *scope + + for i, ord := range scopeOrds { + if ord == -1 { + // Column not mutated, so nothing to do. + continue + } + + // Include or exclude computed columns, depending on the value of + // roundComputedCols. + col := mb.tab.Column(i) + if col.IsComputed() != roundComputedCols { + continue + } + + // Check whether the target column's type may require rounding of the + // input value. + props, overload := findRoundingFunction(col.DatumType(), col.ColTypePrecision()) + if props == nil { + continue + } + private := &memo.FunctionPrivate{ + Name: "crdb_internal.round_decimal_values", + Typ: mb.outScope.cols[ord].typ, + Properties: props, + Overload: overload, + } + variable := mb.b.factory.ConstructVariable(mb.scopeOrdToColID(ord)) + scale := mb.b.factory.ConstructConstVal(tree.NewDInt(tree.DInt(col.ColTypeWidth())), types.Int) + fn := mb.b.factory.ConstructFunction(memo.ScalarListExpr{variable, scale}, private) + + // Lazily create new scope and update the scope column to be rounded. + if projectionsScope == nil { + projectionsScope = mb.outScope.replace() + projectionsScope.appendColumnsFromScope(mb.outScope) + } + mb.b.populateSynthesizedColumn(&projectionsScope.cols[ord], fn) + } + + if projectionsScope != nil { + mb.b.constructProjectForScope(mb.outScope, projectionsScope) + mb.outScope = projectionsScope + } +} + +// findRoundingFunction returns the builtin function overload needed to round +// input values. This is only necessary for DECIMAL or DECIMAL[] types that have +// limited precision, such as: +// +// DECIMAL(15, 1) +// DECIMAL(10, 3)[] +// +// If an input decimal value has more than the required number of fractional +// digits, it must be rounded before being inserted into these types. +// +// NOTE: CRDB does not allow nested array storage types, so only one level of +// array nesting needs to be checked. +func findRoundingFunction(typ types.T, precision int) (*tree.FunctionProperties, *tree.Overload) { + if precision == 0 { + // Unlimited precision decimal target type never needs rounding. + return nil, nil + } + + props, overloads := builtins.GetBuiltinProperties("crdb_internal.round_decimal_values") + + if typ.Equivalent(types.Decimal) { + return props, &overloads[0] + } + if arr, ok := typ.(types.TArray); ok && arr.Typ.Equivalent(types.Decimal) { + return props, &overloads[1] + } + + // Not DECIMAL or DECIMAL[]. + return nil, nil +} + // addCheckConstraintCols synthesizes a boolean output column for each check // constraint defined on the target table. The mutation operator will report // a constraint violation error if the value of the column is false. diff --git a/pkg/sql/opt/optbuilder/testdata/insert b/pkg/sql/opt/optbuilder/testdata/insert index 27f49ef7e594..028141e5b31c 100644 --- a/pkg/sql/opt/optbuilder/testdata/insert +++ b/pkg/sql/opt/optbuilder/testdata/insert @@ -83,6 +83,24 @@ TABLE checks ├── CHECK (checks.b < d) └── CHECK (a > 0) +exec-ddl +CREATE TABLE decimals ( + a DECIMAL(10,0) PRIMARY KEY CHECK (round(a) = a), + b DECIMAL(5,1)[] CHECK (b[0] > 1), + c DECIMAL(10,1) DEFAULT (1.23), + d DECIMAL(10,1) AS (a+c) STORED +) +---- +TABLE decimals + ├── a decimal not null + ├── b decimal[] + ├── c decimal + ├── d decimal + ├── INDEX primary + │ └── a decimal not null + ├── CHECK (round(a) = a) + └── CHECK (b[0] > 1) + # Unknown target table. build INSERT INTO unknown VALUES (1, 2, 3) @@ -1264,7 +1282,7 @@ INSERT INTO mutation (m, n, p) VALUES (1, 2, 3) error (42703): column "p" does not exist # ------------------------------------------------------------------------------ -# Test check constraints +# Test check constraints. # ------------------------------------------------------------------------------ # Insert constants. @@ -1332,3 +1350,48 @@ insert checks └── gt [type=bool] ├── variable: abcde.a [type=int] └── const: 0 [type=int] + +# ------------------------------------------------------------------------------ +# Test decimal column rounding. +# ------------------------------------------------------------------------------ + +opt +INSERT INTO decimals (a, b) VALUES (1.1, ARRAY[0.95, NULL, 15]) +---- +insert decimals + ├── columns: + ├── insert-mapping: + │ ├── a:8 => decimals.a:1 + │ ├── b:9 => decimals.b:2 + │ ├── c:10 => decimals.c:3 + │ └── d:12 => decimals.d:4 + ├── check columns: check1:13(bool) check2:14(bool) + └── project + ├── columns: check1:13(bool) check2:14(bool) d:12(decimal) a:8(decimal) b:9(decimal[]) c:10(decimal) + ├── values + │ ├── columns: a:8(decimal) b:9(decimal[]) c:10(decimal) + │ └── tuple [type=tuple{decimal, decimal[], decimal}] + │ ├── function: crdb_internal.round_decimal_values [type=decimal] + │ │ ├── const: 1.1 [type=decimal] + │ │ └── const: 0 [type=int] + │ ├── function: crdb_internal.round_decimal_values [type=decimal[]] + │ │ ├── const: ARRAY[0.95,NULL,15] [type=decimal[]] + │ │ └── const: 1 [type=int] + │ └── function: crdb_internal.round_decimal_values [type=decimal] + │ ├── const: 1.23 [type=decimal] + │ └── const: 1 [type=int] + └── projections + ├── eq [type=bool] + │ ├── variable: a [type=decimal] + │ └── function: round [type=decimal] + │ └── variable: a [type=decimal] + ├── gt [type=bool] + │ ├── indirection [type=decimal] + │ │ ├── variable: b [type=decimal[]] + │ │ └── const: 0 [type=int] + │ └── const: 1 [type=decimal] + └── function: crdb_internal.round_decimal_values [type=decimal] + ├── plus [type=decimal] + │ ├── variable: a [type=decimal] + │ └── variable: c [type=decimal] + └── const: 1 [type=int] diff --git a/pkg/sql/opt/optbuilder/testdata/update b/pkg/sql/opt/optbuilder/testdata/update index 9ebd0db91001..2ed115b32e92 100644 --- a/pkg/sql/opt/optbuilder/testdata/update +++ b/pkg/sql/opt/optbuilder/testdata/update @@ -83,6 +83,24 @@ TABLE checks ├── CHECK (checks.b < d) └── CHECK (a > 0) +exec-ddl +CREATE TABLE decimals ( + a DECIMAL(10,0) PRIMARY KEY CHECK (round(a) = a), + b DECIMAL(5,1)[] CHECK (b[0] > 1), + c DECIMAL(10,1) DEFAULT (1.23), + d DECIMAL(10,1) AS (a+c) STORED +) +---- +TABLE decimals + ├── a decimal not null + ├── b decimal[] + ├── c decimal + ├── d decimal + ├── INDEX primary + │ └── a decimal not null + ├── CHECK (round(a) = a) + └── CHECK (b[0] > 1) + # ------------------------------------------------------------------------------ # Basic tests. # ------------------------------------------------------------------------------ @@ -1538,3 +1556,49 @@ update checks └── gt [type=bool] ├── variable: abcde.a [type=int] └── const: 0 [type=int] + +# ------------------------------------------------------------------------------ +# Test decimal column truncation. +# ------------------------------------------------------------------------------ + +opt +UPDATE decimals SET a=1.1, b=ARRAY[0.95, NULL, 15] +---- +update decimals + ├── columns: + ├── fetch columns: decimals.a:5(decimal) decimals.b:6(decimal[]) c:7(decimal) decimals.d:8(decimal) + ├── update-mapping: + │ ├── a:11 => decimals.a:1 + │ ├── b:12 => decimals.b:2 + │ └── d:15 => decimals.d:4 + ├── check columns: check1:16(bool) check2:17(bool) + └── project + ├── columns: check1:16(bool) check2:17(bool) d:15(decimal) decimals.a:5(decimal!null) decimals.b:6(decimal[]) c:7(decimal) decimals.d:8(decimal) a:11(decimal) b:12(decimal[]) + ├── project + │ ├── columns: a:11(decimal) b:12(decimal[]) decimals.a:5(decimal!null) decimals.b:6(decimal[]) c:7(decimal) decimals.d:8(decimal) + │ ├── scan decimals + │ │ └── columns: decimals.a:5(decimal!null) decimals.b:6(decimal[]) c:7(decimal) decimals.d:8(decimal) + │ └── projections + │ ├── function: crdb_internal.round_decimal_values [type=decimal] + │ │ ├── const: 1.1 [type=decimal] + │ │ └── const: 0 [type=int] + │ └── function: crdb_internal.round_decimal_values [type=decimal[]] + │ ├── const: ARRAY[0.95,NULL,15] [type=decimal[]] + │ └── const: 1 [type=int] + └── projections + ├── eq [type=bool] + │ ├── variable: a [type=decimal] + │ └── function: round [type=decimal] + │ └── variable: a [type=decimal] + ├── gt [type=bool] + │ ├── indirection [type=decimal] + │ │ ├── variable: b [type=decimal[]] + │ │ └── const: 0 [type=int] + │ └── const: 1 [type=decimal] + └── function: crdb_internal.round_decimal_values [type=decimal] + ├── function: crdb_internal.round_decimal_values [type=decimal] + │ ├── plus [type=decimal] + │ │ ├── variable: a [type=decimal] + │ │ └── variable: c [type=decimal] + │ └── const: 1 [type=int] + └── const: 1 [type=int] diff --git a/pkg/sql/opt/optbuilder/testdata/upsert b/pkg/sql/opt/optbuilder/testdata/upsert index 19146db7e81c..284fd25e9f39 100644 --- a/pkg/sql/opt/optbuilder/testdata/upsert +++ b/pkg/sql/opt/optbuilder/testdata/upsert @@ -117,6 +117,24 @@ TABLE checks ├── CHECK (checks.b < d) └── CHECK (a > 0) +exec-ddl +CREATE TABLE decimals ( + a DECIMAL(10,0) PRIMARY KEY CHECK (round(a) = a), + b DECIMAL(5,1)[] CHECK (b[0] > 1), + c DECIMAL(10,1) DEFAULT (1.23), + d DECIMAL(10,1) AS (a+c) STORED +) +---- +TABLE decimals + ├── a decimal not null + ├── b decimal[] + ├── c decimal + ├── d decimal + ├── INDEX primary + │ └── a decimal not null + ├── CHECK (round(a) = a) + └── CHECK (b[0] > 1) + # ------------------------------------------------------------------------------ # Basic tests. # ------------------------------------------------------------------------------ @@ -1580,7 +1598,7 @@ UPSERT INTO xyz (x, unknown) VALUES (1) error (42703): column "unknown" does not exist # ------------------------------------------------------------------------------ -# Test check constraints +# Test check constraints. # ------------------------------------------------------------------------------ # INSERT..ON CONFLICT @@ -1920,3 +1938,257 @@ upsert checks └── gt [type=bool] ├── variable: abc.a [type=int] └── const: 0 [type=int] + +# ------------------------------------------------------------------------------ +# Test decimal column truncation. +# ------------------------------------------------------------------------------ + +# Fast UPSERT case. +opt +UPSERT INTO decimals (a, b) VALUES (1.1, ARRAY[0.95]) +---- +upsert decimals + ├── columns: + ├── canary column: 13 + ├── fetch columns: decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) + ├── insert-mapping: + │ ├── a:8 => decimals.a:1 + │ ├── b:9 => decimals.b:2 + │ ├── c:10 => decimals.c:3 + │ └── d:12 => decimals.d:4 + ├── update-mapping: + │ ├── b:9 => decimals.b:2 + │ └── upsert_d:21 => decimals.d:4 + ├── check columns: check1:22(bool) check2:23(bool) + └── project + ├── columns: check1:22(bool) check2:23(bool) a:8(decimal) b:9(decimal[]) c:10(decimal) d:12(decimal) decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) upsert_d:21(decimal) + ├── project + │ ├── columns: upsert_a:19(decimal) upsert_d:21(decimal) a:8(decimal) b:9(decimal[]) c:10(decimal) d:12(decimal) decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) + │ ├── left-join (lookup decimals) + │ │ ├── columns: a:8(decimal) b:9(decimal[]) c:10(decimal) d:12(decimal) decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) + │ │ ├── key columns: [8] = [13] + │ │ ├── project + │ │ │ ├── columns: d:12(decimal) a:8(decimal) b:9(decimal[]) c:10(decimal) + │ │ │ ├── values + │ │ │ │ ├── columns: a:8(decimal) b:9(decimal[]) c:10(decimal) + │ │ │ │ └── tuple [type=tuple{decimal, decimal[], decimal}] + │ │ │ │ ├── function: crdb_internal.round_decimal_values [type=decimal] + │ │ │ │ │ ├── const: 1.1 [type=decimal] + │ │ │ │ │ └── const: 0 [type=int] + │ │ │ │ ├── function: crdb_internal.round_decimal_values [type=decimal[]] + │ │ │ │ │ ├── const: ARRAY[0.95] [type=decimal[]] + │ │ │ │ │ └── const: 1 [type=int] + │ │ │ │ └── function: crdb_internal.round_decimal_values [type=decimal] + │ │ │ │ ├── const: 1.23 [type=decimal] + │ │ │ │ └── const: 1 [type=int] + │ │ │ └── projections + │ │ │ └── function: crdb_internal.round_decimal_values [type=decimal] + │ │ │ ├── plus [type=decimal] + │ │ │ │ ├── variable: a [type=decimal] + │ │ │ │ └── variable: c [type=decimal] + │ │ │ └── const: 1 [type=int] + │ │ └── filters (true) + │ └── projections + │ ├── case [type=decimal] + │ │ ├── true [type=bool] + │ │ ├── when [type=decimal] + │ │ │ ├── is [type=bool] + │ │ │ │ ├── variable: decimals.a [type=decimal] + │ │ │ │ └── null [type=unknown] + │ │ │ └── variable: a [type=decimal] + │ │ └── variable: decimals.a [type=decimal] + │ └── case [type=decimal] + │ ├── true [type=bool] + │ ├── when [type=decimal] + │ │ ├── is [type=bool] + │ │ │ ├── variable: decimals.a [type=decimal] + │ │ │ └── null [type=unknown] + │ │ └── variable: d [type=decimal] + │ └── function: crdb_internal.round_decimal_values [type=decimal] + │ ├── plus [type=decimal] + │ │ ├── variable: decimals.a [type=decimal] + │ │ └── variable: decimals.c [type=decimal] + │ └── const: 1 [type=int] + └── projections + ├── eq [type=bool] + │ ├── variable: upsert_a [type=decimal] + │ └── function: round [type=decimal] + │ └── variable: upsert_a [type=decimal] + └── gt [type=bool] + ├── indirection [type=decimal] + │ ├── variable: b [type=decimal[]] + │ └── const: 0 [type=int] + └── const: 1 [type=decimal] + +# Regular UPSERT case. +opt +UPSERT INTO decimals (a) VALUES (1.1) +---- +upsert decimals + ├── columns: + ├── canary column: 13 + ├── fetch columns: decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) + ├── insert-mapping: + │ ├── a:8 => decimals.a:1 + │ ├── b:9 => decimals.b:2 + │ ├── c:10 => decimals.c:3 + │ └── d:12 => decimals.d:4 + ├── update-mapping: + │ └── upsert_d:22 => decimals.d:4 + ├── check columns: check1:23(bool) check2:24(bool) + └── project + ├── columns: check1:23(bool) check2:24(bool) a:8(decimal) b:9(decimal[]) c:10(decimal) d:12(decimal) decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) upsert_d:22(decimal) + ├── project + │ ├── columns: upsert_a:19(decimal) upsert_b:20(decimal[]) upsert_d:22(decimal) a:8(decimal) b:9(decimal[]) c:10(decimal) d:12(decimal) decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) + │ ├── left-join (lookup decimals) + │ │ ├── columns: a:8(decimal) b:9(decimal[]) c:10(decimal) d:12(decimal) decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) + │ │ ├── key columns: [8] = [13] + │ │ ├── project + │ │ │ ├── columns: d:12(decimal) a:8(decimal) b:9(decimal[]) c:10(decimal) + │ │ │ ├── values + │ │ │ │ ├── columns: a:8(decimal) b:9(decimal[]) c:10(decimal) + │ │ │ │ └── tuple [type=tuple{decimal, decimal[], decimal}] + │ │ │ │ ├── function: crdb_internal.round_decimal_values [type=decimal] + │ │ │ │ │ ├── const: 1.1 [type=decimal] + │ │ │ │ │ └── const: 0 [type=int] + │ │ │ │ ├── function: crdb_internal.round_decimal_values [type=decimal[]] + │ │ │ │ │ ├── null [type=decimal[]] + │ │ │ │ │ └── const: 1 [type=int] + │ │ │ │ └── function: crdb_internal.round_decimal_values [type=decimal] + │ │ │ │ ├── const: 1.23 [type=decimal] + │ │ │ │ └── const: 1 [type=int] + │ │ │ └── projections + │ │ │ └── function: crdb_internal.round_decimal_values [type=decimal] + │ │ │ ├── plus [type=decimal] + │ │ │ │ ├── variable: a [type=decimal] + │ │ │ │ └── variable: c [type=decimal] + │ │ │ └── const: 1 [type=int] + │ │ └── filters (true) + │ └── projections + │ ├── case [type=decimal] + │ │ ├── true [type=bool] + │ │ ├── when [type=decimal] + │ │ │ ├── is [type=bool] + │ │ │ │ ├── variable: decimals.a [type=decimal] + │ │ │ │ └── null [type=unknown] + │ │ │ └── variable: a [type=decimal] + │ │ └── variable: decimals.a [type=decimal] + │ ├── case [type=decimal[]] + │ │ ├── true [type=bool] + │ │ ├── when [type=decimal[]] + │ │ │ ├── is [type=bool] + │ │ │ │ ├── variable: decimals.a [type=decimal] + │ │ │ │ └── null [type=unknown] + │ │ │ └── variable: b [type=decimal[]] + │ │ └── variable: decimals.b [type=decimal[]] + │ └── case [type=decimal] + │ ├── true [type=bool] + │ ├── when [type=decimal] + │ │ ├── is [type=bool] + │ │ │ ├── variable: decimals.a [type=decimal] + │ │ │ └── null [type=unknown] + │ │ └── variable: d [type=decimal] + │ └── function: crdb_internal.round_decimal_values [type=decimal] + │ ├── plus [type=decimal] + │ │ ├── variable: decimals.a [type=decimal] + │ │ └── variable: decimals.c [type=decimal] + │ └── const: 1 [type=int] + └── projections + ├── eq [type=bool] + │ ├── variable: upsert_a [type=decimal] + │ └── function: round [type=decimal] + │ └── variable: upsert_a [type=decimal] + └── gt [type=bool] + ├── indirection [type=decimal] + │ ├── variable: upsert_b [type=decimal[]] + │ └── const: 0 [type=int] + └── const: 1 [type=decimal] + +# INSERT...ON CONFLICT case. +opt +INSERT INTO decimals (a, b) VALUES (1.1, ARRAY[0.95]) +ON CONFLICT (a) +DO UPDATE SET b=ARRAY[0.99] +---- +upsert decimals + ├── columns: + ├── canary column: 13 + ├── fetch columns: decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) + ├── insert-mapping: + │ ├── a:8 => decimals.a:1 + │ ├── b:9 => decimals.b:2 + │ ├── c:10 => decimals.c:3 + │ └── d:12 => decimals.d:4 + ├── update-mapping: + │ ├── upsert_b:22 => decimals.b:2 + │ └── upsert_d:24 => decimals.d:4 + ├── check columns: check1:25(bool) check2:26(bool) + └── project + ├── columns: check1:25(bool) check2:26(bool) a:8(decimal) b:9(decimal[]) c:10(decimal) d:12(decimal) decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) upsert_b:22(decimal[]) upsert_d:24(decimal) + ├── project + │ ├── columns: upsert_a:21(decimal) upsert_b:22(decimal[]) upsert_d:24(decimal) a:8(decimal) b:9(decimal[]) c:10(decimal) d:12(decimal) decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) + │ ├── left-join (lookup decimals) + │ │ ├── columns: a:8(decimal) b:9(decimal[]) c:10(decimal) d:12(decimal) decimals.a:13(decimal) decimals.b:14(decimal[]) decimals.c:15(decimal) decimals.d:16(decimal) + │ │ ├── key columns: [8] = [13] + │ │ ├── project + │ │ │ ├── columns: d:12(decimal) a:8(decimal) b:9(decimal[]) c:10(decimal) + │ │ │ ├── values + │ │ │ │ ├── columns: a:8(decimal) b:9(decimal[]) c:10(decimal) + │ │ │ │ └── tuple [type=tuple{decimal, decimal[], decimal}] + │ │ │ │ ├── function: crdb_internal.round_decimal_values [type=decimal] + │ │ │ │ │ ├── const: 1.1 [type=decimal] + │ │ │ │ │ └── const: 0 [type=int] + │ │ │ │ ├── function: crdb_internal.round_decimal_values [type=decimal[]] + │ │ │ │ │ ├── const: ARRAY[0.95] [type=decimal[]] + │ │ │ │ │ └── const: 1 [type=int] + │ │ │ │ └── function: crdb_internal.round_decimal_values [type=decimal] + │ │ │ │ ├── const: 1.23 [type=decimal] + │ │ │ │ └── const: 1 [type=int] + │ │ │ └── projections + │ │ │ └── function: crdb_internal.round_decimal_values [type=decimal] + │ │ │ ├── plus [type=decimal] + │ │ │ │ ├── variable: a [type=decimal] + │ │ │ │ └── variable: c [type=decimal] + │ │ │ └── const: 1 [type=int] + │ │ └── filters (true) + │ └── projections + │ ├── case [type=decimal] + │ │ ├── true [type=bool] + │ │ ├── when [type=decimal] + │ │ │ ├── is [type=bool] + │ │ │ │ ├── variable: decimals.a [type=decimal] + │ │ │ │ └── null [type=unknown] + │ │ │ └── variable: a [type=decimal] + │ │ └── variable: decimals.a [type=decimal] + │ ├── case [type=decimal[]] + │ │ ├── true [type=bool] + │ │ ├── when [type=decimal[]] + │ │ │ ├── is [type=bool] + │ │ │ │ ├── variable: decimals.a [type=decimal] + │ │ │ │ └── null [type=unknown] + │ │ │ └── variable: b [type=decimal[]] + │ │ └── function: crdb_internal.round_decimal_values [type=decimal[]] + │ │ ├── const: ARRAY[0.99] [type=decimal[]] + │ │ └── const: 1 [type=int] + │ └── case [type=decimal] + │ ├── true [type=bool] + │ ├── when [type=decimal] + │ │ ├── is [type=bool] + │ │ │ ├── variable: decimals.a [type=decimal] + │ │ │ └── null [type=unknown] + │ │ └── variable: d [type=decimal] + │ └── function: crdb_internal.round_decimal_values [type=decimal] + │ ├── plus [type=decimal] + │ │ ├── variable: decimals.a [type=decimal] + │ │ └── variable: decimals.c [type=decimal] + │ └── const: 1 [type=int] + └── projections + ├── eq [type=bool] + │ ├── variable: upsert_a [type=decimal] + │ └── function: round [type=decimal] + │ └── variable: upsert_a [type=decimal] + └── gt [type=bool] + ├── indirection [type=decimal] + │ ├── variable: upsert_b [type=decimal[]] + │ └── const: 0 [type=int] + └── const: 1 [type=decimal] diff --git a/pkg/sql/opt/optbuilder/update.go b/pkg/sql/opt/optbuilder/update.go index 88a895ea5f16..c924f3b32109 100644 --- a/pkg/sql/opt/optbuilder/update.go +++ b/pkg/sql/opt/optbuilder/update.go @@ -210,7 +210,7 @@ func (mb *mutationBuilder) addUpdateCols(exprs tree.UpdateExprs) { ord := mb.tabID.ColumnOrdinal(targetColID) checkDatumTypeFitsColumnType(mb.tab.Column(ord), sourceCol.typ) - // Add ordinal of new column scope to the list of columns to update. + // Add ordinal of new scope column to the list of columns to update. mb.updateOrds[ord] = scopeOrd // Rename the column to match the target column being updated. @@ -283,6 +283,15 @@ func (mb *mutationBuilder) addUpdateCols(exprs tree.UpdateExprs) { mb.b.constructProjectForScope(mb.outScope, projectionsScope) mb.outScope = projectionsScope + + // Possibly round DECIMAL-related columns that were updated. Do this + // before evaluating computed expressions, since those may depend on the + // inserted columns. + mb.roundDecimalValues(mb.updateOrds, false /* roundComputedCols */) + + // Add additional columns for computed expressions that may depend on any + // updated columns. + mb.addComputedColsForUpdate() } // addComputedColsForUpdate wraps an Update input expression with a Project @@ -305,6 +314,9 @@ func (mb *mutationBuilder) addComputedColsForUpdate() { mb.updateOrds, func(tabCol cat.Column) bool { return tabCol.IsComputed() }, ) + + // Possibly round DECIMAL-related computed columns. + mb.roundDecimalValues(mb.updateOrds, true /* roundComputedCols */) } // buildUpdate constructs an Update operator, possibly wrapped by a Project diff --git a/pkg/sql/opt/testutils/testcat/create_table.go b/pkg/sql/opt/testutils/testcat/create_table.go index efbc6d65b172..f1e6313e2aa1 100644 --- a/pkg/sql/opt/testutils/testcat/create_table.go +++ b/pkg/sql/opt/testutils/testcat/create_table.go @@ -23,6 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/opt/cat" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sem/types" + "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" "github.com/cockroachdb/cockroach/pkg/util" ) @@ -304,6 +305,16 @@ func (tt *Table) addColumn(def *tree.ColumnTableDef) { Nullable: nullable, } + var err error + col.ColType, err = sqlbase.DatumTypeToColumnType(typ) + if err != nil { + panic(err) + } + col.ColType, err = sqlbase.PopulateTypeAttrs(col.ColType, def.Type) + if err != nil { + panic(err) + } + // Look for name suffixes indicating this is a mutation column. if name, ok := extractWriteOnlyColumn(def); ok { col.Name = name diff --git a/pkg/sql/opt/testutils/testcat/test_catalog.go b/pkg/sql/opt/testutils/testcat/test_catalog.go index b0f79a4235a5..bad2248b5512 100644 --- a/pkg/sql/opt/testutils/testcat/test_catalog.go +++ b/pkg/sql/opt/testutils/testcat/test_catalog.go @@ -20,13 +20,13 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/config" - "github.com/cockroachdb/cockroach/pkg/sql/coltypes" "github.com/cockroachdb/cockroach/pkg/sql/opt/cat" "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sem/types" + "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" "github.com/cockroachdb/cockroach/pkg/sql/stats" "github.com/cockroachdb/cockroach/pkg/util/treeprinter" ) @@ -712,6 +712,7 @@ type Column struct { Nullable bool Name string Type types.T + ColType sqlbase.ColumnType DefaultExpr *string ComputedExpr *string } @@ -738,13 +739,19 @@ func (tc *Column) DatumType() types.T { return tc.Type } +// ColTypePrecision is part of the cat.Column interface. +func (tc *Column) ColTypePrecision() int { + return int(tc.ColType.Precision) +} + +// ColTypeWidth is part of the cat.Column interface. +func (tc *Column) ColTypeWidth() int { + return int(tc.ColType.Width) +} + // ColTypeStr is part of the cat.Column interface. func (tc *Column) ColTypeStr() string { - t, err := coltypes.DatumTypeToColumnType(tc.Type) - if err != nil { - panic(err) - } - return t.String() + return tc.ColType.SQLString() } // IsHidden is part of the cat.Column interface. diff --git a/pkg/sql/sem/builtins/builtins.go b/pkg/sql/sem/builtins/builtins.go index 27167ce5f3eb..d0809c712c54 100644 --- a/pkg/sql/sem/builtins/builtins.go +++ b/pkg/sql/sem/builtins/builtins.go @@ -2991,6 +2991,67 @@ may increase either contention or retry errors, or both.`, Info: "This function is used only by CockroachDB's developers for testing purposes.", }, ), + + "crdb_internal.round_decimal_values": makeBuiltin( + tree.FunctionProperties{ + Category: categorySystemInfo, + }, + tree.Overload{ + Types: tree.ArgTypes{ + {"val", types.Decimal}, + {"scale", types.Int}, + }, + ReturnType: tree.FixedReturnType(types.Decimal), + Fn: func(_ *tree.EvalContext, args tree.Datums) (tree.Datum, error) { + value := args[0].(*tree.DDecimal) + scale := int32(tree.MustBeDInt(args[1])) + return roundDDecimal(value, scale) + }, + Info: "This function is used internally to round decimal values during mutations.", + }, + tree.Overload{ + Types: tree.ArgTypes{ + {"val", types.TArray{Typ: types.Decimal}}, + {"scale", types.Int}, + }, + ReturnType: tree.FixedReturnType(types.TArray{Typ: types.Decimal}), + Fn: func(_ *tree.EvalContext, args tree.Datums) (tree.Datum, error) { + value := args[0].(*tree.DArray) + scale := int32(tree.MustBeDInt(args[1])) + + // Lazily allocate a new array only if/when one of its elements + // is rounded. + var newArr tree.Datums + for i, elem := range value.Array { + // Skip NULL values. + if elem == tree.DNull { + continue + } + + rounded, err := roundDDecimal(elem.(*tree.DDecimal), scale) + if err != nil { + return nil, err + } + if rounded != elem { + if newArr == nil { + newArr = make(tree.Datums, len(value.Array)) + copy(newArr, value.Array) + } + newArr[i] = rounded + } + } + if newArr != nil { + return &tree.DArray{ + ParamTyp: value.ParamTyp, + Array: newArr, + HasNulls: value.HasNulls, + }, nil + } + return value, nil + }, + Info: "This function is used internally to round decimal array values during mutations.", + }, + ), } var lengthImpls = makeBuiltin(tree.FunctionProperties{Category: categoryString}, @@ -4013,9 +4074,20 @@ func overlay(s, to string, pos, size int) (tree.Datum, error) { return tree.NewDString(string(runes[:pos]) + to + string(runes[after:])), nil } -func roundDecimal(x *apd.Decimal, n int32) (tree.Datum, error) { +// roundDDecimal avoids creation of a new DDecimal in common case where no +// rounding is necessary. +func roundDDecimal(d *tree.DDecimal, scale int32) (tree.Datum, error) { + // Fast path: check if number of digits after decimal point is already low + // enough. + if -d.Exponent <= scale { + return d, nil + } + return roundDecimal(&d.Decimal, scale) +} + +func roundDecimal(x *apd.Decimal, scale int32) (tree.Datum, error) { dd := &tree.DDecimal{} - _, err := tree.HighPrecisionCtx.Quantize(&dd.Decimal, x, -n) + _, err := tree.HighPrecisionCtx.Quantize(&dd.Decimal, x, -scale) return dd, err } diff --git a/pkg/sql/sem/tree/testdata/eval/func b/pkg/sql/sem/tree/testdata/eval/func index 2ef78a68f74a..dc84cb8b42ed 100644 --- a/pkg/sql/sem/tree/testdata/eval/func +++ b/pkg/sql/sem/tree/testdata/eval/func @@ -15,3 +15,97 @@ eval UPPER('hello') ---- 'HELLO' + +# Scale < -Exponent. +eval +crdb_internal.round_decimal_values(1.23:::decimal, 1) +---- +1.2 + +# Scale = -Exponent. +eval +crdb_internal.round_decimal_values(1.23:::decimal, 2) +---- +1.23 + +# Scale > -Exponent. +eval +crdb_internal.round_decimal_values(1.23:::decimal, 3) +---- +1.23 + +# Scale=0 with whole number. +eval +crdb_internal.round_decimal_values(123:::decimal, 0) +---- +123 + +# Scale=0 with fractional number. +eval +crdb_internal.round_decimal_values(0.123:::decimal, 0) +---- +0 + +# Special-value cases. +eval +crdb_internal.round_decimal_values('NaN'::decimal, 0) +---- +NaN + +eval +crdb_internal.round_decimal_values('-inf'::decimal, 0) +---- +-Infinity + +eval +crdb_internal.round_decimal_values('inf'::decimal, 0) +---- +Infinity + +# NULL value. +eval +crdb_internal.round_decimal_values(null, 0) +---- +NULL + +# NULL decimal value. +eval +crdb_internal.round_decimal_values(null::decimal, 0) +---- +NULL + +# Round 10th fractional digit. +eval +crdb_internal.round_decimal_values(1000000000000000.0000000005::decimal, 9) +---- +1000000000000000.000000001 + +# Truncate extra zeros. +eval +crdb_internal.round_decimal_values(1000000000000000.0000000000::decimal, 3) +---- +1000000000000000.000 + +# Round with 1 digit coefficient and large negative exponent. +eval +crdb_internal.round_decimal_values(0.0000000005::decimal, 9) +---- +1E-9 + +# Decimal in array. +eval +crdb_internal.round_decimal_values(ARRAY[1.25::decimal], 1) +---- +ARRAY[1.3] + +# Multiple array values need to be rounded + NULL values. +eval +crdb_internal.round_decimal_values(ARRAY[NULL, 1.25::decimal, NULL, 1.23::decimal], 1) +---- +ARRAY[NULL,1.3,NULL,1.2] + +# None of the array values need to be rounded. +eval +crdb_internal.round_decimal_values(ARRAY[1.2::decimal, 5::decimal, NULL], 1) +---- +ARRAY[1.2,5,NULL] diff --git a/pkg/sql/sqlbase/structured.go b/pkg/sql/sqlbase/structured.go index bd8ba01cb457..a94aeb79defb 100644 --- a/pkg/sql/sqlbase/structured.go +++ b/pkg/sql/sqlbase/structured.go @@ -2804,6 +2804,16 @@ func (desc *ColumnDescriptor) DatumType() types.T { return desc.Type.ToDatumType() } +// ColTypePrecision is part of the cat.Column interface. +func (desc *ColumnDescriptor) ColTypePrecision() int { + return int(desc.Type.Precision) +} + +// ColTypeWidth is part of the cat.Column interface. +func (desc *ColumnDescriptor) ColTypeWidth() int { + return int(desc.Type.Width) +} + // ColTypeStr is part of the cat.Column interface. func (desc *ColumnDescriptor) ColTypeStr() string { return desc.Type.SQLString()