diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 8989b3c05..6c532d5c6 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -2,6 +2,8 @@ package ca import ( "context" + "crypto/tls" + "net" "net/http" "strings" @@ -145,3 +147,54 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (* Transport: transport, }, nil } + +// BootstrapListener is a helper function that using the given token returns a +// TLS listener which accepts connections from an inner listener and wraps each +// connection with Server. +// +// Without any extra option the server will be configured for mTLS, it will +// require and verify clients certificates, but options can be used to drop this +// requirement, the most common will be only verify the certs if given with +// ca.VerifyClientCertIfGiven(), or add extra CAs with +// ca.AddClientCA(*x509.Certificate). +// +// Usage: +// inner, err := net.Listen("tcp", ":443") +// if err != nil { +// return nil +// } +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// lis, err := ca.BootstrapListener(ctx, token, inner) +// if err != nil { +// return err +// } +// srv := grpc.NewServer() +// ... // register services +// srv.Serve(lis) +func BootstrapListener(ctx context.Context, token string, inner net.Listener, options ...TLSOption) (net.Listener, error) { + client, err := Bootstrap(token) + if err != nil { + return nil, err + } + + req, pk, err := CreateSignRequest(token) + if err != nil { + return nil, err + } + + sign, err := client.Sign(req) + if err != nil { + return nil, err + } + + // Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs + options = append(options, AddRootsToCAs()) + + tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) + if err != nil { + return nil, err + } + + return tls.NewListener(inner, tlsConfig), nil +} diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 540b1b20f..800c44d26 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -8,13 +8,13 @@ import ( "net/http" "net/http/httptest" "reflect" + "sync" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" - "github.com/smallstep/cli/crypto/randutil" stepJOSE "github.com/smallstep/cli/jose" jose "gopkg.in/square/go-jose.v2" @@ -365,6 +365,7 @@ func TestBootstrapClientServerRotation(t *testing.T) { // doTest does a request that requires mTLS doTest := func(client *http.Client) error { + time.Sleep(1 * time.Second) // test with ca resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) if err != nil { @@ -532,3 +533,70 @@ func doReload(ca *CA) error { newCA.srv.Addr = ca.srv.Addr return ca.srv.Reload(newCA.srv) } + +func TestBootstrapListener(t *testing.T) { + srv := startCABootstrapServer() + defer srv.Close() + token := func() string { + return generateBootstrapToken(srv.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { + token string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{token()}, false}, + {"fail", args{"bad-token"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inner := newLocalListener() + defer inner.Close() + lis, err := BootstrapListener(context.Background(), tt.args.token, inner) + if (err != nil) != tt.wantErr { + t.Errorf("BootstrapListener() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if lis != nil { + t.Errorf("BootstrapListener() = %v, want nil", lis) + } + return + } + wg := new(sync.WaitGroup) + go func() { + wg.Add(1) + http.Serve(lis, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + wg.Done() + }() + defer wg.Wait() + defer lis.Close() + + client, err := BootstrapClient(context.Background(), token()) + if err != nil { + t.Errorf("BootstrapClient() error = %v", err) + return + } + resp, err := client.Get("https://" + lis.Addr().String()) + if err != nil { + t.Errorf("client.Get() error = %v", err) + return + } + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("ioutil.ReadAll() error = %v", err) + return + } + if string(b) != "ok" { + t.Errorf("client.Get() = %s, want ok", string(b)) + return + } + }) + } +} diff --git a/ca/client.go b/ca/client.go index 1ca682de0..2a0e750ab 100644 --- a/ca/client.go +++ b/ca/client.go @@ -12,7 +12,6 @@ import ( "crypto/x509/pkix" "encoding/hex" "encoding/json" - "encoding/pem" "io" "io/ioutil" "net/http" @@ -117,16 +116,10 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) { if err != nil { return nil, errors.Wrapf(err, "error reading %s", filename) } - block, _ := pem.Decode(data) - if block == nil { - return nil, errors.Errorf("error decoding %s", filename) - } - root, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return nil, errors.Wrapf(err, "error parsing %s", filename) - } pool := x509.NewCertPool() - pool.AddCert(root) + if !pool.AppendCertsFromPEM(data) { + return nil, errors.Errorf("error parsing %s: no certificates found", filename) + } return getDefaultTransport(&tls.Config{ MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, diff --git a/ca/mutable_tls_config.go b/ca/mutable_tls_config.go new file mode 100644 index 000000000..031a99e95 --- /dev/null +++ b/ca/mutable_tls_config.go @@ -0,0 +1,109 @@ +package ca + +import ( + "crypto/tls" + "crypto/x509" + "sync" + + "github.com/smallstep/certificates/api" +) + +// mutableTLSConfig allows to use a tls.Config with mutable cert pools. +type mutableTLSConfig struct { + sync.RWMutex + config *tls.Config + clientCerts []*x509.Certificate + rootCerts []*x509.Certificate + mutClientCerts []*x509.Certificate + mutRootCerts []*x509.Certificate +} + +// newMutableTLSConfig creates a new mutableTLSConfig that will be later +// initialized with a tls.Config. +func newMutableTLSConfig() *mutableTLSConfig { + return &mutableTLSConfig{ + clientCerts: []*x509.Certificate{}, + rootCerts: []*x509.Certificate{}, + mutClientCerts: []*x509.Certificate{}, + mutRootCerts: []*x509.Certificate{}, + } +} + +// Init initializes the mutable tls.Config with the given tls.Config. +func (c *mutableTLSConfig) Init(base *tls.Config) { + c.Lock() + c.config = base.Clone() + c.Unlock() +} + +// TLSConfig returns the updated tls.Config it it has changed. It's used in the +// tls.Config GetConfigForClient. +func (c *mutableTLSConfig) TLSConfig() (config *tls.Config) { + c.RLock() + config = c.config + c.RUnlock() + return +} + +// Reload reloads the tls.Config with the new CAs. +func (c *mutableTLSConfig) Reload() { + // Prepare new pools + c.RLock() + rootCAs := x509.NewCertPool() + clientCAs := x509.NewCertPool() + // Fixed certs + for _, cert := range c.rootCerts { + rootCAs.AddCert(cert) + } + for _, cert := range c.clientCerts { + clientCAs.AddCert(cert) + } + // Mutable certs + for _, cert := range c.mutRootCerts { + rootCAs.AddCert(cert) + } + for _, cert := range c.mutClientCerts { + clientCAs.AddCert(cert) + } + c.RUnlock() + + // Set new pool + c.Lock() + c.config.RootCAs = rootCAs + c.config.ClientCAs = clientCAs + c.mutRootCerts = []*x509.Certificate{} + c.mutClientCerts = []*x509.Certificate{} + c.Unlock() +} + +// AddImmutableClientCACert add an immutable cert to ClientCAs. +func (c *mutableTLSConfig) AddImmutableClientCACert(cert *x509.Certificate) { + c.Lock() + c.clientCerts = append(c.clientCerts, cert) + c.Unlock() +} + +// AddImmutableRootCACert add an immutable cert to RootCas. +func (c *mutableTLSConfig) AddImmutableRootCACert(cert *x509.Certificate) { + c.Lock() + c.rootCerts = append(c.rootCerts, cert) + c.Unlock() +} + +// AddClientCAs add mutable certs to ClientCAs. +func (c *mutableTLSConfig) AddClientCAs(certs []api.Certificate) { + c.Lock() + for _, cert := range certs { + c.mutClientCerts = append(c.mutClientCerts, cert.Certificate) + } + c.Unlock() +} + +// AddRootCAs add mutable certs to RootCAs. +func (c *mutableTLSConfig) AddRootCAs(certs []api.Certificate) { + c.Lock() + for _, cert := range certs { + c.mutRootCerts = append(c.mutRootCerts, cert.Certificate) + } + c.Unlock() +} diff --git a/ca/tls.go b/ca/tls.go index 31d1632ba..ef3af5487 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -21,13 +21,21 @@ import ( // sign certificate, and a new certificate pool with the sign root certificate. // The client certificate will automatically rotate before expiring. func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) { - cert, err := TLSCertificate(sign, pk) + tlsConfig, _, err := c.getClientTLSConfig(ctx, sign, pk, options) if err != nil { return nil, err } + return tlsConfig, nil +} + +func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, *http.Transport, error) { + cert, err := TLSCertificate(sign, pk) + if err != nil { + return nil, nil, err + } renewer, err := NewTLSRenewer(cert, nil) if err != nil { - return nil, err + return nil, nil, err } tlsConfig := getDefaultTLSConfig(sign) @@ -35,22 +43,20 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() tlsConfig.GetClientCertificate = renewer.GetClientCertificate tlsConfig.PreferServerCipherSuites = true - // Build RootCAs with given root certificate - if pool := getCertPool(sign); pool != nil { - tlsConfig.RootCAs = pool - } - // Apply options if given - tlsCtx := newTLSOptionCtx(c, tlsConfig) + // Apply options and initialize mutable tls.Config + tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) if err := tlsCtx.apply(options); err != nil { - return nil, err + return nil, nil, err } // Update renew function with transport tr, err := getDefaultTransport(tlsConfig) if err != nil { - return nil, err + return nil, nil, err } + // Use mutable tls.Config on renew + tr.DialTLS = c.buildDialTLS(tlsCtx) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) // Update client transport @@ -58,7 +64,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, // Start renewer renewer.RunContext(ctx) - return tlsConfig, nil + return tlsConfig, tr, nil } // GetServerTLSConfig returns a tls.Config for server use configured with the @@ -82,25 +88,26 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, tlsConfig.GetCertificate = renewer.GetCertificate tlsConfig.GetClientCertificate = renewer.GetClientCertificate tlsConfig.PreferServerCipherSuites = true - // Build RootCAs with given root certificate - if pool := getCertPool(sign); pool != nil { - tlsConfig.ClientCAs = pool - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - // Add RootCAs for refresh client - tlsConfig.RootCAs = pool - } + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - // Apply options if given - tlsCtx := newTLSOptionCtx(c, tlsConfig) + // Apply options and initialize mutable tls.Config + tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) if err := tlsCtx.apply(options); err != nil { return nil, err } + // GetConfigForClient allows seamless root and federated roots rotation. + // If the return of the callback is not-nil, it will use the returned + // tls.Config instead of the default one. + tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx) + // Update renew function with transport tr, err := getDefaultTransport(tlsConfig) if err != nil { return nil, err } + // Use mutable tls.Config on renew + tr.DialTLS = c.buildDialTLS(tlsCtx) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) // Update client transport @@ -113,17 +120,40 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, // Transport returns an http.Transport configured to use the client certificate from the sign response. func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) { - tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk, options...) + _, tr, err := c.getClientTLSConfig(ctx, sign, pk, options) if err != nil { return nil, err } - return getDefaultTransport(tlsConfig) + return tr, nil +} + +// buildGetConfigForClient returns an implementation of GetConfigForClient +// callback in tls.Config. +// +// If the implementation returns a nil tls.Config, the original Config will be +// used, but if it's non-nil, the returned Config will be used to handle this +// connection. +func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHelloInfo) (*tls.Config, error) { + return func(*tls.ClientHelloInfo) (*tls.Config, error) { + return ctx.mutableConfig.TLSConfig(), nil + } +} + +// buildDialTLS returns an implementation of DialTLS callback in http.Transport. +func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) { + return func(network, addr string) (net.Conn, error) { + return tls.DialWithDialer(&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }, network, addr, ctx.mutableConfig.TLSConfig()) + } } // Certificate returns the server or client certificate from the sign response. func Certificate(sign *api.SignResponse) (*x509.Certificate, error) { if sign.ServerPEM.Certificate == nil { - return nil, errors.New("ca: certificate does not exists") + return nil, errors.New("ca: certificate does not exist") } return sign.ServerPEM.Certificate, nil } @@ -132,19 +162,19 @@ func Certificate(sign *api.SignResponse) (*x509.Certificate, error) { // response. func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error) { if sign.CaPEM.Certificate == nil { - return nil, errors.New("ca: certificate does not exists") + return nil, errors.New("ca: certificate does not exist") } return sign.CaPEM.Certificate, nil } // RootCertificate returns the root certificate from the sign response. func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) { - if sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 { - return nil, errors.New("ca: certificate does not exists") + if sign == nil || sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 { + return nil, errors.New("ca: certificate does not exist") } lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1] if len(lastChain) == 0 { - return nil, errors.New("ca: certificate does not exists") + return nil, errors.New("ca: certificate does not exist") } return lastChain[len(lastChain)-1], nil } @@ -178,17 +208,6 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific return &cert, nil } -// getCertPool returns the transport x509.CertPool or the one from the sign -// request. -func getCertPool(sign *api.SignResponse) *x509.CertPool { - if root, err := RootCertificate(sign); err == nil { - pool := x509.NewCertPool() - pool.AddCert(root) - return pool - } - return nil -} - func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { if sign.TLSOptions != nil { return sign.TLSOptions.TLSConfig() diff --git a/ca/tls_options.go b/ca/tls_options.go index 47e2c6270..b3b2d0579 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -3,6 +3,8 @@ package ca import ( "crypto/tls" "crypto/x509" + + "github.com/smallstep/certificates/api" ) // TLSOption defines the type of a function that modifies a tls.Config. @@ -10,16 +12,22 @@ type TLSOption func(ctx *TLSOptionCtx) error // TLSOptionCtx is the context modified on TLSOption methods. type TLSOptionCtx struct { - Client *Client - Config *tls.Config - OnRenewFunc []TLSOption + Client *Client + Config *tls.Config + Sign *api.SignResponse + OnRenewFunc []TLSOption + mutableConfig *mutableTLSConfig + hasRootCA bool + hasClientCA bool } // newTLSOptionCtx creates the TLSOption context. -func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx { +func newTLSOptionCtx(c *Client, config *tls.Config, sign *api.SignResponse) *TLSOptionCtx { return &TLSOptionCtx{ - Client: c, - Config: config, + Client: c, + Config: config, + Sign: sign, + mutableConfig: newMutableTLSConfig(), } } @@ -29,6 +37,44 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error { return err } } + + // Initialize mutable config with the fully configured tls.Config + ctx.mutableConfig.Init(ctx.Config) + + // Build RootCAs and ClientCAs with given root certificate if necessary + if root, err := RootCertificate(ctx.Sign); err == nil { + if !ctx.hasRootCA { + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + ctx.Config.RootCAs.AddCert(root) + ctx.mutableConfig.AddImmutableRootCACert(root) + } + + if !ctx.hasClientCA && ctx.Config.ClientAuth != tls.NoClientCert { + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + ctx.Config.ClientCAs.AddCert(root) + ctx.mutableConfig.AddImmutableClientCACert(root) + } + } + + // Update tls.Config with mutable data + if ctx.Config.RootCAs == nil && len(ctx.mutableConfig.mutRootCerts) > 0 { + ctx.Config.RootCAs = x509.NewCertPool() + } + if ctx.Config.ClientCAs == nil && len(ctx.mutableConfig.mutClientCerts) > 0 { + ctx.Config.ClientCAs = x509.NewCertPool() + } + // Add mutable certificates + for _, cert := range ctx.mutableConfig.mutRootCerts { + ctx.Config.RootCAs.AddCert(cert) + } + for _, cert := range ctx.mutableConfig.mutClientCerts { + ctx.Config.ClientCAs.AddCert(cert) + } + ctx.mutableConfig.Reload() return nil } @@ -38,6 +84,8 @@ func (ctx *TLSOptionCtx) applyRenew() error { return err } } + // Reload mutable config with the changes + ctx.mutableConfig.Reload() return nil } @@ -68,6 +116,7 @@ func AddRootCA(cert *x509.Certificate) TLSOption { ctx.Config.RootCAs = x509.NewCertPool() } ctx.Config.RootCAs.AddCert(cert) + ctx.mutableConfig.AddImmutableRootCACert(cert) return nil } } @@ -81,6 +130,7 @@ func AddClientCA(cert *x509.Certificate) TLSOption { ctx.Config.ClientCAs = x509.NewCertPool() } ctx.Config.ClientCAs.AddCert(cert) + ctx.mutableConfig.AddImmutableClientCACert(cert) return nil } } @@ -91,17 +141,14 @@ func AddClientCA(cert *x509.Certificate) TLSOption { // // BootstrapServer and BootstrapClient methods include this option by default. func AddRootsToRootCAs() TLSOption { + // var once sync.Once fn := func(ctx *TLSOptionCtx) error { certs, err := ctx.Client.Roots() if err != nil { return err } - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.RootCAs.AddCert(cert.Certificate) - } + ctx.hasRootCA = true + ctx.mutableConfig.AddRootCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { @@ -117,17 +164,14 @@ func AddRootsToRootCAs() TLSOption { // // BootstrapServer method includes this option by default. func AddRootsToClientCAs() TLSOption { + // var once sync.Once fn := func(ctx *TLSOptionCtx) error { certs, err := ctx.Client.Roots() if err != nil { return err } - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.ClientCAs.AddCert(cert.Certificate) - } + ctx.hasClientCA = true + ctx.mutableConfig.AddClientCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { @@ -145,12 +189,7 @@ func AddFederationToRootCAs() TLSOption { if err != nil { return err } - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.RootCAs.AddCert(cert.Certificate) - } + ctx.mutableConfig.AddRootCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { @@ -169,12 +208,7 @@ func AddFederationToClientCAs() TLSOption { if err != nil { return err } - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.ClientCAs.AddCert(cert.Certificate) - } + ctx.mutableConfig.AddClientCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { @@ -192,16 +226,10 @@ func AddRootsToCAs() TLSOption { if err != nil { return err } - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.ClientCAs.AddCert(cert.Certificate) - ctx.Config.RootCAs.AddCert(cert.Certificate) - } + ctx.hasRootCA = true + ctx.hasClientCA = true + ctx.mutableConfig.AddRootCAs(certs.Certificates) + ctx.mutableConfig.AddClientCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { @@ -219,15 +247,20 @@ func AddFederationToCAs() TLSOption { if err != nil { return err } - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.ClientCAs.AddCert(cert.Certificate) - ctx.Config.RootCAs.AddCert(cert.Certificate) + if ctx.mutableConfig == nil { + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.RootCAs.AddCert(cert.Certificate) + ctx.Config.ClientCAs.AddCert(cert.Certificate) + } + } else { + ctx.mutableConfig.AddRootCAs(certs.Certificates) + ctx.mutableConfig.AddClientCAs(certs.Certificates) } return nil } diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index ceeea7dc6..a422799e8 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -9,6 +9,8 @@ import ( "reflect" "sort" "testing" + + "github.com/smallstep/certificates/api" ) func Test_newTLSOptionCtx(t *testing.T) { @@ -20,17 +22,18 @@ func Test_newTLSOptionCtx(t *testing.T) { type args struct { c *Client config *tls.Config + sign *api.SignResponse } tests := []struct { name string args args want *TLSOptionCtx }{ - {"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}}}, + {"ok", args{client, &tls.Config{}, &api.SignResponse{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, Sign: &api.SignResponse{}, mutableConfig: newMutableTLSConfig()}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) { + if got := newTLSOptionCtx(tt.args.c, tt.args.config, tt.args.sign); !reflect.DeepEqual(got, tt.want) { t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want) } }) @@ -63,7 +66,8 @@ func TestTLSOptionCtx_apply(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Config: tt.fields.Config, + Config: tt.fields.Config, + mutableConfig: newMutableTLSConfig(), } if err := ctx.apply(tt.args.options); (err != nil) != tt.wantErr { t.Errorf("TLSOptionCtx.apply() error = %v, wantErr %v", err, tt.wantErr) @@ -82,7 +86,8 @@ func TestRequireAndVerifyClientCert(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Config: &tls.Config{}, + Config: &tls.Config{}, + mutableConfig: newMutableTLSConfig(), } if err := RequireAndVerifyClientCert()(ctx); err != nil { t.Errorf("RequireAndVerifyClientCert() error = %v", err) @@ -105,7 +110,8 @@ func TestVerifyClientCertIfGiven(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Config: &tls.Config{}, + Config: &tls.Config{}, + mutableConfig: newMutableTLSConfig(), } if err := VerifyClientCertIfGiven()(ctx); err != nil { t.Errorf("VerifyClientCertIfGiven() error = %v", err) @@ -136,7 +142,8 @@ func TestAddRootCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Config: &tls.Config{}, + Config: &tls.Config{}, + mutableConfig: newMutableTLSConfig(), } if err := AddRootCA(tt.args.cert)(ctx); err != nil { t.Errorf("AddRootCA() error = %v", err) @@ -167,7 +174,8 @@ func TestAddClientCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Config: &tls.Config{}, + Config: &tls.Config{}, + mutableConfig: newMutableTLSConfig(), } if err := AddClientCA(tt.args.cert)(ctx); err != nil { t.Errorf("AddClientCA() error = %v", err) @@ -219,14 +227,15 @@ func TestAddRootsToRootCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddRootsToRootCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(ctx.Config, tt.want) { + if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) { t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want) } }) @@ -272,14 +281,15 @@ func TestAddRootsToClientCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddRootsToClientCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(ctx.Config, tt.want) { + if !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) { t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want) } }) @@ -332,10 +342,11 @@ func TestAddFederationToRootCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddFederationToRootCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr) return } @@ -395,10 +406,11 @@ func TestAddFederationToClientCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddFederationToClientCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr) return } @@ -451,14 +463,15 @@ func TestAddRootsToCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddRootsToCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(ctx.Config, tt.want) { + if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) || !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) { t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want) } }) @@ -511,10 +524,11 @@ func TestAddFederationToCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddFederationToCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr) return }