diff --git a/domain/domain_test.go b/domain/domain_test.go index fa0ab4cc4768b..fca22f7bda826 100644 --- a/domain/domain_test.go +++ b/domain/domain_test.go @@ -51,6 +51,179 @@ func sysMockFactory(dom *Domain) (pools.Resource, error) { return nil, nil } +<<<<<<< HEAD +======= +type mockEtcdBackend struct { + kv.Storage + pdAddrs []string +} + +func (mebd *mockEtcdBackend) EtcdAddrs() []string { + return mebd.pdAddrs +} +func (mebd *mockEtcdBackend) TLSConfig() *tls.Config { return nil } +func (mebd *mockEtcdBackend) StartGCWorker() error { + panic("not implemented") +} + +// ETCD use ip:port as unix socket address, however this address is invalid on windows. +// We have to skip some of the test in such case. +// https://github.com/etcd-io/etcd/blob/f0faa5501d936cd8c9f561bb9d1baca70eb67ab1/pkg/types/urls.go#L42 +func unixSocketAvailable() bool { + c, err := net.Listen("unix", "127.0.0.1:0") + if err == nil { + c.Close() + return true + } + return false +} + +func TestInfo(t *testing.T) { + if !unixSocketAvailable() { + return + } + defer testleak.AfterTestT(t)() + ddlLease := 80 * time.Millisecond + s, err := mockstore.NewMockTikvStore() + if err != nil { + t.Fatal(err) + } + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1}) + defer clus.Terminate(t) + mockStore := &mockEtcdBackend{ + Storage: s, + pdAddrs: []string{clus.Members[0].GRPCAddr()}} + dom := NewDomain(mockStore, ddlLease, 0, mockFactory) + defer func() { + dom.Close() + s.Close() + }() + + cli := clus.RandClient() + dom.etcdClient = cli + // Mock new DDL and init the schema syncer with etcd client. + goCtx := context.Background() + dom.ddl = ddl.NewDDL( + goCtx, + ddl.WithEtcdClient(dom.GetEtcdClient()), + ddl.WithStore(s), + ddl.WithInfoHandle(dom.infoHandle), + ddl.WithLease(ddlLease), + ) + err = failpoint.Enable("github.com/pingcap/tidb/domain/MockReplaceDDL", `return(true)`) + if err != nil { + t.Fatal(err) + } + err = dom.Init(ddlLease, sysMockFactory) + if err != nil { + t.Fatal(err) + } + err = failpoint.Disable("github.com/pingcap/tidb/domain/MockReplaceDDL") + if err != nil { + t.Fatal(err) + } + + // Test for GetServerInfo and GetServerInfoByID. + ddlID := dom.ddl.GetID() + serverInfo, err := infosync.GetServerInfo() + if err != nil { + t.Fatal(err) + } + info, err := infosync.GetServerInfoByID(goCtx, ddlID) + if err != nil { + t.Fatal(err) + } + if serverInfo.ID != info.ID { + t.Fatalf("server self info %v, info %v", serverInfo, info) + } + _, err = infosync.GetServerInfoByID(goCtx, "not_exist_id") + if err == nil || (err != nil && err.Error() != "[info-syncer] get /tidb/server/info/not_exist_id failed") { + t.Fatal(err) + } + + // Test for GetAllServerInfo. + infos, err := infosync.GetAllServerInfo(goCtx) + if err != nil { + t.Fatal(err) + } + if len(infos) != 1 || infos[ddlID].ID != info.ID { + t.Fatalf("server one info %v, info %v", infos[ddlID], info) + } + + // Test the scene where syncer.Done() gets the information. + err = failpoint.Enable("github.com/pingcap/tidb/ddl/util/ErrorMockSessionDone", `return(true)`) + if err != nil { + t.Fatal(err) + } + <-dom.ddl.SchemaSyncer().Done() + err = failpoint.Disable("github.com/pingcap/tidb/ddl/util/ErrorMockSessionDone") + if err != nil { + t.Fatal(err) + } + time.Sleep(15 * time.Millisecond) + syncerStarted := false + for i := 0; i < 200; i++ { + if dom.SchemaValidator.IsStarted() { + syncerStarted = true + break + } + time.Sleep(5 * time.Millisecond) + } + if !syncerStarted { + t.Fatal("start syncer failed") + } + // Make sure loading schema is normal. + cs := &ast.CharsetOpt{ + Chs: "utf8", + Col: "utf8_bin", + } + ctx := mock.NewContext() + err = dom.ddl.CreateSchema(ctx, model.NewCIStr("aaa"), cs) + if err != nil { + t.Fatal(err) + } + err = dom.Reload() + if err != nil { + t.Fatal(err) + } + if dom.InfoSchema().SchemaMetaVersion() != 1 { + t.Fatalf("update schema version failed, ver %d", dom.InfoSchema().SchemaMetaVersion()) + } + + // Test for RemoveServerInfo. + dom.info.RemoveServerInfo() + infos, err = infosync.GetAllServerInfo(goCtx) + if err != nil || len(infos) != 0 { + t.Fatalf("err %v, infos %v", err, infos) + } +} + +type mockSessionManager struct { + PS []*util.ProcessInfo +} + +func (msm *mockSessionManager) ShowProcessList() map[uint64]*util.ProcessInfo { + ret := make(map[uint64]*util.ProcessInfo) + for _, item := range msm.PS { + ret[item.ID] = item + } + return ret +} + +func (msm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) { + for _, item := range msm.PS { + if item.ID == id { + return item, true + } + } + return &util.ProcessInfo{}, false +} + +func (msm *mockSessionManager) Kill(cid uint64, query bool) {} + +func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {} + +>>>>>>> 5c68d53... *: support reload tls used by mysql protocol in place (#14749) func (*testSuite) TestT(c *C) { defer testleak.AfterTest(c)() store, err := mockstore.NewMockTikvStore() diff --git a/executor/executor_pkg_test.go b/executor/executor_pkg_test.go index cf605132e92bd..9b1096619a79a 100644 --- a/executor/executor_pkg_test.go +++ b/executor/executor_pkg_test.go @@ -15,6 +15,7 @@ package executor import ( "context" + "crypto/tls" . "github.com/pingcap/check" "github.com/pingcap/parser/ast" @@ -62,6 +63,9 @@ func (msm *mockSessionManager) Kill(cid uint64, query bool) { } +func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) { +} + func (s *testExecSuite) TestShowProcessList(c *C) { // Compose schema. names := []string{"Id", "User", "Host", "db", "Command", "Time", "State", "Info"} diff --git a/executor/explainfor_test.go b/executor/explainfor_test.go index 632874180495b..7bc150a2b2bd9 100644 --- a/executor/explainfor_test.go +++ b/executor/explainfor_test.go @@ -14,6 +14,7 @@ package executor_test import ( + "crypto/tls" "fmt" . "github.com/pingcap/check" @@ -51,6 +52,9 @@ func (msm *mockSessionManager1) Kill(cid uint64, query bool) { } +func (msm *mockSessionManager1) UpdateTLSConfig(cfg *tls.Config) { +} + func (s *testSuite) TestExplainFor(c *C) { tkRoot := testkit.NewTestKitWithInit(c, s.store) tkUser := testkit.NewTestKitWithInit(c, s.store) diff --git a/executor/simple.go b/executor/simple.go index a62b02430c4e9..58ab74b89fb03 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" @@ -108,6 +109,8 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { err = e.executeUse(x) case *ast.FlushStmt: err = e.executeFlush(x) + case *ast.AlterInstanceStmt: + err = e.executeAlterInstance(x) case *ast.BeginStmt: err = e.executeBegin(ctx, x) case *ast.CommitStmt: @@ -1093,6 +1096,26 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error { return nil } +func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error { + if s.ReloadTLS { + logutil.BgLogger().Info("execute reload tls", zap.Bool("NoRollbackOnError", s.NoRollbackOnError)) + sm := e.ctx.GetSessionManager() + tlsCfg, err := util.LoadTLSCertificates( + variable.SysVars["ssl_ca"].Value, + variable.SysVars["ssl_key"].Value, + variable.SysVars["ssl_cert"].Value, + ) + if err != nil { + if !s.NoRollbackOnError { + return err + } + logutil.BgLogger().Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'") + } + sm.UpdateTLSConfig(tlsCfg) + } + return nil +} + func (e *SimpleExec) executeDropStats(s *ast.DropStatsStmt) error { h := domain.GetDomain(e.ctx).StatsHandle() err := h.DeleteTableStatsFromKV(s.Table.TableInfo.ID) diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index e201d46783d26..e59dacf22a78b 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -321,6 +321,8 @@ func (sm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool func (sm *mockSessionManager) Kill(connectionID uint64, query bool) {} +func (sm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {} + func (s *testTableSuite) TestSomeTables(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index bde5ea183e54e..bfc99b6b32931 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -282,7 +282,7 @@ func (b *PlanBuilder) Build(ctx context.Context, node ast.Node) (Plan, error) { case *ast.AnalyzeTableStmt: return b.buildAnalyze(x) case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt, - *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, + *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.AlterInstanceStmt, *ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt, *ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt, *ast.ShutdownStmt: return b.buildSimple(node.(ast.StmtNode)) @@ -1385,6 +1385,15 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) { p := &Simple{Statement: node} switch raw := node.(type) { +<<<<<<< HEAD +======= + case *ast.FlushStmt: + err := ErrSpecificAccessDenied.GenWithStackByArgs("RELOAD") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.ReloadPriv, "", "", "", err) + case *ast.AlterInstanceStmt: + err := ErrSpecificAccessDenied.GenWithStack("ALTER INSTANCE") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", err) +>>>>>>> 5c68d53... *: support reload tls used by mysql protocol in place (#14749) case *ast.AlterUserStmt: err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err) diff --git a/server/conn.go b/server/conn.go index 64896e8169a6d..b15155798e2af 100644 --- a/server/conn.go +++ b/server/conn.go @@ -495,23 +495,26 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con return err } - if (resp.Capability&mysql.ClientSSL > 0) && cc.server.tlsConfig != nil { - // The packet is a SSLRequest, let's switch to TLS. - if err = cc.upgradeToTLS(cc.server.tlsConfig); err != nil { - return err - } - // Read the following HandshakeResponse packet. - data, err = cc.readPacket() - if err != nil { - return err - } - if isOldVersion { - pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data) - } else { - pos, err = parseHandshakeResponseHeader(ctx, &resp, data) - } - if err != nil { - return err + if resp.Capability&mysql.ClientSSL > 0 { + tlsConfig := (*tls.Config)(atomic.LoadPointer(&cc.server.tlsConfig)) + if tlsConfig != nil { + // The packet is a SSLRequest, let's switch to TLS. + if err = cc.upgradeToTLS(tlsConfig); err != nil { + return err + } + // Read the following HandshakeResponse packet. + data, err = cc.readPacket() + if err != nil { + return err + } + if isOldVersion { + pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data) + } else { + pos, err = parseHandshakeResponseHeader(ctx, &resp, data) + } + if err != nil { + return err + } } } diff --git a/server/server.go b/server/server.go index 2339517125ef1..c2f8311cc57fd 100644 --- a/server/server.go +++ b/server/server.go @@ -31,13 +31,17 @@ package server import ( "context" "crypto/tls" - "crypto/x509" "fmt" "io" - "io/ioutil" "math/rand" "net" "net/http" +<<<<<<< HEAD +======= + "unsafe" + // For pprof + _ "net/http/pprof" +>>>>>>> 5c68d53... *: support reload tls used by mysql protocol in place (#14749) "os" "os/user" "sync" @@ -104,7 +108,7 @@ const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | // Server is the MySQL protocol server type Server struct { cfg *config.Config - tlsConfig *tls.Config + tlsConfig unsafe.Pointer // *tls.Config driver IDriver listener net.Listener socket net.Listener @@ -203,7 +207,16 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { clients: make(map[uint32]*clientConn), stopListenerCh: make(chan struct{}, 1), } - s.loadTLSCertificates() + + tlsConfig, err := util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert) + if err != nil { + logutil.BgLogger().Error("secure connection cert/key/ca load fail", zap.Error(err)) + return nil, err + } + logutil.BgLogger().Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0)) + setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert) + atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig)) + setSystemTimeZoneVariable() s.capability = defaultCapability @@ -211,8 +224,6 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { s.capability |= mysql.ClientSSL } - var err error - if s.cfg.Host != "" && s.cfg.Port != 0 { addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) if s.listener, err = net.Listen("tcp", addr); err == nil { @@ -252,6 +263,7 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { return s, nil } +<<<<<<< HEAD func (s *Server) loadTLSCertificates() { defer func() { if s.tlsConfig != nil { @@ -298,6 +310,14 @@ func (s *Server) loadTLSCertificates() { ClientAuth: clientAuthPolicy, MinVersion: 0, } +======= +func setSSLVariable(ca, key, cert string) { + variable.SysVars["have_openssl"].Value = "YES" + variable.SysVars["have_ssl"].Value = "YES" + variable.SysVars["ssl_cert"].Value = cert + variable.SysVars["ssl_key"].Value = key + variable.SysVars["ssl_ca"].Value = ca +>>>>>>> 5c68d53... *: support reload tls used by mysql protocol in place (#14749) } // Run runs the server. @@ -545,6 +565,15 @@ func (s *Server) Kill(connectionID uint64, query bool) { killConn(conn) } +// UpdateTLSConfig implements the SessionManager interface. +func (s *Server) UpdateTLSConfig(cfg *tls.Config) { + atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(cfg)) +} + +func (s *Server) getTLSConfig() *tls.Config { + return (*tls.Config)(atomic.LoadPointer(&s.tlsConfig)) +} + func killConn(conn *clientConn) { sessVars := conn.ctx.GetSessionVars() atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1) diff --git a/server/server_test.go b/server/server_test.go index c7ef21ae1a3b9..bdb142c5297e9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -27,6 +27,11 @@ import ( "github.com/go-sql-driver/mysql" . "github.com/pingcap/check" +<<<<<<< HEAD +======= + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" +>>>>>>> 5c68d53... *: support reload tls used by mysql protocol in place (#14749) "github.com/pingcap/log" tmysql "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/kv" @@ -1011,11 +1016,32 @@ func runTestStmtCount(t *C) { }) } +<<<<<<< HEAD func runTestTLSConnection(t *C, overrider configOverrider) error { db, err := sql.Open("mysql", getDSN(overrider)) +======= +func (cli *testServerClient) runTestTLSConnection(t *C, overrider configOverrider) error { + dsn := cli.getDSN(overrider) + db, err := sql.Open("mysql", dsn) +>>>>>>> 5c68d53... *: support reload tls used by mysql protocol in place (#14749) t.Assert(err, IsNil) defer db.Close() _, err = db.Exec("USE test") + if err != nil { + return errors.Annotate(err, "dsn:"+dsn) + } + return err +} + +func (cli *testServerClient) runReloadTLS(t *C, overrider configOverrider, errorNoRollback bool) error { + db, err := sql.Open("mysql", cli.getDSN(overrider)) + t.Assert(err, IsNil) + defer db.Close() + sql := "alter instance reload tls" + if errorNoRollback { + sql += " no rollback on error" + } + _, err = db.Exec(sql) return err } diff --git a/server/tidb_test.go b/server/tidb_test.go index 00d690589fbe6..827c05db6381b 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -211,7 +211,7 @@ func (ts *TidbTestSuite) TestSocket(c *C) { // generateCert generates a private key and a certificate in PEM format based on parameters. // If parentCert and parentCertKey is specified, the new certificate will be signed by the parentCert. // Otherwise, the new certificate will be self-signed and is a CA. -func generateCert(sn int, commonName string, parentCert *x509.Certificate, parentCertKey *rsa.PrivateKey, outKeyFile string, outCertFile string) (*x509.Certificate, *rsa.PrivateKey, error) { +func generateCert(sn int, commonName string, parentCert *x509.Certificate, parentCertKey *rsa.PrivateKey, outKeyFile string, outCertFile string, opts ...func(c *x509.Certificate)) (*x509.Certificate, *rsa.PrivateKey, error) { privateKey, err := rsa.GenerateKey(rand.Reader, 528) if err != nil { return nil, nil, errors.Trace(err) @@ -228,6 +228,9 @@ func generateCert(sn int, commonName string, parentCert *x509.Certificate, paren ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, BasicConstraintsValid: true, } + for _, opt := range opts { + opt(&template) + } var parent *x509.Certificate var priv *rsa.PrivateKey @@ -309,7 +312,11 @@ func (ts *TidbTestSuite) TestSystemTimeZone(c *C) { tk.MustQuery("select @@system_time_zone").Check(tz1) } +<<<<<<< HEAD func (ts *TidbTestSuite) TestTLS(c *C) { +======= +func (ts *tidbTestSerialSuite) TestTLS(c *C) { +>>>>>>> 5c68d53... *: support reload tls used by mysql protocol in place (#14749) // Generate valid TLS certificates. caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key.pem", "/tmp/ca-cert.pem") c.Assert(err, IsNil) @@ -343,7 +350,7 @@ func (ts *TidbTestSuite) TestTLS(c *C) { time.Sleep(time.Millisecond * 100) err = runTestTLSConnection(c, connOverrider) // We should get ErrNoTLS. c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, mysql.ErrNoTLS.Error()) + c.Assert(errors.Cause(err).Error(), Equals, mysql.ErrNoTLS.Error()) server.Close() // Start the server with TLS but without CA, in this case the server will not verify client's certificate. @@ -404,6 +411,169 @@ func (ts *TidbTestSuite) TestTLS(c *C) { c.Assert(err, IsNil) runTestRegression(c, connOverrider, "TLSRegression") server.Close() + + c.Assert(util.IsTLSExpiredError(errors.New("unknown test")), IsFalse) + c.Assert(util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.CANotAuthorizedForThisName}), IsFalse) + c.Assert(util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.Expired}), IsTrue) + + _, err = util.LoadTLSCertificates("", "wrong key", "wrong cert") + c.Assert(err, NotNil) + _, err = util.LoadTLSCertificates("wrong ca", "/tmp/server-key.pem", "/tmp/server-cert.pem") + c.Assert(err, NotNil) +} + +func (ts *tidbTestSerialSuite) TestReloadTLS(c *C) { + // Generate valid TLS certificates. + caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-reload.pem", "/tmp/ca-cert-reload.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload.pem", "/tmp/server-cert-reload.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key-reload.pem", "/tmp/client-cert-reload.pem") + c.Assert(err, IsNil) + err = registerTLSConfig("client-certificate-reload", "/tmp/ca-cert-reload.pem", "/tmp/client-cert-reload.pem", "/tmp/client-key-reload.pem", "tidb-server", true) + c.Assert(err, IsNil) + + defer func() { + os.Remove("/tmp/ca-key-reload.pem") + os.Remove("/tmp/ca-cert-reload.pem") + + os.Remove("/tmp/server-key-reload.pem") + os.Remove("/tmp/server-cert-reload.pem") + os.Remove("/tmp/client-key-reload.pem") + os.Remove("/tmp/client-cert-reload.pem") + }() + + // try old cert used in startup configuration. + cli := newTestServerClient() + cfg := config.NewConfig() + cfg.Port = cli.port + cfg.Status.ReportStatus = false + cfg.Security = config.Security{ + SSLCA: "/tmp/ca-cert-reload.pem", + SSLCert: "/tmp/server-cert-reload.pem", + SSLKey: "/tmp/server-key-reload.pem", + } + server, err := NewServer(cfg, ts.tidbdrv) + c.Assert(err, IsNil) + go server.Run() + time.Sleep(time.Millisecond * 100) + // The client provides a valid certificate. + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + } + err = cli.runTestTLSConnection(c, connOverrider) + c.Assert(err, IsNil) + + // try reload a valid cert. + tlsCfg := server.getTLSConfig() + cert, err := x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0]) + c.Assert(err, IsNil) + oldExpireTime := cert.NotAfter + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload2.pem", "/tmp/server-cert-reload2.pem", func(c *x509.Certificate) { + c.NotBefore = time.Now().Add(-24 * time.Hour).UTC() + c.NotAfter = time.Now().Add(1 * time.Hour).UTC() + }) + c.Assert(err, IsNil) + os.Rename("/tmp/server-key-reload2.pem", "/tmp/server-key-reload.pem") + os.Rename("/tmp/server-cert-reload2.pem", "/tmp/server-cert-reload.pem") + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "skip-verify" + } + err = cli.runReloadTLS(c, connOverrider, false) + c.Assert(err, IsNil) + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + } + err = cli.runTestTLSConnection(c, connOverrider) + c.Assert(err, IsNil) + + tlsCfg = server.getTLSConfig() + cert, err = x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0]) + c.Assert(err, IsNil) + newExpireTime := cert.NotAfter + c.Assert(newExpireTime.After(oldExpireTime), IsTrue) + + // try reload a expired cert. + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload3.pem", "/tmp/server-cert-reload3.pem", func(c *x509.Certificate) { + c.NotBefore = time.Now().Add(-24 * time.Hour).UTC() + c.NotAfter = c.NotBefore.Add(1 * time.Hour).UTC() + }) + c.Assert(err, IsNil) + os.Rename("/tmp/server-key-reload3.pem", "/tmp/server-key-reload.pem") + os.Rename("/tmp/server-cert-reload3.pem", "/tmp/server-cert-reload.pem") + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "skip-verify" + } + err = cli.runReloadTLS(c, connOverrider, false) + c.Assert(err, IsNil) + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + } + err = cli.runTestTLSConnection(c, connOverrider) + c.Assert(err, NotNil) + c.Assert(util.IsTLSExpiredError(err), IsTrue, Commentf("real error is %+v", err)) + server.Close() +} + +func (ts *tidbTestSerialSuite) TestErrorNoRollback(c *C) { + // Generate valid TLS certificates. + caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-rollback.pem", "/tmp/ca-cert-rollback.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-rollback.pem", "/tmp/server-cert-rollback.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key-rollback.pem", "/tmp/client-cert-rollback.pem") + c.Assert(err, IsNil) + err = registerTLSConfig("client-cert-rollback-test", "/tmp/ca-cert-rollback.pem", "/tmp/client-cert-rollback.pem", "/tmp/client-key-rollback.pem", "tidb-server", true) + c.Assert(err, IsNil) + + defer func() { + os.Remove("/tmp/ca-key-rollback.pem") + os.Remove("/tmp/ca-cert-rollback.pem") + + os.Remove("/tmp/server-key-rollback.pem") + os.Remove("/tmp/server-cert-rollback.pem") + os.Remove("/tmp/client-key-rollback.pem") + os.Remove("/tmp/client-cert-rollback.pem") + }() + + cli := newTestServerClient() + cfg := config.NewConfig() + cfg.Port = cli.port + cfg.Status.ReportStatus = false + + // test cannot startup with wrong tls config + cfg.Security = config.Security{ + SSLCA: "wrong path", + SSLCert: "wrong path", + SSLKey: "wrong path", + } + _, err = NewServer(cfg, ts.tidbdrv) + c.Assert(err, NotNil) + + // test reload tls fail with/without "error no rollback option" + cfg.Security = config.Security{ + SSLCA: "/tmp/ca-cert-rollback.pem", + SSLCert: "/tmp/server-cert-rollback.pem", + SSLKey: "/tmp/server-key-rollback.pem", + } + server, err := NewServer(cfg, ts.tidbdrv) + c.Assert(err, IsNil) + go server.Run() + time.Sleep(time.Millisecond * 100) + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "client-cert-rollback-test" + } + err = cli.runTestTLSConnection(c, connOverrider) + c.Assert(err, IsNil) + os.Remove("/tmp/server-key-rollback.pem") + err = cli.runReloadTLS(c, connOverrider, false) + c.Assert(err, NotNil) + tlsCfg := server.getTLSConfig() + c.Assert(tlsCfg, NotNil) + err = cli.runReloadTLS(c, connOverrider, true) + c.Assert(err, IsNil) + tlsCfg = server.getTLSConfig() + c.Assert(tlsCfg, IsNil) } func (ts *TidbTestSuite) TestClientWithCollation(c *C) { diff --git a/util/misc.go b/util/misc.go index 5161385d0e9c9..92d1dd40375c7 100644 --- a/util/misc.go +++ b/util/misc.go @@ -16,8 +16,10 @@ package util import ( "context" "crypto/tls" + "crypto/x509" "crypto/x509/pkix" "fmt" + "io/ioutil" "runtime" "strconv" "strings" @@ -300,3 +302,67 @@ func init() { pkixTypeNameAttributes[value] = key } } +<<<<<<< HEAD +======= + +// SequenceSchema is implemented by infoSchema and used by sequence function in expression package. +// Otherwise calling information schema will cause import cycle problem. +type SequenceSchema interface { + SequenceByName(schema, sequence model.CIStr) (SequenceTable, error) +} + +// SequenceTable is implemented by tableCommon, and it is specialised in handling sequence operation. +// Otherwise calling table will cause import cycle problem. +type SequenceTable interface { + GetSequenceID() int64 + GetSequenceNextVal(ctx interface{}, dbName, seqName string) (int64, error) + SetSequenceVal(ctx interface{}, newVal int64, dbName, seqName string) (int64, bool, error) +} + +// LoadTLSCertificates loads CA/KEY/CERT for special paths. +func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error) { + if len(cert) == 0 || len(key) == 0 { + return + } + + var tlsCert tls.Certificate + tlsCert, err = tls.LoadX509KeyPair(cert, key) + if err != nil { + logutil.BgLogger().Warn("load x509 failed", zap.Error(err)) + err = errors.Trace(err) + return + } + + // Try loading CA cert. + clientAuthPolicy := tls.NoClientCert + var certPool *x509.CertPool + if len(ca) > 0 { + var caCert []byte + caCert, err = ioutil.ReadFile(ca) + if err != nil { + logutil.BgLogger().Warn("read file failed", zap.Error(err)) + err = errors.Trace(err) + return + } + certPool = x509.NewCertPool() + if certPool.AppendCertsFromPEM(caCert) { + clientAuthPolicy = tls.VerifyClientCertIfGiven + } + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + ClientCAs: certPool, + ClientAuth: clientAuthPolicy, + } + return +} + +// IsTLSExpiredError checks error is caused by TLS expired. +func IsTLSExpiredError(err error) bool { + err = errors.Cause(err) + if inval, ok := err.(x509.CertificateInvalidError); !ok || inval.Reason != x509.Expired { + return false + } + return true +} +>>>>>>> 5c68d53... *: support reload tls used by mysql protocol in place (#14749) diff --git a/util/processinfo.go b/util/processinfo.go index b09edf810b184..959855ae3e3bc 100644 --- a/util/processinfo.go +++ b/util/processinfo.go @@ -14,6 +14,7 @@ package util import ( + "crypto/tls" "fmt" "time" @@ -76,4 +77,5 @@ type SessionManager interface { ShowProcessList() map[uint64]*ProcessInfo GetProcessInfo(id uint64) (*ProcessInfo, bool) Kill(connectionID uint64, query bool) + UpdateTLSConfig(cfg *tls.Config) }