Skip to content

Commit

Permalink
refresh SSL certs early
Browse files Browse the repository at this point in the history
  • Loading branch information
lychung83 committed Oct 29, 2018
1 parent ac834ce commit 18df49e
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 32 deletions.
73 changes: 53 additions & 20 deletions proxy/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ const (
keepAlivePeriod = time.Minute
)

// errNotCached is returned when the instance was not found in the Client's
// cache. It is an internal detail and is not actually ever returned to the
// user.
var errNotCached = errors.New("instance was not found in cache")
var (
// errNotCached is returned when the instance was not found in the Client's
// cache. It is an internal detail and is not actually ever returned to the
// user.
errNotCached = errors.New("instance was not found in cache")
refreshCertBuffer = 30 * time.Second
)

// Conn represents a connection from a client to a specific instance.
type Conn struct {
Expand Down Expand Up @@ -78,9 +81,12 @@ type Client struct {

// The cfgCache holds the most recent connection configuration keyed by
// instance. Relevant functions are refreshCfg and cachedCfg. It is
// protected by cfgL.
// protected by cacheL.
cfgCache map[string]cacheEntry
cfgL sync.RWMutex
cacheL sync.RWMutex

// refreshCfgL prevents multiple goroutines from contacting the Cloud SQL API at once.
refreshCfgL sync.Mutex

// MaxConnections is the maximum number of connections to establish
// before refusing new connections. 0 means no limit.
Expand Down Expand Up @@ -147,32 +153,35 @@ func (c *Client) handleConn(conn Conn) {
// address as well as construct a new tls.Config to connect to the instance. It
// caches the result.
func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err error) {
c.cfgL.Lock()
defer c.cfgL.Unlock()
c.refreshCfgL.Lock()
defer c.refreshCfgL.Unlock()

throttle := c.RefreshCfgThrottle
if throttle == 0 {
throttle = DefaultRefreshCfgThrottle
}

if old := c.cfgCache[instance]; time.Since(old.lastRefreshed) < throttle {
c.cacheL.Lock()
if c.cfgCache == nil {
c.cfgCache = make(map[string]cacheEntry)
}
old, oldok := c.cfgCache[instance]
c.cacheL.Unlock()
if oldok && time.Since(old.lastRefreshed) < throttle {
logging.Errorf("Throttling refreshCfg(%s): it was only called %v ago", instance, time.Since(old.lastRefreshed))
// Refresh was called too recently, just reuse the result.
return old.addr, old.cfg, old.err
}

if c.cfgCache == nil {
c.cfgCache = make(map[string]cacheEntry)
}

defer func() {
c.cacheL.Lock()
c.cfgCache[instance] = cacheEntry{
lastRefreshed: time.Now(),

err: err,
addr: addr,
cfg: cfg,
err: err,
addr: addr,
cfg: cfg,
}
c.cacheL.Unlock()
}()

mycert, err := c.Certs.Local(instance)
Expand Down Expand Up @@ -201,9 +210,29 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err
InsecureSkipVerify: true,
VerifyPeerCertificate: genVerifyPeerCertificateFunc(name, certs),
}

expire := mycert.Leaf.NotAfter
now := time.Now()
timeToRefresh := expire.Sub(now) - refreshCertBuffer
if timeToRefresh <= 0 {
err = fmt.Errorf("new ephemeral certificate expires too soon: current time: %v, certificate expires: %v", expire, now)
logging.Errorf("ephemeral certificate (%+v) error: %v", mycert, err)
return "", nil, err
}
go c.refreshCertAfter(instance, timeToRefresh)

return fmt.Sprintf("%s:%d", addr, c.Port), cfg, nil
}

// refreshCertAfter refreshes the epehemeral certificate of the instance after timeToRefresh.
func (c *Client) refreshCertAfter(instance string, timeToRefresh time.Duration) {
<-time.After(timeToRefresh)
logging.Verbosef("ephemeral certificate for instance %s will expire soon, refreshing now.", instance)
if _, _, err := c.refreshCfg(instance); err != nil {
logging.Errorf("failed to refresh the ephemeral certificate for %s before expering: %v", instance, err)
}
}

// genVerifyPeerCertificateFunc creates a VerifyPeerCertificate func that verifies that the peer
// certificate is in the cert pool. We need to define our own because of our sketchy non-standard
// CNs.
Expand All @@ -230,13 +259,17 @@ func genVerifyPeerCertificateFunc(instanceName string, pool *x509.CertPool) func
}
}

func isExpired(cfg *tls.Config) bool {
return time.Now().After(cfg.Certificates[0].Leaf.NotAfter)
}

func (c *Client) cachedCfg(instance string) (string, *tls.Config) {
c.cfgL.RLock()
c.cacheL.RLock()
ret, ok := c.cfgCache[instance]
c.cfgL.RUnlock()
c.cacheL.RUnlock()

// Don't waste time returning an expired/invalid cert.
if !ok || ret.err != nil || time.Now().After(ret.cfg.Certificates[0].Leaf.NotAfter) {
if !ok || ret.err != nil || isExpired(ret.cfg) {
return "", nil
}
return ret.addr, ret.cfg
Expand Down
76 changes: 64 additions & 12 deletions proxy/proxy/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,19 @@ import (

const instance = "instance-name"

var errFakeDial = errors.New("this error is returned by the dialer")
var (
errFakeDial = errors.New("this error is returned by the dialer")
forever = time.Date(9999, 0, 0, 0, 0, 0, 0, time.UTC)
)

type fakeCerts struct {
sync.Mutex
called int
}

type blockingCertSource struct {
values map[string]*fakeCerts
values map[string]*fakeCerts
validUntil time.Time
}

func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) {
Expand All @@ -48,11 +52,10 @@ func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) {
v.called++
v.Unlock()

validUntil, _ := time.Parse("2006", "9999")
// Returns a cert which is valid forever.
return tls.Certificate{
Leaf: &x509.Certificate{
NotAfter: validUntil,
NotAfter: cs.validUntil,
},
}, nil
}
Expand All @@ -67,7 +70,9 @@ func TestClientCache(t *testing.T) {
Certs: &blockingCertSource{
map[string]*fakeCerts{
instance: b,
}},
},
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
},
Expand All @@ -92,7 +97,9 @@ func TestConcurrentRefresh(t *testing.T) {
Certs: &blockingCertSource{
map[string]*fakeCerts{
instance: b,
}},
},
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
},
Expand Down Expand Up @@ -131,7 +138,9 @@ func TestMaximumConnectionsCount(t *testing.T) {

b := &fakeCerts{}
certSource := blockingCertSource{
map[string]*fakeCerts{}}
map[string]*fakeCerts{},
forever,
}
firstDialExited := make(chan struct{})
c := &Client{
Certs: &certSource,
Expand Down Expand Up @@ -190,28 +199,71 @@ func TestShutdownTerminatesEarly(t *testing.T) {
Certs: &blockingCertSource{
map[string]*fakeCerts{
instance: b,
}},
},
forever,
},
Dialer: func(string, string) (net.Conn, error) {
return nil, nil
},
}

shutdown := make(chan bool, 1)
go func() {
c.Shutdown(1)
shutdown <- true
}()

shutdownFinished := false

// In case the code is actually broken and the client doesn't shut down quickly, don't cause the test to hang until it times out.
select {
case <-time.After(100 * time.Millisecond):
case shutdownFinished = <-shutdown:
}

if !shutdownFinished {
t.Errorf("shutdown should have completed quickly because there are no active connections")
}
}

func TestRefreshTimer(t *testing.T) {
oldRefreshCertBuffer := refreshCertBuffer
defer func() {
refreshCertBuffer = oldRefreshCertBuffer
}()
refreshCertBuffer = time.Second

timeToExpire := 5 * time.Second
b := &fakeCerts{}
certCreated := time.Now()
c := &Client{
Certs: &blockingCertSource{
map[string]*fakeCerts{
instance: b,
},
certCreated.Add(timeToExpire),
},
Dialer: func(string, string) (net.Conn, error) {
return nil, errFakeDial
},
RefreshCfgThrottle: 20 * time.Millisecond,
}
// Call Dial to cache the cert.
if _, err := c.Dial(instance); err != errFakeDial {
t.Fatalf("Dial(%s) failed: %v", instance, err)
}
c.cacheL.Lock()
cfg, ok := c.cfgCache[instance]
c.cacheL.Unlock()
if !ok {
t.Fatalf("expected instance to be cached")
}

time.Sleep(timeToExpire - time.Since(certCreated))
// Check if cert was refreshed in the background, without calling Dial again.
c.cacheL.Lock()
newCfg, ok := c.cfgCache[instance]
c.cacheL.Unlock()
if !ok {
t.Fatalf("expected instance to be cached")
}
if !newCfg.lastRefreshed.After(cfg.lastRefreshed) {
t.Error("expected cert to be refreshed.")
}
}

0 comments on commit 18df49e

Please sign in to comment.