Skip to content

Commit

Permalink
sql: implement function resolver for general purpose
Browse files Browse the repository at this point in the history
This commit modifies the function resolver interface to
have two methods. One to resolve overloads by function name.
Another to get function overload by OID. An implementation
of the new interface is added by extending the schema
resolver. A cache of ResolvedFunctionDefinition for builtin
functions tree.ResolvedBuiltinFuncDefs is constructed at
init time to avoid allocating for builtin functions.

Use cases of function resolution will be changed to use the
new interface in following commits.

Release note: None
  • Loading branch information
chengxiong-ruan committed Aug 4, 2022
1 parent 60c3679 commit ddabd18
Show file tree
Hide file tree
Showing 30 changed files with 1,193 additions and 89 deletions.
1 change: 1 addition & 0 deletions pkg/sql/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ go_test(
"explain_bundle_test.go",
"explain_test.go",
"explain_tree_test.go",
"function_resolver_test.go",
"grant_revoke_test.go",
"grant_role_test.go",
"index_mutation_test.go",
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/catalog/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ type FunctionDescriptor interface {

// FuncDesc returns the function's underlying protobuf descriptor.
FuncDesc() *descpb.FunctionDescriptor

// ToOverload converts the function descriptor to tree.Overload object which
// can be used for execution.
ToOverload() (ret *tree.Overload, err error)
}

// FilterDescriptorState inspects the state of a given descriptor and returns an
Expand Down
7 changes: 7 additions & 0 deletions pkg/sql/catalog/funcdesc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ go_library(
"//pkg/sql/pgwire/pgerror",
"//pkg/sql/privilege",
"//pkg/sql/schemachanger/scpb",
"//pkg/sql/sem/catid",
"//pkg/sql/sem/eval",
"//pkg/sql/sem/tree",
"//pkg/sql/sem/volatility",
"//pkg/sql/types",
"//pkg/util/hlc",
"//pkg/util/protoutil",
"@com_github_cockroachdb_errors//:errors",
"@com_github_lib_pq//oid",
],
)

Expand All @@ -45,8 +48,12 @@ go_test(
"//pkg/sql/catalog/tabledesc",
"//pkg/sql/catalog/typedesc",
"//pkg/sql/privilege",
"//pkg/sql/sem/tree",
"//pkg/sql/sem/volatility",
"//pkg/sql/types",
"//pkg/util/leaktest",
"@com_github_lib_pq//oid",
"@com_github_stretchr_testify//require",
],
)

Expand Down
77 changes: 77 additions & 0 deletions pkg/sql/catalog/funcdesc/func_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scpb"
"github.com/cockroachdb/cockroach/pkg/sql/sem/catid"
"github.com/cockroachdb/cockroach/pkg/sql/sem/eval"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sem/volatility"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/hlc"
"github.com/cockroachdb/errors"
"github.com/lib/pq/oid"
)

var _ catalog.Descriptor = (*immutable)(nil)
Expand Down Expand Up @@ -462,3 +466,76 @@ func (desc *immutable) ContainsUserDefinedTypes() bool {
}
return desc.ReturnType.Type.UserDefined()
}

func (desc *immutable) ToOverload() (ret *tree.Overload, err error) {
ret = &tree.Overload{
Oid: catid.FuncIDToOID(desc.ID),
ReturnType: tree.FixedReturnType(desc.ReturnType.Type),
ReturnSet: desc.ReturnType.ReturnSet,
Body: desc.FunctionBody,
IsUDF: true,
}

argTypes := make(tree.ArgTypes, 0, len(desc.Args))
for _, arg := range desc.Args {
argTypes = append(
argTypes,
tree.ArgType{Name: arg.Name, Typ: arg.Type},
)
}
ret.Types = argTypes
ret.Volatility, err = desc.getOverloadVolatility()
if err != nil {
return nil, err
}
ret.NullableArgs, err = desc.getOverloadNullableArgs()
if err != nil {
return nil, err
}

return ret, nil
}

func (desc *immutable) getOverloadVolatility() (volatility.V, error) {
var ret volatility.V
switch desc.Volatility {
case catpb.Function_VOLATILE:
ret = volatility.Volatile
case catpb.Function_STABLE:
ret = volatility.Stable
case catpb.Function_IMMUTABLE:
ret = volatility.Immutable
default:
return 0, errors.Newf("unknown volatility")
}
if desc.LeakProof {
if desc.Volatility != catpb.Function_IMMUTABLE {
return 0, errors.Newf("function %d is leakproof but not immutable", desc.ID)
}
ret = volatility.Leakproof
}
return ret, nil
}

func (desc *immutable) getOverloadNullableArgs() (bool, error) {
switch desc.NullInputBehavior {
case catpb.Function_CALLED_ON_NULL_INPUT:
return true, nil
case catpb.Function_RETURNS_NULL_ON_NULL_INPUT, catpb.Function_STRICT:
return false, nil
default:
return false, errors.Newf("unknown null input behavior")
}
}

// UserDefinedFunctionOIDToID converts a UDF OID into a descriptor ID. OID of a
// UDF must be greater CockroachPredefinedOIDMax. The function returns an error
// if the given OID is less than or equal to CockroachPredefinedOIDMax.
func UserDefinedFunctionOIDToID(oid oid.Oid) (descpb.ID, error) {
return catid.UserDefinedOIDToID(oid)
}

// IsOIDUserDefinedFunc returns true if an oid is a user-defined function oid.
func IsOIDUserDefinedFunc(oid oid.Oid) bool {
return catid.IsOIDUserDefined(oid)
}
152 changes: 152 additions & 0 deletions pkg/sql/catalog/funcdesc/func_desc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package funcdesc_test
import (
"context"
"fmt"
"strconv"
"testing"

"github.com/cockroachdb/cockroach/pkg/clusterversion"
Expand All @@ -28,8 +29,12 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/catalog/tabledesc"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/typedesc"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sem/volatility"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/lib/pq/oid"
"github.com/stretchr/testify/require"
)

func TestValidateFuncDesc(t *testing.T) {
Expand Down Expand Up @@ -462,3 +467,150 @@ func TestValidateFuncDesc(t *testing.T) {
}
}
}

func TestToOverload(t *testing.T) {
testCases := []struct {
desc descpb.FunctionDescriptor
expected tree.Overload
err string
}{
{
// Test all fields are properly valued.
desc: descpb.FunctionDescriptor{
ID: 1,
Args: []descpb.FunctionDescriptor_Argument{{Name: "arg1", Type: types.Int}},
ReturnType: descpb.FunctionDescriptor_ReturnType{Type: types.Int, ReturnSet: true},
LeakProof: true,
Volatility: catpb.Function_IMMUTABLE,
NullInputBehavior: catpb.Function_RETURNS_NULL_ON_NULL_INPUT,
FunctionBody: "ANY QUERIES",
},
expected: tree.Overload{
Oid: oid.Oid(100001),
Types: tree.ArgTypes{
{Name: "arg1", Typ: types.Int},
},
ReturnType: tree.FixedReturnType(types.Int),
ReturnSet: true,
Volatility: volatility.Leakproof,
Body: "ANY QUERIES",
IsUDF: true,
},
},
{
// Test ReturnSet matters.
desc: descpb.FunctionDescriptor{
ID: 1,
Args: []descpb.FunctionDescriptor_Argument{{Name: "arg1", Type: types.Int}},
ReturnType: descpb.FunctionDescriptor_ReturnType{Type: types.Int, ReturnSet: false},
LeakProof: true,
Volatility: catpb.Function_IMMUTABLE,
NullInputBehavior: catpb.Function_RETURNS_NULL_ON_NULL_INPUT,
FunctionBody: "ANY QUERIES",
},
expected: tree.Overload{
Oid: oid.Oid(100001),
Types: tree.ArgTypes{
{Name: "arg1", Typ: types.Int},
},
ReturnType: tree.FixedReturnType(types.Int),
ReturnSet: false,
Volatility: volatility.Leakproof,
Body: "ANY QUERIES",
IsUDF: true,
},
},
{
// Test Volatility matters.
desc: descpb.FunctionDescriptor{
ID: 1,
Args: []descpb.FunctionDescriptor_Argument{{Name: "arg1", Type: types.Int}},
ReturnType: descpb.FunctionDescriptor_ReturnType{Type: types.Int, ReturnSet: true},
LeakProof: false,
Volatility: catpb.Function_STABLE,
NullInputBehavior: catpb.Function_RETURNS_NULL_ON_NULL_INPUT,
FunctionBody: "ANY QUERIES",
},
expected: tree.Overload{
Oid: oid.Oid(100001),
Types: tree.ArgTypes{
{Name: "arg1", Typ: types.Int},
},
ReturnType: tree.FixedReturnType(types.Int),
ReturnSet: true,
Volatility: volatility.Stable,
Body: "ANY QUERIES",
IsUDF: true,
},
},
{
// Test NullableArgs matters.
desc: descpb.FunctionDescriptor{
ID: 1,
Args: []descpb.FunctionDescriptor_Argument{{Name: "arg1", Type: types.Int}},
ReturnType: descpb.FunctionDescriptor_ReturnType{Type: types.Int, ReturnSet: true},
LeakProof: true,
Volatility: catpb.Function_IMMUTABLE,
NullInputBehavior: catpb.Function_CALLED_ON_NULL_INPUT,
FunctionBody: "ANY QUERIES",
},
expected: tree.Overload{
Oid: oid.Oid(100001),
Types: tree.ArgTypes{
{Name: "arg1", Typ: types.Int},
},
ReturnType: tree.FixedReturnType(types.Int),
ReturnSet: true,
Volatility: volatility.Leakproof,
Body: "ANY QUERIES",
IsUDF: true,
NullableArgs: true,
},
},
{
// Test failure on non-immutable but leakproof function.
desc: descpb.FunctionDescriptor{
ID: 1,
Args: []descpb.FunctionDescriptor_Argument{{Name: "arg1", Type: types.Int}},
ReturnType: descpb.FunctionDescriptor_ReturnType{Type: types.Int, ReturnSet: true},
LeakProof: true,
Volatility: catpb.Function_STABLE,
NullInputBehavior: catpb.Function_RETURNS_NULL_ON_NULL_INPUT,
FunctionBody: "ANY QUERIES",
},
expected: tree.Overload{
Oid: oid.Oid(100001),
Types: tree.ArgTypes{
{Name: "arg1", Typ: types.Int},
},
ReturnType: tree.FixedReturnType(types.Int),
ReturnSet: true,
Volatility: volatility.Leakproof,
Body: "ANY QUERIES",
IsUDF: true,
},
err: "function 1 is leakproof but not immutable",
},
}

for i, tc := range testCases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
desc := funcdesc.NewBuilder(&tc.desc).BuildImmutable().(catalog.FunctionDescriptor)
overload, err := desc.ToOverload()
if tc.err == "" {
require.NoError(t, err)
} else {
require.Equal(t, tc.err, err.Error())
return
}

returnType := overload.ReturnType([]tree.TypedExpr{})
expectedReturnType := tc.expected.ReturnType([]tree.TypedExpr{})
require.Equal(t, expectedReturnType, returnType)
// Set ReturnType(which is function) to nil for easier equality check.
overload.ReturnType = nil
tc.expected.ReturnType = nil
require.Equal(t, tc.expected, *overload)
})
}
}
14 changes: 13 additions & 1 deletion pkg/sql/catalog/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@

package catalog

import "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb"
import (
"github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
)

// SchemaDescriptor encapsulates the basic
type SchemaDescriptor interface {
Expand All @@ -29,6 +32,15 @@ type SchemaDescriptor interface {

// GetFunction returns a list of function overloads given a name.
GetFunction(name string) (descpb.SchemaDescriptor_Function, bool)

// GetResolvedFuncDefinition returns a ResolvedFunctionDefinition given a
// function name. This is needed by function resolution and expression type
// checking during which candidate function overloads are searched for the
// best match. Only function signatures are needed during this process. Schema
// stores all the signatures of the functions created under it and this method
// returns a collection of overloads with the same function name, each
// overload is prefixed with the same schema name.
GetResolvedFuncDefinition(name string) (*tree.ResolvedFunctionDefinition, bool)
}

// ResolvedSchemaKind is an enum that represents what kind of schema
Expand Down
2 changes: 2 additions & 0 deletions pkg/sql/catalog/schemadesc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ go_library(
"//pkg/sql/privilege",
"//pkg/sql/schemachanger/scpb",
"//pkg/sql/sem/catconstants",
"//pkg/sql/sem/catid",
"//pkg/sql/sem/eval",
"//pkg/sql/sem/tree",
"//pkg/sql/types",
"//pkg/util/hlc",
"//pkg/util/log",
"//pkg/util/protoutil",
Expand Down
Loading

0 comments on commit ddabd18

Please sign in to comment.