Skip to content

Commit

Permalink
sql: Ensure inserts can be performed on columns with
Browse files Browse the repository at this point in the history
default/on update expr with differing types

Release note: None
  • Loading branch information
e-mbrown committed May 31, 2022
1 parent 78745ae commit e9962a8
Show file tree
Hide file tree
Showing 26 changed files with 96 additions and 74 deletions.
2 changes: 1 addition & 1 deletion pkg/ccl/changefeedccl/avro_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func parseValues(tableDesc catalog.TableDescriptor, values string) ([]rowenc.Enc
for colIdx, expr := range rowTuple {
col := tableDesc.PublicColumns()[colIdx]
typedExpr, err := schemaexpr.SanitizeVarFreeExpr(
ctx, expr, col.GetType(), "avro", &semaCtx, volatility.Stable)
ctx, expr, col.GetType(), "avro", &semaCtx, volatility.Stable, false)
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions pkg/ccl/partitionccl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ func valueEncodePartitionTuple(
typedExpr, err := schemaexpr.SanitizeVarFreeExpr(evalCtx.Context, expr, cols[i].GetType(), "partition",
&semaCtx,
volatility.Immutable,
false,
)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ func sanitizeColumnExpression(
) (tree.TypedExpr, string, error) {
colDatumType := col.GetType()
typedExpr, err := schemaexpr.SanitizeVarFreeExpr(
p.ctx, expr, colDatumType, opName, &p.p.semaCtx, volatility.Volatile,
p.ctx, expr, colDatumType, opName, &p.p.semaCtx, volatility.Volatile, false,
)
if err != nil {
return nil, "", pgerror.WithCandidateCode(err, pgcode.DatatypeMismatch)
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/alter_table_locality.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ func (n *alterTableSetLocalityNode) alterTableLocalityToRegionalByRow(
"REGIONAL BY ROW DEFAULT",
params.p.SemaCtx(),
volatility.Volatile,
false,
)
if err != nil {
return err
Expand Down
3 changes: 1 addition & 2 deletions pkg/sql/analyze_expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ func (p *planner) analyzeExpr(
var err error
p.semaCtx.IVarContainer = iVarHelper.Container()
if requireType {
typedExpr, err = tree.TypeCheckAndRequire(ctx, resolved, &p.semaCtx,
expectedType, typingContext)
typedExpr, err = tree.TypeCheckAndRequire(ctx, resolved, &p.semaCtx, expectedType, typingContext, false)
} else {
typedExpr, err = tree.TypeCheck(ctx, resolved, &p.semaCtx, expectedType)
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/catalog/schemaexpr/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ go_library(
"//pkg/sql/parser",
"//pkg/sql/pgwire/pgcode",
"//pkg/sql/pgwire/pgerror",
"//pkg/sql/sem/cast",
"//pkg/sql/sem/catid",
"//pkg/sql/sem/eval",
"//pkg/sql/sem/transform",
Expand Down
12 changes: 10 additions & 2 deletions pkg/sql/catalog/schemaexpr/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/sem/cast"
"github.com/cockroachdb/cockroach/pkg/sql/sem/eval"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sem/volatility"
Expand Down Expand Up @@ -105,6 +106,7 @@ func DequalifyAndValidateExpr(
context,
semaCtx,
maxVolatility,
false,
)

if err != nil {
Expand Down Expand Up @@ -295,7 +297,7 @@ func deserializeExprForFormatting(
// typedExpr.
if fmtFlags == tree.FmtPGCatalog {
sanitizedExpr, err := SanitizeVarFreeExpr(ctx, expr, typedExpr.ResolvedType(), "FORMAT", semaCtx,
volatility.Immutable)
volatility.Immutable, false)
// If the expr has no variables and has Immutable, we can evaluate
// it and turn it into a constant.
if err == nil {
Expand Down Expand Up @@ -398,6 +400,7 @@ func SanitizeVarFreeExpr(
context string,
semaCtx *tree.SemaContext,
maxVolatility volatility.V,
allowAssignmentCast bool,
) (tree.TypedExpr, error) {
if tree.ContainsVars(expr) {
return nil, pgerror.Newf(pgcode.Syntax,
Expand Down Expand Up @@ -436,7 +439,12 @@ func SanitizeVarFreeExpr(
actualType := typedExpr.ResolvedType()
if !expectedType.Equivalent(actualType) && typedExpr != tree.DNull {
// The expression must match the column type exactly unless it is a constant
// NULL value.
// NULL value or assignment casts are allowed.
if allowAssignmentCast {
if ok := cast.ValidCast(actualType, expectedType, cast.ContextAssignment); ok {
return typedExpr, nil
}
}
return nil, fmt.Errorf("expected %s expression to have type %s, but '%s' has type %s",
context, expectedType, expr, actualType)
}
Expand Down
28 changes: 4 additions & 24 deletions pkg/sql/catalog/tabledesc/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/catalog/schemaexpr"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/sem/cast"
"github.com/cockroachdb/cockroach/pkg/sql/sem/eval"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sem/volatility"
Expand Down Expand Up @@ -159,22 +158,13 @@ func MakeColumnDefDescs(
col.Type = resType

if d.HasDefaultExpr() {
var innerErr error
// Verify the default expression type is compatible with the column type
// and does not contain invalid functions.
ret.DefaultExpr, err = schemaexpr.SanitizeVarFreeExpr(
ctx, d.DefaultExpr.Expr, resType, "DEFAULT", semaCtx, volatility.Volatile,
ctx, d.DefaultExpr.Expr, resType, "DEFAULT", semaCtx, volatility.Volatile, true,
)
if err != nil {
// Check if the default expression type can be assignment-cast into the
// column type. If it can allow the default and column type to differ.
ret.DefaultExpr, innerErr = d.DefaultExpr.Expr.TypeCheck(ctx, semaCtx, types.Any)
if innerErr != nil {
return nil, err
}
if ok := cast.ValidCast(ret.DefaultExpr.ResolvedType(), resType, cast.ContextAssignment); !ok {
return nil, err
}
return nil, err
}

// Keep the type checked expression so that the type annotation gets
Expand All @@ -190,20 +180,10 @@ func MakeColumnDefDescs(
if d.HasOnUpdateExpr() {
// Verify the on update expression type is compatible with the column type
// and does not contain invalid functions.
var innerErr error
ret.OnUpdateExpr, err = schemaexpr.SanitizeVarFreeExpr(
ctx, d.OnUpdateExpr.Expr, resType, "ON UPDATE", semaCtx, volatility.Volatile,
ctx, d.OnUpdateExpr.Expr, resType, "ON UPDATE", semaCtx, volatility.Volatile, true,
)
if err != nil {
// Check if the on update expression type can be assignment-cast into the
// column type. If it can allow the on update expr and column type to differ.
ret.OnUpdateExpr, innerErr = d.OnUpdateExpr.Expr.TypeCheck(ctx, semaCtx, types.Any)
if innerErr != nil {
return nil, err
}
if ok := cast.ValidCast(ret.OnUpdateExpr.ResolvedType(), resType, cast.ContextAssignment); !ok {
return nil, err
}
return nil, err
}

Expand Down Expand Up @@ -286,7 +266,7 @@ func EvalShardBucketCount(
shardBuckets = paramVal
}
typedExpr, err := schemaexpr.SanitizeVarFreeExpr(
ctx, shardBuckets, types.Int, "BUCKET_COUNT", semaCtx, volatility.Volatile,
ctx, shardBuckets, types.Int, "BUCKET_COUNT", semaCtx, volatility.Volatile, false,
)
if err != nil {
return 0, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (p *planner) fillInPlaceholders(
}
}
typedExpr, err := schemaexpr.SanitizeVarFreeExpr(
ctx, e, typ, "EXECUTE parameter" /* context */, &semaCtx, volatility.Volatile,
ctx, e, typ, "EXECUTE parameter" /* context */, &semaCtx, volatility.Volatile, false,
)
if err != nil {
return nil, pgerror.WithCandidateCode(err, pgcode.WrongObjectType)
Expand Down
40 changes: 28 additions & 12 deletions pkg/sql/logictest/testdata/logic_test/cast
Original file line number Diff line number Diff line change
Expand Up @@ -1403,26 +1403,42 @@ select CASE WHEN false THEN 1::REGTYPE ELSE 2::REGPROCEDURE END;
select CASE WHEN false THEN 1::REGTYPE ELSE 2::REGROLE END;


# Test that default/on update expression differing from column type
# Test that default/on update expression is allowed to differ from column type
statement ok
CREATE TABLE def_assn_cast (
a INT4 DEFAULT 1.0::FLOAT4,
b VARCHAR DEFAULT 'true'::BOOL,
c NAME DEFAULT 'foo'::BPCHAR,
d JSONB DEFAULT 'null'::CHAR,
id INT4,
a INT4 DEFAULT 1.0::FLOAT4,
b VARCHAR DEFAULT 'true'::BOOL,
c NAME DEFAULT 'foo'::BPCHAR,
d VARCHAR DEFAULT '{ "customer": "John Doe"}'::JSONB
)

statement error pq: could not parse . as type .: invalid . value
# Ensure insertions are allowed
statement ok
INSERT INTO def_assn_cast(id) VALUES (1)

query IITTT
SELECT * from def_assn_cast
----
1 1 true f {"customer": "John Doe"}


statement error pq: could not parse .* as type .*: invalid .* value
CREATE TABLE fail_assn_cast (
a BOOL DEFAULT 'foo'
a BOOL DEFAULT 'foo'
)

statement error pq: could not parse . as type .: invalid . value
statement error pq: expected DEFAULT expression to have type .*, but .* has type .*
CREATE TABLE fail_assn_cast (
b JSONB DEFAULT 'null'::CHAR
a DATE DEFAULT 1.0::FLOAT4
)

statement error pq: expected DEFAULT expression to have type ., but . has type .
CREATE TABLE def_assn_cast (
a INT4 DEFAULT 1.0::BOOL
statement error pq: expected DEFAULT expression to have type .*, but .* has type .*
CREATE TABLE fail_assn_cast (
b JSONB DEFAULT 'null'::CHAR
)

statement error pq: expected DEFAULT expression to have type .*, but .* has type .*
CREATE TABLE fail_assn_cast (
a INT4 DEFAULT 1.0::BOOL
)
1 change: 1 addition & 0 deletions pkg/sql/opt/bench/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ func newHarness(tb testing.TB, query benchQuery) *harness {
"", /* context */
&h.semaCtx,
volatility.Volatile,
false,
)
if err != nil {
tb.Fatalf("%v", err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/groupby.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ func (b *Builder) analyzeHaving(having *tree.Where, fromScope *scope) tree.Typed
exprKindHaving.String(), tree.RejectWindowApplications|tree.RejectGenerators,
)
fromScope.context = exprKindHaving
return fromScope.resolveAndRequireType(having.Expr, types.Bool)
return fromScope.resolveAndRequireType(having.Expr, types.Bool, "")
}

// buildHaving builds a set of memo groups that represent the given HAVING
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (b *Builder) buildJoin(
)
outScope.context = exprKindOn
filter := b.buildScalar(
outScope.resolveAndRequireType(on.Expr, types.Bool), outScope, nil, nil, nil,
outScope.resolveAndRequireType(on.Expr, types.Bool, ""), outScope, nil, nil, nil,
)
filters = memo.FiltersExpr{b.factory.ConstructFiltersItem(filter)}
} else {
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/opt/optbuilder/mutation_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ func (mb *mutationBuilder) addCheckConstraintCols(isUpdate bool) {
panic(err)
}

texpr := mb.outScope.resolveAndRequireType(expr, types.Bool)
texpr := mb.outScope.resolveAndRequireType(expr, types.Bool, "")

// Use an anonymous name because the column cannot be referenced
// in other expressions.
Expand Down Expand Up @@ -835,7 +835,7 @@ func (mb *mutationBuilder) projectPartialIndexColsImpl(putScope, delScope *scope

// Build synthesized PUT columns.
if putScope != nil {
texpr := putScope.resolveAndRequireType(expr, types.Bool)
texpr := putScope.resolveAndRequireType(expr, types.Bool, "")

// Use an anonymous name because the column cannot be referenced
// in other expressions.
Expand All @@ -848,7 +848,7 @@ func (mb *mutationBuilder) projectPartialIndexColsImpl(putScope, delScope *scope

// Build synthesized DEL columns.
if delScope != nil {
texpr := delScope.resolveAndRequireType(expr, types.Bool)
texpr := delScope.resolveAndRequireType(expr, types.Bool, "")

// Use an anonymous name because the column cannot be referenced
// in other expressions.
Expand Down
10 changes: 5 additions & 5 deletions pkg/sql/opt/optbuilder/mutation_builder_arbiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func (mb *mutationBuilder) buildAntiJoinForDoNothingArbiter(
// wraps the scan on the right side of the anti-join with the partial
// index predicate expression as the filter.
if pred != nil {
texpr := fetchScope.resolveAndRequireType(pred, types.Bool)
texpr := fetchScope.resolveAndRequireType(pred, types.Bool, "")
predScalar := mb.b.buildScalar(texpr, fetchScope, nil, nil, nil)
fetchScope.expr = mb.b.factory.ConstructSelect(
fetchScope.expr,
Expand Down Expand Up @@ -333,7 +333,7 @@ func (mb *mutationBuilder) buildAntiJoinForDoNothingArbiter(
// rows in the unique partial index. Therefore, the partial index
// predicate expression is added to the ON filters.
if pred != nil {
texpr := mb.outScope.resolveAndRequireType(pred, types.Bool)
texpr := mb.outScope.resolveAndRequireType(pred, types.Bool, "")
predScalar := mb.b.buildScalar(texpr, mb.outScope, nil, nil, nil)
on = append(on, mb.b.factory.ConstructFiltersItem(predScalar))
}
Expand Down Expand Up @@ -388,7 +388,7 @@ func (mb *mutationBuilder) buildLeftJoinForUpsertArbiter(
// the scan on the right side of the left outer join with the partial index
// predicate expression as the filter.
if pred != nil {
texpr := mb.fetchScope.resolveAndRequireType(pred, types.Bool)
texpr := mb.fetchScope.resolveAndRequireType(pred, types.Bool, "")
predScalar := mb.b.buildScalar(texpr, mb.fetchScope, nil, nil, nil)
mb.fetchScope.expr = mb.b.factory.ConstructSelect(
mb.fetchScope.expr,
Expand Down Expand Up @@ -418,7 +418,7 @@ func (mb *mutationBuilder) buildLeftJoinForUpsertArbiter(
// the unique partial index. Therefore, the partial index predicate
// expression is added to the ON filters.
if pred != nil {
texpr := mb.outScope.resolveAndRequireType(pred, types.Bool)
texpr := mb.outScope.resolveAndRequireType(pred, types.Bool, "")
predScalar := mb.b.buildScalar(texpr, mb.outScope, nil, nil, nil)
on = append(on, mb.b.factory.ConstructFiltersItem(predScalar))
}
Expand Down Expand Up @@ -523,7 +523,7 @@ func (mb *mutationBuilder) projectPartialArbiterDistinctColumn(
Left: pred,
Right: tree.DNull,
}
texpr := insertScope.resolveAndRequireType(expr, types.Bool)
texpr := insertScope.resolveAndRequireType(expr, types.Bool, "")

// Use an anonymous name because the column cannot be referenced
// in other expressions.
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/opt/optbuilder/mutation_builder_unique.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (mb *mutationBuilder) uniqueColsUpdated(uniqueOrdinal cat.UniqueOrdinal) bo

if _, isPartial := uc.Predicate(); isPartial {
pred := mb.parseUniqueConstraintPredicateExpr(uniqueOrdinal)
typedPred := mb.fetchScope.resolveAndRequireType(pred, types.Bool)
typedPred := mb.fetchScope.resolveAndRequireType(pred, types.Bool, "")

var predCols opt.ColSet
mb.b.buildScalar(typedPred, mb.fetchScope, nil, nil, &predCols)
Expand Down Expand Up @@ -343,11 +343,11 @@ func (h *uniqueCheckHelper) buildInsertionCheck() memo.UniqueChecksItem {
if isPartial {
pred := h.mb.parseUniqueConstraintPredicateExpr(h.uniqueOrdinal)

typedPred := withScanScope.resolveAndRequireType(pred, types.Bool)
typedPred := withScanScope.resolveAndRequireType(pred, types.Bool, "")
withScanPred := h.mb.b.buildScalar(typedPred, withScanScope, nil, nil, nil)
semiJoinFilters = append(semiJoinFilters, f.ConstructFiltersItem(withScanPred))

typedPred = h.scanScope.resolveAndRequireType(pred, types.Bool)
typedPred = h.scanScope.resolveAndRequireType(pred, types.Bool, "")
scanPred := h.mb.b.buildScalar(typedPred, h.scanScope, nil, nil, nil)
semiJoinFilters = append(semiJoinFilters, f.ConstructFiltersItem(scanPred))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/partial_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,5 +214,5 @@ func resolvePartialIndexPredicate(tableScope *scope, expr tree.Expr) tree.TypedE
panic(errors.AssertionFailedf("unexpected error during partial index predicate type resolution: %v", r))
}
}()
return tableScope.resolveAndRequireType(expr, types.Bool)
return tableScope.resolveAndRequireType(expr, types.Bool, "")
}
10 changes: 6 additions & 4 deletions pkg/sql/opt/optbuilder/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,11 @@ func (s *scope) resolveType(expr tree.Expr, desired *types.T) tree.TypedExpr {
// expression with no error). If the result type is types.Unknown, then
// resolveType will wrap the expression in a type cast in order to produce the
// desired type.
func (s *scope) resolveAndRequireType(expr tree.Expr, desired *types.T) tree.TypedExpr {
func (s *scope) resolveAndRequireType(
expr tree.Expr, desired *types.T, mutSuffix string,
) tree.TypedExpr {
expr = s.walkExprTree(expr)
texpr, err := tree.TypeCheckAndRequire(s.builder.ctx, expr, s.builder.semaCtx, desired, s.context.String())
texpr, err := tree.TypeCheckAndRequire(s.builder.ctx, expr, s.builder.semaCtx, desired, s.context.String(), true)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -1452,13 +1454,13 @@ func analyzeWindowFrame(s *scope, windowDef *tree.WindowDef) error {
if startBound != nil && startBound.OffsetExpr != nil {
oldContext := s.context
s.context = exprKindWindowFrameStart
startBound.OffsetExpr = s.resolveAndRequireType(startBound.OffsetExpr, requiredType)
startBound.OffsetExpr = s.resolveAndRequireType(startBound.OffsetExpr, requiredType, "")
s.context = oldContext
}
if endBound != nil && endBound.OffsetExpr != nil {
oldContext := s.context
s.context = exprKindWindowFrameEnd
endBound.OffsetExpr = s.resolveAndRequireType(endBound.OffsetExpr, requiredType)
endBound.OffsetExpr = s.resolveAndRequireType(endBound.OffsetExpr, requiredType, "")
s.context = oldContext
}
return nil
Expand Down
Loading

0 comments on commit e9962a8

Please sign in to comment.