Skip to content

Commit

Permalink
Merge pull request #27 from smallstep/mariano/renew-pool
Browse files Browse the repository at this point in the history
SDK should update certificate pools safely
  • Loading branch information
maraino authored Feb 7, 2019
2 parents 7e43402 + e0fff4d commit 262a9d0
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 126 deletions.
53 changes: 53 additions & 0 deletions ca/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package ca

import (
"context"
"crypto/tls"
"net"
"net/http"
"strings"

Expand Down Expand Up @@ -145,3 +147,54 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*
Transport: transport,
}, nil
}

// BootstrapListener is a helper function that using the given token returns a
// TLS listener which accepts connections from an inner listener and wraps each
// connection with Server.
//
// Without any extra option the server will be configured for mTLS, it will
// require and verify clients certificates, but options can be used to drop this
// requirement, the most common will be only verify the certs if given with
// ca.VerifyClientCertIfGiven(), or add extra CAs with
// ca.AddClientCA(*x509.Certificate).
//
// Usage:
// inner, err := net.Listen("tcp", ":443")
// if err != nil {
// return nil
// }
// ctx, cancel := context.WithCancel(context.Background())
// defer cancel()
// lis, err := ca.BootstrapListener(ctx, token, inner)
// if err != nil {
// return err
// }
// srv := grpc.NewServer()
// ... // register services
// srv.Serve(lis)
func BootstrapListener(ctx context.Context, token string, inner net.Listener, options ...TLSOption) (net.Listener, error) {
client, err := Bootstrap(token)
if err != nil {
return nil, err
}

req, pk, err := CreateSignRequest(token)
if err != nil {
return nil, err
}

sign, err := client.Sign(req)
if err != nil {
return nil, err
}

// Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs
options = append(options, AddRootsToCAs())

tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...)
if err != nil {
return nil, err
}

return tls.NewListener(inner, tlsConfig), nil
}
70 changes: 69 additions & 1 deletion ca/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"sync"
"testing"
"time"

"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"

"github.com/smallstep/cli/crypto/randutil"
stepJOSE "github.com/smallstep/cli/jose"
jose "gopkg.in/square/go-jose.v2"
Expand Down Expand Up @@ -365,6 +365,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {

// doTest does a request that requires mTLS
doTest := func(client *http.Client) error {
time.Sleep(1 * time.Second)
// test with ca
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
if err != nil {
Expand Down Expand Up @@ -532,3 +533,70 @@ func doReload(ca *CA) error {
newCA.srv.Addr = ca.srv.Addr
return ca.srv.Reload(newCA.srv)
}

func TestBootstrapListener(t *testing.T) {
srv := startCABootstrapServer()
defer srv.Close()
token := func() string {
return generateBootstrapToken(srv.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
}
type args struct {
token string
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{token()}, false},
{"fail", args{"bad-token"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inner := newLocalListener()
defer inner.Close()
lis, err := BootstrapListener(context.Background(), tt.args.token, inner)
if (err != nil) != tt.wantErr {
t.Errorf("BootstrapListener() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
if lis != nil {
t.Errorf("BootstrapListener() = %v, want nil", lis)
}
return
}
wg := new(sync.WaitGroup)
go func() {
wg.Add(1)
http.Serve(lis, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}))
wg.Done()
}()
defer wg.Wait()
defer lis.Close()

client, err := BootstrapClient(context.Background(), token())
if err != nil {
t.Errorf("BootstrapClient() error = %v", err)
return
}
resp, err := client.Get("https://" + lis.Addr().String())
if err != nil {
t.Errorf("client.Get() error = %v", err)
return
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("ioutil.ReadAll() error = %v", err)
return
}
if string(b) != "ok" {
t.Errorf("client.Get() = %s, want ok", string(b))
return
}
})
}
}
13 changes: 3 additions & 10 deletions ca/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"crypto/x509/pkix"
"encoding/hex"
"encoding/json"
"encoding/pem"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -117,16 +116,10 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) {
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename)
}
block, _ := pem.Decode(data)
if block == nil {
return nil, errors.Errorf("error decoding %s", filename)
}
root, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrapf(err, "error parsing %s", filename)
}
pool := x509.NewCertPool()
pool.AddCert(root)
if !pool.AppendCertsFromPEM(data) {
return nil, errors.Errorf("error parsing %s: no certificates found", filename)
}
return getDefaultTransport(&tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
Expand Down
109 changes: 109 additions & 0 deletions ca/mutable_tls_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package ca

import (
"crypto/tls"
"crypto/x509"
"sync"

"github.com/smallstep/certificates/api"
)

// mutableTLSConfig allows to use a tls.Config with mutable cert pools.
type mutableTLSConfig struct {
sync.RWMutex
config *tls.Config
clientCerts []*x509.Certificate
rootCerts []*x509.Certificate
mutClientCerts []*x509.Certificate
mutRootCerts []*x509.Certificate
}

// newMutableTLSConfig creates a new mutableTLSConfig that will be later
// initialized with a tls.Config.
func newMutableTLSConfig() *mutableTLSConfig {
return &mutableTLSConfig{
clientCerts: []*x509.Certificate{},
rootCerts: []*x509.Certificate{},
mutClientCerts: []*x509.Certificate{},
mutRootCerts: []*x509.Certificate{},
}
}

// Init initializes the mutable tls.Config with the given tls.Config.
func (c *mutableTLSConfig) Init(base *tls.Config) {
c.Lock()
c.config = base.Clone()
c.Unlock()
}

// TLSConfig returns the updated tls.Config it it has changed. It's used in the
// tls.Config GetConfigForClient.
func (c *mutableTLSConfig) TLSConfig() (config *tls.Config) {
c.RLock()
config = c.config
c.RUnlock()
return
}

// Reload reloads the tls.Config with the new CAs.
func (c *mutableTLSConfig) Reload() {
// Prepare new pools
c.RLock()
rootCAs := x509.NewCertPool()
clientCAs := x509.NewCertPool()
// Fixed certs
for _, cert := range c.rootCerts {
rootCAs.AddCert(cert)
}
for _, cert := range c.clientCerts {
clientCAs.AddCert(cert)
}
// Mutable certs
for _, cert := range c.mutRootCerts {
rootCAs.AddCert(cert)
}
for _, cert := range c.mutClientCerts {
clientCAs.AddCert(cert)
}
c.RUnlock()

// Set new pool
c.Lock()
c.config.RootCAs = rootCAs
c.config.ClientCAs = clientCAs
c.mutRootCerts = []*x509.Certificate{}
c.mutClientCerts = []*x509.Certificate{}
c.Unlock()
}

// AddImmutableClientCACert add an immutable cert to ClientCAs.
func (c *mutableTLSConfig) AddImmutableClientCACert(cert *x509.Certificate) {
c.Lock()
c.clientCerts = append(c.clientCerts, cert)
c.Unlock()
}

// AddImmutableRootCACert add an immutable cert to RootCas.
func (c *mutableTLSConfig) AddImmutableRootCACert(cert *x509.Certificate) {
c.Lock()
c.rootCerts = append(c.rootCerts, cert)
c.Unlock()
}

// AddClientCAs add mutable certs to ClientCAs.
func (c *mutableTLSConfig) AddClientCAs(certs []api.Certificate) {
c.Lock()
for _, cert := range certs {
c.mutClientCerts = append(c.mutClientCerts, cert.Certificate)
}
c.Unlock()
}

// AddRootCAs add mutable certs to RootCAs.
func (c *mutableTLSConfig) AddRootCAs(certs []api.Certificate) {
c.Lock()
for _, cert := range certs {
c.mutRootCerts = append(c.mutRootCerts, cert.Certificate)
}
c.Unlock()
}
Loading

0 comments on commit 262a9d0

Please sign in to comment.