Skip to content

Commit

Permalink
sql: add assignment casts for UPSERTs
Browse files Browse the repository at this point in the history
Assignment casts are now added to query plans for upserts, including
`UPSERT`, `INSERT .. ON CONFLICT DO NOTHING`, and
`INSERT .. ON CONFLICT DO UPDATE ..` statements.

Assignment casts are a more general form of the logic for rounding
decimal values, so the use of `round_decimal_values` in mutations is no
longer needed. This logic has been removed.

Fixes cockroachdb#67083

There is no release note because the behavior of upserts should not
change with this commit.

Release note: None
  • Loading branch information
mgartner committed Jan 12, 2022
1 parent 4149ca7 commit b0a382a
Show file tree
Hide file tree
Showing 12 changed files with 2,382 additions and 558 deletions.
452 changes: 450 additions & 2 deletions pkg/sql/logictest/testdata/logic_test/cast

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pkg/sql/opt/optbuilder/fk_cascade.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ func (cb *onDeleteSetBuilder) Build(
updateExprs[i].Expr = tree.DefaultVal{}
}
}
mb.addUpdateCols(updateExprs, false /* isUpsert */)
mb.addUpdateCols(updateExprs)

// TODO(radu): consider plumbing a flag to prevent building the FK check
// against the parent we are cascading from. Need to investigate in which
Expand Down Expand Up @@ -687,7 +687,7 @@ func (cb *onUpdateCascadeBuilder) Build(
panic(errors.AssertionFailedf("unsupported action"))
}
}
mb.addUpdateCols(updateExprs, false /* isUpsert */)
mb.addUpdateCols(updateExprs)

mb.buildUpdate(nil /* returning */)
return mb.outScope.expr
Expand Down
47 changes: 13 additions & 34 deletions pkg/sql/opt/optbuilder/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ func (b *Builder) buildInsert(ins *tree.Insert, inScope *scope) (outScope *scope
//
// INSERT INTO <table> DEFAULT VALUES
//
isUpsert := ins.OnConflict != nil && !ins.OnConflict.DoNothing
if !ins.DefaultValues() {
// Replace any DEFAULT expressions in the VALUES clause, if a VALUES clause
// exists:
Expand All @@ -268,15 +267,15 @@ func (b *Builder) buildInsert(ins *tree.Insert, inScope *scope) (outScope *scope
//
rows := mb.replaceDefaultExprs(ins.Rows)

mb.buildInputForInsert(inScope, rows, isUpsert)
mb.buildInputForInsert(inScope, rows)
} else {
mb.buildInputForInsert(inScope, nil /* rows */, isUpsert)
mb.buildInputForInsert(inScope, nil /* rows */)
}

// Add default columns that were not explicitly specified by name or
// implicitly targeted by input columns. Also add any computed columns. In
// both cases, include columns undergoing mutations in the write-only state.
mb.addSynthesizedColsForInsert(isUpsert)
mb.addSynthesizedColsForInsert()

var returning tree.ReturningExprs
if resultsNeeded(ins.Returning) {
Expand Down Expand Up @@ -317,7 +316,7 @@ func (b *Builder) buildInsert(ins *tree.Insert, inScope *scope) (outScope *scope

// Add additional columns for computed expressions that may depend on any
// updated columns, as well as mutation columns with default values.
mb.addSynthesizedColsForUpdate(true /* isUpsert */)
mb.addSynthesizedColsForUpdate()
}

// Build the final upsert statement, including any returned expressions.
Expand All @@ -334,7 +333,7 @@ func (b *Builder) buildInsert(ins *tree.Insert, inScope *scope) (outScope *scope
mb.addTargetColsForUpdate(ins.OnConflict.Exprs)

// Build each of the SET expressions.
mb.addUpdateCols(ins.OnConflict.Exprs, true /* isUpsert */)
mb.addUpdateCols(ins.OnConflict.Exprs)

// Build the final upsert statement, including any returned expressions.
mb.buildUpsert(returning)
Expand Down Expand Up @@ -558,9 +557,7 @@ func (mb *mutationBuilder) addTargetTableColsForInsert(maxCols int) {

// buildInputForInsert constructs the memo group for the input expression and
// constructs a new output scope containing that expression's output columns.
func (mb *mutationBuilder) buildInputForInsert(
inScope *scope, inputRows *tree.Select, isUpsert bool,
) {
func (mb *mutationBuilder) buildInputForInsert(inScope *scope, inputRows *tree.Select) {
// Handle DEFAULT VALUES case by creating a single empty row as input.
if inputRows == nil {
mb.outScope = inScope.push()
Expand Down Expand Up @@ -619,11 +616,6 @@ func (mb *mutationBuilder) buildInputForInsert(
inCol := &mb.outScope.cols[i]
ord := mb.tabID.ColumnOrdinal(mb.targetColList[i])

if isUpsert {
// Type check the input column against the corresponding table column.
checkDatumTypeFitsColumnType(mb.tab.Column(ord), inCol.typ)
}

// Raise an error if the target column is a `GENERATED ALWAYS AS
// IDENTITY` column. Such a column is not allowed to be explicitly
// written to.
Expand All @@ -645,18 +637,16 @@ func (mb *mutationBuilder) buildInputForInsert(
mb.insertColIDs[ord] = inCol.id
}

if !isUpsert {
// Add assignment casts for insert columns.
mb.addAssignmentCasts(mb.insertColIDs)
}
// Add assignment casts for insert columns.
mb.addAssignmentCasts(mb.insertColIDs)
}

// addSynthesizedColsForInsert wraps an Insert input expression with a Project
// operator containing any default (or nullable) columns and any computed
// columns that are not yet part of the target column list. This includes all
// write-only mutation columns, since they must always have default or computed
// values.
func (mb *mutationBuilder) addSynthesizedColsForInsert(isUpsert bool) {
func (mb *mutationBuilder) addSynthesizedColsForInsert() {
// Start by adding non-computed columns that have not already been explicitly
// specified in the query. Do this before adding computed columns, since those
// may depend on non-computed columns.
Expand All @@ -666,25 +656,14 @@ func (mb *mutationBuilder) addSynthesizedColsForInsert(isUpsert bool) {
false, /* applyOnUpdate */
)

if isUpsert {
// Possibly round DECIMAL-related columns containing insertion values (whether
// synthesized or not).
mb.roundDecimalValues(mb.insertColIDs, false /* roundComputedCols */)
} else {
// Add assignment casts for default column values.
mb.addAssignmentCasts(mb.insertColIDs)
}
// Add assignment casts for default column values.
mb.addAssignmentCasts(mb.insertColIDs)

// Now add all computed columns.
mb.addSynthesizedComputedCols(mb.insertColIDs, false /* restrict */)

// Possibly round DECIMAL-related computed columns.
if isUpsert {
mb.roundDecimalValues(mb.insertColIDs, true /* roundComputedCols */)
} else {
// Add assignment casts for computed column values.
mb.addAssignmentCasts(mb.insertColIDs)
}
// Add assignment casts for computed column values.
mb.addAssignmentCasts(mb.insertColIDs)
}

// buildInsert constructs an Insert operator, possibly wrapped by a Project
Expand Down
168 changes: 0 additions & 168 deletions pkg/sql/opt/optbuilder/mutation_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/sem/builtins"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sqlerrors"
"github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry"
Expand Down Expand Up @@ -137,10 +136,6 @@ type mutationBuilder struct {
// detect conflicts for UPSERT and INSERT ON CONFLICT statements.
arbiters arbiterSet

// roundedDecimalCols is the set of columns that have already been rounded.
// Keeping this set avoids rounding the same column multiple times.
roundedDecimalCols opt.ColSet

// subqueries temporarily stores subqueries that were built during initial
// analysis of SET expressions. They will be used later when the subqueries
// are joined into larger LEFT OUTER JOIN expressions.
Expand Down Expand Up @@ -721,148 +716,6 @@ func (mb *mutationBuilder) addSynthesizedComputedCols(colIDs opt.OptionalColList
mb.outScope = pb.Finish()
}

// roundDecimalValues wraps each DECIMAL-related column (including arrays of
// decimals) with a call to the crdb_internal.round_decimal_values function, if
// column values may need to be rounded. This is necessary when mutating table
// columns that have a limited scale (e.g. DECIMAL(10, 1)). Here is the PG docs
// description:
//
// http://www.postgresql.org/docs/9.5/static/datatype-numeric.html
// "If the scale of a value to be stored is greater than
// the declared scale of the column, the system will round the
// value to the specified number of fractional digits. Then,
// if the number of digits to the left of the decimal point
// exceeds the declared precision minus the declared scale, an
// error is raised."
//
// Note that this function only handles the rounding portion of that. The
// precision check is done by the execution engine. The rounding cannot be done
// there, since it needs to happen before check constraints are computed, and
// before UPSERT joins.
//
// If roundComputedCols is false, then don't wrap computed columns. If true,
// then only wrap computed columns. This is necessary because computed columns
// can depend on other columns mutated by the operation; it is necessary to
// first round those values, then evaluated the computed expression, and then
// round the result of the computation.
//
// roundDecimalValues will only round decimal columns that are part of the
// colIDs list (i.e. are not 0). If a column is rounded, then the list will be
// updated with the column ID of the new synthesized column.
func (mb *mutationBuilder) roundDecimalValues(colIDs opt.OptionalColList, roundComputedCols bool) {
var projectionsScope *scope

for i, id := range colIDs {
if id == 0 {
// Column not mutated, so nothing to do.
continue
}

// Include or exclude computed columns, depending on the value of
// roundComputedCols.
col := mb.tab.Column(i)
if col.IsComputed() != roundComputedCols {
continue
}

// Check whether the target column's type may require rounding of the
// input value.
colType := col.DatumType()
precision, width := colType.Precision(), colType.Width()
if colType.Family() == types.ArrayFamily {
innerType := colType.ArrayContents()
if innerType.Family() == types.ArrayFamily {
panic(errors.AssertionFailedf("column type should never be a nested array"))
}
precision, width = innerType.Precision(), innerType.Width()
}

props, overload := findRoundingFunction(colType, precision)
if props == nil {
continue
}

// If column has already been rounded, then skip it.
if mb.roundedDecimalCols.Contains(id) {
continue
}

private := &memo.FunctionPrivate{
Name: "crdb_internal.round_decimal_values",
Typ: col.DatumType(),
Properties: props,
Overload: overload,
}
variable := mb.b.factory.ConstructVariable(id)
scale := mb.b.factory.ConstructConstVal(tree.NewDInt(tree.DInt(width)), types.Int)
fn := mb.b.factory.ConstructFunction(memo.ScalarListExpr{variable, scale}, private)

// Lazily create new scope and update the scope column to be rounded.
if projectionsScope == nil {
projectionsScope = mb.outScope.replace()
projectionsScope.appendColumnsFromScope(mb.outScope)
}
scopeCol := projectionsScope.getColumn(id)
mb.b.populateSynthesizedColumn(scopeCol, fn)

// Overwrite the input column ID with the new synthesized column ID.
colIDs[i] = scopeCol.id
mb.roundedDecimalCols.Add(scopeCol.id)

// When building an UPDATE..FROM expression the projectionScope may have
// two columns with different names but the same ID. As a result, the
// scope column with the correct name (the name of the target column)
// may not be returned from projectionScope.getColumn. We set the name
// of the new scope column to the target column name to ensure it is
// in-scope when building CHECK constraint and partial index PUT
// expressions. See #61520.
// TODO(mgartner): Find a less brittle way to manage the scopes of
// mutations so that this isn't necessary. Ideally the scope produced by
// addUpdateColumns would not include columns in the FROM clause. Those
// columns are only in-scope in the RETURNING clause via
// mb.extraAccessibleCols.
scopeCol.name = scopeColName(mb.tab.Column(i).ColName())
}

if projectionsScope != nil {
mb.b.constructProjectForScope(mb.outScope, projectionsScope)
mb.outScope = projectionsScope
}
}

// findRoundingFunction returns the builtin function overload needed to round
// input values. This is only necessary for DECIMAL or DECIMAL[] types that have
// limited precision, such as:
//
// DECIMAL(15, 1)
// DECIMAL(10, 3)[]
//
// If an input decimal value has more than the required number of fractional
// digits, it must be rounded before being inserted into these types.
//
// NOTE: CRDB does not allow nested array storage types, so only one level of
// array nesting needs to be checked.
func findRoundingFunction(
typ *types.T, precision int32,
) (*tree.FunctionProperties, *tree.Overload) {
if precision == 0 {
// Unlimited precision decimal target type never needs rounding.
return nil, nil
}

props, overloads := builtins.GetBuiltinProperties("crdb_internal.round_decimal_values")

if typ.Equivalent(types.Decimal) {
return props, &overloads[0]
}
if typ.Equivalent(types.DecimalArray) {
return props, &overloads[1]
}

// Not DECIMAL or DECIMAL[].
return nil, nil
}

// addCheckConstraintCols synthesizes a boolean output column for each check
// constraint defined on the target table. The mutation operator will report a
// constraint violation error if the value of the column is false.
Expand Down Expand Up @@ -1369,27 +1222,6 @@ func resultsNeeded(r tree.ReturningClause) bool {
}
}

// checkDatumTypeFitsColumnType verifies that a given scalar value type is valid
// to be stored in a column of the given column type.
//
// For the purpose of this analysis, column type aliases are not considered to
// be different (eg. TEXT and VARCHAR will fit the same scalar type String).
//
// This is used by the UPDATE, INSERT and UPSERT code.
// TODO(mgartner): Remove this once assignment casts are fully supported.
func checkDatumTypeFitsColumnType(col *cat.Column, typ *types.T) {
if typ.Equivalent(col.DatumType()) {
return
}

colName := string(col.ColName())
err := pgerror.Newf(pgcode.DatatypeMismatch,
"value type %s doesn't match type %s of column %q",
typ, col.DatumType(), tree.ErrNameString(colName))
err = errors.WithHint(err, "you will need to rewrite or cast the expression")
panic(err)
}

// addAssignmentCasts builds a projection that wraps columns in srcCols with
// assignment casts when necessary so that the resulting columns have types
// identical to their target column types.
Expand Down
Loading

0 comments on commit b0a382a

Please sign in to comment.