From 8d9647dee55c12ef54ddcb8a787c19488a324e19 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Mon, 1 Nov 2021 16:46:51 +0800 Subject: [PATCH] privileges: fix create temporary tables privilege (#29279) --- parser/mysql/privs.go | 2 +- planner/core/common_plans.go | 3 +- planner/core/optimizer.go | 71 +++++++++++ planner/core/planbuilder.go | 13 +- planner/optimize.go | 3 +- privilege/privileges/privileges_test.go | 161 +++++++++++++++++++++++- 6 files changed, 246 insertions(+), 7 deletions(-) diff --git a/parser/mysql/privs.go b/parser/mysql/privs.go index 7ed874f1017b8..93e5db579ad1e 100644 --- a/parser/mysql/privs.go +++ b/parser/mysql/privs.go @@ -255,7 +255,7 @@ const ( CreateRolePriv // DropRolePriv is the privilege to drop a role. DropRolePriv - + // CreateTMPTablePriv is the privilege to create a temporary table. CreateTMPTablePriv LockTablesPriv CreateRoutinePriv diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 5187342cb7fa4..1377c404f85ba 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -365,7 +365,8 @@ func (e *Execute) handleExecuteBuilderOption(sctx sessionctx.Context, func (e *Execute) checkPreparedPriv(ctx context.Context, sctx sessionctx.Context, preparedObj *CachedPrepareStmt, is infoschema.InfoSchema) error { if pm := privilege.GetPrivilegeManager(sctx); pm != nil { - if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, preparedObj.VisitInfos); err != nil { + visitInfo := VisitInfo4PrivCheck(is, preparedObj.PreparedAst.Stmt, preparedObj.VisitInfos) + if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, visitInfo); err != nil { return err } } diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index d9a81baea3c77..7cc352c9b8447 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/lock" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/privilege" @@ -120,6 +121,76 @@ func CheckPrivilege(activeRoles []*auth.RoleIdentity, pm privilege.Manager, vs [ return nil } +// VisitInfo4PrivCheck generates privilege check infos because privilege check of local temporary tables is different +// with normal tables. `CREATE` statement needs `CREATE TEMPORARY TABLE` privilege from the database, and subsequent +// statements do not need any privileges. +func VisitInfo4PrivCheck(is infoschema.InfoSchema, node ast.Node, vs []visitInfo) (privVisitInfo []visitInfo) { + if node == nil { + return vs + } + + switch stmt := node.(type) { + case *ast.CreateTableStmt: + privVisitInfo = make([]visitInfo, 0, len(vs)) + for _, v := range vs { + if v.privilege == mysql.CreatePriv { + if stmt.TemporaryKeyword == ast.TemporaryLocal { + // `CREATE TEMPORARY TABLE` privilege is required from the database, not the table. + newVisitInfo := v + newVisitInfo.privilege = mysql.CreateTMPTablePriv + newVisitInfo.table = "" + privVisitInfo = append(privVisitInfo, newVisitInfo) + } else { + // If both the normal table and temporary table already exist, we need to check the privilege. + privVisitInfo = append(privVisitInfo, v) + } + } else { + // `CREATE TABLE LIKE tmp` or `CREATE TABLE FROM SELECT tmp` in the future. + if needCheckTmpTablePriv(is, v) { + privVisitInfo = append(privVisitInfo, v) + } + } + } + case *ast.DropTableStmt: + // Dropping a local temporary table doesn't need any privileges. + if stmt.IsView { + privVisitInfo = vs + } else { + privVisitInfo = make([]visitInfo, 0, len(vs)) + if stmt.TemporaryKeyword != ast.TemporaryLocal { + for _, v := range vs { + if needCheckTmpTablePriv(is, v) { + privVisitInfo = append(privVisitInfo, v) + } + } + } + } + case *ast.GrantStmt, *ast.DropSequenceStmt, *ast.DropPlacementPolicyStmt: + // Some statements ignore local temporary tables, so they should check the privileges on normal tables. + privVisitInfo = vs + default: + privVisitInfo = make([]visitInfo, 0, len(vs)) + for _, v := range vs { + if needCheckTmpTablePriv(is, v) { + privVisitInfo = append(privVisitInfo, v) + } + } + } + return +} + +func needCheckTmpTablePriv(is infoschema.InfoSchema, v visitInfo) bool { + if v.db != "" && v.table != "" { + // Other statements on local temporary tables except `CREATE` do not check any privileges. + tb, err := is.TableByName(model.NewCIStr(v.db), model.NewCIStr(v.table)) + // If the table doesn't exist, we do not report errors to avoid leaking the existence of the table. + if err == nil && tb.Meta().TempTableType == model.TempTableLocal { + return false + } + } + return true +} + // CheckTableLock checks the table lock. func CheckTableLock(ctx sessionctx.Context, is infoschema.InfoSchema, vs []visitInfo) error { if !config.TableLockEnabled() { diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 4033bc7ff055c..eb437309b8f45 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -3868,8 +3868,15 @@ func (b *PlanBuilder) buildDDL(ctx context.Context, node ast.DDLNode) (Plan, err } } if b.ctx.GetSessionVars().User != nil { - authErr = ErrTableaccessDenied.GenWithStackByArgs("CREATE", b.ctx.GetSessionVars().User.AuthUsername, - b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L) + // This is tricky here: we always need the visitInfo because it's not only used in privilege checks, and we + // must pass the table name. However, the privilege check is towards the database. We'll deal with it later. + if v.TemporaryKeyword == ast.TemporaryLocal { + authErr = ErrDBaccessDenied.GenWithStackByArgs(b.ctx.GetSessionVars().User.AuthUsername, + b.ctx.GetSessionVars().User.AuthHostname, v.Table.Schema.L) + } else { + authErr = ErrTableaccessDenied.GenWithStackByArgs("CREATE", b.ctx.GetSessionVars().User.AuthUsername, + b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L) + } } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreatePriv, v.Table.Schema.L, v.Table.Name.L, "", authErr) @@ -3936,7 +3943,7 @@ func (b *PlanBuilder) buildDDL(ctx context.Context, node ast.DDLNode) (Plan, err "", "", authErr) case *ast.DropIndexStmt: if b.ctx.GetSessionVars().User != nil { - authErr = ErrTableaccessDenied.GenWithStackByArgs("INDEx", b.ctx.GetSessionVars().User.AuthUsername, + authErr = ErrTableaccessDenied.GenWithStackByArgs("INDEX", b.ctx.GetSessionVars().User.AuthUsername, b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L) } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.IndexPriv, v.Table.Schema.L, diff --git a/planner/optimize.go b/planner/optimize.go index 67cde3f459b56..e8536e8dbed34 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -345,7 +345,8 @@ func optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in // we need the table information to check privilege, which is collected // into the visitInfo in the logical plan builder. if pm := privilege.GetPrivilegeManager(sctx); pm != nil { - if err := plannercore.CheckPrivilege(activeRoles, pm, builder.GetVisitInfo()); err != nil { + visitInfo := plannercore.VisitInfo4PrivCheck(is, node, builder.GetVisitInfo()) + if err := plannercore.CheckPrivilege(activeRoles, pm, visitInfo); err != nil { return nil, nil, 0, err } } diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 66da786ec9d29..a85a33e41eda9 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -26,6 +26,7 @@ import ( "strings" "testing" + "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/auth" @@ -2666,7 +2667,7 @@ func TestGrantCreateTmpTables(t *testing.T) { tk.MustExec("CREATE TABLE create_tmp_table_table (a int)") tk.MustExec("GRANT CREATE TEMPORARY TABLES on create_tmp_table_db.* to u1") tk.MustExec("GRANT CREATE TEMPORARY TABLES on *.* to u1") - // Must set a session user to avoid null pointer dereferencing + // Must set a session user to avoid null pointer dereference tk.Session().Auth(&auth.UserIdentity{ Username: "root", Hostname: "localhost", @@ -2678,6 +2679,164 @@ func TestGrantCreateTmpTables(t *testing.T) { tk.MustExec("DROP DATABASE create_tmp_table_db") } +func TestCreateTmpTablesPriv(t *testing.T) { + t.Parallel() + store, clean := newStore(t) + defer clean() + + createStmt := "CREATE TEMPORARY TABLE test.tmp(id int)" + dropStmt := "DROP TEMPORARY TABLE IF EXISTS test.tmp" + + tk := testkit.NewTestKit(t, store) + tk.MustExec(dropStmt) + tk.MustExec("CREATE TABLE test.t(id int primary key)") + tk.MustExec("CREATE SEQUENCE test.tmp") + tk.MustExec("CREATE USER vcreate, vcreate_tmp, vcreate_tmp_all") + tk.MustExec("GRANT CREATE, USAGE ON test.* TO vcreate") + tk.MustExec("GRANT CREATE TEMPORARY TABLES, USAGE ON test.* TO vcreate_tmp") + tk.MustExec("GRANT CREATE TEMPORARY TABLES, USAGE ON *.* TO vcreate_tmp_all") + + tk.Session().Auth(&auth.UserIdentity{Username: "vcreate", Hostname: "localhost"}, nil, nil) + err := tk.ExecToErr(createStmt) + require.EqualError(t, err, "[planner:1044]Access denied for user 'vcreate'@'%' to database 'test'") + tk.Session().Auth(&auth.UserIdentity{Username: "vcreate_tmp", Hostname: "localhost"}, nil, nil) + tk.MustExec(createStmt) + tk.MustExec(dropStmt) + tk.Session().Auth(&auth.UserIdentity{Username: "vcreate_tmp_all", Hostname: "localhost"}, nil, nil) + // TODO: issue #29280 to be fixed. + //err = tk.ExecToErr(createStmt) + //require.EqualError(t, err, "[planner:1044]Access denied for user 'vcreate_tmp_all'@'%' to database 'test'") + + tests := []struct { + sql string + errcode int + }{ + { + sql: "create temporary table tmp(id int primary key)", + }, + { + sql: "insert into tmp value(1)", + }, + { + sql: "insert into tmp value(1) on duplicate key update id=1", + }, + { + sql: "replace tmp values(1)", + }, + { + sql: "insert into tmp select * from t", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "update tmp set id=1 where id=1", + }, + { + sql: "update tmp t1, t t2 set t1.id=t2.id where t1.id=t2.id", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "delete from tmp where id=1", + }, + { + sql: "delete t1 from tmp t1 join t t2 where t1.id=t2.id", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "select * from tmp where id=1", + }, + { + sql: "select * from tmp where id in (1,2)", + }, + { + sql: "select * from tmp", + }, + { + sql: "select * from tmp join t where tmp.id=t.id", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "(select * from tmp) union (select * from t)", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "create temporary table tmp1 like t", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "create table tmp(id int primary key)", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "create table t(id int primary key)", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "analyze table tmp", + }, + { + sql: "analyze table tmp, t", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "show create table tmp", + }, + // TODO: issue #29281 to be fixed. + //{ + // sql: "show create table t", + // errcode: mysql.ErrTableaccessDenied, + //}, + { + sql: "drop sequence tmp", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "alter table tmp add column c1 char(10)", + errcode: errno.ErrUnsupportedDDLOperation, + }, + { + sql: "truncate table tmp", + }, + { + sql: "drop temporary table t", + errcode: mysql.ErrBadTable, + }, + { + sql: "drop table t", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "drop table t, tmp", + errcode: mysql.ErrTableaccessDenied, + }, + { + sql: "drop temporary table tmp", + }, + } + + tk.Session().Auth(&auth.UserIdentity{Username: "vcreate_tmp", Hostname: "localhost"}, nil, nil) + tk.MustExec("use test") + tk.MustExec(dropStmt) + for _, test := range tests { + if test.errcode == 0 { + tk.MustExec(test.sql) + } else { + tk.MustGetErrCode(test.sql, test.errcode) + } + } + + // TODO: issue #29282 to be fixed. + //for i, test := range tests { + // preparedStmt := fmt.Sprintf("prepare stmt%d from '%s'", i, test.sql) + // executeStmt := fmt.Sprintf("execute stmt%d", i) + // tk.MustExec(preparedStmt) + // if test.errcode == 0 { + // tk.MustExec(executeStmt) + // } else { + // tk.MustGetErrCode(executeStmt, test.errcode) + // } + //} +} + func TestRevokeSecondSyntax(t *testing.T) { t.Parallel() store, clean := newStore(t)