Skip to content

Commit

Permalink
extension: provide more informations in extension.FunctionContext (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Nov 10, 2022
1 parent a20b70f commit f51227c
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 11 deletions.
18 changes: 18 additions & 0 deletions expression/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/extension"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/sem"
Expand Down Expand Up @@ -183,6 +185,22 @@ func (b *extensionFuncSig) EvalArgs(row chunk.Row) ([]types.Datum, error) {
return result, nil
}

func (b *extensionFuncSig) ConnectionInfo() *variable.ConnectionInfo {
return b.ctx.GetSessionVars().ConnectionInfo
}

func (b *extensionFuncSig) User() *auth.UserIdentity {
return b.ctx.GetSessionVars().User
}

func (b *extensionFuncSig) ActiveRoles() []*auth.RoleIdentity {
return b.ctx.GetSessionVars().ActiveRoles
}

func (b *extensionFuncSig) CurrentDB() string {
return b.ctx.GetSessionVars().CurrentDB
}

func init() {
extension.RegisterExtensionFunc = registerExtensionFunc
extension.RemoveExtensionFunc = removeExtensionFunc
Expand Down
8 changes: 7 additions & 1 deletion extension/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ import (
"context"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
)

// FunctionContext is a interface to provide context to the custom function
// FunctionContext is an interface to provide context to the custom function
type FunctionContext interface {
context.Context
User() *auth.UserIdentity
ActiveRoles() []*auth.RoleIdentity
CurrentDB() string
ConnectionInfo() *variable.ConnectionInfo
EvalArgs(row chunk.Row) ([]types.Datum, error)
}

Expand Down
73 changes: 63 additions & 10 deletions extension/function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/extension"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -68,10 +69,64 @@ var customFunc2 = &extension.FunctionDef{
},
}

func TestInvokeFunc(t *testing.T) {
defer func() {
extension.Reset()
}()
func TestExtensionFuncCtx(t *testing.T) {
defer extension.Reset()
extension.Reset()

invoked := false
var user *auth.UserIdentity
var currentDB string
var activeRoles []*auth.RoleIdentity
var conn *variable.ConnectionInfo

require.NoError(t, extension.Register("test", extension.WithCustomFunctions([]*extension.FunctionDef{
{
Name: "custom_get_ctx",
EvalTp: types.ETString,
EvalStringFunc: func(ctx extension.FunctionContext, row chunk.Row) (string, bool, error) {
require.False(t, invoked)
invoked = true
user = ctx.User()
currentDB = ctx.CurrentDB()
activeRoles = ctx.ActiveRoles()
conn = ctx.ConnectionInfo()
return "done", false, nil
},
},
})))

store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("create user u1@localhost")
tk.MustExec("create role r1")
tk.MustExec("grant r1 to u1@localhost")
tk.MustExec("grant ALL ON test.* to u1@localhost")

tk1 := testkit.NewTestKit(t, store)
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil))
tk1.MustExec("set role r1")
tk1.MustExec("use test")
tk1.Session().GetSessionVars().ConnectionInfo = &variable.ConnectionInfo{
ConnectionID: 12345,
User: "u1",
}

tk1.MustQuery("select custom_get_ctx()").Check(testkit.Rows("done"))

require.True(t, invoked)
require.NotNil(t, user)
require.Equal(t, *tk1.Session().GetSessionVars().User, *user)
require.Equal(t, "test", currentDB)
require.NotNil(t, conn)
require.Equal(t, *tk1.Session().GetSessionVars().ConnectionInfo, *conn)
require.Equal(t, 1, len(activeRoles))
require.Equal(t, auth.RoleIdentity{Username: "r1", Hostname: "%"}, *activeRoles[0])
}

func TestInvokeExtensionFunc(t *testing.T) {
defer extension.Reset()
extension.Reset()

extension.Reset()
orgFuncList := expression.GetBuiltinList()
Expand Down Expand Up @@ -99,7 +154,7 @@ func TestInvokeFunc(t *testing.T) {
require.EqualError(t, tk2.ExecToErr("select custom_func2(1, 2)"), "[expression:1305]FUNCTION test.custom_func2 does not exist")
}

func TestFuncDynamicArgLen(t *testing.T) {
func TestExtensionFuncDynamicArgLen(t *testing.T) {
defer extension.Reset()
extension.Reset()

Expand Down Expand Up @@ -142,10 +197,8 @@ func TestFuncDynamicArgLen(t *testing.T) {
require.EqualError(t, tk.ExecToErr("select dynamic_arg_func(1, 2)"), expectedErrMsg)
}

func TestRegisterFunc(t *testing.T) {
defer func() {
extension.Reset()
}()
func TestRegisterExtensionFunc(t *testing.T) {
defer extension.Reset()

// nil func
extension.Reset()
Expand Down Expand Up @@ -213,7 +266,7 @@ func checkFuncList(t *testing.T, orgList []string, customFuncs ...string) {
require.Equal(t, checkList, expression.GetBuiltinList())
}

func TestFuncPrivilege(t *testing.T) {
func TestExtensionFuncPrivilege(t *testing.T) {
defer func() {
extension.Reset()
sem.Disable()
Expand Down

0 comments on commit f51227c

Please sign in to comment.