diff --git a/DEPS.bzl b/DEPS.bzl index 0814096870f6..4050acef18c7 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -1327,8 +1327,8 @@ def go_deps(): name = "com_github_golang_snappy", build_file_proto_mode = "disable_global", importpath = "github.com/golang/snappy", - sum = "h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw=", - version = "v0.0.2", + sum = "h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=", + version = "v0.0.3", ) go_repository( name = "com_github_gomodule_redigo", diff --git a/go.mod b/go.mod index 27dfb0f84dbf..daaede2e2cdb 100644 --- a/go.mod +++ b/go.mod @@ -71,7 +71,7 @@ require ( github.com/golang-commonmark/puny v0.0.0-20180910110745-050be392d8b8 // indirect github.com/golang/geo v0.0.0-20200319012246-673a6f80352d github.com/golang/protobuf v1.4.2 - github.com/golang/snappy v0.0.2 + github.com/golang/snappy v0.0.3 github.com/google/btree v1.0.0 github.com/google/flatbuffers v1.11.0 github.com/google/go-cmp v0.5.2 diff --git a/go.sum b/go.sum index 19053bf3770f..8e00b6dd08bb 100644 --- a/go.sum +++ b/go.sum @@ -448,8 +448,8 @@ github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0 github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw= -github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= diff --git a/pkg/security/auto_tls_init.go b/pkg/security/auto_tls_init.go index cf53c6c9f95e..68a9ca6588e5 100644 --- a/pkg/security/auto_tls_init.go +++ b/pkg/security/auto_tls_init.go @@ -26,8 +26,8 @@ import ( ) // TODO(aaron-crl): This shared a name and purpose with the value in -// pkg/security and should be consolidated. -const defaultKeySize = 4096 +// pkg/cli/cert.go and should be consolidated. +const defaultKeySize = 2048 // notBeforeMargin provides a window to compensate for potential clock skew. const notBeforeMargin = time.Second * 30 @@ -128,7 +128,7 @@ func CreateCACertAndKey( // CreateServiceCertAndKey creates a cert/key pair signed by the provided CA. // This is a utility function to help with cluster auto certificate generation. func CreateServiceCertAndKey( - lifespan time.Duration, service, hostname string, caCertPEM []byte, caKeyPEM []byte, + lifespan time.Duration, service string, hostnames []string, caCertPEM []byte, caKeyPEM []byte, ) (certPEM []byte, keyPEM []byte, err error) { notBefore := timeutil.Now().Add(-notBeforeMargin) notAfter := timeutil.Now().Add(lifespan) @@ -183,11 +183,13 @@ func CreateServiceCertAndKey( // Attempt to parse hostname as IP, if successful add it as an IP // otherwise presume it is a DNS name. // TODO(aaron-crl): Pass these values via config object. - ip := net.ParseIP(hostname) - if ip != nil { - serviceCert.IPAddresses = []net.IP{ip} - } else { - serviceCert.DNSNames = []string{hostname} + for _, hostname := range hostnames { + ip := net.ParseIP(hostname) + if ip != nil { + serviceCert.IPAddresses = []net.IP{ip} + } else { + serviceCert.DNSNames = []string{hostname} + } } servicePrivKey, err := rsa.GenerateKey(rand.Reader, defaultKeySize) @@ -223,5 +225,5 @@ func CreateServiceCertAndKey( return nil, nil, err } - return serviceCertBlock.Bytes(), servicePrivKeyPEM.Bytes(), err + return serviceCertBlock.Bytes(), servicePrivKeyPEM.Bytes(), nil } diff --git a/pkg/security/auto_tls_init_test.go b/pkg/security/auto_tls_init_test.go index 9848aeda3a1b..b66c67904e55 100644 --- a/pkg/security/auto_tls_init_test.go +++ b/pkg/security/auto_tls_init_test.go @@ -40,7 +40,7 @@ func TestDummyCreateServiceCertAndKey(t *testing.T) { _, _, err = security.CreateServiceCertAndKey( time.Minute, "test Service cert generation", - "localhost", + []string{"localhost", "127.0.0.1"}, caCert, caKey, ) diff --git a/pkg/server/auto_tls_init.go b/pkg/server/auto_tls_init.go index 6ed600f15665..748a359d4f43 100644 --- a/pkg/server/auto_tls_init.go +++ b/pkg/server/auto_tls_init.go @@ -17,6 +17,7 @@ package server import ( + "encoding/pem" "io/ioutil" "os" "time" @@ -27,9 +28,19 @@ import ( "github.com/cockroachdb/errors/oserror" ) -// Define default certificate lifespan of 366 days -// TODO(aaron-crl): Put this in the config map. -const initLifespan = time.Minute * 60 * 24 * 366 +// TODO(aaron-crl): This is an exact copy from `pkg/cli/cert.go` and should +// be refactored to share consts. +// We use 366 days on certificate lifetimes to at least match X years, +// otherwise leap years risk putting us just under. +const defaultCALifetime = 10 * 366 * 24 * time.Hour // ten years +const defaultCertLifetime = 5 * 366 * 24 * time.Hour // five years + +// Service Name Strings for autogenerated certificates. +const serviceNameInterNode = "InterNode Service" +const serviceNameUserAuth = "User Auth Service" +const serviceNameSQL = "SQL Service" +const serviceNameRPC = "RPC Service" +const serviceNameUI = "UI Service" // CertificateBundle manages the collection of certificates used by a // CockroachDB node. @@ -126,9 +137,10 @@ func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( serviceKeyPath string, caCertPath string, caKeyPath string, - initLifespan time.Duration, + serviceCertLifespan time.Duration, + caCertLifespan time.Duration, serviceName string, - hostname string, + hostnames []string, ) error { var err error @@ -164,7 +176,7 @@ func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( } } else if oserror.IsNotExist(err) { // CA cert does not yet exist, create it and its key. - err = sb.createServiceCA(caCertPath, caKeyPath, initLifespan, serviceName) + err = sb.createServiceCA(caCertPath, caKeyPath, caCertLifespan, serviceName) if err != nil { return errors.Wrap( err, "failed to create Service CA", @@ -173,11 +185,11 @@ func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( } // CA cert and key should now be loaded, create service cert and key. - var hostCert, hostKey []byte - hostCert, hostKey, err = security.CreateServiceCertAndKey( - initLifespan, + //var hostCert, hostKey []byte + sb.HostCertificate, sb.HostKey, err = security.CreateServiceCertAndKey( + serviceCertLifespan, serviceName, - hostname, + hostnames, sb.CACertificate, sb.CAKey, ) @@ -187,12 +199,12 @@ func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( ) } - err = writeCertificateFile(serviceCertPath, hostCert) + err = writeCertificateFile(serviceCertPath, sb.HostCertificate, false) if err != nil { return err } - err = writeKeyFile(serviceKeyPath, hostKey) + err = writeKeyFile(serviceKeyPath, sb.HostKey, false) if err != nil { return err } @@ -210,12 +222,12 @@ func (sb *ServiceCertificateBundle) createServiceCA( return } - err = writeCertificateFile(caCertPath, sb.CACertificate) + err = writeCertificateFile(caCertPath, sb.CACertificate, false) if err != nil { return } - err = writeKeyFile(caKeyPath, sb.CAKey) + err = writeKeyFile(caKeyPath, sb.CAKey, false) if err != nil { return } @@ -238,29 +250,42 @@ func loadKeyFile(keyPath string) (key []byte, err error) { } // Simple wrapper to make it easier to store certs somewhere else later. -// This function will error if a file already exists at certPath. -func writeCertificateFile(certPath string, certPEM []byte) error { - if _, err := os.Stat(certPath); err == nil { - return errors.Newf("found existing certfile at: %q", certPath) - } else if !oserror.IsNotExist(err) { - return errors.Wrapf(err, - "problem writing keyfile at: %q", certPath) +// Unless overwrite is true, this function will error if a file alread exists +// at certFilePath. +// TODO(aaron-crl): This was lifted from 'pkg/security' and modified. It might +// make sense to refactor these calls back to 'pkg/security' rather than +// maintain these functions. +func writeCertificateFile(certFilePath string, certificatePEMBytes []byte, overwrite bool) error { + // Validate that we are about to write a cert. And reshape for common + // security.WritePEMToFile(). + // TODO(aaron-crl): Validate this is actually a cert. + caCert, _ := pem.Decode(certificatePEMBytes) + if nil == caCert { + return errors.New("failed to parse valid PEM from certificatePEMBytes") } + // TODO(aaron-crl): Add logging here. - return ioutil.WriteFile(certPath, certPEM, 0600) + return security.WritePEMToFile(certFilePath, 0600, overwrite, caCert) } // Simple wrapper to make it easier to store certs somewhere else later. -// This function will error if a file alread exists at keyPath. -func writeKeyFile(keyPath string, keyPEM []byte) error { - if _, err := os.Stat(keyPath); err == nil { - return errors.Newf("found existing keyfile at: %q", keyPath) - } else if !oserror.IsNotExist(err) { - return errors.Wrapf(err, - "problem writing keyfile at: %q", keyPath) +// Unless overwrite is true, this function will error if a file alread exists +// at keyFilePath. +// TODO(aaron-crl): This was lifted from 'pkg/security' and modified. It might +// make sense to refactor these calls back to 'pkg/security' rather than +// maintain these functions. +func writeKeyFile(keyFilePath string, keyPEMBytes []byte, overwrite bool) error { + // Validate that we are about to write a key and reshape for common + // security.WritePEMToFile(). + // TODO(aaron-crl): Validate this is actually a key. + + keyBlock, _ := pem.Decode(keyPEMBytes) + if keyBlock == nil { + return errors.New("failed to parse valid PEM from certificatePEMBytes") } + // TODO(aaron-crl): Add logging here. - return ioutil.WriteFile(keyPath, keyPEM, 0600) + return security.WritePEMToFile(keyFilePath, 0600, overwrite, keyBlock) } // InitializeFromConfig is called by the node creating certificates for the @@ -288,9 +313,10 @@ func (b *CertificateBundle) InitializeFromConfig(c base.Config) error { cl.NodeKeyPath(), cl.CACertPath(), cl.CAKeyPath(), - initLifespan, - "InterNode Service", - c.Addr, + defaultCertLifetime, + defaultCALifetime, + serviceNameInterNode, + []string{c.Addr, c.AdvertiseAddr}, ) if err != nil { return errors.Wrap(err, @@ -298,13 +324,11 @@ func (b *CertificateBundle) InitializeFromConfig(c base.Config) error { } // Initialize User auth certificates. - // TODO(aaron-crl): Double check that we want to do this. It seems - // like this is covered by the interface certificates? err = b.UserAuth.loadOrCreateUserAuthCACertAndKey( cl.ClientCACertPath(), cl.ClientCAKeyPath(), - initLifespan, - "User Authentication", + defaultCALifetime, + serviceNameUserAuth, ) if err != nil { return errors.Wrap(err, @@ -317,9 +341,11 @@ func (b *CertificateBundle) InitializeFromConfig(c base.Config) error { cl.SQLServiceKeyPath(), cl.SQLServiceCACertPath(), cl.SQLServiceCAKeyPath(), - initLifespan, - "SQL Service", - c.SQLAddr, + defaultCertLifetime, + defaultCALifetime, + serviceNameSQL, + // TODO(aaron-crl): Add RPC variable to config or SplitSQLAddr. + []string{c.SQLAddr, c.SQLAdvertiseAddr}, ) if err != nil { return errors.Wrap(err, @@ -332,9 +358,11 @@ func (b *CertificateBundle) InitializeFromConfig(c base.Config) error { cl.RPCServiceKeyPath(), cl.RPCServiceCACertPath(), cl.RPCServiceCAKeyPath(), - initLifespan, - "RPC Service", - c.SQLAddr, // TODO(aaron-crl): Add RPC variable to config. + defaultCertLifetime, + defaultCALifetime, + serviceNameRPC, + // TODO(aaron-crl): Add RPC variable to config. + []string{c.SQLAddr, c.SQLAdvertiseAddr}, ) if err != nil { return errors.Wrap(err, @@ -347,9 +375,10 @@ func (b *CertificateBundle) InitializeFromConfig(c base.Config) error { cl.UIKeyPath(), cl.UICACertPath(), cl.UICAKeyPath(), - initLifespan, - "AdminUI Service", - c.HTTPAddr, + defaultCertLifetime, + defaultCALifetime, + serviceNameUI, + []string{c.HTTPAddr, c.HTTPAdvertiseAddr}, ) if err != nil { return errors.Wrap(err, @@ -426,14 +455,14 @@ func (b *CertificateBundle) InitializeNodeFromBundle(c base.Config) error { // error if it fails to write a file to disk. func (sb *ServiceCertificateBundle) writeCAOrFail(certPath string, keyPath string) (err error) { if sb.CACertificate != nil { - err = writeCertificateFile(certPath, sb.CACertificate) + err = writeCertificateFile(certPath, sb.CACertificate, false) if err != nil { return } } if sb.CAKey != nil { - err = writeKeyFile(keyPath, sb.CAKey) + err = writeKeyFile(keyPath, sb.CAKey, false) if err != nil { return } @@ -442,21 +471,188 @@ func (sb *ServiceCertificateBundle) writeCAOrFail(certPath string, keyPath strin return } -// copyOnlyCAs is a helper function to only populate the CA portion of -// a ServiceCertificateBundle -func (sb *ServiceCertificateBundle) copyOnlyCAs(destBundle *ServiceCertificateBundle) { - destBundle.CACertificate = sb.CACertificate - destBundle.CAKey = sb.CAKey +func (sb *ServiceCertificateBundle) loadCACertAndKeyIfExists( + certPath string, keyPath string, +) error { + // TODO(aaron-crl): Possibly add a warning to the log that a CA was not + // found. + err := sb.loadCACertAndKey(certPath, keyPath) + if oserror.IsNotExist(err) { + return nil + } + return err } -// ToPeerInitBundle populates a bundle of initialization certificate CAs (only). -// This function is expected to serve any node providing a init bundle to a -// joining or starting peer. -func (b *CertificateBundle) ToPeerInitBundle() (pb CertificateBundle) { - b.InterNode.copyOnlyCAs(&pb.InterNode) - b.UserAuth.copyOnlyCAs(&pb.UserAuth) - b.SQLService.copyOnlyCAs(&pb.SQLService) - b.RPCService.copyOnlyCAs(&pb.RPCService) - b.AdminUIService.copyOnlyCAs(&pb.AdminUIService) - return +// collectLocalCABundle will load any CA certs and keys present on disk. It +// will skip any CA's where the certificate is not found. Any other read errors +// including permissions result in an error. +func collectLocalCABundle(c base.Config) (CertificateBundle, error) { + cl := security.MakeCertsLocator(c.SSLCertsDir) + var b CertificateBundle + var err error + + err = b.InterNode.loadCACertAndKeyIfExists(cl.CACertPath(), cl.CAKeyPath()) + if err != nil { + return b, errors.Wrap( + err, "error loading InterNode CA cert and/or key") + } + + err = b.UserAuth.loadCACertAndKeyIfExists( + cl.ClientCACertPath(), cl.ClientCAKeyPath()) + if err != nil { + return b, errors.Wrap( + err, "error loading UserAuth CA cert and/or key") + } + + err = b.SQLService.loadCACertAndKeyIfExists( + cl.SQLServiceCACertPath(), cl.SQLServiceCAKeyPath()) + if err != nil { + return b, errors.Wrap( + err, "error loading SQL CA cert and/or key") + } + err = b.RPCService.loadCACertAndKeyIfExists( + cl.RPCServiceCACertPath(), cl.RPCServiceCAKeyPath()) + if err != nil { + return b, errors.Wrap( + err, "error loading RPC CA cert and/or key") + } + + err = b.AdminUIService.loadCACertAndKeyIfExists( + cl.UICACertPath(), cl.UICAKeyPath()) + if err != nil { + return b, errors.Wrap( + err, "error loading AdminUI CA cert and/or key") + } + + return b, nil +} + +// rotateGeneratedCertsOnDisk will generate and replace interface certificates +// where a corresponding CA cert and key are found. This function does not +// restart any services or cause the node to restart. That must be triggered +// after this function is successfully run. +// Service certs are written as they are generated but will return on first +// error. This is not seen as harmful as the rotation command may be rerun +// manually after rotation errors are corrected without negatively impacting +// any interface. All existing interfaces will again receive a new +// certificate/key pair. +func rotateGeneratedCerts(c base.Config) error { + cl := security.MakeCertsLocator(c.SSLCertsDir) + + // Fail fast if we can't load the CAs. + b, err := collectLocalCABundle(c) + if err != nil { + return errors.Wrap( + err, "failed to load local CAs for certificate rotation") + } + + // Rotate InterNode Certs. + if b.InterNode.CACertificate != nil { + err = b.InterNode.rotateServiceCert( + cl.NodeCertPath(), + cl.NodeKeyPath(), + defaultCertLifetime, + serviceNameInterNode, + []string{c.HTTPAddr, c.HTTPAdvertiseAddr}, + ) + if err != nil { + return errors.Wrap(err, "failed to rotate InterNode cert") + } + } + + // TODO(aaron-crl): Should we rotate UserAuth Certs. + + // Rotate SQLService Certs. + if b.SQLService.CACertificate != nil { + err = b.SQLService.rotateServiceCert( + cl.SQLServiceCertPath(), + cl.SQLServiceKeyPath(), + defaultCertLifetime, + serviceNameSQL, + []string{c.HTTPAddr, c.HTTPAdvertiseAddr}, + ) + if err != nil { + return errors.Wrap(err, "failed to rotate SQLService cert") + } + } + + // Rotate RPCService Certs. + if b.RPCService.CACertificate != nil { + err = b.RPCService.rotateServiceCert( + cl.RPCServiceCertPath(), + cl.RPCServiceKeyPath(), + defaultCertLifetime, + serviceNameRPC, + []string{c.HTTPAddr, c.HTTPAdvertiseAddr}, + ) + if err != nil { + return errors.Wrap(err, "failed to rotate RPCService cert") + } + } + + // Rotate AdminUIService Certs. + if b.AdminUIService.CACertificate != nil { + err = b.AdminUIService.rotateServiceCert( + cl.UICertPath(), + cl.UIKeyPath(), + defaultCertLifetime, + serviceNameUI, + []string{c.HTTPAddr, c.HTTPAdvertiseAddr}, + ) + if err != nil { + return errors.Wrap(err, "failed to rotate AdminUIService cert") + } + } + + return nil +} + +// rotateServiceCert will generate a new service certificate for the provided +// hostnames and path signed by the ca at the supplied paths. It will only +// succeed if it is able to generate these and OVERWRITE an exist file. +func (sb *ServiceCertificateBundle) rotateServiceCert( + certPath string, + keyPath string, + serviceCertLifespan time.Duration, + serviceString string, + hostnames []string, +) error { + // generate + certPEM, keyPEM, err := security.CreateServiceCertAndKey( + serviceCertLifespan, + serviceString, + hostnames, + sb.CACertificate, + sb.CAKey, + ) + if err != nil { + return errors.Wrapf( + err, "failed to rotate certs for %q", serviceString) + } + + // Check to make sure we're about to overwrite a file. + if _, err := os.Stat(certPath); err == nil { + err = writeCertificateFile(certPath, certPEM, true) + if err != nil { + return errors.Wrapf( + err, "failed to rotate certs for %q", serviceString) + } + } else { + return errors.Wrapf( + err, "failed to rotate certs for %q", serviceString) + } + + // Check to make sure we're about to overwrite a file. + if _, err := os.Stat(certPath); err == nil { + err = writeKeyFile(keyPath, keyPEM, true) + if err != nil { + return errors.Wrapf( + err, "failed to rotate certs for %q", serviceString) + } + } else { + return errors.Wrapf( + err, "failed to rotate certs for %q", serviceString) + } + + return nil } diff --git a/pkg/server/auto_tls_init_test.go b/pkg/server/auto_tls_init_test.go index 0a74688b081e..8643bcab7bfa 100644 --- a/pkg/server/auto_tls_init_test.go +++ b/pkg/server/auto_tls_init_test.go @@ -11,17 +11,20 @@ package server import ( + "bytes" + "io" "io/ioutil" "os" "testing" + "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/util/leaktest" ) -// TestDummyInitializeFromConfig is a placeholder for actual testing functions. -// TODO(aaron-crl): [tests] write unit tests. -func TestDummyInitializeFromConfig(t *testing.T) { +// TestInitializeFromConfig is a placeholder for actual testing functions. +func TestInitializeFromConfig(t *testing.T) { defer leaktest.AfterTest(t)() // Create a temp dir for all certificate tests. @@ -37,9 +40,19 @@ func TestDummyInitializeFromConfig(t *testing.T) { err = certBundle.InitializeFromConfig(cfg) if err != nil { - t.Fatalf("expected err=nil, got: %s", err) + t.Fatalf("expected err=nil, got: %q", err) + } + + // Verify certs written to disk match certs in bundles. + bundleFromDisk, err := loadAllCertsFromDisk(cfg) + if err != nil { + t.Fatalf("failed loading certs from disk, got: %q", err) } + // Compare each set of certs and keys to those loaded from disk. + compareBundleCaCerts(t, bundleFromDisk, certBundle, true) + compareBundleServiceCerts(t, bundleFromDisk, certBundle, true) + // Remove temp directory now that we are done with it. err = os.RemoveAll(tempDir) if err != nil { @@ -48,6 +61,143 @@ func TestDummyInitializeFromConfig(t *testing.T) { } +func loadAllCertsFromDisk(cfg base.Config) (CertificateBundle, error) { + cl := security.MakeCertsLocator(cfg.SSLCertsDir) + bundleFromDisk, err := collectLocalCABundle(cfg) + if err != nil { + return bundleFromDisk, err + } + + err = bundleFromDisk.InterNode.loadOrCreateServiceCertificates( + cl.NodeCertPath(), cl.NodeKeyPath(), "", "", 0, 0, "", []string{}, + ) + if err != nil { + return bundleFromDisk, err + } + + // TODO(aaron-crl): Figure out how to handle client auth case. + //bundleFromDisk.UserAuth.loadOrCreateServiceCertificates( + // cl.ClientCertPath(), cl.ClientKeyPath(), "", "", 0, "", []string{}, + //) + err = bundleFromDisk.SQLService.loadOrCreateServiceCertificates( + cl.SQLServiceCertPath(), cl.SQLServiceKeyPath(), "", "", 0, 0, "", []string{}, + ) + if err != nil { + return bundleFromDisk, err + } + + err = bundleFromDisk.RPCService.loadOrCreateServiceCertificates( + cl.RPCServiceCertPath(), cl.RPCServiceKeyPath(), "", "", 0, 0, "", []string{}, + ) + if err != nil { + return bundleFromDisk, err + } + + err = bundleFromDisk.AdminUIService.loadOrCreateServiceCertificates( + cl.UICertPath(), cl.UIKeyPath(), "", "", 0, 0, "", []string{}, + ) + if err != nil { + return bundleFromDisk, err + } + + return bundleFromDisk, nil +} + +func certCompareHelper(t *testing.T, desireEqual bool) func([]byte, []byte, string) { + if desireEqual { + return func(b1 []byte, b2 []byte, certName string) { + if !bytes.Equal(b1, b2) { + t.Fatalf("bytes for %s not equal", certName) + } + } + } + return func(b1 []byte, b2 []byte, certName string) { + if bytes.Equal(b1, b2) && b1 != nil { + t.Fatalf("bytes for %s were equal", certName) + } + } +} + +func compareBundleCaCerts( + t *testing.T, cb1 CertificateBundle, cb2 CertificateBundle, desireEqual bool, +) { + cmp := certCompareHelper(t, desireEqual) + // Compare InterNode CA cert and key. + cmp( + cb1.InterNode.CACertificate, + cb2.InterNode.CACertificate, serviceNameInterNode+" CA cert") + cmp( + cb1.InterNode.CAKey, + cb2.InterNode.CAKey, serviceNameInterNode+" CA key") + + // Compare UserAuth CA cert and key. + cmp( + cb1.UserAuth.CACertificate, + cb2.UserAuth.CACertificate, serviceNameUserAuth+" CA cert") + cmp( + cb1.UserAuth.CAKey, + cb2.UserAuth.CAKey, serviceNameUserAuth+" CA key") + + // Compare SQL CA cert and key. + cmp( + cb1.SQLService.CACertificate, + cb2.SQLService.CACertificate, serviceNameSQL+" CA cert") + cmp( + cb1.SQLService.CAKey, + cb2.SQLService.CAKey, serviceNameSQL+" CA key") + + // Compare RPC CA cert and key. + cmp( + cb1.RPCService.CACertificate, + cb2.RPCService.CACertificate, serviceNameRPC+" CA cert") + cmp( + cb1.RPCService.CAKey, + cb2.RPCService.CAKey, serviceNameRPC+" CA key") + + // Compare UI CA cert and key. + cmp( + cb1.AdminUIService.CACertificate, + cb2.AdminUIService.CACertificate, serviceNameUI+" CA cert") + cmp( + cb1.AdminUIService.CAKey, + cb2.AdminUIService.CAKey, serviceNameUI+" CA key") + +} + +func compareBundleServiceCerts( + t *testing.T, cb1 CertificateBundle, cb2 CertificateBundle, desireEqual bool, +) { + cmp := certCompareHelper(t, desireEqual) + + cmp( + cb1.InterNode.HostCertificate, + cb2.InterNode.HostCertificate, serviceNameInterNode+" Host cert") + cmp( + cb1.InterNode.HostKey, + cb2.InterNode.HostKey, serviceNameInterNode+" Host key") + + cmp( + cb1.SQLService.HostCertificate, + cb2.SQLService.HostCertificate, serviceNameSQL+" Host cert") + cmp( + cb1.SQLService.HostKey, + cb2.SQLService.HostKey, serviceNameSQL+" Host key") + + cmp( + cb1.RPCService.HostCertificate, + cb2.RPCService.HostCertificate, serviceNameRPC+" Host cert") + cmp( + cb1.RPCService.HostKey, + cb2.RPCService.HostKey, serviceNameRPC+" Host key") + + cmp( + cb1.AdminUIService.HostCertificate, + cb2.AdminUIService.HostCertificate, serviceNameUI+" Host cert") + cmp( + cb1.AdminUIService.HostKey, + cb2.AdminUIService.HostKey, serviceNameUI+" Host key") +} + // TestDummyInitializeNodeFromBundle is a placeholder for actual testing functions. // TODO(aaron-crl): [tests] write unit tests. func TestDummyInitializeNodeFromBundle(t *testing.T) { @@ -58,6 +208,11 @@ func TestDummyInitializeNodeFromBundle(t *testing.T) { if err != nil { t.Fatalf("failed to create test temp dir: %s", err) } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatal(err) + } + }() certBundle := CertificateBundle{} cfg := base.Config{ @@ -68,12 +223,6 @@ func TestDummyInitializeNodeFromBundle(t *testing.T) { if err != nil { t.Fatalf("expected err=nil, got: %s", err) } - - // Remove temp directory now that we are done with it. - err = os.RemoveAll(tempDir) - if err != nil { - t.Fatalf("failed to remove test temp dir: %s", err) - } } // TestDummyCertLoader is a placeholder for actual testing functions. @@ -84,8 +233,189 @@ func TestDummyCertLoader(t *testing.T) { scb := ServiceCertificateBundle{} _ = scb.loadServiceCertAndKey("", "") _ = scb.loadCACertAndKey("", "") + _ = scb.rotateServiceCert("", "", time.Minute, "", []string{""}) +} + +// TestNodeCertRotation tests that the rotation function will overwrite the +// expected certificates and fail if they are not there. +func TestRotationOnUnintializedNode(t *testing.T) { + defer leaktest.AfterTest(t)() + + // Create a temp dir for all certificate tests. + tempDir, err := ioutil.TempDir("", "auto_tls_init_test") + if err != nil { + t.Fatalf("failed to create test temp dir: %s", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatal(err) + } + }() + + cfg := base.Config{ + SSLCertsDir: tempDir, + } + + // Check the empty case. + // Check to see that the only file in dir is the EOF. + dir, err := os.Open(cfg.SSLCertsDir) + if err != nil { + t.Fatalf( + "failed to open cfg.SSLCertsDir: %q with err: %v", + cfg.SSLCertsDir, + err) + } + defer dir.Close() + _, err = dir.Readdir(1) + if err != io.EOF { + // Directory is not empty to start with, this is an error. + t.Fatal("files added to cfg.SSLCertsDir when they shouldn't have been") + } + + err = rotateGeneratedCerts(cfg) + if err != nil { + t.Fatalf("expected nil error generating no certs, got: %q", err) + } + +} + +func TestRotationOnIntializedNode(t *testing.T) { + defer leaktest.AfterTest(t)() - cb := CertificateBundle{} - cb.InterNode.copyOnlyCAs(&scb) - _ = cb.ToPeerInitBundle() + // Create a temp dir for all certificate tests. + tempDir, err := ioutil.TempDir("", "auto_tls_init_test") + if err != nil { + t.Fatalf("failed to create test temp dir: %s", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatal(err) + } + }() + + cfg := base.Config{ + SSLCertsDir: tempDir, + } + + // Test in the fully provisioned case. + certBundle := CertificateBundle{} + err = certBundle.InitializeFromConfig(cfg) + if err != nil { + t.Fatalf("expected err=nil, got: %q", err) + } + + err = rotateGeneratedCerts(cfg) + if err != nil { + t.Fatalf("rotation failed; expected err=nil, got: %q", err) + } + + // Verify that any existing certs have changed on disk for services + diskBundle, err := loadAllCertsFromDisk(cfg) + if err != nil { + t.Fatalf("failed loading certs from disk, got: %q", err) + } + compareBundleServiceCerts(t, certBundle, diskBundle, false) +} + +func TestRotationOnPartialIntializedNode(t *testing.T) { + defer leaktest.AfterTest(t)() + + // Create a temp dir for all certificate tests. + tempDir, err := ioutil.TempDir("", "auto_tls_init_test") + if err != nil { + t.Fatalf("failed to create test temp dir: %s", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatal(err) + } + }() + + cfg := base.Config{ + SSLCertsDir: tempDir, + } + // Test in the partially provisioned case (remove the Client and UI CAs). + certBundle := CertificateBundle{} + err = certBundle.InitializeFromConfig(cfg) + if err != nil { + t.Fatalf("expected err=nil, got: %q", err) + } + + cl := security.MakeCertsLocator(cfg.SSLCertsDir) + if err = os.Remove(cl.ClientCACertPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + if err = os.Remove(cl.ClientCAKeyPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + if err = os.Remove(cl.UICACertPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + if err = os.Remove(cl.UICAKeyPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + + // This should rotate all service certs except client and UI. + err = rotateGeneratedCerts(cfg) + if err != nil { + t.Fatalf("rotation failed; expected err=nil, got: %q", err) + } + + // Verify that client and UI service host certs are unchanged. + diskBundle, err := loadAllCertsFromDisk(cfg) + if err != nil { + t.Fatalf("failed loading certs from disk, got: %q", err) + } + cmp := certCompareHelper(t, true) + cmp( + certBundle.UserAuth.HostCertificate, + diskBundle.UserAuth.HostCertificate, "UserAuth host cert") + cmp( + certBundle.UserAuth.HostKey, + diskBundle.UserAuth.HostKey, "UserAuth host key") + cmp( + certBundle.AdminUIService.HostCertificate, + diskBundle.AdminUIService.HostCertificate, "AdminUIService host cert") + cmp( + certBundle.AdminUIService.HostKey, + diskBundle.AdminUIService.HostKey, "AdminUIService host key") +} + +// TestRotationOnBrokenIntializedNode in the partially provisioned case (remove the Client and UI CAs). +func TestRotationOnBrokenIntializedNode(t *testing.T) { + defer leaktest.AfterTest(t)() + + // Create a temp dir for all certificate tests. + tempDir, err := ioutil.TempDir("", "auto_tls_init_test") + if err != nil { + t.Fatalf("failed to create test temp dir: %s", err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatal(err) + } + }() + + cfg := base.Config{ + SSLCertsDir: tempDir, + } + cl := security.MakeCertsLocator(cfg.SSLCertsDir) + certBundle := CertificateBundle{} + err = certBundle.InitializeFromConfig(cfg) + if err != nil { + t.Fatalf("expected err=nil, got: %q", err) + } + // Test in the case where a leaf certificate has been removed but a CA is + // still present with key. This should fail. Removing SQL certificate. + if err = os.Remove(cl.SQLServiceCertPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + if err = os.Remove(cl.SQLServiceKeyPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + + err = rotateGeneratedCerts(cfg) + if err == nil { + t.Fatalf("rotation succeeded but should have failed with missing leaf certs for SQLService") + } } diff --git a/pkg/server/init_handshake.go b/pkg/server/init_handshake.go index 7b7e2f5f8b6d..fdccf9681ecf 100644 --- a/pkg/server/init_handshake.go +++ b/pkg/server/init_handshake.go @@ -128,13 +128,13 @@ func pemToSignature(caCertPEM []byte) ([]byte, error) { } func createNodeInitTempCertificates( - hostname string, lifespan time.Duration, + hostnames []string, lifespan time.Duration, ) (certs ServiceCertificateBundle, err error) { caCert, caKey, err := security.CreateCACertAndKey(lifespan, initServiceName) if err != nil { return certs, err } - serviceCert, serviceKey, err := security.CreateServiceCertAndKey(lifespan, initServiceName, hostname, caCert, caKey) + serviceCert, serviceKey, err := security.CreateServiceCertAndKey(lifespan, initServiceName, hostnames, caCert, caKey) if err != nil { return certs, err } @@ -424,7 +424,7 @@ func initHandshakeHelper( default: return errors.New("unsupported listener protocol: only TCP listeners supported") } - tempCerts, err := createNodeInitTempCertificates(listenHost, defaultInitLifespan) + tempCerts, err := createNodeInitTempCertificates([]string{listenHost}, defaultInitLifespan) if err != nil { return errors.Wrap(err, "failed to create certificates") } @@ -500,7 +500,12 @@ func initHandshakeHelper( if err := b.InitializeFromConfig(*cfg); err != nil { return errors.Wrap(err, "error when creating initialization bundle") } - peerInit := b.ToPeerInitBundle() + + peerInit, err := collectLocalCABundle(*cfg) + if err != nil { + return errors.Wrap(err, "error when loading initialization bundle") + } + trustBundle := nodeTrustBundle{Bundle: peerInit} trustBundle.signHMAC(handshaker.token) // For each peer, use its CA to establish a secure connection and deliver the trust bundle. diff --git a/vendor b/vendor index 7848358afbe7..797f916a0d9c 160000 --- a/vendor +++ b/vendor @@ -1 +1 @@ -Subproject commit 7848358afbe7fb0351801eb4c86f2cf6d048231d +Subproject commit 797f916a0d9c283c8f05e4d7af9790d4b165b834