diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index aea82fc6ff6c..955815abf240 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -546,6 +546,7 @@ go_test( "explain_bundle_test.go", "explain_test.go", "explain_tree_test.go", + "function_resolver_test.go", "grant_revoke_test.go", "grant_role_test.go", "index_mutation_test.go", diff --git a/pkg/sql/catalog/descriptor.go b/pkg/sql/catalog/descriptor.go index 93228971d72a..4bad937abfe7 100644 --- a/pkg/sql/catalog/descriptor.go +++ b/pkg/sql/catalog/descriptor.go @@ -871,6 +871,10 @@ type FunctionDescriptor interface { // FuncDesc returns the function's underlying protobuf descriptor. FuncDesc() *descpb.FunctionDescriptor + + // ToOverload converts the function descriptor to tree.Overload object which + // can be used for execution. + ToOverload() (ret *tree.Overload, err error) } // FilterDescriptorState inspects the state of a given descriptor and returns an diff --git a/pkg/sql/catalog/funcdesc/BUILD.bazel b/pkg/sql/catalog/funcdesc/BUILD.bazel index a6e015f59b8f..510a749e0268 100644 --- a/pkg/sql/catalog/funcdesc/BUILD.bazel +++ b/pkg/sql/catalog/funcdesc/BUILD.bazel @@ -19,12 +19,15 @@ go_library( "//pkg/sql/pgwire/pgerror", "//pkg/sql/privilege", "//pkg/sql/schemachanger/scpb", + "//pkg/sql/sem/catid", "//pkg/sql/sem/eval", "//pkg/sql/sem/tree", + "//pkg/sql/sem/volatility", "//pkg/sql/types", "//pkg/util/hlc", "//pkg/util/protoutil", "@com_github_cockroachdb_errors//:errors", + "@com_github_lib_pq//oid", ], ) @@ -45,8 +48,12 @@ go_test( "//pkg/sql/catalog/tabledesc", "//pkg/sql/catalog/typedesc", "//pkg/sql/privilege", + "//pkg/sql/sem/tree", + "//pkg/sql/sem/volatility", "//pkg/sql/types", "//pkg/util/leaktest", + "@com_github_lib_pq//oid", + "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/sql/catalog/funcdesc/func_desc.go b/pkg/sql/catalog/funcdesc/func_desc.go index 6b9db57709d1..cf06b8aeeb76 100644 --- a/pkg/sql/catalog/funcdesc/func_desc.go +++ b/pkg/sql/catalog/funcdesc/func_desc.go @@ -19,10 +19,14 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scpb" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sem/volatility" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/errors" + "github.com/lib/pq/oid" ) var _ catalog.Descriptor = (*immutable)(nil) @@ -462,3 +466,76 @@ func (desc *immutable) ContainsUserDefinedTypes() bool { } return desc.ReturnType.Type.UserDefined() } + +func (desc *immutable) ToOverload() (ret *tree.Overload, err error) { + ret = &tree.Overload{ + Oid: catid.FuncIDToOID(desc.ID), + ReturnType: tree.FixedReturnType(desc.ReturnType.Type), + ReturnSet: desc.ReturnType.ReturnSet, + Body: desc.FunctionBody, + IsUDF: true, + } + + argTypes := make(tree.ArgTypes, 0, len(desc.Args)) + for _, arg := range desc.Args { + argTypes = append( + argTypes, + tree.ArgType{Name: arg.Name, Typ: arg.Type}, + ) + } + ret.Types = argTypes + ret.Volatility, err = desc.getOverloadVolatility() + if err != nil { + return nil, err + } + ret.NullableArgs, err = desc.getOverloadNullableArgs() + if err != nil { + return nil, err + } + + return ret, nil +} + +func (desc *immutable) getOverloadVolatility() (volatility.V, error) { + var ret volatility.V + switch desc.Volatility { + case catpb.Function_VOLATILE: + ret = volatility.Volatile + case catpb.Function_STABLE: + ret = volatility.Stable + case catpb.Function_IMMUTABLE: + ret = volatility.Immutable + default: + return 0, errors.Newf("unknown volatility") + } + if desc.LeakProof { + if desc.Volatility != catpb.Function_IMMUTABLE { + return 0, errors.Newf("function %d is leakproof but not immutable", desc.ID) + } + ret = volatility.Leakproof + } + return ret, nil +} + +func (desc *immutable) getOverloadNullableArgs() (bool, error) { + switch desc.NullInputBehavior { + case catpb.Function_CALLED_ON_NULL_INPUT: + return true, nil + case catpb.Function_RETURNS_NULL_ON_NULL_INPUT, catpb.Function_STRICT: + return false, nil + default: + return false, errors.Newf("unknown null input behavior") + } +} + +// UserDefinedFunctionOIDToID converts a UDF OID into a descriptor ID. OID of a +// UDF must be greater CockroachPredefinedOIDMax. The function returns an error +// if the given OID is less than or equal to CockroachPredefinedOIDMax. +func UserDefinedFunctionOIDToID(oid oid.Oid) (descpb.ID, error) { + return catid.UserDefinedOIDToID(oid) +} + +// IsOIDUserDefinedFunc returns true if an oid is a user-defined function oid. +func IsOIDUserDefinedFunc(oid oid.Oid) bool { + return catid.IsOIDUserDefined(oid) +} diff --git a/pkg/sql/catalog/funcdesc/func_desc_test.go b/pkg/sql/catalog/funcdesc/func_desc_test.go index ed7bf7c754dc..5d537c80554e 100644 --- a/pkg/sql/catalog/funcdesc/func_desc_test.go +++ b/pkg/sql/catalog/funcdesc/func_desc_test.go @@ -13,6 +13,7 @@ package funcdesc_test import ( "context" "fmt" + "strconv" "testing" "github.com/cockroachdb/cockroach/pkg/clusterversion" @@ -28,8 +29,12 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/tabledesc" "github.com/cockroachdb/cockroach/pkg/sql/catalog/typedesc" "github.com/cockroachdb/cockroach/pkg/sql/privilege" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sem/volatility" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/lib/pq/oid" + "github.com/stretchr/testify/require" ) func TestValidateFuncDesc(t *testing.T) { @@ -462,3 +467,150 @@ func TestValidateFuncDesc(t *testing.T) { } } } + +func TestToOverload(t *testing.T) { + testCases := []struct { + desc descpb.FunctionDescriptor + expected tree.Overload + err string + }{ + { + // Test all fields are properly valued. + desc: descpb.FunctionDescriptor{ + ID: 1, + Args: []descpb.FunctionDescriptor_Argument{{Name: "arg1", Type: types.Int}}, + ReturnType: descpb.FunctionDescriptor_ReturnType{Type: types.Int, ReturnSet: true}, + LeakProof: true, + Volatility: catpb.Function_IMMUTABLE, + NullInputBehavior: catpb.Function_RETURNS_NULL_ON_NULL_INPUT, + FunctionBody: "ANY QUERIES", + }, + expected: tree.Overload{ + Oid: oid.Oid(100001), + Types: tree.ArgTypes{ + {Name: "arg1", Typ: types.Int}, + }, + ReturnType: tree.FixedReturnType(types.Int), + ReturnSet: true, + Volatility: volatility.Leakproof, + Body: "ANY QUERIES", + IsUDF: true, + }, + }, + { + // Test ReturnSet matters. + desc: descpb.FunctionDescriptor{ + ID: 1, + Args: []descpb.FunctionDescriptor_Argument{{Name: "arg1", Type: types.Int}}, + ReturnType: descpb.FunctionDescriptor_ReturnType{Type: types.Int, ReturnSet: false}, + LeakProof: true, + Volatility: catpb.Function_IMMUTABLE, + NullInputBehavior: catpb.Function_RETURNS_NULL_ON_NULL_INPUT, + FunctionBody: "ANY QUERIES", + }, + expected: tree.Overload{ + Oid: oid.Oid(100001), + Types: tree.ArgTypes{ + {Name: "arg1", Typ: types.Int}, + }, + ReturnType: tree.FixedReturnType(types.Int), + ReturnSet: false, + Volatility: volatility.Leakproof, + Body: "ANY QUERIES", + IsUDF: true, + }, + }, + { + // Test Volatility matters. + desc: descpb.FunctionDescriptor{ + ID: 1, + Args: []descpb.FunctionDescriptor_Argument{{Name: "arg1", Type: types.Int}}, + ReturnType: descpb.FunctionDescriptor_ReturnType{Type: types.Int, ReturnSet: true}, + LeakProof: false, + Volatility: catpb.Function_STABLE, + NullInputBehavior: catpb.Function_RETURNS_NULL_ON_NULL_INPUT, + FunctionBody: "ANY QUERIES", + }, + expected: tree.Overload{ + Oid: oid.Oid(100001), + Types: tree.ArgTypes{ + {Name: "arg1", Typ: types.Int}, + }, + ReturnType: tree.FixedReturnType(types.Int), + ReturnSet: true, + Volatility: volatility.Stable, + Body: "ANY QUERIES", + IsUDF: true, + }, + }, + { + // Test NullableArgs matters. + desc: descpb.FunctionDescriptor{ + ID: 1, + Args: []descpb.FunctionDescriptor_Argument{{Name: "arg1", Type: types.Int}}, + ReturnType: descpb.FunctionDescriptor_ReturnType{Type: types.Int, ReturnSet: true}, + LeakProof: true, + Volatility: catpb.Function_IMMUTABLE, + NullInputBehavior: catpb.Function_CALLED_ON_NULL_INPUT, + FunctionBody: "ANY QUERIES", + }, + expected: tree.Overload{ + Oid: oid.Oid(100001), + Types: tree.ArgTypes{ + {Name: "arg1", Typ: types.Int}, + }, + ReturnType: tree.FixedReturnType(types.Int), + ReturnSet: true, + Volatility: volatility.Leakproof, + Body: "ANY QUERIES", + IsUDF: true, + NullableArgs: true, + }, + }, + { + // Test failure on non-immutable but leakproof function. + desc: descpb.FunctionDescriptor{ + ID: 1, + Args: []descpb.FunctionDescriptor_Argument{{Name: "arg1", Type: types.Int}}, + ReturnType: descpb.FunctionDescriptor_ReturnType{Type: types.Int, ReturnSet: true}, + LeakProof: true, + Volatility: catpb.Function_STABLE, + NullInputBehavior: catpb.Function_RETURNS_NULL_ON_NULL_INPUT, + FunctionBody: "ANY QUERIES", + }, + expected: tree.Overload{ + Oid: oid.Oid(100001), + Types: tree.ArgTypes{ + {Name: "arg1", Typ: types.Int}, + }, + ReturnType: tree.FixedReturnType(types.Int), + ReturnSet: true, + Volatility: volatility.Leakproof, + Body: "ANY QUERIES", + IsUDF: true, + }, + err: "function 1 is leakproof but not immutable", + }, + } + + for i, tc := range testCases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + desc := funcdesc.NewBuilder(&tc.desc).BuildImmutable().(catalog.FunctionDescriptor) + overload, err := desc.ToOverload() + if tc.err == "" { + require.NoError(t, err) + } else { + require.Equal(t, tc.err, err.Error()) + return + } + + returnType := overload.ReturnType([]tree.TypedExpr{}) + expectedReturnType := tc.expected.ReturnType([]tree.TypedExpr{}) + require.Equal(t, expectedReturnType, returnType) + // Set ReturnType(which is function) to nil for easier equality check. + overload.ReturnType = nil + tc.expected.ReturnType = nil + require.Equal(t, tc.expected, *overload) + }) + } +} diff --git a/pkg/sql/catalog/schema.go b/pkg/sql/catalog/schema.go index 507d96ec83e1..99edb117ba14 100644 --- a/pkg/sql/catalog/schema.go +++ b/pkg/sql/catalog/schema.go @@ -10,7 +10,10 @@ package catalog -import "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" +import ( + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" +) // SchemaDescriptor encapsulates the basic type SchemaDescriptor interface { @@ -29,6 +32,15 @@ type SchemaDescriptor interface { // GetFunction returns a list of function overloads given a name. GetFunction(name string) (descpb.SchemaDescriptor_Function, bool) + + // GetResolvedFuncDefinition returns a ResolvedFunctionDefinition given a + // function name. This is needed by function resolution and expression type + // checking during which candidate function overloads are searched for the + // best match. Only function signatures are needed during this process. Schema + // stores all the signatures of the functions created under it and this method + // returns a collection of overloads with the same function name, each + // overload is prefixed with the same schema name. + GetResolvedFuncDefinition(name string) (*tree.ResolvedFunctionDefinition, bool) } // ResolvedSchemaKind is an enum that represents what kind of schema diff --git a/pkg/sql/catalog/schemadesc/BUILD.bazel b/pkg/sql/catalog/schemadesc/BUILD.bazel index 4de125f976fd..2711f63b73a8 100644 --- a/pkg/sql/catalog/schemadesc/BUILD.bazel +++ b/pkg/sql/catalog/schemadesc/BUILD.bazel @@ -24,8 +24,10 @@ go_library( "//pkg/sql/privilege", "//pkg/sql/schemachanger/scpb", "//pkg/sql/sem/catconstants", + "//pkg/sql/sem/catid", "//pkg/sql/sem/eval", "//pkg/sql/sem/tree", + "//pkg/sql/types", "//pkg/util/hlc", "//pkg/util/log", "//pkg/util/protoutil", diff --git a/pkg/sql/catalog/schemadesc/schema_desc.go b/pkg/sql/catalog/schemadesc/schema_desc.go index de4946720e0b..ec94762017cd 100644 --- a/pkg/sql/catalog/schemadesc/schema_desc.go +++ b/pkg/sql/catalog/schemadesc/schema_desc.go @@ -25,7 +25,10 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scpb" "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/errors" "github.com/cockroachdb/redact" @@ -429,6 +432,43 @@ func (desc *immutable) ContainsUserDefinedTypes() bool { return false } +// GetResolvedFuncDefinition implements the SchemaDescriptor interface. +func (desc *immutable) GetResolvedFuncDefinition( + name string, +) (*tree.ResolvedFunctionDefinition, bool) { + funcDescPb, found := desc.GetFunction(name) + if !found { + return nil, false + } + funcDef := &tree.ResolvedFunctionDefinition{ + Name: name, + Overloads: make([]tree.QualifiedOverload, 0, len(funcDescPb.Overloads)), + } + for i := range funcDescPb.Overloads { + retType := funcDescPb.Overloads[i].ReturnType + overload := &tree.Overload{ + Oid: catid.FuncIDToOID(funcDescPb.Overloads[i].ID), + ReturnType: func(args []tree.TypedExpr) *types.T { + return retType + }, + IsUDF: true, + UDFContainsOnlySignature: true, + } + argTypes := make(tree.ArgTypes, 0, len(funcDescPb.Overloads[i].ArgTypes)) + for _, argType := range funcDescPb.Overloads[i].ArgTypes { + argTypes = append( + argTypes, + tree.ArgType{Typ: argType}, + ) + } + overload.Types = argTypes + prefixedOverload := tree.MakeQualifiedOverload(desc.GetName(), overload) + funcDef.Overloads = append(funcDef.Overloads, prefixedOverload) + } + + return funcDef, true +} + // IsSchemaNameValid returns whether the input name is valid for a user defined // schema. func IsSchemaNameValid(name string) error { diff --git a/pkg/sql/catalog/schemadesc/synthetic_schema_desc.go b/pkg/sql/catalog/schemadesc/synthetic_schema_desc.go index 402a3311a9cf..b889523014d6 100644 --- a/pkg/sql/catalog/schemadesc/synthetic_schema_desc.go +++ b/pkg/sql/catalog/schemadesc/synthetic_schema_desc.go @@ -19,6 +19,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scpb" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/log" ) @@ -146,6 +147,7 @@ func (p synthetic) GetDefaultPrivilegeDescriptor() catalog.DefaultPrivilegeDescr return catprivilege.MakeDefaultPrivileges(catprivilege.MakeDefaultPrivilegeDescriptor(catpb.DefaultPrivilegeDescriptor_SCHEMA)) } +// GetFunction implements the SchemaDescriptor interface. func (p synthetic) GetFunction(name string) (descpb.SchemaDescriptor_Function, bool) { return descpb.SchemaDescriptor_Function{}, false } @@ -153,3 +155,8 @@ func (p synthetic) GetFunction(name string) (descpb.SchemaDescriptor_Function, b func (p synthetic) ContainsUserDefinedTypes() bool { return false } + +// GetResolvedFuncDefinition implements the SchemaDescriptor interface. +func (p synthetic) GetResolvedFuncDefinition(name string) (*tree.ResolvedFunctionDefinition, bool) { + return nil, false +} diff --git a/pkg/sql/catalog/typedesc/BUILD.bazel b/pkg/sql/catalog/typedesc/BUILD.bazel index f04eb1a221e6..e37ab06b087c 100644 --- a/pkg/sql/catalog/typedesc/BUILD.bazel +++ b/pkg/sql/catalog/typedesc/BUILD.bazel @@ -19,7 +19,6 @@ go_library( "//pkg/sql/catalog/descpb", "//pkg/sql/catalog/multiregion", "//pkg/sql/enum", - "//pkg/sql/oidext", "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", "//pkg/sql/privilege", diff --git a/pkg/sql/catalog/typedesc/type_desc.go b/pkg/sql/catalog/typedesc/type_desc.go index 90ec29c5bd33..6613d2971c7b 100644 --- a/pkg/sql/catalog/typedesc/type_desc.go +++ b/pkg/sql/catalog/typedesc/type_desc.go @@ -25,7 +25,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/multiregion" "github.com/cockroachdb/cockroach/pkg/sql/enum" - "github.com/cockroachdb/cockroach/pkg/sql/oidext" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/privilege" @@ -120,11 +119,7 @@ func UpdateCachedFieldsOnModifiedMutable(desc catalog.TypeDescriptor) (*Mutable, // CockroachPredefinedOIDMax. The function returns an error if the // given OID is less than or equals to CockroachPredefinedMax. func UserDefinedTypeOIDToID(oid oid.Oid) (descpb.ID, error) { - if descpb.ID(oid) <= oidext.CockroachPredefinedOIDMax { - return 0, errors.Newf("user-defined OID %d should be greater "+ - "than predefined Max: %d.", oid, oidext.CockroachPredefinedOIDMax) - } - return descpb.ID(oid) - oidext.CockroachPredefinedOIDMax, nil + return catid.UserDefinedOIDToID(oid) } // GetUserDefinedTypeDescID gets the type descriptor ID from a user defined type. diff --git a/pkg/sql/create_function.go b/pkg/sql/create_function.go index d9fefb929149..48a758f5d74d 100644 --- a/pkg/sql/create_function.go +++ b/pkg/sql/create_function.go @@ -47,8 +47,16 @@ func (n *createFunctionNode) startExec(params runParams) error { return unimplemented.NewWithIssue(85144, "CREATE FUNCTION...sql_body unimplemented") } - scDesc := n.scDesc.NewBuilder().BuildExistingMutable().(*schemadesc.Mutable) - udfMutableDesc, err := n.getMutableFuncDesc(params) + scDesc, err := params.p.descCollection.GetMutableSchemaByName( + params.ctx, params.p.Txn(), n.dbDesc, n.scDesc.GetName(), + tree.SchemaLookupFlags{Required: true, RequireMutable: true}, + ) + if err != nil { + return err + } + mutScDesc := scDesc.(*schemadesc.Mutable) + + udfMutableDesc, err := n.getMutableFuncDesc(mutScDesc, params) if err != nil { return err } @@ -194,7 +202,7 @@ func (n *createFunctionNode) startExec(params runParams) error { for i, arg := range udfMutableDesc.Args { argTypes[i] = arg.Type } - scDesc.AddFunction( + mutScDesc.AddFunction( udfMutableDesc.GetName(), descpb.SchemaDescriptor_FunctionOverload{ ID: udfMutableDesc.GetID(), @@ -203,7 +211,7 @@ func (n *createFunctionNode) startExec(params runParams) error { ReturnSet: udfMutableDesc.ReturnType.ReturnSet, }, ) - if err := params.p.writeSchemaDescChange(params.ctx, scDesc, "Create Function"); err != nil { + if err := params.p.writeSchemaDescChange(params.ctx, mutScDesc, "Create Function"); err != nil { return err } @@ -214,7 +222,9 @@ func (*createFunctionNode) Next(params runParams) (bool, error) { return false, func (*createFunctionNode) Values() tree.Datums { return tree.Datums{} } func (*createFunctionNode) Close(ctx context.Context) {} -func (n *createFunctionNode) getMutableFuncDesc(params runParams) (*funcdesc.Mutable, error) { +func (n *createFunctionNode) getMutableFuncDesc( + scDesc catalog.SchemaDescriptor, params runParams, +) (*funcdesc.Mutable, error) { if n.cf.Replace { return nil, unimplemented.New("CREATE OR REPLACE FUNCTION", "replacing function") } @@ -237,7 +247,7 @@ func (n *createFunctionNode) getMutableFuncDesc(params runParams) (*funcdesc.Mut privileges := catprivilege.CreatePrivilegesFromDefaultPrivileges( n.dbDesc.GetDefaultPrivilegeDescriptor(), - n.scDesc.GetDefaultPrivilegeDescriptor(), + scDesc.GetDefaultPrivilegeDescriptor(), n.dbDesc.GetID(), params.SessionData().User(), privilege.Functions, @@ -247,7 +257,7 @@ func (n *createFunctionNode) getMutableFuncDesc(params runParams) (*funcdesc.Mut newUdfDesc := funcdesc.NewMutableFunctionDescriptor( funcDescID, n.dbDesc.GetID(), - n.scDesc.GetID(), + scDesc.GetID(), string(n.cf.FuncName.ObjectName), len(n.cf.Args), returnType, diff --git a/pkg/sql/function_resolver_test.go b/pkg/sql/function_resolver_test.go new file mode 100644 index 000000000000..92e9db21ab00 --- /dev/null +++ b/pkg/sql/function_resolver_test.go @@ -0,0 +1,252 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package sql_test + +import ( + "context" + "sort" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/kv" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" + "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/stretchr/testify/require" +) + +func TestSimpleResolveFunction(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + tDB := sqlutils.MakeSQLRunner(sqlDB) + + tDB.Exec(t, ` +CREATE TABLE t( + a INT PRIMARY KEY, + b INT, + C INT, + INDEX t_idx_b(b), + INDEX t_idx_c(c) +); +CREATE SEQUENCE sq1; +CREATE TABLE t2(a INT PRIMARY KEY); +CREATE VIEW v AS SELECT a FROM t2; +CREATE TYPE notmyworkday AS ENUM ('Monday', 'Tuesday'); +CREATE FUNCTION f(a notmyworkday) RETURNS INT IMMUTABLE LANGUAGE SQL AS $$ + SELECT a FROM t; + SELECT b FROM t@t_idx_b; + SELECT c FROM t@t_idx_c; + SELECT a FROM v; + SELECT nextval('sq1'); +$$; +CREATE FUNCTION f() RETURNS VOID IMMUTABLE LANGUAGE SQL AS $$ SELECT 1 $$;`) + + var sessionData sessiondatapb.SessionData + { + var sessionSerialized []byte + tDB.QueryRow(t, "SELECT crdb_internal.serialize_session()").Scan(&sessionSerialized) + require.NoError(t, protoutil.Unmarshal(sessionSerialized, &sessionData)) + } + + err := sql.TestingDescsTxn(ctx, s, func(ctx context.Context, txn *kv.Txn, col *descs.Collection) error { + execCfg := s.ExecutorConfig().(sql.ExecutorConfig) + planner, cleanup := sql.NewInternalPlanner( + "resolve-index", txn, username.RootUserName(), &sql.MemoryMetrics{}, &execCfg, sessionData, + ) + defer cleanup() + ec := planner.(interface{ EvalContext() *eval.Context }).EvalContext() + // Set "defaultdb" as current database. + ec.SessionData().Database = "defaultdb" + searchPathArray := ec.SessionData().SearchPath.GetPathArray() + + funcResolver := planner.(tree.FunctionReferenceResolver) + fname := tree.UnresolvedName{NumParts: 1, Star: false} + fname.Parts[0] = "f" + path := sessiondata.MakeSearchPath(searchPathArray) + funcDef, err := funcResolver.ResolveFunction(ctx, &fname, &path) + require.NoError(t, err) + require.Equal(t, 2, len(funcDef.Overloads)) + + // Verify Function Signature looks good + sort.Slice(funcDef.Overloads, func(i, j int) bool { + return funcDef.Overloads[i].Overload.Oid < funcDef.Overloads[j].Overload.Oid + }) + require.Equal(t, 100110, int(funcDef.Overloads[0].Oid)) + require.True(t, funcDef.Overloads[0].UDFContainsOnlySignature) + require.True(t, funcDef.Overloads[0].IsUDF) + require.Equal(t, 1, len(funcDef.Overloads[0].Types.Types())) + require.NotEqual(t, funcDef.Overloads[0].Types.Types()[0].TypeMeta, types.UserDefinedTypeMetadata{}) + require.Equal(t, types.EnumFamily, funcDef.Overloads[0].Types.Types()[0].Family()) + require.Equal(t, types.Int, funcDef.Overloads[0].ReturnType([]tree.TypedExpr{})) + + require.Equal(t, 100111, int(funcDef.Overloads[1].Oid)) + require.True(t, funcDef.Overloads[1].UDFContainsOnlySignature) + require.True(t, funcDef.Overloads[1].IsUDF) + require.Equal(t, 0, len(funcDef.Overloads[1].Types.Types())) + require.Equal(t, types.Void, funcDef.Overloads[1].ReturnType([]tree.TypedExpr{})) + + overload, err := funcResolver.ResolveFunctionByOID(ctx, funcDef.Overloads[0].Oid) + require.NoError(t, err) + require.Equal(t, `SELECT a FROM defaultdb.public.t; +SELECT b FROM defaultdb.public.t@t_idx_b; +SELECT c FROM defaultdb.public.t@t_idx_c; +SELECT a FROM defaultdb.public.v; +SELECT nextval(105:::REGCLASS);`, overload.Body) + require.True(t, overload.IsUDF) + require.False(t, overload.UDFContainsOnlySignature) + require.Equal(t, 1, len(overload.Types.Types())) + require.NotEqual(t, overload.Types.Types()[0].TypeMeta, types.UserDefinedTypeMetadata{}) + require.Equal(t, types.EnumFamily, overload.Types.Types()[0].Family()) + require.Equal(t, types.Int, overload.ReturnType([]tree.TypedExpr{})) + + overload, err = funcResolver.ResolveFunctionByOID(ctx, funcDef.Overloads[1].Oid) + require.NoError(t, err) + require.Equal(t, `SELECT 1;`, overload.Body) + require.True(t, overload.IsUDF) + require.False(t, overload.UDFContainsOnlySignature) + require.Equal(t, 0, len(overload.Types.Types())) + require.Equal(t, types.Void, overload.ReturnType([]tree.TypedExpr{})) + + return nil + }) + require.NoError(t, err) +} + +func TestResolveFunctionRespectSearchPath(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + tDB := sqlutils.MakeSQLRunner(sqlDB) + + tDB.Exec(t, ` +CREATE SCHEMA sc1; +CREATE SCHEMA sc2; +CREATE FUNCTION sc1.f() RETURNS INT IMMUTABLE LANGUAGE SQL AS $$ SELECT 1 $$; +CREATE FUNCTION sc2.f() RETURNS INT IMMUTABLE LANGUAGE SQL AS $$ SELECT 2 $$; +CREATE FUNCTION sc1.lower() RETURNS INT IMMUTABLE LANGUAGE SQL AS $$ SELECT 3 $$; +`, + ) + + testCases := []struct { + testName string + funName tree.UnresolvedName + searchPath []string + expectedBody []string + expectedSchema []string + expectedErr string + }{ + { + testName: "cross db should fail", + funName: tree.UnresolvedName{NumParts: 3, Parts: tree.NameParts{"some_f", "some_sc", "some_db", ""}}, + expectedErr: "cross-database function references not allowed", + }, + { + testName: "schema not found", + funName: tree.UnresolvedName{NumParts: 2, Parts: tree.NameParts{"some_f", "some_sc", "", ""}}, + expectedErr: "schema \"some_sc\" does not exist", + }, + { + testName: "function not found", + funName: tree.UnresolvedName{NumParts: 2, Parts: tree.NameParts{"some_f", "sc1", "", ""}}, + searchPath: []string{"sc1", "sc2"}, + expectedErr: "unknown function: sc1.some_f()", + }, + { + testName: "function not found", + funName: tree.UnresolvedName{NumParts: 1, Parts: tree.NameParts{"some_f", "", "", ""}}, + searchPath: []string{"sc1", "sc2"}, + expectedErr: "unknown function: some_f()", + }, + { + testName: "function with explicit schema skip first schema in path", + funName: tree.UnresolvedName{NumParts: 2, Parts: tree.NameParts{"f", "sc2", "", ""}}, + searchPath: []string{"sc1", "sc2"}, + expectedBody: []string{"SELECT 2;"}, + expectedSchema: []string{"sc2"}, + }, + { + testName: "use functions from search path", + funName: tree.UnresolvedName{NumParts: 1, Parts: tree.NameParts{"f", "", "", ""}}, + searchPath: []string{"sc1", "sc2"}, + expectedBody: []string{"SELECT 1;", "SELECT 2;"}, + expectedSchema: []string{"sc1", "sc2"}, + }, + { + testName: "unsupported builtin function", + funName: tree.UnresolvedName{NumParts: 1, Parts: tree.NameParts{"querytree", "", "", ""}}, + searchPath: []string{"sc1", "sc2"}, + expectedErr: "querytree(): unimplemented: this function is not yet supported", + }, + // TODO(Chengxiong): add test case for builtin function names when builtin + // OIDs are changed to fixed IDs. + } + + var sessionData sessiondatapb.SessionData + { + var sessionSerialized []byte + tDB.QueryRow(t, "SELECT crdb_internal.serialize_session()").Scan(&sessionSerialized) + require.NoError(t, protoutil.Unmarshal(sessionSerialized, &sessionData)) + } + + err := sql.TestingDescsTxn(ctx, s, func(ctx context.Context, txn *kv.Txn, col *descs.Collection) error { + execCfg := s.ExecutorConfig().(sql.ExecutorConfig) + planner, cleanup := sql.NewInternalPlanner( + "resolve-index", txn, username.RootUserName(), &sql.MemoryMetrics{}, &execCfg, sessionData, + ) + defer cleanup() + ec := planner.(interface{ EvalContext() *eval.Context }).EvalContext() + // Set "defaultdb" as current database. + ec.SessionData().Database = "defaultdb" + + funcResolver := planner.(tree.FunctionReferenceResolver) + + for _, tc := range testCases { + path := sessiondata.MakeSearchPath(tc.searchPath) + funcDef, err := funcResolver.ResolveFunction(ctx, &tc.funName, &path) + if tc.expectedErr != "" { + require.Equal(t, tc.expectedErr, err.Error()) + continue + } + require.NoError(t, err) + + require.Equal(t, len(tc.expectedBody), len(funcDef.Overloads)) + bodies := make([]string, len(funcDef.Overloads)) + schemas := make([]string, len(funcDef.Overloads)) + for i, o := range funcDef.Overloads { + overload, err := funcResolver.ResolveFunctionByOID(ctx, o.Oid) + require.NoError(t, err) + bodies[i] = overload.Body + schemas[i] = o.Schema + } + require.Equal(t, tc.expectedBody, bodies) + require.Equal(t, tc.expectedSchema, schemas) + } + return nil + }) + require.NoError(t, err) +} diff --git a/pkg/sql/opt/cat/catalog.go b/pkg/sql/opt/cat/catalog.go index e88228abd74e..4e571a9b2cf4 100644 --- a/pkg/sql/opt/cat/catalog.go +++ b/pkg/sql/opt/cat/catalog.go @@ -139,8 +139,11 @@ type Catalog interface { // ResolveFunction resolves a function by name. ResolveFunction( - name *tree.UnresolvedName, path tree.SearchPath, - ) (*tree.FunctionDefinition, error) + ctx context.Context, name *tree.UnresolvedName, path tree.SearchPath, + ) (*tree.ResolvedFunctionDefinition, error) + + // ResolveFunctionByOID resolves a function overload by OID. + ResolveFunctionByOID(ctx context.Context, oid oid.Oid) (*tree.Overload, error) // CheckPrivilege verifies that the current user has the given privilege on // the given catalog object. If not, then CheckPrivilege returns an error. diff --git a/pkg/sql/opt/testutils/testcat/function.go b/pkg/sql/opt/testutils/testcat/function.go index 6d0a5f2cce22..87d5de16ce9f 100644 --- a/pkg/sql/opt/testutils/testcat/function.go +++ b/pkg/sql/opt/testutils/testcat/function.go @@ -20,16 +20,27 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sem/volatility" "github.com/cockroachdb/cockroach/pkg/util/treeprinter" + "github.com/cockroachdb/errors" + "github.com/lib/pq/oid" ) var _ tree.FunctionReferenceResolver = (*Catalog)(nil) // ResolveFunction part of the tree.FunctionReferenceResolver interface. func (tc *Catalog) ResolveFunction( - name *tree.UnresolvedName, path tree.SearchPath, -) (*tree.FunctionDefinition, error) { + ctx context.Context, name *tree.UnresolvedName, path tree.SearchPath, +) (*tree.ResolvedFunctionDefinition, error) { + fn, err := name.ToFunctionName() + if err != nil { + return nil, err + } + // Attempt to resolve to a built-in function first. - if def, err := name.ResolveFunction(path); err == nil { + def, err := tree.GetBuiltinFuncDefinition(fn, path) + if err != nil { + return nil, err + } + if def != nil { return def, nil } // Otherwise, try to resolve to a user-defined function. @@ -39,6 +50,11 @@ func (tc *Catalog) ResolveFunction( return nil, pgerror.Newf(pgcode.UndefinedFunction, "unknown function: %s", name) } +// ResolveFunctionByOID part of the tree.FunctionReferenceResolver interface. +func (tc *Catalog) ResolveFunctionByOID(ctx context.Context, oid oid.Oid) (*tree.Overload, error) { + return nil, errors.AssertionFailedf("ResolveFunctionByOID not supported in test catalog") +} + // CreateFunction handles the CREATE FUNCTION statement. func (tc *Catalog) CreateFunction(c *tree.CreateFunction) { name := c.FuncName.String() @@ -73,21 +89,24 @@ func (tc *Catalog) CreateFunction(c *tree.CreateFunction) { body, v, nullableArgs := collectFuncOptions(c.Options) if tc.udfs == nil { - tc.udfs = make(map[string]*tree.FunctionDefinition) - } - tc.udfs[name] = tree.NewFunctionDefinition( - name, - &tree.FunctionProperties{ - // TODO(mgartner): Consider setting Class and CompositeInsensitive. - }, - []tree.Overload{{ - Types: argTypes, - ReturnType: tree.FixedReturnType(retType), - Body: body, - Volatility: v, - NullableArgs: nullableArgs, - }}, - ) + tc.udfs = make(map[string]*tree.ResolvedFunctionDefinition) + } + + overload := &tree.Overload{ + Types: argTypes, + ReturnType: tree.FixedReturnType(retType), + Body: body, + Volatility: v, + NullableArgs: nullableArgs, + } + prefixedOverload := tree.MakeQualifiedOverload("public", overload) + def := &tree.ResolvedFunctionDefinition{ + Name: name, + // TODO(mgartner): Consider setting Class and CompositeInsensitive fo + // overloads. + Overloads: []tree.QualifiedOverload{prefixedOverload}, + } + tc.udfs[name] = def } func collectFuncOptions(o tree.FunctionOptions) (body string, v volatility.V, nullableArgs bool) { @@ -146,11 +165,11 @@ func collectFuncOptions(o tree.FunctionOptions) (body string, v volatility.V, nu // formatFunction nicely formats a function definition creating in the opt test // catalog using a treeprinter for debugging and testing. -func formatFunction(fn *tree.FunctionDefinition) string { - if len(fn.Definition) != 1 { +func formatFunction(fn *tree.ResolvedFunctionDefinition) string { + if len(fn.Overloads) != 1 { panic(fmt.Errorf("functions with multiple overloads not supported")) } - o := fn.Definition[0] + o := fn.Overloads[0] tp := treeprinter.New() nullStr := "" if !o.NullableArgs { diff --git a/pkg/sql/opt/testutils/testcat/test_catalog.go b/pkg/sql/opt/testutils/testcat/test_catalog.go index 01c6941c4677..285a4151a7d1 100644 --- a/pkg/sql/opt/testutils/testcat/test_catalog.go +++ b/pkg/sql/opt/testutils/testcat/test_catalog.go @@ -50,7 +50,7 @@ type Catalog struct { testSchema Schema counter int enumTypes map[string]*types.T - udfs map[string]*tree.FunctionDefinition + udfs map[string]*tree.ResolvedFunctionDefinition } type dataSource interface { @@ -487,7 +487,7 @@ func (tc *Catalog) ExecuteDDLWithIndexVersion( case *tree.ShowCreateFunction: fn := stmt.Name.FunctionReference.(*tree.UnresolvedName) - def, err := tc.ResolveFunction(fn, tree.EmptySearchPath) + def, err := tc.ResolveFunction(context.Background(), fn, tree.EmptySearchPath) if err != nil { return "", err } diff --git a/pkg/sql/opt_catalog.go b/pkg/sql/opt_catalog.go index 14997058d863..a7c187476d47 100644 --- a/pkg/sql/opt_catalog.go +++ b/pkg/sql/opt_catalog.go @@ -329,8 +329,14 @@ func (oc *optCatalog) ResolveType( // ResolveFunction is part of the cat.Catalog interface. func (oc *optCatalog) ResolveFunction( - name *tree.UnresolvedName, path tree.SearchPath, -) (*tree.FunctionDefinition, error) { + ctx context.Context, name *tree.UnresolvedName, path tree.SearchPath, +) (*tree.ResolvedFunctionDefinition, error) { + return nil, errors.AssertionFailedf("unimplemented") +} + +func (oc *optCatalog) ResolveFunctionByOID( + ctx context.Context, oid oid.Oid, +) (*tree.Overload, error) { return nil, errors.AssertionFailedf("unimplemented") } diff --git a/pkg/sql/schema_resolver.go b/pkg/sql/schema_resolver.go index b65e69ea2906..8df779db4766 100644 --- a/pkg/sql/schema_resolver.go +++ b/pkg/sql/schema_resolver.go @@ -13,21 +13,26 @@ package sql import ( "context" "fmt" + "strings" "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/kv" "github.com/cockroachdb/cockroach/pkg/sql/catalog" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/funcdesc" "github.com/cockroachdb/cockroach/pkg/sql/catalog/resolver" "github.com/cockroachdb/cockroach/pkg/sql/catalog/typedesc" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scbuild" + "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/errors" "github.com/lib/pq/oid" ) @@ -341,6 +346,132 @@ func (sr *schemaResolver) runWithOptions(flags resolveFlags, fn func()) { fn() } +func (sr *schemaResolver) ResolveFunction( + ctx context.Context, name *tree.UnresolvedName, path tree.SearchPath, +) (*tree.ResolvedFunctionDefinition, error) { + if name.NumParts > 3 || len(name.Parts[0]) == 0 || name.Star { + return nil, pgerror.Newf(pgcode.InvalidName, "invalid function name: %s", name) + } + + fn, err := name.ToFunctionName() + if err != nil { + return nil, err + } + + if fn.ExplicitCatalog && fn.Catalog() != sr.CurrentDatabase() { + return nil, pgerror.New(pgcode.FeatureNotSupported, "cross-database function references not allowed") + } + + // Get builtin functions if there is any match. + builtinDef, err := tree.GetBuiltinFuncDefinition(fn, path) + if err != nil { + return nil, err + } + + var udfDef *tree.ResolvedFunctionDefinition + if fn.ExplicitSchema && fn.Schema() != catconstants.CRDBInternalSchemaName { + found, prefix, err := sr.LookupSchema(ctx, sr.CurrentDatabase(), fn.Schema()) + if err != nil { + return nil, err + } + + if !found { + return nil, pgerror.Newf(pgcode.UndefinedSchema, "schema %q does not exist", fn.Schema()) + } + + sc := prefix.Schema + udfDef, _ = sc.GetResolvedFuncDefinition(fn.Object()) + } else { + if err := path.IterateSearchPath(func(schema string) error { + found, prefix, err := sr.LookupSchema(ctx, sr.CurrentDatabase(), schema) + if err != nil { + return err + } + if !found { + return nil + } + curUdfDef, found := prefix.Schema.GetResolvedFuncDefinition(fn.Object()) + if !found { + return nil + } + + udfDef, err = udfDef.MergeWith(curUdfDef) + return err + }); err != nil { + return nil, err + } + } + + if builtinDef == nil && udfDef == nil { + // If nothing found, there is a chance that user typed in a quoted function + // name which is not lowercase. So here we try to lowercase the given + // function name and find a suggested function name if possible. + extraMsg := "" + lowerName := tree.MakeUnresolvedName(strings.ToLower(name.Parts[0])) + if lowerName != *name { + alternative, err := sr.ResolveFunction(ctx, &lowerName, path) + if err == nil && alternative != nil { + extraMsg = fmt.Sprintf(", but %s() exists", alternative.Name) + } + } + return nil, pgerror.Newf(pgcode.UndefinedFunction, "unknown function: %s()%s", tree.ErrString(name), extraMsg) + } + if builtinDef == nil { + return udfDef, nil + } + if udfDef == nil { + props, _ := builtinsregistry.GetBuiltinProperties(builtinDef.Name) + if props.UnsupportedWithIssue != 0 { + // Note: no need to embed the function name in the message; the + // caller will add the function name as prefix. + const msg = "this function is not yet supported" + var unImplErr error + if props.UnsupportedWithIssue < 0 { + unImplErr = unimplemented.New(builtinDef.Name+"()", msg) + } else { + unImplErr = unimplemented.NewWithIssueDetail(props.UnsupportedWithIssue, builtinDef.Name, msg) + } + return nil, pgerror.Wrapf(unImplErr, pgcode.InvalidParameterValue, "%s()", builtinDef.Name) + } + return builtinDef, nil + } + + return builtinDef.MergeWith(udfDef) +} + +func (sr *schemaResolver) ResolveFunctionByOID( + ctx context.Context, oid oid.Oid, +) (*tree.Overload, error) { + if !funcdesc.IsOIDUserDefinedFunc(oid) { + name, ok := tree.OidToBuiltinName[oid] + if !ok { + return nil, pgerror.Newf(pgcode.UndefinedFunction, "function %d not found", oid) + } + funcDef := tree.FunDefs[name] + for _, o := range funcDef.Definition { + if o.Oid == oid { + return o, nil + } + } + } + + flags := tree.ObjectLookupFlagsWithRequired() + flags.AvoidLeased = sr.skipDescriptorCache + descID, err := funcdesc.UserDefinedFunctionOIDToID(oid) + if err != nil { + return nil, err + } + funcDesc, err := sr.descCollection.GetImmutableFunctionByID(ctx, sr.txn, descID, flags) + if err != nil { + return nil, err + } + ret, err := funcDesc.ToOverload() + if err != nil { + return nil, err + } + return ret, nil +} + // NewSkippingCacheSchemaResolver constructs a schemaResolver which always skip // descriptor cache. func NewSkippingCacheSchemaResolver( diff --git a/pkg/sql/sem/builtins/all_builtins.go b/pkg/sql/sem/builtins/all_builtins.go index 816d44d02611..d44256d23f58 100644 --- a/pkg/sql/sem/builtins/all_builtins.go +++ b/pkg/sql/sem/builtins/all_builtins.go @@ -13,12 +13,15 @@ package builtins import ( "fmt" "sort" + "strings" "github.com/cockroachdb/cockroach/pkg/server/telemetry" "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/errors" ) // AllBuiltinNames is an array containing all the built-in function @@ -52,8 +55,10 @@ func init() { initProbeRangesBuiltins() tree.FunDefs = make(map[string]*tree.FunctionDefinition) + tree.ResolvedBuiltinFuncDefs = make(map[string]*tree.ResolvedFunctionDefinition) builtinsregistry.Iterate(func(name string, props *tree.FunctionProperties, overloads []tree.Overload) { fDef := tree.NewFunctionDefinition(name, props, overloads) + addResolvedFuncDef(tree.ResolvedBuiltinFuncDefs, fDef) tree.FunDefs[name] = fDef if !fDef.ShouldDocument() { // Avoid listing help for undocumented functions. @@ -72,6 +77,28 @@ func init() { sort.Strings(AllWindowBuiltinNames) } +func addResolvedFuncDef( + resolved map[string]*tree.ResolvedFunctionDefinition, def *tree.FunctionDefinition, +) { + parts := strings.Split(def.Name, ".") + if len(parts) > 2 || len(parts) == 0 { + // This shouldn't happen in theory. + panic(errors.AssertionFailedf("invalid builtin function name: %s", def.Name)) + } + + if len(parts) == 2 { + resolved[def.Name] = tree.QualifyBuiltinFunctionDefinition(def, parts[0]) + return + } + + resolvedName := catconstants.PgCatalogName + "." + def.Name + resolved[resolvedName] = tree.QualifyBuiltinFunctionDefinition(def, catconstants.PgCatalogName) + if def.AvailableOnPublicSchema { + resolvedName = catconstants.PublicSchemaName + "." + def.Name + resolved[resolvedName] = tree.QualifyBuiltinFunctionDefinition(def, catconstants.PublicSchemaName) + } +} + func registerBuiltin(name string, def builtinDefinition) { for i, overload := range def.overloads { fnCount := 0 diff --git a/pkg/sql/sem/builtins/all_builtins_test.go b/pkg/sql/sem/builtins/all_builtins_test.go index 2abbea9ffe5a..b1610c57585b 100644 --- a/pkg/sql/sem/builtins/all_builtins_test.go +++ b/pkg/sql/sem/builtins/all_builtins_test.go @@ -14,6 +14,7 @@ import ( "encoding/csv" "io" "os" + "strconv" "strings" "testing" @@ -163,3 +164,80 @@ func TestOverloadsVolatilityMatchesPostgres(t *testing.T) { } }) } + +func TestAddResolvedFuncDef(t *testing.T) { + defer leaktest.AfterTest(t)() + + testCases := []struct { + def *tree.FunctionDefinition + resolved map[string]*tree.ResolvedFunctionDefinition + }{ + { + def: &tree.FunctionDefinition{Name: "crdb_internal.fun", Definition: []*tree.Overload{{}, {}}}, + resolved: map[string]*tree.ResolvedFunctionDefinition{ + "crdb_internal.fun": { + Name: "crdb_internal.fun", + Overloads: []tree.QualifiedOverload{ + { + Schema: "crdb_internal", + Overload: &tree.Overload{}, + }, + { + Schema: "crdb_internal", + Overload: &tree.Overload{}, + }, + }, + }, + }, + }, + { + def: &tree.FunctionDefinition{Name: "fun", Definition: []*tree.Overload{{}}}, + resolved: map[string]*tree.ResolvedFunctionDefinition{ + "pg_catalog.fun": { + Name: "fun", + Overloads: []tree.QualifiedOverload{ + { + Schema: "pg_catalog", + Overload: &tree.Overload{}, + }, + }, + }, + }, + }, + { + def: &tree.FunctionDefinition{ + Name: "fun", + Definition: []*tree.Overload{{}}, + FunctionProperties: tree.FunctionProperties{AvailableOnPublicSchema: true}, + }, + resolved: map[string]*tree.ResolvedFunctionDefinition{ + "pg_catalog.fun": { + Name: "fun", + Overloads: []tree.QualifiedOverload{ + { + Schema: "pg_catalog", + Overload: &tree.Overload{}, + }, + }, + }, + "public.fun": { + Name: "fun", + Overloads: []tree.QualifiedOverload{ + { + Schema: "public", + Overload: &tree.Overload{}, + }, + }, + }, + }, + }, + } + + for i, tc := range testCases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + resolved := make(map[string]*tree.ResolvedFunctionDefinition) + addResolvedFuncDef(resolved, tc.def) + require.Equal(t, tc.resolved, resolved) + }) + } +} diff --git a/pkg/sql/sem/catid/BUILD.bazel b/pkg/sql/sem/catid/BUILD.bazel index a46d15068968..e9bf1880b56f 100644 --- a/pkg/sql/sem/catid/BUILD.bazel +++ b/pkg/sql/sem/catid/BUILD.bazel @@ -12,6 +12,7 @@ go_library( deps = [ "//pkg/sql/oidext", "//pkg/util", + "@com_github_cockroachdb_errors//:errors", "@com_github_lib_pq//oid", ], ) diff --git a/pkg/sql/sem/catid/ids.go b/pkg/sql/sem/catid/ids.go index 6e749306d4ea..2980fa754741 100644 --- a/pkg/sql/sem/catid/ids.go +++ b/pkg/sql/sem/catid/ids.go @@ -13,6 +13,7 @@ package catid import ( "github.com/cockroachdb/cockroach/pkg/sql/oidext" + "github.com/cockroachdb/errors" "github.com/lib/pq/oid" ) @@ -27,9 +28,34 @@ func (DescID) SafeValue() {} // TypeIDToOID converts a type descriptor ID into a type OID. func TypeIDToOID(id DescID) oid.Oid { + return idToUserDefinedOID(id) +} + +// FuncIDToOID converts a function descriptor ID into a function OID. +func FuncIDToOID(id DescID) oid.Oid { + return idToUserDefinedOID(id) +} + +func idToUserDefinedOID(id DescID) oid.Oid { return oid.Oid(id) + oidext.CockroachPredefinedOIDMax } +// UserDefinedOIDToID converts an oid to a descriptor id. Error is returned if +// the given oid is not user defined. +func UserDefinedOIDToID(oid oid.Oid) (DescID, error) { + if !IsOIDUserDefined(oid) { + return 0, errors.Newf("user-defined OID %d should be greater "+ + "than predefined Max: %d.", oid, oidext.CockroachPredefinedOIDMax) + } + return DescID(oid) - oidext.CockroachPredefinedOIDMax, nil +} + +// IsOIDUserDefined returns true if oid is greater than +// CockroachPredefinedOIDMax, otherwise false. +func IsOIDUserDefined(oid oid.Oid) bool { + return DescID(oid) > oidext.CockroachPredefinedOIDMax +} + // ColumnID is a custom type for Column IDs. type ColumnID uint32 diff --git a/pkg/sql/sem/tree/BUILD.bazel b/pkg/sql/sem/tree/BUILD.bazel index 4130da09c66a..dc68a0d81db2 100644 --- a/pkg/sql/sem/tree/BUILD.bazel +++ b/pkg/sql/sem/tree/BUILD.bazel @@ -177,6 +177,7 @@ go_test( "datum_test.go", "expr_test.go", "format_test.go", + "function_definition_test.go", "function_name_test.go", "indexed_vars_test.go", "interval_test.go", diff --git a/pkg/sql/sem/tree/function_definition.go b/pkg/sql/sem/tree/function_definition.go index 420e3dc6a6fb..ea45674af91d 100644 --- a/pkg/sql/sem/tree/function_definition.go +++ b/pkg/sql/sem/tree/function_definition.go @@ -13,6 +13,9 @@ package tree import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants" + "github.com/cockroachdb/cockroach/pkg/util/iterutil" + "github.com/cockroachdb/errors" "github.com/lib/pq/oid" ) @@ -20,7 +23,8 @@ import ( // overloads for a built-in function. // TODO(Chengxiong): Remove this struct entirely. Instead, use overloads from // function resolution or use "GetBuiltinProperties" if the need is to only look -// at builtin functions(there are such existing use cases). +// at builtin functions(there are such existing use cases). Also change "Name" +// of ResolvedFunctionDefinition to Name type. type FunctionDefinition struct { // Name is the short name of the function. Name string @@ -32,6 +36,28 @@ type FunctionDefinition struct { FunctionProperties } +// ResolvedFunctionDefinition is similar to FunctionDefinition but with all the +// overloads qualified with schema name. +type ResolvedFunctionDefinition struct { + // Name is the name of the function and not the name of the schema. And, it's + // not qualified. + Name string + + Overloads []QualifiedOverload +} + +// QualifiedOverload is a wrapper of Overload prefixed with a schema name. +// It indicates that the overload is defined with the specified schema. +type QualifiedOverload struct { + Schema string + *Overload +} + +// MakeQualifiedOverload creates a new QualifiedOverload. +func MakeQualifiedOverload(schema string, overload *Overload) QualifiedOverload { + return QualifiedOverload{Schema: schema, Overload: overload} +} + // FunctionProperties defines the properties of the built-in // functions that are common across all overloads. type FunctionProperties struct { @@ -166,6 +192,10 @@ func NewFunctionDefinition( // function definition resolution to interfaces defined in the SemaContext. var FunDefs map[string]*FunctionDefinition +// ResolvedBuiltinFuncDefs holds pre-allocated ResolvedFunctionDefinition +// instances. Keys of the map is schema qualified function names. +var ResolvedBuiltinFuncDefs map[string]*ResolvedFunctionDefinition + // OidToBuiltinName contains a map from the hashed OID of all builtin functions // to their name. We populate this from the pg_catalog.go file in the sql // package because of dependency issues: we can't use oidHasher from this file. @@ -179,20 +209,54 @@ func (fd *FunctionDefinition) Format(ctx *FmtCtx) { // String implements the Stringer interface. func (fd *FunctionDefinition) String() string { return AsString(fd) } -// TODO(Chengxiong): Remove this method after we moved the -// "UnsupportedWithIssue" check into function resolver implementation. -func (fd *FunctionDefinition) undefined() bool { - return fd.UnsupportedWithIssue != 0 +// Format implements the NodeFormatter interface. +func (fd *ResolvedFunctionDefinition) Format(ctx *FmtCtx) { + ctx.WriteString(fd.Name) +} + +// String implements the Stringer interface. +func (fd *ResolvedFunctionDefinition) String() string { return AsString(fd) } + +// MergeWith is used to merge two UDF definitions with same name. +func (fd *ResolvedFunctionDefinition) MergeWith( + another *ResolvedFunctionDefinition, +) (*ResolvedFunctionDefinition, error) { + if fd == nil { + return another, nil + } + if another == nil { + return fd, nil + } + + if fd.Name != another.Name { + return nil, errors.Newf("cannot merge function definition of %q with %q", fd.Name, another.Name) + } + + return &ResolvedFunctionDefinition{ + Name: fd.Name, + Overloads: combineOverloads(fd.Overloads, another.Overloads), + }, nil +} + +func combineOverloads(a, b []QualifiedOverload) []QualifiedOverload { + return append(append(make([]QualifiedOverload, 0, len(a)+len(b)), a...), b...) } // GetClass returns function class by checking each overload's Class and returns // the homogeneous Class value if all overloads are the same Class. Ambiguous // error is returned if there is any overload with different Class. -func (fd *FunctionDefinition) GetClass() (FunctionClass, error) { - if fd.undefined() { - return fd.Class, nil +// +// TODO(chengxiong,mgartner): make sure that, at places of the use cases of this +// method, function is resolved to one overload, so that we can get rid of this +// function and similar methods below. +func (fd *ResolvedFunctionDefinition) GetClass() (FunctionClass, error) { + ret := fd.Overloads[0].Class + for i := range fd.Overloads { + if fd.Overloads[i].Class != ret { + return 0, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function class on %s", fd.Name) + } } - return getFuncClass(fd.Name, fd.Definition) + return ret, nil } // GetReturnLabel returns function ReturnLabel by checking each overload and @@ -200,49 +264,104 @@ func (fd *FunctionDefinition) GetClass() (FunctionClass, error) { // Ambiguous error is returned if there is any overload has ReturnLabel of a // different length. This is good enough since we don't create UDF with // ReturnLabel. -func (fd *FunctionDefinition) GetReturnLabel() ([]string, error) { - if fd.undefined() { - return fd.ReturnLabels, nil +func (fd *ResolvedFunctionDefinition) GetReturnLabel() ([]string, error) { + ret := fd.Overloads[0].ReturnLabels + for i := range fd.Overloads { + if len(ret) != len(fd.Overloads[i].ReturnLabels) { + return nil, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function return label on %s", fd.Name) + } } - return getFuncReturnLabels(fd.Name, fd.Definition) + return ret, nil } // GetHasSequenceArguments returns function's HasSequenceArguments flag by // checking each overload's HasSequenceArguments flag. Ambiguous error is // returned if there is any overload has a different flag. -func (fd *FunctionDefinition) GetHasSequenceArguments() (bool, error) { - if fd.undefined() { - return fd.HasSequenceArguments, nil +func (fd *ResolvedFunctionDefinition) GetHasSequenceArguments() (bool, error) { + ret := fd.Overloads[0].HasSequenceArguments + for i := range fd.Overloads { + if ret != fd.Overloads[i].HasSequenceArguments { + return false, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function sequence argument on %s", fd.Name) + } } - return getHasSequenceArguments(fd.Name, fd.Definition) + return ret, nil } -func getFuncClass(fnName string, fns []*Overload) (FunctionClass, error) { - ret := fns[0].Class - for _, o := range fns { - if o.Class != ret { - return 0, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function class on %s", fnName) - } +// QualifyBuiltinFunctionDefinition qualified all overloads in a function +// definition with a schema name. Note that this function can only be used for +// builtin function. +func QualifyBuiltinFunctionDefinition( + def *FunctionDefinition, schema string, +) *ResolvedFunctionDefinition { + ret := &ResolvedFunctionDefinition{ + Name: def.Name, + Overloads: make([]QualifiedOverload, 0, len(def.Definition)), } - return ret, nil + for _, o := range def.Definition { + ret.Overloads = append( + ret.Overloads, + MakeQualifiedOverload(schema, o), + ) + } + return ret } -func getFuncReturnLabels(fnName string, fns []*Overload) ([]string, error) { - ret := fns[0].ReturnLabels - for _, o := range fns { - if len(ret) != len(o.ReturnLabels) { - return nil, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function return label on %s", fnName) - } +// GetBuiltinFuncDefinitionOrFail is similar to GetBuiltinFuncDefinition but +// returns an error if function is not found. +func GetBuiltinFuncDefinitionOrFail( + fName *FunctionName, searchPath SearchPath, +) (*ResolvedFunctionDefinition, error) { + def, err := GetBuiltinFuncDefinition(fName, searchPath) + if err != nil { + return nil, err } - return ret, nil + if def == nil { + return nil, pgerror.Newf(pgcode.UndefinedFunction, "unknown function: %s", fName.String()) + } + return def, nil } -func getHasSequenceArguments(fnName string, fns []*Overload) (bool, error) { - ret := fns[0].HasSequenceArguments - for _, o := range fns { - if ret != o.HasSequenceArguments { - return false, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function sequence argument on %s", fnName) +// GetBuiltinFuncDefinition search for a builtin function given a function name +// and a search path. If function name is prefixed, only the builtin functions +// in the specific schema are searched. Otherwise, all schemas on the given +// searchPath are searched. A nil is returned if no function is found. It's +// caller's choice to error out if function not found. +// +// In theory, this function returns an error only when the search path iterator +// errors which won't happen since the iterating function never errors out. But +// error is still checked and return from the function signature just in case +// we change the iterating function in the future. +func GetBuiltinFuncDefinition( + fName *FunctionName, searchPath SearchPath, +) (*ResolvedFunctionDefinition, error) { + if fName.ExplicitSchema { + return ResolvedBuiltinFuncDefs[fName.Schema()+"."+fName.Object()], nil + } + + // First try that if we can get function directly with the function name. + // There is a case where the part[0] of the name is a qualified string. + // TODO(Chengxiong): figure out why that could be an input. + if def, ok := ResolvedBuiltinFuncDefs[fName.Object()]; ok { + return def, nil + } + + // Then try if it's in pg_catalog. + if def, ok := ResolvedBuiltinFuncDefs[catconstants.PgCatalogName+"."+fName.Object()]; ok { + return def, nil + } + + // If not in pg_catalog, go through search path. + var resolvedDef *ResolvedFunctionDefinition + if err := searchPath.IterateSearchPath(func(schema string) error { + fullName := schema + "." + fName.Object() + if def, ok := ResolvedBuiltinFuncDefs[fullName]; ok { + resolvedDef = def + return iterutil.StopIteration() } + return nil + }); err != nil { + return nil, err } - return ret, nil + + return resolvedDef, nil } diff --git a/pkg/sql/sem/tree/function_definition_test.go b/pkg/sql/sem/tree/function_definition_test.go new file mode 100644 index 000000000000..6ab9c4010295 --- /dev/null +++ b/pkg/sql/sem/tree/function_definition_test.go @@ -0,0 +1,87 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tree_test + +import ( + "testing" + + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +func TestBuiltinFunctionResolver(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testCases := []struct { + testName string + fnName tree.UnresolvedName + expectedSchema string + expectNoFound bool + }{ + { + testName: "not found", + fnName: tree.UnresolvedName{NumParts: 1, Parts: tree.NameParts{"whathmm", "", "", ""}}, + expectNoFound: true, + }, + { + testName: "default to use pg_catalog schema", + fnName: tree.UnresolvedName{NumParts: 1, Parts: tree.NameParts{"lower", "", "", ""}}, + expectedSchema: "pg_catalog", + }, + { + testName: "explicit to use pg_catalog schema", + fnName: tree.UnresolvedName{NumParts: 2, Parts: tree.NameParts{"lower", "pg_catalog", "", ""}}, + expectedSchema: "pg_catalog", + }, + { + testName: "explicit to use public schema", + fnName: tree.UnresolvedName{NumParts: 2, Parts: tree.NameParts{"st_makeline", "public", "", ""}}, + expectedSchema: "public", + }, + { + testName: "explicit to use public schema but not available", + fnName: tree.UnresolvedName{NumParts: 2, Parts: tree.NameParts{"lower", "public", "", ""}}, + expectNoFound: true, + }, + { + testName: "explicit to use crdb_internal", + fnName: tree.UnresolvedName{NumParts: 2, Parts: tree.NameParts{"json_to_pb", "crdb_internal", "", ""}}, + expectedSchema: "crdb_internal", + }, + { + testName: "implicit to use crdb_internal", + fnName: tree.UnresolvedName{NumParts: 1, Parts: tree.NameParts{"json_to_pb", "", "", ""}}, + expectedSchema: "crdb_internal", + }, + } + + path := sessiondata.MakeSearchPath([]string{"crdb_internal"}) + + for _, tc := range testCases { + t.Run(tc.testName, func(t *testing.T) { + fnName, err := tc.fnName.ToFunctionName() + require.NoError(t, err) + funcDef, err := tree.GetBuiltinFuncDefinition(fnName, &path) + require.NoError(t, err) + if tc.expectNoFound { + require.Nil(t, funcDef) + return + } + for _, o := range funcDef.Overloads { + require.Equal(t, tc.expectedSchema, o.Schema) + } + }) + } +} diff --git a/pkg/sql/sem/tree/function_name.go b/pkg/sql/sem/tree/function_name.go index fb7c61bb7309..8a62ca22caf8 100644 --- a/pkg/sql/sem/tree/function_name.go +++ b/pkg/sql/sem/tree/function_name.go @@ -11,10 +11,12 @@ package tree import ( + "context" "fmt" "github.com/cockroachdb/errors" "github.com/cockroachdb/redact" + "github.com/lib/pq/oid" ) // Function names are used in expressions in the FuncExpr node. @@ -35,7 +37,15 @@ type FunctionReferenceResolver interface { // the input of this method, so that we can try to narrow down the scope of // overloads a bit earlier and decrease the possibility of ambiguous error // on function properties. - ResolveFunction(name *UnresolvedName, path SearchPath) (*FunctionDefinition, error) + ResolveFunction( + ctx context.Context, name *UnresolvedName, path SearchPath, + ) (*ResolvedFunctionDefinition, error) + + // ResolveFunctionByOID looks up a function overload by using a given oid. + // Error is thrown if there is no function with the same oid. + ResolveFunctionByOID( + ctx context.Context, oid oid.Oid, + ) (*Overload, error) } // ResolvableFunctionReference implements the editable reference call of a @@ -62,7 +72,11 @@ func (ref *ResolvableFunctionReference) Resolve( // Use the default resolution logic if there is no resolver. fd, err = t.ResolveFunction(path) } else { - fd, err = resolver.ResolveFunction(t, path) + // TODO(Chengxiong): plumb a context through when fixing all use cases of + // ResolvableFunctionReference.Resolve in later commits. + // TODO(Chengxiong): fix ResolvableFunctionReference.Resolve to return + // ResolvedFunctionDefinition. + // fd, err = resolver.ResolveFunction(context.Background(), t, path) } if err != nil { return nil, err @@ -77,6 +91,9 @@ func (ref *ResolvableFunctionReference) Resolve( // WrapFunction creates a new ResolvableFunctionReference holding a pre-resolved // function from a built-in function name. Helper for grammar rules and // execbuilder. +// +// TODO(Chengxiong): get rid of FunctionDefinition entirely and use +// ResolvedFunctionDefinition instead. func WrapFunction(n string) ResolvableFunctionReference { fd, ok := FunDefs[n] if !ok { @@ -94,6 +111,8 @@ type FunctionReference interface { var _ FunctionReference = &UnresolvedName{} var _ FunctionReference = &FunctionDefinition{} +var _ FunctionReference = &ResolvedFunctionDefinition{} -func (*UnresolvedName) functionReference() {} -func (*FunctionDefinition) functionReference() {} +func (*UnresolvedName) functionReference() {} +func (*FunctionDefinition) functionReference() {} +func (*ResolvedFunctionDefinition) functionReference() {} diff --git a/pkg/sql/sem/tree/name_part.go b/pkg/sql/sem/tree/name_part.go index 69122960693a..b5065e5bf550 100644 --- a/pkg/sql/sem/tree/name_part.go +++ b/pkg/sql/sem/tree/name_part.go @@ -227,3 +227,13 @@ func (u *UnresolvedName) ToUnresolvedObjectName(idx AnnotationIdx) (*UnresolvedO idx, ) } + +// ToFunctionName converts an UnresolvedName to a FunctionName. +func (u *UnresolvedName) ToFunctionName() (*FunctionName, error) { + un, err := u.ToUnresolvedObjectName(NoAnnotation) + if err != nil { + return nil, err + } + fn := un.ToFunctionName() + return &fn, nil +} diff --git a/pkg/sql/sem/tree/overload.go b/pkg/sql/sem/tree/overload.go index 8f4b6c7c0637..c579842ce647 100644 --- a/pkg/sql/sem/tree/overload.go +++ b/pkg/sql/sem/tree/overload.go @@ -110,7 +110,8 @@ type Overload struct { AggregateFunc AggregateOverload WindowFunc WindowOverload - // Only one of the following six attributes can be set. + // Only one of the "Fn", "FnWithExprs", "Generate", "GeneratorWithExprs", + // "SQLFn" and "Body" attributes can be set. // Fn is the normal builtin implementation function. It's for functions that // take in Datums and return a Datum. @@ -133,9 +134,6 @@ type Overload struct { // statement which will be executed as a common table expression in the query. SQLFn SQLFnOverload - // Body is the SQL string body of a user-defined function. - Body string - // OnTypeCheck is incremented every time this overload is type checked. OnTypeCheck func() @@ -175,6 +173,19 @@ type Overload struct { // FunctionProperties are the properties of this overload. FunctionProperties + + // IsUDF is set to true when this is a user-defined function overload. + // Note: Body can be empty string even IsUDF is true. + IsUDF bool + // UDFContainsOnlySignature is only set to true for Overload signatures cached + // in a Schema descriptor, which means that the full UDF descriptor need to be + // fetched to get more info, e.g. function Body. + UDFContainsOnlySignature bool + // Body is the SQL string body of a user-defined function. + Body string + // ReturnSet is set to true when a user-defined function is defined to return + // a set of values. + ReturnSet bool } // params implements the overloadImpl interface. @@ -289,11 +300,18 @@ var _ TypeList = VariadicType{} // ArgTypes is very similar to ArgTypes except it allows keeping a string // name for each argument as well and using those when printing the // human-readable signature. +// TODO(chengxiong): change ArgTypes to []ArgType. type ArgTypes []struct { Name string Typ *types.T } +// ArgType encapsulate an argument name and type. +type ArgType struct { + Name string + Typ *types.T +} + // Match is part of the TypeList interface. func (a ArgTypes) Match(types []*types.T) bool { if len(types) != len(a) { diff --git a/pkg/sql/types/BUILD.bazel b/pkg/sql/types/BUILD.bazel index b7610f091b8c..9bcc4eebd8f1 100644 --- a/pkg/sql/types/BUILD.bazel +++ b/pkg/sql/types/BUILD.bazel @@ -22,6 +22,7 @@ go_library( "//pkg/sql/oidext", "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", + "//pkg/sql/sem/catid", "//pkg/util/errorutil/unimplemented", "//pkg/util/protoutil", "@com_github_cockroachdb_errors//:errors", diff --git a/pkg/sql/types/types.go b/pkg/sql/types/types.go index 5558659ae4f4..9872796c589b 100644 --- a/pkg/sql/types/types.go +++ b/pkg/sql/types/types.go @@ -22,6 +22,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/oidext" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/errors" @@ -1381,8 +1382,7 @@ func (t *T) UserDefined() bool { // IsOIDUserDefinedType returns whether or not o corresponds to a user // defined type. func IsOIDUserDefinedType(o oid.Oid) bool { - // Types with OIDs larger than the predefined max are user defined. - return o > oidext.CockroachPredefinedOIDMax + return catid.IsOIDUserDefined(o) } var familyNames = map[Family]string{