diff --git a/pkg/bootstrap/versions/upgrade_strategy.go b/pkg/bootstrap/versions/upgrade_strategy.go index da647885e9999..be3ecb8a54895 100644 --- a/pkg/bootstrap/versions/upgrade_strategy.go +++ b/pkg/bootstrap/versions/upgrade_strategy.go @@ -25,6 +25,16 @@ import ( "github.com/matrixorigin/matrixone/pkg/util/executor" ) +const ( + // system tenant default role ID and user ID + sysAdminRoleID = 0 + sysRootID = 0 + + // general tenant default role ID and user ID + accountAdminRoleID = 2 + accountAdminUserID = 2 +) + const ( T_any = "ANY" T_bool = "BOOL" @@ -131,6 +141,13 @@ type UpgradeEntry struct { // Upgrade entity execution upgrade entrance func (u *UpgradeEntry) Upgrade(txn executor.TxnExecutor, accountId uint32) error { + userId := uint32(sysRootID) + roleId := uint32(sysAdminRoleID) + if accountId != catalog.System_Account { + userId = accountAdminUserID + roleId = accountAdminRoleID + } + exist, err := u.CheckFunc(txn, accountId) if err != nil { getLogger(txn.Txn().TxnOptions().CN).Error("execute upgrade entry check error", zap.Error(err), zap.String("upgrade entry", u.String())) @@ -142,7 +159,7 @@ func (u *UpgradeEntry) Upgrade(txn executor.TxnExecutor, accountId uint32) error } else { // 1. First, judge whether there is prefix sql if u.PreSql != "" { - res, err := txn.Exec(u.PreSql, executor.StatementOption{}.WithAccountID(accountId)) + res, err := txn.Exec(u.PreSql, executor.StatementOption{}.WithAccountID(accountId).WithUserID(userId).WithRoleID(roleId)) if err != nil { getLogger(txn.Txn().TxnOptions().CN).Error("execute upgrade entry pre-sql error", zap.Error(err), zap.String("upgrade entry", u.String())) return err @@ -151,7 +168,7 @@ func (u *UpgradeEntry) Upgrade(txn executor.TxnExecutor, accountId uint32) error } // 2. Second, Execute upgrade sql - res, err := txn.Exec(u.UpgSql, executor.StatementOption{}.WithAccountID(accountId)) + res, err := txn.Exec(u.UpgSql, executor.StatementOption{}.WithAccountID(accountId).WithUserID(userId).WithRoleID(roleId)) if err != nil { getLogger(txn.Txn().TxnOptions().CN).Error("execute upgrade entry sql error", zap.Error(err), zap.String("upgrade entry", u.String())) return err @@ -160,7 +177,7 @@ func (u *UpgradeEntry) Upgrade(txn executor.TxnExecutor, accountId uint32) error // 2. Third, after the upgrade is completed, judge whether there is post-sql if u.PostSql != "" { - res, err = txn.Exec(u.PostSql, executor.StatementOption{}.WithAccountID(accountId)) + res, err = txn.Exec(u.PostSql, executor.StatementOption{}.WithAccountID(accountId).WithUserID(userId).WithRoleID(roleId)) if err != nil { getLogger(txn.Txn().TxnOptions().CN).Error("execute upgrade entry post-sql error", zap.Error(err), zap.String("upgrade entry", u.String())) return err diff --git a/pkg/bootstrap/versions/v2_0_1/upgrade_test.go b/pkg/bootstrap/versions/v2_0_1/upgrade_test.go index c5c1766580fba..5488911b2db85 100644 --- a/pkg/bootstrap/versions/v2_0_1/upgrade_test.go +++ b/pkg/bootstrap/versions/v2_0_1/upgrade_test.go @@ -215,8 +215,8 @@ func Test_HandleTenantUpgrade(t *testing.T) { return executor.Result{}, nil }, txnOperator) - upg_mo_user_add_password_last_changed.Upgrade(executor, uint32(0)) - upg_mo_user_add_lock_time.Upgrade(executor, uint32(0)) + upg_mo_user_add_password_last_changed.Upgrade(executor, uint32(1)) + upg_mo_user_add_lock_time.Upgrade(executor, uint32(1)) }, ) runtime.RunTest( diff --git a/pkg/sql/compile/sql_executor.go b/pkg/sql/compile/sql_executor.go index c88bf3f483fe8..f1867fd4590e3 100644 --- a/pkg/sql/compile/sql_executor.go +++ b/pkg/sql/compile/sql_executor.go @@ -22,7 +22,6 @@ import ( "go.uber.org/zap" - "github.com/matrixorigin/matrixone/pkg/catalog" "github.com/matrixorigin/matrixone/pkg/common/buffer" "github.com/matrixorigin/matrixone/pkg/common/moerr" "github.com/matrixorigin/matrixone/pkg/common/mpool" @@ -241,27 +240,26 @@ func (exec *txnExecutor) Exec( sql string, statementOption executor.StatementOption, ) (executor.Result, error) { - - //----------------------------------------------------------------------------------------- // NOTE: This code is to restore tenantID information in the Context when temporarily switching tenants // so that it can be restored to its original state after completing the task. - recoverAccount := func(exec *txnExecutor, accId uint32) { - exec.ctx = context.WithValue(exec.ctx, defines.TenantIDKey{}, accId) - } - + var originCtx context.Context if statementOption.HasAccountID() { - originAccountID := catalog.System_Account - if v := exec.ctx.Value(defines.TenantIDKey{}); v != nil { - originAccountID = v.(uint32) + // save the current context + originCtx = exec.ctx + // switch tenantID + exec.ctx = context.WithValue(exec.ctx, defines.TenantIDKey{}, statementOption.AccountID()) + if statementOption.HasUserID() { + exec.ctx = context.WithValue(exec.ctx, defines.UserIDKey{}, statementOption.UserID()) } - - exec.ctx = context.WithValue(exec.ctx, - defines.TenantIDKey{}, - statementOption.AccountID()) - // NOTE: Restore AccountID information in context.Context - defer recoverAccount(exec, originAccountID) + if statementOption.HasRoleID() { + exec.ctx = context.WithValue(exec.ctx, defines.RoleIDKey{}, statementOption.RoleID()) + } + defer func() { + // restore context at the end of the function + exec.ctx = originCtx + }() } - //----------------------------------------------------------------------------------------- + //------------------------------------------------------------------------------------------------------------------ if statementOption.IgnoreForeignKey() { exec.ctx = context.WithValue(exec.ctx, defines.IgnoreForeignKey{}, diff --git a/pkg/sql/compile/sql_executor_context_test.go b/pkg/sql/compile/sql_executor_context_test.go index 44096b6dade29..1ac3dd0ab4254 100644 --- a/pkg/sql/compile/sql_executor_context_test.go +++ b/pkg/sql/compile/sql_executor_context_test.go @@ -17,12 +17,12 @@ package compile import ( "testing" - "github.com/matrixorigin/matrixone/pkg/common/mpool" - "github.com/matrixorigin/matrixone/pkg/testutil" - "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" + "github.com/matrixorigin/matrixone/pkg/common/mpool" + "github.com/matrixorigin/matrixone/pkg/testutil" + mock_frontend "github.com/matrixorigin/matrixone/pkg/frontend/test" "github.com/matrixorigin/matrixone/pkg/sql/plan" ) diff --git a/pkg/tests/txnexecutor/txn_executor_test.go b/pkg/tests/txnexecutor/txn_executor_test.go new file mode 100644 index 0000000000000..3d9a683119db1 --- /dev/null +++ b/pkg/tests/txnexecutor/txn_executor_test.go @@ -0,0 +1,49 @@ +// Copyright 2024 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package txnexecutor + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/matrixorigin/matrixone/pkg/embed" + "github.com/matrixorigin/matrixone/pkg/tests/testutils" + "github.com/matrixorigin/matrixone/pkg/util/executor" +) + +func Test_TxnExecutorExec(t *testing.T) { + c, err := embed.NewCluster(embed.WithCNCount(1)) + require.NoError(t, err) + require.NoError(t, c.Start()) + + svc, err := c.GetCNService(0) + require.NoError(t, err) + + exec := testutils.GetSQLExecutor(svc) + require.NotNil(t, exec) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + + err = exec.ExecTxn(ctx, func(txn executor.TxnExecutor) error { + _, err = txn.Exec("select count(*) from mo_catalog.mo_tables", executor.StatementOption{}.WithAccountID(1).WithUserID(2).WithRoleID(2)) + require.NoError(t, err) + return nil + }, executor.Options{}.WithWaitCommittedLogApplied()) + require.NoError(t, err) +} diff --git a/pkg/util/executor/options.go b/pkg/util/executor/options.go index 2c601a6a5ec6f..d89ead9c62efe 100644 --- a/pkg/util/executor/options.go +++ b/pkg/util/executor/options.go @@ -169,6 +169,32 @@ func (opts StatementOption) HasAccountID() bool { return opts.accountId > 0 } +func (opts StatementOption) WithRoleID(roleID uint32) StatementOption { + opts.roleId = roleID + return opts +} + +func (opts StatementOption) RoleID() uint32 { + return opts.roleId +} + +func (opts StatementOption) HasRoleID() bool { + return opts.roleId > 0 +} + +func (opts StatementOption) WithUserID(userID uint32) StatementOption { + opts.userId = userID + return opts +} + +func (opts StatementOption) UserID() uint32 { + return opts.userId +} + +func (opts StatementOption) HasUserID() bool { + return opts.userId > 0 +} + func (opts StatementOption) WithDisableLog() StatementOption { opts.disableLog = true return opts diff --git a/pkg/util/executor/types.go b/pkg/util/executor/types.go index 09709fb4b77ef..1aca13449ac26 100644 --- a/pkg/util/executor/types.go +++ b/pkg/util/executor/types.go @@ -71,6 +71,8 @@ type Options struct { type StatementOption struct { waitPolicy lock.WaitPolicy accountId uint32 + roleId uint32 + userId uint32 disableLog bool ignoreForeignKey bool }