From a7cf662124e4d7158e0ea4002753dd4f9fe13f19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Mon, 17 Jun 2024 15:28:15 +0800 Subject: [PATCH 1/2] This is an automated cherry-pick of #54032 Signed-off-by: ti-chi-bot --- pkg/server/tests/servertestkit/testkit.go | 170 ++++++++++++++++++++++ privilege/privileges/privileges.go | 35 +++-- server/conn.go | 13 +- server/mock_conn.go | 10 ++ server/tidb_test.go | 73 ++++++++++ 5 files changed, 289 insertions(+), 12 deletions(-) create mode 100644 pkg/server/tests/servertestkit/testkit.go diff --git a/pkg/server/tests/servertestkit/testkit.go b/pkg/server/tests/servertestkit/testkit.go new file mode 100644 index 0000000000000..d1efdf4cd37f3 --- /dev/null +++ b/pkg/server/tests/servertestkit/testkit.go @@ -0,0 +1,170 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package servertestkit + +import ( + "context" + "database/sql" + "sync" + "testing" + + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" + srv "github.com/pingcap/tidb/pkg/server" + "github.com/pingcap/tidb/pkg/server/internal/testserverclient" + "github.com/pingcap/tidb/pkg/server/internal/testutil" + "github.com/pingcap/tidb/pkg/server/internal/util" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/store/mockstore" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/util/cpuprofile" + "github.com/pingcap/tidb/pkg/util/topsql/collector/mock" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/stretchr/testify/require" + "go.opencensus.io/stats/view" +) + +// TidbTestSuite is a test suite for tidb +type TidbTestSuite struct { + *testserverclient.TestServerClient + Tidbdrv *srv.TiDBDriver + Server *srv.Server + Domain *domain.Domain + Store kv.Storage +} + +// CreateTidbTestSuite creates a test suite for tidb +func CreateTidbTestSuite(t *testing.T) *TidbTestSuite { + cfg := util.NewTestConfig() + cfg.Port = 0 + cfg.Status.ReportStatus = true + cfg.Status.StatusPort = 0 + cfg.Status.RecordDBLabel = true + cfg.Performance.TCPKeepAlive = true + return CreateTidbTestSuiteWithCfg(t, cfg) +} + +// CreateTidbTestSuiteWithCfg creates a test suite for tidb with config +func CreateTidbTestSuiteWithCfg(t *testing.T, cfg *config.Config) *TidbTestSuite { + ts := &TidbTestSuite{TestServerClient: testserverclient.NewTestServerClient()} + + // setup tidbTestSuite + var err error + ts.Store, err = mockstore.NewMockStore() + session.DisableStats4Test() + require.NoError(t, err) + ts.Domain, err = session.BootstrapSession(ts.Store) + require.NoError(t, err) + ts.Tidbdrv = srv.NewTiDBDriver(ts.Store) + + srv.RunInGoTestChan = make(chan struct{}) + server, err := srv.NewServer(cfg, ts.Tidbdrv) + require.NoError(t, err) + + ts.Server = server + ts.Server.SetDomain(ts.Domain) + ts.Domain.InfoSyncer().SetSessionManager(ts.Server) + go func() { + err := ts.Server.Run(nil) + require.NoError(t, err) + }() + <-srv.RunInGoTestChan + ts.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) + ts.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) + ts.WaitUntilServerOnline() + + t.Cleanup(func() { + if ts.Domain != nil { + ts.Domain.Close() + } + if ts.Server != nil { + ts.Server.Close() + } + if ts.Store != nil { + require.NoError(t, ts.Store.Close()) + } + view.Stop() + }) + return ts +} + +type tidbTestTopSQLSuite struct { + *TidbTestSuite +} + +// CreateTidbTestTopSQLSuite creates a test suite for top-sql test. +func CreateTidbTestTopSQLSuite(t *testing.T) *tidbTestTopSQLSuite { + base := CreateTidbTestSuite(t) + + ts := &tidbTestTopSQLSuite{base} + + // Initialize global variable for top-sql test. + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + dbt := testkit.NewDBTestKit(t, db) + topsqlstate.GlobalState.PrecisionSeconds.Store(1) + topsqlstate.GlobalState.ReportIntervalSeconds.Store(2) + dbt.MustExec("set @@global.tidb_top_sql_max_time_series_count=5;") + + require.NoError(t, cpuprofile.StartCPUProfiler()) + t.Cleanup(func() { + cpuprofile.StopCPUProfiler() + topsqlstate.GlobalState.PrecisionSeconds.Store(topsqlstate.DefTiDBTopSQLPrecisionSeconds) + topsqlstate.GlobalState.ReportIntervalSeconds.Store(topsqlstate.DefTiDBTopSQLReportIntervalSeconds) + view.Stop() + }) + return ts +} + +// TestCase is to run the test case for top-sql test. +func (ts *tidbTestTopSQLSuite) TestCase(t *testing.T, mc *mock.TopSQLCollector, execFn func(db *sql.DB), checkFn func()) { + var wg sync.WaitGroup + ctx, cancel := context.WithCancel(context.Background()) + wg.Add(1) + go func() { + defer wg.Done() + ts.loopExec(ctx, t, execFn) + }() + + checkFn() + cancel() + wg.Wait() + mc.Reset() +} + +func (ts *tidbTestTopSQLSuite) loopExec(ctx context.Context, t *testing.T, fn func(db *sql.DB)) { + db, err := sql.Open("mysql", ts.GetDSN()) + require.NoError(t, err, "Error connecting") + defer func() { + err := db.Close() + require.NoError(t, err) + }() + dbt := testkit.NewDBTestKit(t, db) + dbt.MustExec("use topsql;") + for { + select { + case <-ctx.Done(): + return + default: + } + fn(db) + } +} diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 544a0e21ac259..63827dae590e4 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -542,6 +542,28 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse logutil.BgLogger().Error("check claims failed", zap.Error(err)) return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) } +<<<<<<< HEAD:privilege/privileges/privileges.go +======= + } else if record.AuthPlugin == mysql.AuthLDAPSASL { + if err = ldap.LDAPSASLAuthImpl.AuthLDAPSASL(authUser, pwd, authentication, authConn); err != nil { + // though the pwd stores only `dn` for LDAP SASL, it could be unsafe to print it out. + // for example, someone may alter the auth plugin name but forgot to change the password... + logutil.BgLogger().Warn("verify through LDAP SASL failed", zap.String("username", user.Username), zap.Error(err)) + return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) + } + } else if record.AuthPlugin == mysql.AuthLDAPSimple { + if err = ldap.LDAPSimpleAuthImpl.AuthLDAPSimple(authUser, pwd, authentication); err != nil { + logutil.BgLogger().Warn("verify through LDAP Simple failed", zap.String("username", user.Username), zap.Error(err)) + return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) + } + } else if record.AuthPlugin == mysql.AuthSocket { + if string(authentication) != authUser && string(authentication) != pwd { + logutil.BgLogger().Error("Failed socket auth", zap.String("authUser", authUser), + zap.String("socket_user", string(authentication)), + zap.String("authentication_string", pwd)) + return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) + } +>>>>>>> 72d22d60fca (privilege: fix `auth_socket` bug, should only allow os user name to login (#54032)):pkg/privilege/privileges/privileges.go } else if len(pwd) > 0 && len(authentication) > 0 { switch record.AuthPlugin { // NOTE: If the checking of the clear-text password fails, please set `info.FailedDueToWrongPassword = true`. @@ -567,22 +589,13 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse info.FailedDueToWrongPassword = true return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) } - case mysql.AuthSocket: - if string(authentication) != authUser && string(authentication) != pwd { - logutil.BgLogger().Error("Failed socket auth", zap.String("authUser", authUser), - zap.String("socket_user", string(authentication)), - zap.String("authentication_string", pwd)) - return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) - } default: logutil.BgLogger().Error("unknown authentication plugin", zap.String("authUser", authUser), zap.String("plugin", record.AuthPlugin)) return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) } } else if len(pwd) > 0 || len(authentication) > 0 { - if record.AuthPlugin != mysql.AuthSocket { - info.FailedDueToWrongPassword = true - return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) - } + info.FailedDueToWrongPassword = true + return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) } // Login a locked account is not allowed. diff --git a/server/conn.go b/server/conn.go index b65538e999f88..d7ee24f532c00 100644 --- a/server/conn.go +++ b/server/conn.go @@ -892,6 +892,9 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e return nil } +// mockOSUserForAuthSocketTest should only be used in test +var mockOSUserForAuthSocketTest atomic.Pointer[string] + // Check if the Authentication Plugin of the server, client and user configuration matches func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeResponse41) ([]byte, error) { // Open a context unless this was done before. @@ -940,7 +943,15 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeRespon if err != nil { return nil, err } - return []byte(user.Username), nil + uname := user.Username + + if intest.InTest { + if p := mockOSUserForAuthSocketTest.Load(); p != nil { + uname = *p + } + } + + return []byte(uname), nil } if len(userplugin) == 0 { // No user plugin set, assuming MySQL Native Password diff --git a/server/mock_conn.go b/server/mock_conn.go index 36ad6503db812..7bd5fa53f11e2 100644 --- a/server/mock_conn.go +++ b/server/mock_conn.go @@ -127,3 +127,13 @@ func CreateMockConn(t *testing.T, server *Server) MockConn { t: t, } } + +// MockOSUserForAuthSocket mocks the OS user for AUTH_SOCKET plugin +func MockOSUserForAuthSocket(uname string) { + mockOSUserForAuthSocketTest.Store(&uname) +} + +// ClearOSUserForAuthSocket clears the mocked OS user for AUTH_SOCKET plugin +func ClearOSUserForAuthSocket() { + mockOSUserForAuthSocketTest.Store(nil) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index 259942d929a52..a7e81db9cdd35 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -34,6 +34,10 @@ import ( "net/http" "os" "path/filepath" +<<<<<<< HEAD:server/tidb_test.go +======= + "runtime" +>>>>>>> 72d22d60fca (privilege: fix `auth_socket` bug, should only allow os user name to login (#54032)):pkg/server/tests/commontest/tidb_test.go "strings" "sync" "sync/atomic" @@ -3229,3 +3233,72 @@ func TestLoadData(t *testing.T) { ts.runTestLoadDataReplace(t) ts.runTestLoadDataReplaceNonclusteredPK(t) } + +func TestAuthSocket(t *testing.T) { + defer server2.ClearOSUserForAuthSocket() + + cfg := util2.NewTestConfig() + cfg.Socket = filepath.Join(t.TempDir(), "authsock.sock") + cfg.Port = 0 + cfg.Status.StatusPort = 0 + ts := servertestkit.CreateTidbTestSuiteWithCfg(t, cfg) + ts.WaitUntilServerCanConnect() + + ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec("CREATE USER 'u1'@'%' IDENTIFIED WITH auth_socket;") + dbt.MustExec("CREATE USER 'u2'@'%' IDENTIFIED WITH auth_socket AS 'sockuser'") + dbt.MustExec("CREATE USER 'sockuser'@'%' IDENTIFIED WITH auth_socket;") + }) + + // network login should be denied + for _, uname := range []string{"u1", "u2", "u3"} { + server2.MockOSUserForAuthSocket(uname) + db, err := sql.Open("mysql", ts.GetDSN(func(config *mysql.Config) { + config.User = uname + })) + require.NoError(t, err) + _, err = db.Conn(context.TODO()) + require.EqualError(t, + err, + fmt.Sprintf("Error 1045 (28000): Access denied for user '%s'@'127.0.0.1' (using password: NO)", uname), + ) + require.NoError(t, db.Close()) + } + + socketAuthConf := func(user string) func(*mysql.Config) { + return func(config *mysql.Config) { + config.User = user + config.Net = "unix" + config.Addr = cfg.Socket + config.DBName = "" + } + } + + server2.MockOSUserForAuthSocket("sockuser") + + // mysql username that is different with the OS user should be rejected. + db, err := sql.Open("mysql", ts.GetDSN(socketAuthConf("u1"))) + require.NoError(t, err) + _, err = db.Conn(context.TODO()) + require.EqualError(t, err, "Error 1045 (28000): Access denied for user 'u1'@'localhost' (using password: YES)") + require.NoError(t, db.Close()) + + // mysql username that is the same with the OS user should be accepted. + ts.RunTests(t, socketAuthConf("sockuser"), func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select current_user();") + ts.CheckRows(t, rows, "sockuser@%") + }) + + // When a user is created with `IDENTIFIED WITH auth_socket AS ...`. + // It should be accepted when username or as string is the same with OS user. + ts.RunTests(t, socketAuthConf("u2"), func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select current_user();") + ts.CheckRows(t, rows, "u2@%") + }) + + server2.MockOSUserForAuthSocket("u2") + ts.RunTests(t, socketAuthConf("u2"), func(dbt *testkit.DBTestKit) { + rows := dbt.MustQuery("select current_user();") + ts.CheckRows(t, rows, "u2@%") + }) +} From efff1d5e24756e180752f53a0d13eeed9466b48d Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Mon, 17 Jun 2024 18:25:15 +0800 Subject: [PATCH 2/2] resolve conflict --- pkg/server/tests/servertestkit/testkit.go | 170 ---------------------- privilege/privileges/privileges.go | 15 -- server/conn.go | 4 +- server/mock_conn.go | 10 -- server/tidb_test.go | 60 ++++---- 5 files changed, 35 insertions(+), 224 deletions(-) delete mode 100644 pkg/server/tests/servertestkit/testkit.go diff --git a/pkg/server/tests/servertestkit/testkit.go b/pkg/server/tests/servertestkit/testkit.go deleted file mode 100644 index d1efdf4cd37f3..0000000000000 --- a/pkg/server/tests/servertestkit/testkit.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package servertestkit - -import ( - "context" - "database/sql" - "sync" - "testing" - - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/kv" - srv "github.com/pingcap/tidb/pkg/server" - "github.com/pingcap/tidb/pkg/server/internal/testserverclient" - "github.com/pingcap/tidb/pkg/server/internal/testutil" - "github.com/pingcap/tidb/pkg/server/internal/util" - "github.com/pingcap/tidb/pkg/session" - "github.com/pingcap/tidb/pkg/store/mockstore" - "github.com/pingcap/tidb/pkg/testkit" - "github.com/pingcap/tidb/pkg/util/cpuprofile" - "github.com/pingcap/tidb/pkg/util/topsql/collector/mock" - topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" - "github.com/stretchr/testify/require" - "go.opencensus.io/stats/view" -) - -// TidbTestSuite is a test suite for tidb -type TidbTestSuite struct { - *testserverclient.TestServerClient - Tidbdrv *srv.TiDBDriver - Server *srv.Server - Domain *domain.Domain - Store kv.Storage -} - -// CreateTidbTestSuite creates a test suite for tidb -func CreateTidbTestSuite(t *testing.T) *TidbTestSuite { - cfg := util.NewTestConfig() - cfg.Port = 0 - cfg.Status.ReportStatus = true - cfg.Status.StatusPort = 0 - cfg.Status.RecordDBLabel = true - cfg.Performance.TCPKeepAlive = true - return CreateTidbTestSuiteWithCfg(t, cfg) -} - -// CreateTidbTestSuiteWithCfg creates a test suite for tidb with config -func CreateTidbTestSuiteWithCfg(t *testing.T, cfg *config.Config) *TidbTestSuite { - ts := &TidbTestSuite{TestServerClient: testserverclient.NewTestServerClient()} - - // setup tidbTestSuite - var err error - ts.Store, err = mockstore.NewMockStore() - session.DisableStats4Test() - require.NoError(t, err) - ts.Domain, err = session.BootstrapSession(ts.Store) - require.NoError(t, err) - ts.Tidbdrv = srv.NewTiDBDriver(ts.Store) - - srv.RunInGoTestChan = make(chan struct{}) - server, err := srv.NewServer(cfg, ts.Tidbdrv) - require.NoError(t, err) - - ts.Server = server - ts.Server.SetDomain(ts.Domain) - ts.Domain.InfoSyncer().SetSessionManager(ts.Server) - go func() { - err := ts.Server.Run(nil) - require.NoError(t, err) - }() - <-srv.RunInGoTestChan - ts.Port = testutil.GetPortFromTCPAddr(server.ListenAddr()) - ts.StatusPort = testutil.GetPortFromTCPAddr(server.StatusListenerAddr()) - ts.WaitUntilServerOnline() - - t.Cleanup(func() { - if ts.Domain != nil { - ts.Domain.Close() - } - if ts.Server != nil { - ts.Server.Close() - } - if ts.Store != nil { - require.NoError(t, ts.Store.Close()) - } - view.Stop() - }) - return ts -} - -type tidbTestTopSQLSuite struct { - *TidbTestSuite -} - -// CreateTidbTestTopSQLSuite creates a test suite for top-sql test. -func CreateTidbTestTopSQLSuite(t *testing.T) *tidbTestTopSQLSuite { - base := CreateTidbTestSuite(t) - - ts := &tidbTestTopSQLSuite{base} - - // Initialize global variable for top-sql test. - db, err := sql.Open("mysql", ts.GetDSN()) - require.NoError(t, err) - defer func() { - err := db.Close() - require.NoError(t, err) - }() - - dbt := testkit.NewDBTestKit(t, db) - topsqlstate.GlobalState.PrecisionSeconds.Store(1) - topsqlstate.GlobalState.ReportIntervalSeconds.Store(2) - dbt.MustExec("set @@global.tidb_top_sql_max_time_series_count=5;") - - require.NoError(t, cpuprofile.StartCPUProfiler()) - t.Cleanup(func() { - cpuprofile.StopCPUProfiler() - topsqlstate.GlobalState.PrecisionSeconds.Store(topsqlstate.DefTiDBTopSQLPrecisionSeconds) - topsqlstate.GlobalState.ReportIntervalSeconds.Store(topsqlstate.DefTiDBTopSQLReportIntervalSeconds) - view.Stop() - }) - return ts -} - -// TestCase is to run the test case for top-sql test. -func (ts *tidbTestTopSQLSuite) TestCase(t *testing.T, mc *mock.TopSQLCollector, execFn func(db *sql.DB), checkFn func()) { - var wg sync.WaitGroup - ctx, cancel := context.WithCancel(context.Background()) - wg.Add(1) - go func() { - defer wg.Done() - ts.loopExec(ctx, t, execFn) - }() - - checkFn() - cancel() - wg.Wait() - mc.Reset() -} - -func (ts *tidbTestTopSQLSuite) loopExec(ctx context.Context, t *testing.T, fn func(db *sql.DB)) { - db, err := sql.Open("mysql", ts.GetDSN()) - require.NoError(t, err, "Error connecting") - defer func() { - err := db.Close() - require.NoError(t, err) - }() - dbt := testkit.NewDBTestKit(t, db) - dbt.MustExec("use topsql;") - for { - select { - case <-ctx.Done(): - return - default: - } - fn(db) - } -} diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 63827dae590e4..94ca09159485c 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -542,20 +542,6 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse logutil.BgLogger().Error("check claims failed", zap.Error(err)) return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) } -<<<<<<< HEAD:privilege/privileges/privileges.go -======= - } else if record.AuthPlugin == mysql.AuthLDAPSASL { - if err = ldap.LDAPSASLAuthImpl.AuthLDAPSASL(authUser, pwd, authentication, authConn); err != nil { - // though the pwd stores only `dn` for LDAP SASL, it could be unsafe to print it out. - // for example, someone may alter the auth plugin name but forgot to change the password... - logutil.BgLogger().Warn("verify through LDAP SASL failed", zap.String("username", user.Username), zap.Error(err)) - return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) - } - } else if record.AuthPlugin == mysql.AuthLDAPSimple { - if err = ldap.LDAPSimpleAuthImpl.AuthLDAPSimple(authUser, pwd, authentication); err != nil { - logutil.BgLogger().Warn("verify through LDAP Simple failed", zap.String("username", user.Username), zap.Error(err)) - return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) - } } else if record.AuthPlugin == mysql.AuthSocket { if string(authentication) != authUser && string(authentication) != pwd { logutil.BgLogger().Error("Failed socket auth", zap.String("authUser", authUser), @@ -563,7 +549,6 @@ func (p *UserPrivileges) ConnectionVerification(user *auth.UserIdentity, authUse zap.String("authentication_string", pwd)) return info, ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) } ->>>>>>> 72d22d60fca (privilege: fix `auth_socket` bug, should only allow os user name to login (#54032)):pkg/privilege/privileges/privileges.go } else if len(pwd) > 0 && len(authentication) > 0 { switch record.AuthPlugin { // NOTE: If the checking of the clear-text password fails, please set `info.FailedDueToWrongPassword = true`. diff --git a/server/conn.go b/server/conn.go index d7ee24f532c00..4599c1609d3bb 100644 --- a/server/conn.go +++ b/server/conn.go @@ -945,11 +945,11 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeRespon } uname := user.Username - if intest.InTest { + failpoint.Inject("MockOSUserForAuthSocket", func() { if p := mockOSUserForAuthSocketTest.Load(); p != nil { uname = *p } - } + }) return []byte(uname), nil } diff --git a/server/mock_conn.go b/server/mock_conn.go index 7bd5fa53f11e2..36ad6503db812 100644 --- a/server/mock_conn.go +++ b/server/mock_conn.go @@ -127,13 +127,3 @@ func CreateMockConn(t *testing.T, server *Server) MockConn { t: t, } } - -// MockOSUserForAuthSocket mocks the OS user for AUTH_SOCKET plugin -func MockOSUserForAuthSocket(uname string) { - mockOSUserForAuthSocketTest.Store(&uname) -} - -// ClearOSUserForAuthSocket clears the mocked OS user for AUTH_SOCKET plugin -func ClearOSUserForAuthSocket() { - mockOSUserForAuthSocketTest.Store(nil) -} diff --git a/server/tidb_test.go b/server/tidb_test.go index a7e81db9cdd35..0394a8bdd25f0 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -34,10 +34,6 @@ import ( "net/http" "os" "path/filepath" -<<<<<<< HEAD:server/tidb_test.go -======= - "runtime" ->>>>>>> 72d22d60fca (privilege: fix `auth_socket` bug, should only allow os user name to login (#54032)):pkg/server/tests/commontest/tidb_test.go "strings" "sync" "sync/atomic" @@ -87,6 +83,15 @@ type tidbTestSuite struct { } func createTidbTestSuite(t *testing.T) *tidbTestSuite { + cfg := newTestConfig() + cfg.Port = 0 + cfg.Status.ReportStatus = true + cfg.Status.StatusPort = 0 + cfg.Performance.TCPKeepAlive = true + return createTidbTestSuiteWithCfg(t, cfg) +} + +func createTidbTestSuiteWithCfg(t *testing.T, cfg *config.Config) *tidbTestSuite { ts := &tidbTestSuite{testServerClient: newTestServerClient()} // setup tidbTestSuite @@ -97,11 +102,6 @@ func createTidbTestSuite(t *testing.T) *tidbTestSuite { ts.domain, err = session.BootstrapSession(ts.store) require.NoError(t, err) ts.tidbdrv = NewTiDBDriver(ts.store) - cfg := newTestConfig() - cfg.Port = ts.port - cfg.Status.ReportStatus = true - cfg.Status.StatusPort = ts.statusPort - cfg.Performance.TCPKeepAlive = true RunInGoTestChan = make(chan struct{}) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) @@ -3235,16 +3235,20 @@ func TestLoadData(t *testing.T) { } func TestAuthSocket(t *testing.T) { - defer server2.ClearOSUserForAuthSocket() + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/MockOSUserForAuthSocket", "return(true)")) + defer func() { + mockOSUserForAuthSocketTest.Store(nil) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/MockOSUserForAuthSocket")) + }() - cfg := util2.NewTestConfig() + cfg := newTestConfig() cfg.Socket = filepath.Join(t.TempDir(), "authsock.sock") cfg.Port = 0 cfg.Status.StatusPort = 0 - ts := servertestkit.CreateTidbTestSuiteWithCfg(t, cfg) - ts.WaitUntilServerCanConnect() + ts := createTidbTestSuiteWithCfg(t, cfg) + ts.waitUntilServerCanConnect() - ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) { + ts.runTests(t, nil, func(dbt *testkit.DBTestKit) { dbt.MustExec("CREATE USER 'u1'@'%' IDENTIFIED WITH auth_socket;") dbt.MustExec("CREATE USER 'u2'@'%' IDENTIFIED WITH auth_socket AS 'sockuser'") dbt.MustExec("CREATE USER 'sockuser'@'%' IDENTIFIED WITH auth_socket;") @@ -3252,15 +3256,15 @@ func TestAuthSocket(t *testing.T) { // network login should be denied for _, uname := range []string{"u1", "u2", "u3"} { - server2.MockOSUserForAuthSocket(uname) - db, err := sql.Open("mysql", ts.GetDSN(func(config *mysql.Config) { + mockOSUserForAuthSocketTest.Store(&uname) + db, err := sql.Open("mysql", ts.getDSN(func(config *mysql.Config) { config.User = uname })) require.NoError(t, err) _, err = db.Conn(context.TODO()) require.EqualError(t, err, - fmt.Sprintf("Error 1045 (28000): Access denied for user '%s'@'127.0.0.1' (using password: NO)", uname), + fmt.Sprintf("Error 1045: Access denied for user '%s'@'127.0.0.1' (using password: NO)", uname), ) require.NoError(t, db.Close()) } @@ -3274,31 +3278,33 @@ func TestAuthSocket(t *testing.T) { } } - server2.MockOSUserForAuthSocket("sockuser") + mockOSUser := "sockuser" + mockOSUserForAuthSocketTest.Store(&mockOSUser) // mysql username that is different with the OS user should be rejected. - db, err := sql.Open("mysql", ts.GetDSN(socketAuthConf("u1"))) + db, err := sql.Open("mysql", ts.getDSN(socketAuthConf("u1"))) require.NoError(t, err) _, err = db.Conn(context.TODO()) - require.EqualError(t, err, "Error 1045 (28000): Access denied for user 'u1'@'localhost' (using password: YES)") + require.EqualError(t, err, "Error 1045: Access denied for user 'u1'@'localhost' (using password: YES)") require.NoError(t, db.Close()) // mysql username that is the same with the OS user should be accepted. - ts.RunTests(t, socketAuthConf("sockuser"), func(dbt *testkit.DBTestKit) { + ts.runTests(t, socketAuthConf("sockuser"), func(dbt *testkit.DBTestKit) { rows := dbt.MustQuery("select current_user();") - ts.CheckRows(t, rows, "sockuser@%") + ts.checkRows(t, rows, "sockuser@%") }) // When a user is created with `IDENTIFIED WITH auth_socket AS ...`. // It should be accepted when username or as string is the same with OS user. - ts.RunTests(t, socketAuthConf("u2"), func(dbt *testkit.DBTestKit) { + ts.runTests(t, socketAuthConf("u2"), func(dbt *testkit.DBTestKit) { rows := dbt.MustQuery("select current_user();") - ts.CheckRows(t, rows, "u2@%") + ts.checkRows(t, rows, "u2@%") }) - server2.MockOSUserForAuthSocket("u2") - ts.RunTests(t, socketAuthConf("u2"), func(dbt *testkit.DBTestKit) { + mockOSUser = "u2" + mockOSUserForAuthSocketTest.Store(&mockOSUser) + ts.runTests(t, socketAuthConf("u2"), func(dbt *testkit.DBTestKit) { rows := dbt.MustQuery("select current_user();") - ts.CheckRows(t, rows, "u2@%") + ts.checkRows(t, rows, "u2@%") }) }