diff --git a/executor/executor_pkg_test.go b/executor/executor_pkg_test.go index cf605132e92bd..9b1096619a79a 100644 --- a/executor/executor_pkg_test.go +++ b/executor/executor_pkg_test.go @@ -15,6 +15,7 @@ package executor import ( "context" + "crypto/tls" . "github.com/pingcap/check" "github.com/pingcap/parser/ast" @@ -62,6 +63,9 @@ func (msm *mockSessionManager) Kill(cid uint64, query bool) { } +func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) { +} + func (s *testExecSuite) TestShowProcessList(c *C) { // Compose schema. names := []string{"Id", "User", "Host", "db", "Command", "Time", "State", "Info"} diff --git a/executor/explainfor_test.go b/executor/explainfor_test.go index 632874180495b..7bc150a2b2bd9 100644 --- a/executor/explainfor_test.go +++ b/executor/explainfor_test.go @@ -14,6 +14,7 @@ package executor_test import ( + "crypto/tls" "fmt" . "github.com/pingcap/check" @@ -51,6 +52,9 @@ func (msm *mockSessionManager1) Kill(cid uint64, query bool) { } +func (msm *mockSessionManager1) UpdateTLSConfig(cfg *tls.Config) { +} + func (s *testSuite) TestExplainFor(c *C) { tkRoot := testkit.NewTestKitWithInit(c, s.store) tkUser := testkit.NewTestKitWithInit(c, s.store) diff --git a/executor/simple.go b/executor/simple.go index b34beef93a813..b690c4a694b7f 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" @@ -108,6 +109,8 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { err = e.executeUse(x) case *ast.FlushStmt: err = e.executeFlush(x) + case *ast.AlterInstanceStmt: + err = e.executeAlterInstance(x) case *ast.BeginStmt: err = e.executeBegin(ctx, x) case *ast.CommitStmt: @@ -1093,6 +1096,26 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error { return nil } +func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error { + if s.ReloadTLS { + logutil.Logger(context.Background()).Info("execute reload tls", zap.Bool("NoRollbackOnError", s.NoRollbackOnError)) + sm := e.ctx.GetSessionManager() + tlsCfg, err := util.LoadTLSCertificates( + variable.SysVars["ssl_ca"].Value, + variable.SysVars["ssl_key"].Value, + variable.SysVars["ssl_cert"].Value, + ) + if err != nil { + if !s.NoRollbackOnError { + return err + } + logutil.Logger(context.Background()).Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'") + } + sm.UpdateTLSConfig(tlsCfg) + } + return nil +} + func (e *SimpleExec) executeDropStats(s *ast.DropStatsStmt) error { h := domain.GetDomain(e.ctx).StatsHandle() err := h.DeleteTableStatsFromKV(s.Table.TableInfo.ID) diff --git a/go.mod b/go.mod index 84b36787ea014..6c82d9f0f6139 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e github.com/pingcap/kvproto v0.0.0-20191106014506-c5d88d699a8d github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd - github.com/pingcap/parser v0.0.0-20200301155133-79ec3dee69a5 + github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6 github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2 github.com/pingcap/tidb-tools v3.0.6-0.20191119150227-ff0a3c6e5763+incompatible github.com/pingcap/tipb v0.0.0-20191120045257-1b9900292ab6 diff --git a/go.sum b/go.sum index 62dd659803480..799ebdb21b4da 100644 --- a/go.sum +++ b/go.sum @@ -155,8 +155,8 @@ github.com/pingcap/kvproto v0.0.0-20191106014506-c5d88d699a8d h1:zTHgLr8+0LTEJmj github.com/pingcap/kvproto v0.0.0-20191106014506-c5d88d699a8d/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY= github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd h1:hWDol43WY5PGhsh3+8794bFHY1bPrmu6bTalpssCrGg= github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw= -github.com/pingcap/parser v0.0.0-20200301155133-79ec3dee69a5 h1:r2c8RQynYNGCFDWFPgo3TNx7Roq94STRcYTrtTg3JQ4= -github.com/pingcap/parser v0.0.0-20200301155133-79ec3dee69a5/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= +github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6 h1:Xm46UzfGEzxovTaj/hhIX8Q+o/mL4iB6SbwktExvMAY= +github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2 h1:NL23b8tsg6M1QpSQedK14/Jx++QeyKL2rGiBvXAQVfA= github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2/go.mod h1:b4gaAPSxaVVtaB+EHamV4Nsv8JmTdjlw0cTKmp4+dRQ= github.com/pingcap/tidb-tools v3.0.6-0.20191119150227-ff0a3c6e5763+incompatible h1:I8HirWsu1MZp6t9G/g8yKCEjJJxtHooKakEgccvdJ4M= @@ -183,7 +183,6 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFd github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7 h1:FUL3b97ZY2EPqg2NbXKuMHs5pXJB9hjj1fDHnF2vl28= github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44 h1:tB9NOR21++IjLyVx3/PCPhWMwqGNCMQEH96A6dMZ/gc= github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBhDSX0OFYyMYminYs= github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index e201d46783d26..c9019bc129d61 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -14,6 +14,7 @@ package infoschema_test import ( + "crypto/tls" "fmt" "os" "strconv" @@ -321,6 +322,8 @@ func (sm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool func (sm *mockSessionManager) Kill(connectionID uint64, query bool) {} +func (sm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {} + func (s *testTableSuite) TestSomeTables(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index bde5ea183e54e..d165e30eb2fb4 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -282,7 +282,7 @@ func (b *PlanBuilder) Build(ctx context.Context, node ast.Node) (Plan, error) { case *ast.AnalyzeTableStmt: return b.buildAnalyze(x) case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt, - *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, + *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.AlterInstanceStmt, *ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt, *ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt, *ast.ShutdownStmt: return b.buildSimple(node.(ast.StmtNode)) @@ -1385,6 +1385,9 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) { p := &Simple{Statement: node} switch raw := node.(type) { + case *ast.AlterInstanceStmt: + err := ErrSpecificAccessDenied.GenWithStack("ALTER INSTANCE") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", err) case *ast.AlterUserStmt: err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err) diff --git a/server/conn.go b/server/conn.go index 64896e8169a6d..b15155798e2af 100644 --- a/server/conn.go +++ b/server/conn.go @@ -495,23 +495,26 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con return err } - if (resp.Capability&mysql.ClientSSL > 0) && cc.server.tlsConfig != nil { - // The packet is a SSLRequest, let's switch to TLS. - if err = cc.upgradeToTLS(cc.server.tlsConfig); err != nil { - return err - } - // Read the following HandshakeResponse packet. - data, err = cc.readPacket() - if err != nil { - return err - } - if isOldVersion { - pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data) - } else { - pos, err = parseHandshakeResponseHeader(ctx, &resp, data) - } - if err != nil { - return err + if resp.Capability&mysql.ClientSSL > 0 { + tlsConfig := (*tls.Config)(atomic.LoadPointer(&cc.server.tlsConfig)) + if tlsConfig != nil { + // The packet is a SSLRequest, let's switch to TLS. + if err = cc.upgradeToTLS(tlsConfig); err != nil { + return err + } + // Read the following HandshakeResponse packet. + data, err = cc.readPacket() + if err != nil { + return err + } + if isOldVersion { + pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data) + } else { + pos, err = parseHandshakeResponseHeader(ctx, &resp, data) + } + if err != nil { + return err + } } } diff --git a/server/server.go b/server/server.go index 2339517125ef1..8d4d8c314601e 100644 --- a/server/server.go +++ b/server/server.go @@ -31,10 +31,8 @@ package server import ( "context" "crypto/tls" - "crypto/x509" "fmt" "io" - "io/ioutil" "math/rand" "net" "net/http" @@ -43,6 +41,7 @@ import ( "sync" "sync/atomic" "time" + "unsafe" // For pprof _ "net/http/pprof" @@ -104,7 +103,7 @@ const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | // Server is the MySQL protocol server type Server struct { cfg *config.Config - tlsConfig *tls.Config + tlsConfig unsafe.Pointer // *tls.Config driver IDriver listener net.Listener socket net.Listener @@ -203,7 +202,15 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { clients: make(map[uint32]*clientConn), stopListenerCh: make(chan struct{}, 1), } - s.loadTLSCertificates() + + tlsConfig, err := util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert) + if err != nil { + logutil.Logger(context.Background()).Error("secure connection cert/key/ca load fail", zap.Error(err)) + } + logutil.Logger(context.Background()).Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0)) + setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert) + atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig)) + setSystemTimeZoneVariable() s.capability = defaultCapability @@ -211,8 +218,6 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { s.capability |= mysql.ClientSSL } - var err error - if s.cfg.Host != "" && s.cfg.Port != 0 { addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) if s.listener, err = net.Listen("tcp", addr); err == nil { @@ -252,52 +257,12 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { return s, nil } -func (s *Server) loadTLSCertificates() { - defer func() { - if s.tlsConfig != nil { - logutil.Logger(context.Background()).Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0)) - variable.SysVars["have_openssl"].Value = "YES" - variable.SysVars["have_ssl"].Value = "YES" - variable.SysVars["ssl_cert"].Value = s.cfg.Security.SSLCert - variable.SysVars["ssl_key"].Value = s.cfg.Security.SSLKey - } else { - logutil.Logger(context.Background()).Warn("secure connection is not enabled") - } - }() - - if len(s.cfg.Security.SSLCert) == 0 || len(s.cfg.Security.SSLKey) == 0 { - s.tlsConfig = nil - return - } - - tlsCert, err := tls.LoadX509KeyPair(s.cfg.Security.SSLCert, s.cfg.Security.SSLKey) - if err != nil { - logutil.Logger(context.Background()).Warn("load x509 failed", zap.Error(err)) - s.tlsConfig = nil - return - } - - // Try loading CA cert. - clientAuthPolicy := tls.NoClientCert - var certPool *x509.CertPool - if len(s.cfg.Security.SSLCA) > 0 { - caCert, err := ioutil.ReadFile(s.cfg.Security.SSLCA) - if err != nil { - logutil.Logger(context.Background()).Warn("read file failed", zap.Error(err)) - } else { - certPool = x509.NewCertPool() - if certPool.AppendCertsFromPEM(caCert) { - clientAuthPolicy = tls.VerifyClientCertIfGiven - } - variable.SysVars["ssl_ca"].Value = s.cfg.Security.SSLCA - } - } - s.tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - ClientCAs: certPool, - ClientAuth: clientAuthPolicy, - MinVersion: 0, - } +func setSSLVariable(ca, key, cert string) { + variable.SysVars["have_openssl"].Value = "YES" + variable.SysVars["have_ssl"].Value = "YES" + variable.SysVars["ssl_cert"].Value = cert + variable.SysVars["ssl_key"].Value = key + variable.SysVars["ssl_ca"].Value = ca } // Run runs the server. @@ -545,6 +510,15 @@ func (s *Server) Kill(connectionID uint64, query bool) { killConn(conn) } +// UpdateTLSConfig implements the SessionManager interface. +func (s *Server) UpdateTLSConfig(cfg *tls.Config) { + atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(cfg)) +} + +func (s *Server) getTLSConfig() *tls.Config { + return (*tls.Config)(atomic.LoadPointer(&s.tlsConfig)) +} + func killConn(conn *clientConn) { sessVars := conn.ctx.GetSessionVars() atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1) diff --git a/server/server_test.go b/server/server_test.go index c7ef21ae1a3b9..be82317d1c0dd 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -17,6 +17,7 @@ import ( "database/sql" "encoding/json" "fmt" + "github.com/pingcap/errors" "io/ioutil" "net/http" "os" @@ -1012,10 +1013,26 @@ func runTestStmtCount(t *C) { } func runTestTLSConnection(t *C, overrider configOverrider) error { - db, err := sql.Open("mysql", getDSN(overrider)) + dsn := getDSN(overrider) + db, err := sql.Open("mysql", dsn) t.Assert(err, IsNil) defer db.Close() _, err = db.Exec("USE test") + if err != nil { + return errors.Annotate(err, "dsn:"+dsn) + } + return err +} + +func runReloadTLS(t *C, overrider configOverrider, errorNoRollback bool) error { + db, err := sql.Open("mysql", getDSN(overrider)) + t.Assert(err, IsNil) + defer db.Close() + sql := "alter instance reload tls" + if errorNoRollback { + sql += " no rollback on error" + } + _, err = db.Exec(sql) return err } diff --git a/server/tidb_test.go b/server/tidb_test.go index 00d690589fbe6..f75900698c750 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -211,7 +211,7 @@ func (ts *TidbTestSuite) TestSocket(c *C) { // generateCert generates a private key and a certificate in PEM format based on parameters. // If parentCert and parentCertKey is specified, the new certificate will be signed by the parentCert. // Otherwise, the new certificate will be self-signed and is a CA. -func generateCert(sn int, commonName string, parentCert *x509.Certificate, parentCertKey *rsa.PrivateKey, outKeyFile string, outCertFile string) (*x509.Certificate, *rsa.PrivateKey, error) { +func generateCert(sn int, commonName string, parentCert *x509.Certificate, parentCertKey *rsa.PrivateKey, outKeyFile string, outCertFile string, opts ...func(c *x509.Certificate)) (*x509.Certificate, *rsa.PrivateKey, error) { privateKey, err := rsa.GenerateKey(rand.Reader, 528) if err != nil { return nil, nil, errors.Trace(err) @@ -228,6 +228,9 @@ func generateCert(sn int, commonName string, parentCert *x509.Certificate, paren ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, BasicConstraintsValid: true, } + for _, opt := range opts { + opt(&template) + } var parent *x509.Certificate var priv *rsa.PrivateKey @@ -343,7 +346,7 @@ func (ts *TidbTestSuite) TestTLS(c *C) { time.Sleep(time.Millisecond * 100) err = runTestTLSConnection(c, connOverrider) // We should get ErrNoTLS. c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, mysql.ErrNoTLS.Error()) + c.Assert(errors.Cause(err).Error(), Equals, mysql.ErrNoTLS.Error()) server.Close() // Start the server with TLS but without CA, in this case the server will not verify client's certificate. @@ -404,6 +407,164 @@ func (ts *TidbTestSuite) TestTLS(c *C) { c.Assert(err, IsNil) runTestRegression(c, connOverrider, "TLSRegression") server.Close() + + c.Assert(util.IsTLSExpiredError(errors.New("unknown test")), IsFalse) + c.Assert(util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.CANotAuthorizedForThisName}), IsFalse) + c.Assert(util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.Expired}), IsTrue) + + _, err = util.LoadTLSCertificates("", "wrong key", "wrong cert") + c.Assert(err, NotNil) + _, err = util.LoadTLSCertificates("wrong ca", "/tmp/server-key.pem", "/tmp/server-cert.pem") + c.Assert(err, NotNil) +} + +func (ts *TidbTestSuite) TestReloadTLS(c *C) { + // Generate valid TLS certificates. + caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-reload.pem", "/tmp/ca-cert-reload.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload.pem", "/tmp/server-cert-reload.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key-reload.pem", "/tmp/client-cert-reload.pem") + c.Assert(err, IsNil) + err = registerTLSConfig("client-certificate-reload", "/tmp/ca-cert-reload.pem", "/tmp/client-cert-reload.pem", "/tmp/client-key-reload.pem", "tidb-server", true) + c.Assert(err, IsNil) + + defer func() { + os.Remove("/tmp/ca-key-reload.pem") + os.Remove("/tmp/ca-cert-reload.pem") + + os.Remove("/tmp/server-key-reload.pem") + os.Remove("/tmp/server-cert-reload.pem") + os.Remove("/tmp/client-key-reload.pem") + os.Remove("/tmp/client-cert-reload.pem") + }() + + // try old cert used in startup configuration. + cfg := config.NewConfig() + cfg.Port = 4005 + cfg.Status.ReportStatus = false + cfg.Security = config.Security{ + SSLCA: "/tmp/ca-cert-reload.pem", + SSLCert: "/tmp/server-cert-reload.pem", + SSLKey: "/tmp/server-key-reload.pem", + } + server, err := NewServer(cfg, ts.tidbdrv) + c.Assert(err, IsNil) + go server.Run() + time.Sleep(time.Millisecond * 100) + // The client provides a valid certificate. + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + config.Addr = "localhost:4005" + } + err = runTestTLSConnection(c, connOverrider) + c.Assert(err, IsNil) + + // try reload a valid cert. + tlsCfg := server.getTLSConfig() + cert, err := x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0]) + c.Assert(err, IsNil) + oldExpireTime := cert.NotAfter + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload2.pem", "/tmp/server-cert-reload2.pem", func(c *x509.Certificate) { + c.NotBefore = time.Now().Add(-24 * time.Hour).UTC() + c.NotAfter = time.Now().Add(1 * time.Hour).UTC() + }) + c.Assert(err, IsNil) + os.Rename("/tmp/server-key-reload2.pem", "/tmp/server-key-reload.pem") + os.Rename("/tmp/server-cert-reload2.pem", "/tmp/server-cert-reload.pem") + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "skip-verify" + config.Addr = "localhost:4005" + } + err = runReloadTLS(c, connOverrider, false) + c.Assert(err, IsNil) + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + config.Addr = "localhost:4005" + } + err = runTestTLSConnection(c, connOverrider) + c.Assert(err, IsNil) + + tlsCfg = server.getTLSConfig() + cert, err = x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0]) + c.Assert(err, IsNil) + newExpireTime := cert.NotAfter + c.Assert(newExpireTime.After(oldExpireTime), IsTrue) + + // try reload a expired cert. + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload3.pem", "/tmp/server-cert-reload3.pem", func(c *x509.Certificate) { + c.NotBefore = time.Now().Add(-24 * time.Hour).UTC() + c.NotAfter = c.NotBefore.Add(1 * time.Hour).UTC() + }) + c.Assert(err, IsNil) + os.Rename("/tmp/server-key-reload3.pem", "/tmp/server-key-reload.pem") + os.Rename("/tmp/server-cert-reload3.pem", "/tmp/server-cert-reload.pem") + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "skip-verify" + config.Addr = "localhost:4005" + } + err = runReloadTLS(c, connOverrider, false) + c.Assert(err, IsNil) + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + config.Addr = "localhost:4005" + } + err = runTestTLSConnection(c, connOverrider) + c.Assert(err, NotNil) + c.Assert(util.IsTLSExpiredError(err), IsTrue, Commentf("real error is %+v", err)) + server.Close() +} + +func (ts *TidbTestSuite) TestErrorNoRollback(c *C) { + // Generate valid TLS certificates. + caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-rollback.pem", "/tmp/ca-cert-rollback.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-rollback.pem", "/tmp/server-cert-rollback.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key-rollback.pem", "/tmp/client-cert-rollback.pem") + c.Assert(err, IsNil) + err = registerTLSConfig("client-cert-rollback-test", "/tmp/ca-cert-rollback.pem", "/tmp/client-cert-rollback.pem", "/tmp/client-key-rollback.pem", "tidb-server", true) + c.Assert(err, IsNil) + + defer func() { + os.Remove("/tmp/ca-key-rollback.pem") + os.Remove("/tmp/ca-cert-rollback.pem") + + os.Remove("/tmp/server-key-rollback.pem") + os.Remove("/tmp/server-cert-rollback.pem") + os.Remove("/tmp/client-key-rollback.pem") + os.Remove("/tmp/client-cert-rollback.pem") + }() + + cfg := config.NewConfig() + cfg.Port = 4006 + cfg.Status.ReportStatus = false + + // test reload tls fail with/without "error no rollback option" + cfg.Security = config.Security{ + SSLCA: "/tmp/ca-cert-rollback.pem", + SSLCert: "/tmp/server-cert-rollback.pem", + SSLKey: "/tmp/server-key-rollback.pem", + } + server, err := NewServer(cfg, ts.tidbdrv) + c.Assert(err, IsNil) + go server.Run() + time.Sleep(time.Millisecond * 100) + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "client-cert-rollback-test" + config.Addr = "localhost:4006" + } + err = runTestTLSConnection(c, connOverrider) + c.Assert(err, IsNil) + os.Remove("/tmp/server-key-rollback.pem") + err = runReloadTLS(c, connOverrider, false) + c.Assert(err, NotNil) + tlsCfg := server.getTLSConfig() + c.Assert(tlsCfg, NotNil) + err = runReloadTLS(c, connOverrider, true) + c.Assert(err, IsNil) + tlsCfg = server.getTLSConfig() + c.Assert(tlsCfg, IsNil) } func (ts *TidbTestSuite) TestClientWithCollation(c *C) { diff --git a/util/misc.go b/util/misc.go index 5161385d0e9c9..743b58fa24a58 100644 --- a/util/misc.go +++ b/util/misc.go @@ -16,8 +16,10 @@ package util import ( "context" "crypto/tls" + "crypto/x509" "crypto/x509/pkix" "fmt" + "io/ioutil" "runtime" "strconv" "strings" @@ -300,3 +302,50 @@ func init() { pkixTypeNameAttributes[value] = key } } + +// LoadTLSCertificates loads CA/KEY/CERT for special paths. +func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error) { + if len(cert) == 0 || len(key) == 0 { + return + } + + var tlsCert tls.Certificate + tlsCert, err = tls.LoadX509KeyPair(cert, key) + if err != nil { + logutil.Logger(context.Background()).Warn("load x509 failed", zap.Error(err)) + err = errors.Trace(err) + return + } + + // Try loading CA cert. + clientAuthPolicy := tls.NoClientCert + var certPool *x509.CertPool + if len(ca) > 0 { + var caCert []byte + caCert, err = ioutil.ReadFile(ca) + if err != nil { + logutil.Logger(context.Background()).Warn("read file failed", zap.Error(err)) + err = errors.Trace(err) + return + } + certPool = x509.NewCertPool() + if certPool.AppendCertsFromPEM(caCert) { + clientAuthPolicy = tls.VerifyClientCertIfGiven + } + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + ClientCAs: certPool, + ClientAuth: clientAuthPolicy, + } + return +} + +// IsTLSExpiredError checks error is caused by TLS expired. +func IsTLSExpiredError(err error) bool { + err = errors.Cause(err) + if inval, ok := err.(x509.CertificateInvalidError); !ok || inval.Reason != x509.Expired { + return false + } + return true +} diff --git a/util/processinfo.go b/util/processinfo.go index b09edf810b184..959855ae3e3bc 100644 --- a/util/processinfo.go +++ b/util/processinfo.go @@ -14,6 +14,7 @@ package util import ( + "crypto/tls" "fmt" "time" @@ -76,4 +77,5 @@ type SessionManager interface { ShowProcessList() map[uint64]*ProcessInfo GetProcessInfo(id uint64) (*ProcessInfo, bool) Kill(connectionID uint64, query bool) + UpdateTLSConfig(cfg *tls.Config) }