diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index c98dcdb1ec483..b5cd5ff84300f 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -52,6 +52,7 @@ go_library( "explain.go", "expr_to_pb.go", "expression.go", + "extension.go", "function_traits.go", "helper.go", "partition_pruner.go", @@ -66,6 +67,7 @@ go_library( deps = [ "//config", "//errno", + "//extension", "//kv", "//parser", "//parser/ast", @@ -97,6 +99,7 @@ go_library( "//util/parser", "//util/plancodec", "//util/printer", + "//util/sem", "//util/set", "//util/size", "//util/sqlexec", diff --git a/expression/builtin.go b/expression/builtin.go index 851b29196f71d..0ad5756493336 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -944,6 +944,13 @@ func GetBuiltinList() []string { } res = append(res, funcName) } + + extensionFuncs.Range(func(key, _ any) bool { + funcName := key.(string) + res = append(res, funcName) + return true + }) + slices.Sort(res) return res } diff --git a/expression/extension.go b/expression/extension.go new file mode 100644 index 0000000000000..7ee0c3057e49d --- /dev/null +++ b/expression/extension.go @@ -0,0 +1,185 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "strings" + "sync" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/extension" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/privilege" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sem" +) + +var extensionFuncs sync.Map + +func registerExtensionFunc(def *extension.FunctionDef) error { + if def == nil { + return errors.New("extension function def is nil") + } + + if def.Name == "" { + return errors.New("extension function name should not be empty") + } + + lowerName := strings.ToLower(def.Name) + if _, ok := funcs[lowerName]; ok { + return errors.Errorf("extension function name '%s' conflict with builtin", def.Name) + } + + class, err := newExtensionFuncClass(def) + if err != nil { + return err + } + + _, exist := extensionFuncs.LoadOrStore(lowerName, class) + if exist { + return errors.Errorf("duplicated extension function name '%s'", def.Name) + } + + return nil +} + +func removeExtensionFunc(name string) { + extensionFuncs.Delete(name) +} + +type extensionFuncClass struct { + baseFunctionClass + funcDef extension.FunctionDef + flen int +} + +func newExtensionFuncClass(def *extension.FunctionDef) (*extensionFuncClass, error) { + var flen int + switch def.EvalTp { + case types.ETString: + flen = mysql.MaxFieldVarCharLength + if def.EvalStringFunc == nil { + return nil, errors.New("eval function is nil") + } + case types.ETInt: + flen = mysql.MaxIntWidth + if def.EvalIntFunc == nil { + return nil, errors.New("eval function is nil") + } + default: + return nil, errors.Errorf("unsupported extension function ret type: '%v'", def.EvalTp) + } + + return &extensionFuncClass{ + baseFunctionClass: baseFunctionClass{def.Name, len(def.ArgTps), len(def.ArgTps)}, + flen: flen, + funcDef: *def, + }, nil +} + +func (c *extensionFuncClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.checkPrivileges(ctx); err != nil { + return nil, err + } + + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, c.funcDef.EvalTp, c.funcDef.ArgTps...) + if err != nil { + return nil, err + } + bf.tp.SetFlen(c.flen) + sig := &extensionFuncSig{bf, c.funcDef} + return sig, nil +} + +func (c *extensionFuncClass) checkPrivileges(ctx sessionctx.Context) error { + privs := c.funcDef.RequireDynamicPrivileges + if semPrivs := c.funcDef.SemRequireDynamicPrivileges; len(semPrivs) > 0 && sem.IsEnabled() { + privs = semPrivs + } + + if len(privs) == 0 { + return nil + } + + manager := privilege.GetPrivilegeManager(ctx) + activeRoles := ctx.GetSessionVars().ActiveRoles + + for _, priv := range privs { + if !manager.RequestDynamicVerification(activeRoles, priv, false) { + msg := priv + if !sem.IsEnabled() { + msg = "SUPER or " + msg + } + return errSpecificAccessDenied.GenWithStackByArgs(msg) + } + } + + return nil +} + +var _ extension.FunctionContext = &extensionFuncSig{} + +type extensionFuncSig struct { + baseBuiltinFunc + extension.FunctionDef +} + +func (b *extensionFuncSig) Clone() builtinFunc { + newSig := &extensionFuncSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.FunctionDef = b.FunctionDef + return newSig +} + +func (b *extensionFuncSig) evalString(row chunk.Row) (string, bool, error) { + if b.EvalTp == types.ETString { + return b.EvalStringFunc(b, row) + } + return b.baseBuiltinFunc.evalString(row) +} + +func (b *extensionFuncSig) evalInt(row chunk.Row) (int64, bool, error) { + if b.EvalTp == types.ETInt { + return b.EvalIntFunc(b, row) + } + return b.baseBuiltinFunc.evalInt(row) +} + +func (b *extensionFuncSig) EvalArgs(row chunk.Row) ([]types.Datum, error) { + if len(b.args) == 0 { + return nil, nil + } + + result := make([]types.Datum, 0, len(b.args)) + for _, arg := range b.args { + val, err := arg.Eval(row) + if err != nil { + return nil, err + } + result = append(result, val) + } + + return result, nil +} + +func init() { + extension.RegisterExtensionFunc = registerExtensionFunc + extension.RemoveExtensionFunc = removeExtensionFunc +} diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 2483d342b662b..9e3e343b2851c 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -193,6 +193,13 @@ func newFunctionImpl(ctx sessionctx.Context, fold int, funcName string, retType } } fc, ok := funcs[funcName] + if !ok { + if extFunc, exist := extensionFuncs.Load(funcName); exist { + fc = extFunc.(functionClass) + ok = true + } + } + if !ok { db := ctx.GetSessionVars().CurrentDB if db == "" { diff --git a/extension/BUILD.bazel b/extension/BUILD.bazel index 95ffeb73fd60f..abd5b7121b6b7 100644 --- a/extension/BUILD.bazel +++ b/extension/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "extension", srcs = [ "extensions.go", + "function.go", "manifest.go", "registry.go", "util.go", @@ -12,6 +13,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//sessionctx/variable", + "//types", "//util/chunk", "@com_github_pingcap_errors//:errors", ], @@ -21,15 +23,21 @@ go_test( name = "extension_test", srcs = [ "bootstrap_test.go", + "function_test.go", "main_test.go", "registry_test.go", ], embed = [":extension"], deps = [ + "//expression", + "//parser/auth", "//privilege/privileges", "//sessionctx/variable", "//testkit", "//testkit/testsetup", + "//types", + "//util/chunk", + "//util/sem", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", "@org_uber_go_goleak//:goleak", diff --git a/extension/extensionimpl/BUILD.bazel b/extension/extensionimpl/BUILD.bazel index 956e6db58d463..4719d7777b43c 100644 --- a/extension/extensionimpl/BUILD.bazel +++ b/extension/extensionimpl/BUILD.bazel @@ -11,5 +11,6 @@ go_library( "//kv", "//util/chunk", "//util/sqlexec", + "@com_github_pingcap_errors//:errors", ], ) diff --git a/extension/function.go b/extension/function.go new file mode 100644 index 0000000000000..d6ca6890b837e --- /dev/null +++ b/extension/function.go @@ -0,0 +1,48 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extension + +import ( + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" +) + +// FunctionContext is a interface to provide context to the custom function +type FunctionContext interface { + EvalArgs(row chunk.Row) ([]types.Datum, error) +} + +// FunctionDef is the definition for the custom function +type FunctionDef struct { + Name string + EvalTp types.EvalType + ArgTps []types.EvalType + // EvalStringFunc is the eval function when `EvalTp` is `types.ETString` + EvalStringFunc func(ctx FunctionContext, row chunk.Row) (string, bool, error) + // EvalIntFunc is the eval function when `EvalTp` is `types.ETInt` + EvalIntFunc func(ctx FunctionContext, row chunk.Row) (int64, bool, error) + // RequireDynamicPrivileges is the dynamic privileges needed to invoke the function + // If `RequireDynamicPrivileges` is empty, it means every one can invoke this function + RequireDynamicPrivileges []string + // SemRequireDynamicPrivileges is the dynamic privileges needed to invoke the function in sem mode + // If `SemRequireDynamicPrivileges` is empty, `DynamicPrivileges` will be used in sem mode + SemRequireDynamicPrivileges []string +} + +// RegisterExtensionFunc is to avoid dependency cycle +var RegisterExtensionFunc func(*FunctionDef) error + +// RemoveExtensionFunc is to avoid dependency cycle +var RemoveExtensionFunc func(string) diff --git a/extension/function_test.go b/extension/function_test.go new file mode 100644 index 0000000000000..7fb36600b0ac9 --- /dev/null +++ b/extension/function_test.go @@ -0,0 +1,310 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extension_test + +import ( + "fmt" + "sort" + "strings" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/extension" + "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sem" + "github.com/stretchr/testify/require" +) + +var customFunc1 = &extension.FunctionDef{ + Name: "custom_func1", + EvalTp: types.ETString, + ArgTps: []types.EvalType{ + types.ETInt, + types.ETString, + }, + EvalStringFunc: func(ctx extension.FunctionContext, row chunk.Row) (string, bool, error) { + args, err := ctx.EvalArgs(row) + if err != nil { + return "", false, err + } + + if args[1].GetString() == "error" { + return "", false, errors.New("custom error") + } + + return fmt.Sprintf("%d,%s", args[0].GetInt64(), args[1].GetString()), false, nil + }, +} + +var customFunc2 = &extension.FunctionDef{ + Name: "custom_func2", + EvalTp: types.ETInt, + ArgTps: []types.EvalType{ + types.ETInt, + types.ETInt, + }, + EvalIntFunc: func(ctx extension.FunctionContext, row chunk.Row) (int64, bool, error) { + args, err := ctx.EvalArgs(row) + if err != nil { + return 0, false, err + } + return args[0].GetInt64()*100 + args[1].GetInt64(), false, nil + }, +} + +func TestInvokeFunc(t *testing.T) { + defer func() { + extension.Reset() + }() + + extension.Reset() + orgFuncList := expression.GetBuiltinList() + checkFuncList(t, orgFuncList) + require.NoError(t, extension.Register("test", extension.WithCustomFunctions([]*extension.FunctionDef{ + customFunc1, + customFunc2, + }))) + require.NoError(t, extension.Setup()) + checkFuncList(t, orgFuncList, "custom_func1", "custom_func2") + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustQuery("select custom_func1(1, 'abc')").Check(testkit.Rows("1,abc")) + tk.MustQuery("select custom_func2(7, 8)").Check(testkit.Rows("708")) + require.EqualError(t, tk.QueryToErr("select custom_func1(1, 'error')"), "custom error") + require.EqualError(t, tk.ExecToErr("select custom_func1(1)"), "[expression:1582]Incorrect parameter count in the call to native function 'custom_func1'") + + extension.Reset() + checkFuncList(t, orgFuncList) + store2 := testkit.CreateMockStore(t) + tk2 := testkit.NewTestKit(t, store2) + tk2.MustExec("use test") + require.EqualError(t, tk2.ExecToErr("select custom_func1(1, 'abc')"), "[expression:1305]FUNCTION test.custom_func1 does not exist") + require.EqualError(t, tk2.ExecToErr("select custom_func2(1, 2)"), "[expression:1305]FUNCTION test.custom_func2 does not exist") +} + +func TestRegisterFunc(t *testing.T) { + defer func() { + extension.Reset() + }() + + // nil func + extension.Reset() + orgFuncList := expression.GetBuiltinList() + checkFuncList(t, orgFuncList) + require.NoError(t, extension.Register("test", extension.WithCustomFunctions([]*extension.FunctionDef{ + customFunc1, + nil, + }))) + require.EqualError(t, extension.Setup(), "extension function def is nil") + checkFuncList(t, orgFuncList) + + // dup name with builtin + extension.Reset() + var def extension.FunctionDef + def = *customFunc1 + def.Name = "substring" + require.NoError(t, extension.Register("test", extension.WithCustomFunctions([]*extension.FunctionDef{ + customFunc1, + &def, + }))) + require.EqualError(t, extension.Setup(), "extension function name 'substring' conflict with builtin") + checkFuncList(t, orgFuncList) + + // empty func name + extension.Reset() + def = *customFunc1 + def.Name = "" + require.NoError(t, extension.Register("test", extension.WithCustomFunctions([]*extension.FunctionDef{ + &def, + }))) + require.EqualError(t, extension.Setup(), "extension function name should not be empty") + checkFuncList(t, orgFuncList) + + // dup name with other func in one extension + extension.Reset() + require.NoError(t, extension.Register("test", extension.WithCustomFunctions([]*extension.FunctionDef{ + customFunc1, + customFunc1, + }))) + require.EqualError(t, extension.Setup(), "duplicated extension function name 'custom_func1'") + checkFuncList(t, orgFuncList) + + // dup name with other func in different extension + extension.Reset() + require.NoError(t, extension.Register("test1", extension.WithCustomFunctions([]*extension.FunctionDef{ + customFunc1, + }))) + require.NoError(t, extension.Register("test2", extension.WithCustomFunctions([]*extension.FunctionDef{ + customFunc1, + }))) + require.EqualError(t, extension.Setup(), "duplicated extension function name 'custom_func1'") + checkFuncList(t, orgFuncList) +} + +func checkFuncList(t *testing.T, orgList []string, customFuncs ...string) { + for _, name := range orgList { + require.False(t, strings.HasPrefix(name, "custom_"), name) + } + + checkList := make([]string, 0, len(orgList)+len(customFuncs)) + checkList = append(checkList, orgList...) + checkList = append(checkList, customFuncs...) + sort.Strings(checkList) + require.Equal(t, checkList, expression.GetBuiltinList()) +} + +func TestFuncPrivilege(t *testing.T) { + defer func() { + extension.Reset() + sem.Disable() + }() + + extension.Reset() + require.NoError(t, extension.Register("test", + extension.WithCustomFunctions([]*extension.FunctionDef{ + { + Name: "custom_no_priv_func", + EvalTp: types.ETString, + EvalStringFunc: func(ctx extension.FunctionContext, row chunk.Row) (string, bool, error) { + return "zzz", false, nil + }, + }, + { + Name: "custom_only_dyn_priv_func", + EvalTp: types.ETString, + RequireDynamicPrivileges: []string{"CUSTOM_DYN_PRIV_1"}, + EvalStringFunc: func(ctx extension.FunctionContext, row chunk.Row) (string, bool, error) { + return "abc", false, nil + }, + }, + { + Name: "custom_only_sem_dyn_priv_func", + EvalTp: types.ETString, + SemRequireDynamicPrivileges: []string{"RESTRICTED_CUSTOM_DYN_PRIV_2"}, + EvalStringFunc: func(ctx extension.FunctionContext, row chunk.Row) (string, bool, error) { + return "def", false, nil + }, + }, + { + Name: "custom_both_dyn_priv_func", + EvalTp: types.ETString, + RequireDynamicPrivileges: []string{"CUSTOM_DYN_PRIV_1"}, + SemRequireDynamicPrivileges: []string{"RESTRICTED_CUSTOM_DYN_PRIV_2"}, + EvalStringFunc: func(ctx extension.FunctionContext, row chunk.Row) (string, bool, error) { + return "ghi", false, nil + }, + }, + }), + extension.WithCustomDynPrivs([]string{ + "CUSTOM_DYN_PRIV_1", + "RESTRICTED_CUSTOM_DYN_PRIV_2", + }), + )) + require.NoError(t, extension.Setup()) + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + tk.MustExec("create user u1@localhost") + + tk.MustExec("create user u2@localhost") + tk.MustExec("GRANT CUSTOM_DYN_PRIV_1 on *.* TO u2@localhost") + + tk.MustExec("create user u3@localhost") + tk.MustExec("GRANT RESTRICTED_CUSTOM_DYN_PRIV_2 on *.* TO u3@localhost") + + tk.MustExec("create user u4@localhost") + tk.MustExec("GRANT CUSTOM_DYN_PRIV_1, RESTRICTED_CUSTOM_DYN_PRIV_2 on *.* TO u4@localhost") + + tk1 := testkit.NewTestKit(t, store) + + // root has all privileges by default + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil)) + tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz")) + tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) + tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) + tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi")) + + // u1 in non-sem + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil)) + tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz")) + require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") + tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) + require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") + + // u2 in non-sem + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "localhost"}, nil, nil)) + tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz")) + tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) + tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) + tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi")) + + // u3 in non-sem + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u3", Hostname: "localhost"}, nil, nil)) + tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz")) + require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") + tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) + require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") + + // u4 in non-sem + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u4", Hostname: "localhost"}, nil, nil)) + tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz")) + tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) + tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) + tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi")) + + sem.Enable() + + // root in sem + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil)) + tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz")) + tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) + require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") + + // u1 in sem + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil)) + tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz")) + require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the CUSTOM_DYN_PRIV_1 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") + + // u2 in sem + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "localhost"}, nil, nil)) + tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz")) + tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) + require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") + + // u3 in sem + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u3", Hostname: "localhost"}, nil, nil)) + tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz")) + require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the CUSTOM_DYN_PRIV_1 privilege(s) for this operation") + tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) + tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi")) + + // u4 in sem + require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u4", Hostname: "localhost"}, nil, nil)) + tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz")) + tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) + tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) + tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi")) +} diff --git a/extension/manifest.go b/extension/manifest.go index ef7ea35104cbb..c01e95aef24b8 100644 --- a/extension/manifest.go +++ b/extension/manifest.go @@ -39,6 +39,13 @@ func WithCustomDynPrivs(privs []string) Option { } } +// WithCustomFunctions specifies custom functions +func WithCustomFunctions(funcs []*FunctionDef) Option { + return func(m *Manifest) { + m.funcs = funcs + } +} + // WithClose specifies the close function of an extension. // It will be invoked when `extension.Reset` is called func WithClose(fn func()) Option { @@ -79,6 +86,7 @@ type Manifest struct { sysVariables []*variable.SysVar dynPrivs []string bootstrap func(BootstrapContext) error + funcs []*FunctionDef close func() } @@ -156,6 +164,25 @@ func newManifestWithSetup(name string, factory func() ([]Option, error)) (_ *Man return nil, nil, err } } + + // setup functions + for i := range m.funcs { + def := m.funcs[i] + err = clearBuilder.DoWithCollectClear(func() (func(), error) { + if err := RegisterExtensionFunc(def); err != nil { + return nil, err + } + + return func() { + RemoveExtensionFunc(def.Name) + }, nil + }) + + if err != nil { + return nil, nil, err + } + } + return m, clearBuilder.Build(), nil }