From 18df49eed5359e5fe9cf748963e3c8ed2fc2987d Mon Sep 17 00:00:00 2001 From: lychung83 Date: Mon, 29 Oct 2018 11:00:49 -0700 Subject: [PATCH] refresh SSL certs early --- proxy/proxy/client.go | 73 ++++++++++++++++++++++++++---------- proxy/proxy/client_test.go | 76 ++++++++++++++++++++++++++++++++------ 2 files changed, 117 insertions(+), 32 deletions(-) diff --git a/proxy/proxy/client.go b/proxy/proxy/client.go index 35c4dca42..0fab2087e 100644 --- a/proxy/proxy/client.go +++ b/proxy/proxy/client.go @@ -32,10 +32,13 @@ const ( keepAlivePeriod = time.Minute ) -// errNotCached is returned when the instance was not found in the Client's -// cache. It is an internal detail and is not actually ever returned to the -// user. -var errNotCached = errors.New("instance was not found in cache") +var ( + // errNotCached is returned when the instance was not found in the Client's + // cache. It is an internal detail and is not actually ever returned to the + // user. + errNotCached = errors.New("instance was not found in cache") + refreshCertBuffer = 30 * time.Second +) // Conn represents a connection from a client to a specific instance. type Conn struct { @@ -78,9 +81,12 @@ type Client struct { // The cfgCache holds the most recent connection configuration keyed by // instance. Relevant functions are refreshCfg and cachedCfg. It is - // protected by cfgL. + // protected by cacheL. cfgCache map[string]cacheEntry - cfgL sync.RWMutex + cacheL sync.RWMutex + + // refreshCfgL prevents multiple goroutines from contacting the Cloud SQL API at once. + refreshCfgL sync.Mutex // MaxConnections is the maximum number of connections to establish // before refusing new connections. 0 means no limit. @@ -147,32 +153,35 @@ func (c *Client) handleConn(conn Conn) { // address as well as construct a new tls.Config to connect to the instance. It // caches the result. func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err error) { - c.cfgL.Lock() - defer c.cfgL.Unlock() + c.refreshCfgL.Lock() + defer c.refreshCfgL.Unlock() throttle := c.RefreshCfgThrottle if throttle == 0 { throttle = DefaultRefreshCfgThrottle } - if old := c.cfgCache[instance]; time.Since(old.lastRefreshed) < throttle { + c.cacheL.Lock() + if c.cfgCache == nil { + c.cfgCache = make(map[string]cacheEntry) + } + old, oldok := c.cfgCache[instance] + c.cacheL.Unlock() + if oldok && time.Since(old.lastRefreshed) < throttle { logging.Errorf("Throttling refreshCfg(%s): it was only called %v ago", instance, time.Since(old.lastRefreshed)) // Refresh was called too recently, just reuse the result. return old.addr, old.cfg, old.err } - if c.cfgCache == nil { - c.cfgCache = make(map[string]cacheEntry) - } - defer func() { + c.cacheL.Lock() c.cfgCache[instance] = cacheEntry{ lastRefreshed: time.Now(), - - err: err, - addr: addr, - cfg: cfg, + err: err, + addr: addr, + cfg: cfg, } + c.cacheL.Unlock() }() mycert, err := c.Certs.Local(instance) @@ -201,9 +210,29 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err InsecureSkipVerify: true, VerifyPeerCertificate: genVerifyPeerCertificateFunc(name, certs), } + + expire := mycert.Leaf.NotAfter + now := time.Now() + timeToRefresh := expire.Sub(now) - refreshCertBuffer + if timeToRefresh <= 0 { + err = fmt.Errorf("new ephemeral certificate expires too soon: current time: %v, certificate expires: %v", expire, now) + logging.Errorf("ephemeral certificate (%+v) error: %v", mycert, err) + return "", nil, err + } + go c.refreshCertAfter(instance, timeToRefresh) + return fmt.Sprintf("%s:%d", addr, c.Port), cfg, nil } +// refreshCertAfter refreshes the epehemeral certificate of the instance after timeToRefresh. +func (c *Client) refreshCertAfter(instance string, timeToRefresh time.Duration) { + <-time.After(timeToRefresh) + logging.Verbosef("ephemeral certificate for instance %s will expire soon, refreshing now.", instance) + if _, _, err := c.refreshCfg(instance); err != nil { + logging.Errorf("failed to refresh the ephemeral certificate for %s before expering: %v", instance, err) + } +} + // genVerifyPeerCertificateFunc creates a VerifyPeerCertificate func that verifies that the peer // certificate is in the cert pool. We need to define our own because of our sketchy non-standard // CNs. @@ -230,13 +259,17 @@ func genVerifyPeerCertificateFunc(instanceName string, pool *x509.CertPool) func } } +func isExpired(cfg *tls.Config) bool { + return time.Now().After(cfg.Certificates[0].Leaf.NotAfter) +} + func (c *Client) cachedCfg(instance string) (string, *tls.Config) { - c.cfgL.RLock() + c.cacheL.RLock() ret, ok := c.cfgCache[instance] - c.cfgL.RUnlock() + c.cacheL.RUnlock() // Don't waste time returning an expired/invalid cert. - if !ok || ret.err != nil || time.Now().After(ret.cfg.Certificates[0].Leaf.NotAfter) { + if !ok || ret.err != nil || isExpired(ret.cfg) { return "", nil } return ret.addr, ret.cfg diff --git a/proxy/proxy/client_test.go b/proxy/proxy/client_test.go index a59ad1e1c..a3beaaf35 100644 --- a/proxy/proxy/client_test.go +++ b/proxy/proxy/client_test.go @@ -28,7 +28,10 @@ import ( const instance = "instance-name" -var errFakeDial = errors.New("this error is returned by the dialer") +var ( + errFakeDial = errors.New("this error is returned by the dialer") + forever = time.Date(9999, 0, 0, 0, 0, 0, 0, time.UTC) +) type fakeCerts struct { sync.Mutex @@ -36,7 +39,8 @@ type fakeCerts struct { } type blockingCertSource struct { - values map[string]*fakeCerts + values map[string]*fakeCerts + validUntil time.Time } func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) { @@ -48,11 +52,10 @@ func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) { v.called++ v.Unlock() - validUntil, _ := time.Parse("2006", "9999") // Returns a cert which is valid forever. return tls.Certificate{ Leaf: &x509.Certificate{ - NotAfter: validUntil, + NotAfter: cs.validUntil, }, }, nil } @@ -67,7 +70,9 @@ func TestClientCache(t *testing.T) { Certs: &blockingCertSource{ map[string]*fakeCerts{ instance: b, - }}, + }, + forever, + }, Dialer: func(string, string) (net.Conn, error) { return nil, errFakeDial }, @@ -92,7 +97,9 @@ func TestConcurrentRefresh(t *testing.T) { Certs: &blockingCertSource{ map[string]*fakeCerts{ instance: b, - }}, + }, + forever, + }, Dialer: func(string, string) (net.Conn, error) { return nil, errFakeDial }, @@ -131,7 +138,9 @@ func TestMaximumConnectionsCount(t *testing.T) { b := &fakeCerts{} certSource := blockingCertSource{ - map[string]*fakeCerts{}} + map[string]*fakeCerts{}, + forever, + } firstDialExited := make(chan struct{}) c := &Client{ Certs: &certSource, @@ -190,28 +199,71 @@ func TestShutdownTerminatesEarly(t *testing.T) { Certs: &blockingCertSource{ map[string]*fakeCerts{ instance: b, - }}, + }, + forever, + }, Dialer: func(string, string) (net.Conn, error) { return nil, nil }, } - shutdown := make(chan bool, 1) go func() { c.Shutdown(1) shutdown <- true }() - shutdownFinished := false - // In case the code is actually broken and the client doesn't shut down quickly, don't cause the test to hang until it times out. select { case <-time.After(100 * time.Millisecond): case shutdownFinished = <-shutdown: } - if !shutdownFinished { t.Errorf("shutdown should have completed quickly because there are no active connections") } +} + +func TestRefreshTimer(t *testing.T) { + oldRefreshCertBuffer := refreshCertBuffer + defer func() { + refreshCertBuffer = oldRefreshCertBuffer + }() + refreshCertBuffer = time.Second + timeToExpire := 5 * time.Second + b := &fakeCerts{} + certCreated := time.Now() + c := &Client{ + Certs: &blockingCertSource{ + map[string]*fakeCerts{ + instance: b, + }, + certCreated.Add(timeToExpire), + }, + Dialer: func(string, string) (net.Conn, error) { + return nil, errFakeDial + }, + RefreshCfgThrottle: 20 * time.Millisecond, + } + // Call Dial to cache the cert. + if _, err := c.Dial(instance); err != errFakeDial { + t.Fatalf("Dial(%s) failed: %v", instance, err) + } + c.cacheL.Lock() + cfg, ok := c.cfgCache[instance] + c.cacheL.Unlock() + if !ok { + t.Fatalf("expected instance to be cached") + } + + time.Sleep(timeToExpire - time.Since(certCreated)) + // Check if cert was refreshed in the background, without calling Dial again. + c.cacheL.Lock() + newCfg, ok := c.cfgCache[instance] + c.cacheL.Unlock() + if !ok { + t.Fatalf("expected instance to be cached") + } + if !newCfg.lastRefreshed.After(cfg.lastRefreshed) { + t.Error("expected cert to be refreshed.") + } }