Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
120851: sql: support output parameters in procedures r=yuzefovich a=yuzefovich

See each commit for details.

Addresses: cockroachdb#100405.
Epic: CRDB-30611

Co-authored-by: Yahor Yuzefovich <[email protected]>
  • Loading branch information
craig[bot] and yuzefovich committed Mar 23, 2024
2 parents 1084d88 + 6406f4e commit 4a9385c
Show file tree
Hide file tree
Showing 32 changed files with 1,733 additions and 270 deletions.
534 changes: 532 additions & 2 deletions pkg/ccl/logictestccl/testdata/logic_test/procedure_params

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions pkg/ccl/logictestccl/testdata/logic_test/udf_params
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ SELECT * FROM f();
param
1

statement error pgcode 42809 f\(unknown\) is not a procedure
CALL f(NULL);

statement error pgcode 42809 f\(int\) is not a procedure
CALL f(NULL::INT);

statement error pgcode 42809 f\(int\) is not a procedure
CALL f(1);

statement ok
DROP FUNCTION f;

Expand Down
62 changes: 62 additions & 0 deletions pkg/ccl/logictestccl/testdata/logic_test/udf_rewrite
Original file line number Diff line number Diff line change
Expand Up @@ -488,4 +488,66 @@ SELECT get_body_str('p_rewrite');
----
"BEGIN\nUPDATE test.public.t_rewrite SET w = b'\\xa0':::@100107 WHERE w = b'\\x80':::@100107 RETURNING w;\nEND;\n"

statement ok
DROP PROCEDURE p_rewrite();

statement ok
CREATE PROCEDURE p_rewrite(INOUT param1 weekday, OUT param2 weekday) AS
$$
BEGIN
param2 = param1;
param1 = 'friday'::weekday;
END
$$ LANGUAGE PLPGSQL;

query T
SELECT get_body_str('p_rewrite');
----
"BEGIN\nparam2 := param1;\nparam1 := b'\\xc0':::@100107;\nEND;\n"

query TT
SHOW CREATE PROCEDURE p_rewrite;
----
p_rewrite CREATE PROCEDURE public.p_rewrite(INOUT param1 test.public.weekday, OUT param2 test.public.weekday)
LANGUAGE plpgsql
AS $$
BEGIN
param2 := param1;
param1 := 'friday':::test.public.weekday;
END;
$$

statement ok
ALTER TYPE weekday RENAME VALUE 'friday' TO 'humpday';

statement ok
ALTER TYPE weekday RENAME TO workday;

query T
SELECT get_body_str('p_rewrite');
----
"BEGIN\nparam2 := param1;\nparam1 := b'\\xc0':::@100107;\nEND;\n"

query TT
SHOW CREATE PROCEDURE p_rewrite;
----
p_rewrite CREATE PROCEDURE public.p_rewrite(INOUT param1 test.public.workday, OUT param2 test.public.workday)
LANGUAGE plpgsql
AS $$
BEGIN
param2 := param1;
param1 := 'humpday':::test.public.workday;
END;
$$

# Reset types for subtest.
statement ok
ALTER TYPE workday RENAME TO weekday;

statement ok
ALTER TYPE weekday RENAME VALUE 'humpday' TO 'friday';

statement ok
DROP PROCEDURE p_rewrite;

subtest end
9 changes: 7 additions & 2 deletions pkg/sql/alter_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,15 @@ func toSchemaOverloadSignature(fnDesc *funcdesc.Mutable) descpb.SchemaDescriptor
ReturnSet: fnDesc.ReturnType.ReturnSet,
IsProcedure: fnDesc.IsProcedure(),
}
for _, param := range fnDesc.Params {
if tree.IsParamIncludedIntoSignature(funcdesc.ToTreeRoutineParamClass(param.Class), ret.IsProcedure) {
for paramIdx, param := range fnDesc.Params {
class := funcdesc.ToTreeRoutineParamClass(param.Class)
if tree.IsInParamClass(class) {
ret.ArgTypes = append(ret.ArgTypes, param.Type)
}
if class == tree.RoutineParamOut {
ret.OutParamOrdinals = append(ret.OutParamOrdinals, int32(paramIdx))
ret.OutParamTypes = append(ret.OutParamTypes, param.Type)
}
}
return ret
}
14 changes: 12 additions & 2 deletions pkg/sql/catalog/descpb/structured.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1558,15 +1558,25 @@ message SchemaDescriptor {
optional uint32 id = 1 [(gogoproto.nullable) = false,
(gogoproto.customname) = "ID", (gogoproto.casttype) = "ID"];

// ArgTypes contains only IN / INOUT parameters when IsProcedure is false,
// and all parameters when IsProcedure is true.
// ArgTypes contains only input parameters.
repeated sql.sem.types.T arg_types = 2;

optional sql.sem.types.T return_type = 3;

optional bool return_set = 4 [(gogoproto.nullable) = false];

optional bool is_procedure = 5 [(gogoproto.nullable) = false];

// OutParamOrdinals contains all ordinals of OUT parameters among all
// parameters in the function definition. For example, if the function /
// procedure has a definition like
// (IN p0, INOUT p1, OUT p2, INOUT p3, OUT p4),
// then OutParamOrdinals will contain [2, 4] (while ArgTypes will contain
// types of [p0, p1, p3]).
repeated int32 out_param_ordinals = 6;
// OutParamTypes contains types of all OUT parameters (it has 1-to-1 match
// with OutParamOrdinals).
repeated sql.sem.types.T out_param_types = 7;
}

// Function contains a group of UDFs with the same name.
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/catalog/funcdesc/func_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ func (desc *immutable) ToOverload() (ret *tree.Overload, err error) {
signatureTypes := make(tree.ParamTypes, 0, len(desc.Params))
for _, param := range desc.Params {
class := ToTreeRoutineParamClass(param.Class)
if tree.IsParamIncludedIntoSignature(class, desc.IsProcedure()) {
if tree.IsInParamClass(class) {
signatureTypes = append(signatureTypes, tree.ParamType{Name: param.Name, Typ: param.Type})
}
ret.RoutineParams = append(ret.RoutineParams, tree.RoutineParam{
Expand Down
5 changes: 5 additions & 0 deletions pkg/sql/catalog/rewrite/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,11 @@ func SchemaDescs(schemas []*schemadesc.Mutable, descriptorRewrites jobspb.DescRe
if err := rewriteIDsInTypesT(sig.ReturnType, descriptorRewrites); err != nil {
return err
}
for _, typ := range sig.OutParamTypes {
if err := rewriteIDsInTypesT(typ, descriptorRewrites); err != nil {
return err
}
}
newSigs = append(newSigs, *sig)
}
if len(newSigs) > 0 {
Expand Down
47 changes: 47 additions & 0 deletions pkg/sql/catalog/schemadesc/schema_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ func (desc *immutable) ForEachUDTDependentForHydration(fn func(t *types.T) error
return iterutil.Map(err)
}
}
for _, typ := range sig.OutParamTypes {
if !catid.IsOIDUserDefined(typ.Oid()) {
continue
}
if err := fn(typ); err != nil {
return iterutil.Map(err)
}
}
if !catid.IsOIDUserDefined(sig.ReturnType.Oid()) {
continue
}
Expand Down Expand Up @@ -484,6 +492,37 @@ func (desc *Mutable) RemoveFunction(name string, id descpb.ID) {
}
}

// ReplaceOverload updates the function signature that matches the existing
// overload with the new one. An error is returned if the function doesn't exist
// or a match is not found.
func (desc *Mutable) ReplaceOverload(
name string,
existing *tree.QualifiedOverload,
newSignature descpb.SchemaDescriptor_FunctionSignature,
) error {
fn, ok := desc.Functions[name]
if !ok {
return errors.AssertionFailedf("unexpectedly didn't find a function %s", name)
}
for i := range fn.Signatures {
sig := fn.Signatures[i]
match := existing.Types.Length() == len(sig.ArgTypes) &&
len(existing.OutParamOrdinals) == len(sig.OutParamOrdinals)
for j := 0; match && j < len(sig.ArgTypes); j++ {
match = existing.Types.GetAt(j).Equivalent(sig.ArgTypes[j])
}
for j := 0; match && j < len(sig.OutParamOrdinals); j++ {
match = existing.OutParamOrdinals[j] == sig.OutParamOrdinals[j] &&
existing.OutParamTypes.GetAt(j).Equivalent(sig.OutParamTypes[j])
}
if match {
fn.Signatures[i] = newSignature
return nil
}
}
return errors.AssertionFailedf("unexpectedly didn't find overload match for function %s with types %v", name, existing.Types.Types())
}

// GetObjectType implements the Object interface.
func (desc *immutable) GetObjectType() privilege.ObjectType {
return privilege.Schema
Expand Down Expand Up @@ -522,6 +561,7 @@ func (desc *immutable) GetResolvedFuncDefinition(
},
Type: routineType,
UDFContainsOnlySignature: true,
OutParamOrdinals: sig.OutParamOrdinals,
}
if funcDescPb.Signatures[i].ReturnSet {
overload.Class = tree.GeneratorClass
Expand All @@ -537,6 +577,13 @@ func (desc *immutable) GetResolvedFuncDefinition(
)
}
overload.Types = paramTypes
if len(sig.OutParamTypes) > 0 {
outParamTypes := make(tree.ParamTypes, len(sig.OutParamTypes))
for j := range outParamTypes {
outParamTypes[j] = tree.ParamType{Typ: sig.OutParamTypes[j]}
}
overload.OutParamTypes = outParamTypes
}
prefixedOverload := tree.MakeQualifiedOverload(desc.GetName(), overload)
funcDef.Overloads = append(funcDef.Overloads, prefixedOverload)
}
Expand Down
Loading

0 comments on commit 4a9385c

Please sign in to comment.