Skip to content

Commit

Permalink
sql: remove usages of types.IsRecordType
Browse files Browse the repository at this point in the history
Previously, the `types.IsRecordType` function was used in different
contexts (with vs without OUT-params, function params vs return type).
This made it difficult to determine whether a particular usage was
correct, and led to a few bugs in cases where additional checks
were necessary.

This commit replaces usages of `types.IsRecordType` with either:
1. `typ.Identical(types.AnyTuple)`, or
2. `typ.Oid() == oid.T_record`

The former should be used for a RECORD-returning routine with no
OUT-parameters, as well as for a RECORD-typed variable. The latter
should be used to match either a RECORD-returning routine, or one
with multiple OUT-parameters.

Informs #114846

Release note: None
  • Loading branch information
DrewKimball committed Apr 12, 2024
1 parent 900f8b6 commit 5b6e564
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 39 deletions.
2 changes: 1 addition & 1 deletion pkg/internal/sqlsmith/plpgsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (s *Smither) makePLpgSQLDeclarations(
varName = s.name("decl")
}
varTyp := s.randType()
for types.IsRecordType(varTyp) || varTyp.Family() == types.CollatedStringFamily {
for varTyp.Identical(types.AnyTuple) || varTyp.Family() == types.CollatedStringFamily {
// TODO(#114874): allow record types here when they are supported.
// TODO(#105245): allow collated strings when they are supported.
varTyp = s.randType()
Expand Down
2 changes: 1 addition & 1 deletion pkg/internal/sqlsmith/relational.go
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ func (s *Smither) makeCreateFunc() (cf *tree.CreateRoutine, ok bool) {
// TODO(#105713): lift the RECORD-type restriction.
ptyp := s.randType()
for ptyp.Family() == types.CollatedStringFamily ||
(lang == tree.RoutineLangPLpgSQL && types.IsRecordType(ptyp)) {
(lang == tree.RoutineLangPLpgSQL && ptyp.Identical(types.AnyTuple)) {
ptyp = s.randType()
}
pname := fmt.Sprintf("p%d", i)
Expand Down
4 changes: 1 addition & 3 deletions pkg/sql/catalog/funcdesc/func_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,9 +809,7 @@ func (desc *immutable) ToOverload() (ret *tree.Overload, err error) {
ret.RoutineParams = append(ret.RoutineParams, routineParam)
}
ret.ReturnType = tree.FixedReturnType(desc.ReturnType.Type)
// TODO(yuzefovich): we should not be setting ReturnsRecordType to 'true'
// when the return type is based on output parameters.
ret.ReturnsRecordType = types.IsRecordType(desc.ReturnType.Type)
ret.ReturnsRecordType = desc.ReturnType.Type.Identical(types.AnyTuple)
ret.Types = signatureTypes
ret.Volatility, err = desc.getOverloadVolatility()
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions pkg/sql/opt/optbuilder/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (b *Builder) buildCreateFunction(cf *tree.CreateRoutine, inScope *scope) (o
}
// The parameter type must be supported by the current cluster version.
checkUnsupportedType(b.ctx, b.semaCtx, typ)
if types.IsRecordType(typ) {
if typ.Identical(types.AnyTuple) {
if language == tree.RoutineLangSQL {
panic(pgerror.Newf(pgcode.InvalidFunctionDefinition,
"SQL functions cannot have arguments of type record"))
Expand Down Expand Up @@ -478,15 +478,15 @@ func validateReturnType(

// If return type is RECORD and the tuple content types unspecified by OUT
// parameters, any column types are valid. This is the case when we have
// RETURNS RECORD without OUT params - we don't need to check the types
// RETURNS RECORD without OUT-params - we don't need to check the types
// below.
if types.IsRecordType(expected) && types.IsWildcardTupleType(expected) {
if expected.Identical(types.AnyTuple) {
return nil
}

if len(cols) == 1 {
typeToCheck := expected
if isSQLProcedure && types.IsRecordType(expected) && len(expected.TupleContents()) == 1 {
if isSQLProcedure && len(expected.TupleContents()) == 1 {
// For SQL procedures with output parameters we get a record type
// even with a single column.
typeToCheck = expected.TupleContents()[0]
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/optbuilder/plpgsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ func (b *plpgsqlBuilder) buildBlock(astBlock *ast.Block, s *scope) *scope {
if err != nil {
panic(err)
}
if types.IsRecordType(typ) {
if typ.Identical(types.AnyTuple) {
panic(recordVarErr)
}
b.addVariable(dec.Var, typ)
Expand All @@ -360,7 +360,7 @@ func (b *plpgsqlBuilder) buildBlock(astBlock *ast.Block, s *scope) *scope {
block.cursors[dec.Name] = *dec
}
}
if types.IsRecordType(b.returnType) && types.IsWildcardTupleType(b.returnType) {
if b.returnType.Identical(types.AnyTuple) {
// For a RECORD-returning routine, infer the concrete type by examining the
// RETURN statements. This has to happen after building the declaration
// block because RETURN statements can reference declared variables. Only
Expand Down
52 changes: 34 additions & 18 deletions pkg/sql/opt/optbuilder/routine.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func (b *Builder) buildRoutine(
// be concrete in order to decode them correctly. We can determine the types
// from the result columns or tuple of the last statement.
finishResolveType := func(lastStmtScope *scope) {
if types.IsWildcardTupleType(rtyp) {
if rtyp.Identical(types.AnyTuple) {
if len(lastStmtScope.cols) == 1 &&
lastStmtScope.cols[0].typ.Family() == types.TupleFamily {
// When the final statement returns a single tuple, we can use
Expand Down Expand Up @@ -502,6 +502,10 @@ func (b *Builder) finishBuildLastStmt(
expr, physProps = stmtScope.expr, stmtScope.makePhysicalProps()
rtyp := f.ResolvedType()

// Note: since the final return type has already been resolved by this point,
// we can't check if this is a RECORD-returning routine by examining rTyp.
isRecordReturning := f.ResolvedOverload().ReturnsRecordType

// Add a LIMIT 1 to the last statement if the UDF is not
// set-returning. This is valid because any other rows after the
// first can simply be ignored. The limit could be beneficial
Expand Down Expand Up @@ -542,21 +546,32 @@ func (b *Builder) finishBuildLastStmt(
expr = b.constructProject(expr, elems)
physProps = stmtScope.makePhysicalProps()
}
} else if len(cols) > 1 || (types.IsRecordType(rtyp) && !isSingleTupleResult) {
// Only a single column can be returned from a UDF, unless it is used as a
// data source (see comment above). If there are multiple columns, combine
// them into a tuple. If the last statement is already returning a tuple
// and the function has a record return type, then do not wrap the
// output in another tuple.
elems := make(memo.ScalarListExpr, len(cols))
for i := range cols {
elems[i] = b.factory.ConstructVariable(cols[i].ID)
} else {
// Only a single column can be returned from a routine, unless it is a UDF
// used as a data source (see comment above). There are three cases in which
// we must wrap the column(s) from the last statement into a single tuple:
// 1. The last statement has multiple result columns.
// 2. The routine returns RECORD, and the last statement does not already
// return a tuple column.
// 3. The routine is a stored procedure that returns a non-VOID type, and
// the last statement does not already return a tuple column.
overload := f.ResolvedOverload()
mustWrapColsInTuple := len(cols) > 1
if len(cols) == 1 && !isSingleTupleResult {
mustWrapColsInTuple = mustWrapColsInTuple || isRecordReturning ||
(rtyp.Family() != types.VoidFamily && overload.Type == tree.ProcedureRoutine)
}
if mustWrapColsInTuple {
elems := make(memo.ScalarListExpr, len(cols))
for i := range cols {
elems[i] = b.factory.ConstructVariable(cols[i].ID)
}
tup := b.factory.ConstructTuple(elems, rtyp)
stmtScope = bodyScope.push()
col := b.synthesizeColumn(stmtScope, scopeColName(""), rtyp, nil /* expr */, tup)
expr = b.constructProject(expr, []scopeColumn{*col})
physProps = stmtScope.makePhysicalProps()
}
tup := b.factory.ConstructTuple(elems, rtyp)
stmtScope = bodyScope.push()
col := b.synthesizeColumn(stmtScope, scopeColName(""), rtyp, nil /* expr */, tup)
expr = b.constructProject(expr, []scopeColumn{*col})
physProps = stmtScope.makePhysicalProps()
}

// We must preserve the presentation of columns as physical
Expand All @@ -569,17 +584,18 @@ func (b *Builder) finishBuildLastStmt(
if len(cols) > 0 {
returnCol := physProps.Presentation[0].ID
returnColMeta := b.factory.Metadata().ColumnMeta(returnCol)
if !types.IsRecordType(rtyp) && !isMultiColDataSource && !returnColMeta.Type.Identical(rtyp) {
if !isRecordReturning && !isMultiColDataSource &&
!returnColMeta.Type.Identical(rtyp) {
if !cast.ValidCast(returnColMeta.Type, rtyp, cast.ContextAssignment) {
panic(sqlerrors.NewInvalidAssignmentCastError(
returnColMeta.Type, rtyp, returnColMeta.Alias))
}
cast := b.factory.ConstructAssignmentCast(
assignCast := b.factory.ConstructAssignmentCast(
b.factory.ConstructVariable(physProps.Presentation[0].ID),
rtyp,
)
stmtScope = bodyScope.push()
col := b.synthesizeColumn(stmtScope, scopeColName(""), rtyp, nil /* expr */, cast)
col := b.synthesizeColumn(stmtScope, scopeColName(""), rtyp, nil /* expr */, assignCast)
expr = b.constructProject(expr, []scopeColumn{*col})
physProps = stmtScope.makePhysicalProps()
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/testutils/testcat/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (tc *Catalog) CreateRoutine(c *tree.CreateRoutine) {
OutParamTypes: outParams,
DefaultExprs: defaultExprs,
}
overload.ReturnsRecordType = types.IsRecordType(retType)
overload.ReturnsRecordType = retType.Identical(types.AnyTuple)
if c.ReturnType != nil && c.ReturnType.SetOf {
overload.Class = tree.GeneratorClass
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/sqlerrors"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/errors"
"github.com/lib/pq/oid"
)

func CreateFunction(b BuildCtx, n *tree.CreateRoutine) {
Expand Down Expand Up @@ -66,7 +67,7 @@ func CreateFunction(b BuildCtx, n *tree.CreateRoutine) {
if n.IsProcedure {
if n.ReturnType != nil {
returnType := b.ResolveTypeRef(n.ReturnType.Type)
if returnType.Type.Family() != types.VoidFamily && !types.IsRecordType(returnType.Type) {
if returnType.Type.Family() != types.VoidFamily && returnType.Type.Oid() != oid.T_record {
panic(errors.AssertionFailedf(
"CreateRoutine.ReturnType is expected to be empty, VOID, or RECORD for procedures",
))
Expand All @@ -79,7 +80,7 @@ func CreateFunction(b BuildCtx, n *tree.CreateRoutine) {
}
} else if n.ReturnType != nil {
typ = n.ReturnType.Type
if returnType := b.ResolveTypeRef(typ); types.IsRecordType(returnType.Type) {
if returnType := b.ResolveTypeRef(typ); returnType.Type.Oid() == oid.T_record {
// If the function returns a RECORD type, then we need to check
// whether its OUT parameters specify labels for the return type.
outParamTypes, outParamNames := getOutputParameters(b, n.Params)
Expand Down
7 changes: 0 additions & 7 deletions pkg/sql/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2836,13 +2836,6 @@ func IsWildcardTupleType(t *T) bool {
return len(t.TupleContents()) == 1 && t.TupleContents()[0].Family() == AnyFamily
}

// IsRecordType returns true if this is a RECORD type. This should only be used
// when processing UDFs. A record differs from AnyTuple in that the tuple
// contents may contain types other than Any.
func IsRecordType(typ *T) bool {
return typ.Family() == TupleFamily && typ.Oid() == oid.T_record
}

// collatedStringTypeSQL returns the string representation of a COLLATEDSTRING
// or []COLLATEDSTRING type. This is tricky in the case of an array of collated
// string, since brackets must precede the COLLATE identifier:
Expand Down

0 comments on commit 5b6e564

Please sign in to comment.