From 0961c0f599ef1c13c7609a98ce559a03a80d1d59 Mon Sep 17 00:00:00 2001 From: Aaron Blum Date: Fri, 26 Feb 2021 20:11:46 -0500 Subject: [PATCH] server: added test cases for auto cert rotation code Release justification: low, added/improved tests to existing function Release note: None --- pkg/server/auto_tls_init.go | 18 +- pkg/server/auto_tls_init_test.go | 306 ++++++++++++++++++++++++++++++- 2 files changed, 304 insertions(+), 20 deletions(-) diff --git a/pkg/server/auto_tls_init.go b/pkg/server/auto_tls_init.go index f4967dfdc15b..c4e781def4f6 100644 --- a/pkg/server/auto_tls_init.go +++ b/pkg/server/auto_tls_init.go @@ -20,7 +20,6 @@ import ( "encoding/pem" "io/ioutil" "os" - "strings" "time" "github.com/cockroachdb/cockroach/pkg/base" @@ -186,8 +185,8 @@ func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( } // CA cert and key should now be loaded, create service cert and key. - var hostCert, hostKey []byte - hostCert, hostKey, err = security.CreateServiceCertAndKey( + //var hostCert, hostKey []byte + sb.HostCertificate, sb.HostKey, err = security.CreateServiceCertAndKey( initLifespan, serviceName, hostnames, @@ -200,12 +199,12 @@ func (sb *ServiceCertificateBundle) loadOrCreateServiceCertificates( ) } - err = writeCertificateFile(serviceCertPath, hostCert, false) + err = writeCertificateFile(serviceCertPath, sb.HostCertificate, false) if err != nil { return err } - err = writeKeyFile(serviceKeyPath, hostKey, false) + err = writeKeyFile(serviceKeyPath, sb.HostKey, false) if err != nil { return err } @@ -286,7 +285,7 @@ func writeKeyFile(keyFilePath string, keyPEMBytes []byte, overwrite bool) error } // TODO(aaron-crl): Add logging here. - return security.WritePEMToFile(keyFilePath, 600, overwrite, keyBlock) + return security.WritePEMToFile(keyFilePath, 0600, overwrite, keyBlock) } // InitializeFromConfig is called by the node creating certificates for the @@ -537,7 +536,6 @@ func collectLocalCABundle(c base.Config) (CertificateBundle, error) { // certificate/key pair. func rotateGeneratedCerts(c base.Config) error { cl := security.MakeCertsLocator(c.SSLCertsDir) - var errStrings []string // Fail fast if we can't load the CAs. b, err := collectLocalCABundle(c) @@ -604,7 +602,7 @@ func rotateGeneratedCerts(c base.Config) error { } } - return errors.Errorf(strings.Join(errStrings, "\n")) + return nil } // rotateServiceCert will generate a new service certificate for the provided @@ -631,7 +629,7 @@ func (sb *ServiceCertificateBundle) rotateServiceCert( } // Check to make sure we're about to overwrite a file. - if _, err := os.Stat(certPath); err != nil { + if _, err := os.Stat(certPath); err == nil { err = writeCertificateFile(certPath, certPEM, true) if err != nil { return errors.Wrapf( @@ -643,7 +641,7 @@ func (sb *ServiceCertificateBundle) rotateServiceCert( } // Check to make sure we're about to overwrite a file. - if _, err := os.Stat(certPath); err != nil { + if _, err := os.Stat(certPath); err == nil { err = writeKeyFile(keyPath, keyPEM, true) if err != nil { return errors.Wrapf( diff --git a/pkg/server/auto_tls_init_test.go b/pkg/server/auto_tls_init_test.go index a7b9a334a065..c493be4414ec 100644 --- a/pkg/server/auto_tls_init_test.go +++ b/pkg/server/auto_tls_init_test.go @@ -11,18 +11,20 @@ package server import ( + "bytes" + "io" "io/ioutil" "os" "testing" "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/util/leaktest" ) -// TestDummyInitializeFromConfig is a placeholder for actual testing functions. -// TODO(aaron-crl): [tests] write unit tests. -func TestDummyInitializeFromConfig(t *testing.T) { +// TestInitializeFromConfig is a placeholder for actual testing functions. +func TestInitializeFromConfig(t *testing.T) { defer leaktest.AfterTest(t)() // Create a temp dir for all certificate tests. @@ -38,9 +40,19 @@ func TestDummyInitializeFromConfig(t *testing.T) { err = certBundle.InitializeFromConfig(cfg) if err != nil { - t.Fatalf("expected err=nil, got: %s", err) + t.Fatalf("expected err=nil, got: %q", err) + } + + // Verify certs written to disk match certs in bundles. + bundleFromDisk, err := loadAllCertsFromDisk(cfg) + if err != nil { + t.Fatalf("failed loading certs from disk, got: %q", err) } + // Compare each set of certs and keys to those loaded from disk. + compareBundleCaCerts(t, bundleFromDisk, certBundle, true) + compareBundleServiceCerts(t, bundleFromDisk, certBundle, true) + // Remove temp directory now that we are done with it. err = os.RemoveAll(tempDir) if err != nil { @@ -49,6 +61,127 @@ func TestDummyInitializeFromConfig(t *testing.T) { } +func loadAllCertsFromDisk(cfg base.Config) (CertificateBundle, error) { + cl := security.MakeCertsLocator(cfg.SSLCertsDir) + bundleFromDisk, err := collectLocalCABundle(cfg) + if err != nil { + return bundleFromDisk, err + } + bundleFromDisk.InterNode.loadOrCreateServiceCertificates( + cl.NodeCertPath(), cl.NodeKeyPath(), "", "", 0, "", []string{}, + ) + // TODO(aaron-crl): Figure out how to handle client auth case. + //bundleFromDisk.UserAuth.loadOrCreateServiceCertificates( + // cl.ClientCertPath(), cl.ClientKeyPath(), "", "", 0, "", []string{}, + //) + bundleFromDisk.SQLService.loadOrCreateServiceCertificates( + cl.SQLServiceCertPath(), cl.SQLServiceKeyPath(), "", "", 0, "", []string{}, + ) + bundleFromDisk.RPCService.loadOrCreateServiceCertificates( + cl.RPCServiceCertPath(), cl.RPCServiceKeyPath(), "", "", 0, "", []string{}, + ) + bundleFromDisk.AdminUIService.loadOrCreateServiceCertificates( + cl.UICertPath(), cl.UIKeyPath(), "", "", 0, "", []string{}, + ) + + return bundleFromDisk, nil +} + +func certCompareHelper(t *testing.T, desireEqual bool) func([]byte, []byte, string) { + if desireEqual { + return func(b1 []byte, b2 []byte, certName string) { + if bytes.Compare(b1, b2) != 0 { + t.Fatalf("bytes for %s not equal", certName) + } + } + } + return func(b1 []byte, b2 []byte, certName string) { + if bytes.Compare(b1, b2) == 0 && b1 != nil { + t.Fatalf("bytes for %s were equal", certName) + } + } +} + +func compareBundleCaCerts( + t *testing.T, cb1 CertificateBundle, cb2 CertificateBundle, desireEqual bool, +) { + cmp := certCompareHelper(t, desireEqual) + // Compare InterNode CA cert and key. + cmp( + cb1.InterNode.CACertificate, + cb2.InterNode.CACertificate, "InterNodeCA cert") + cmp( + cb1.InterNode.CAKey, + cb2.InterNode.CAKey, "InterNodeCA key") + + // Compare UserAuth CA cert and key. + cmp( + cb1.UserAuth.CACertificate, + cb2.UserAuth.CACertificate, "UserAuth CA cert") + cmp( + cb1.UserAuth.CAKey, + cb2.UserAuth.CAKey, "UserAuth CA key") + + // Compare SQL CA cert and key. + cmp( + cb1.SQLService.CACertificate, + cb2.SQLService.CACertificate, "SQLService CA cert") + cmp( + cb1.SQLService.CAKey, + cb2.SQLService.CAKey, "SQLService CA key") + + // Compare RPC CA cert and key. + cmp( + cb1.RPCService.CACertificate, + cb2.RPCService.CACertificate, "RPCService CA cert") + cmp( + cb1.RPCService.CAKey, + cb2.RPCService.CAKey, "RPCService CA key") + + // Compare UI CA cert and key. + cmp( + cb1.AdminUIService.CACertificate, + cb2.AdminUIService.CACertificate, "AdminUIService CA cert") + cmp( + cb1.AdminUIService.CAKey, + cb2.AdminUIService.CAKey, "AdminUIService CA key") + +} + +func compareBundleServiceCerts( + t *testing.T, cb1 CertificateBundle, cb2 CertificateBundle, desireEqual bool, +) { + cmp := certCompareHelper(t, desireEqual) + + cmp( + cb1.InterNode.HostCertificate, + cb2.InterNode.HostCertificate, "InterNode Host cert") + cmp( + cb1.InterNode.HostKey, + cb2.InterNode.HostKey, "InterNode Host key") + + cmp( + cb1.SQLService.HostCertificate, + cb2.SQLService.HostCertificate, "SQLService Host cert") + cmp( + cb1.SQLService.HostKey, + cb2.SQLService.HostKey, "SQLService Host key") + + cmp( + cb1.RPCService.HostCertificate, + cb2.RPCService.HostCertificate, "RPCService Host cert") + cmp( + cb1.RPCService.HostKey, + cb2.RPCService.HostKey, "RPCService Host key") + + cmp( + cb1.AdminUIService.HostCertificate, + cb2.AdminUIService.HostCertificate, "AdminUIService Host cert") + cmp( + cb1.AdminUIService.HostKey, + cb2.AdminUIService.HostKey, "AdminUIService Host key") +} + // TestDummyInitializeNodeFromBundle is a placeholder for actual testing functions. // TODO(aaron-crl): [tests] write unit tests. func TestDummyInitializeNodeFromBundle(t *testing.T) { @@ -59,6 +192,7 @@ func TestDummyInitializeNodeFromBundle(t *testing.T) { if err != nil { t.Fatalf("failed to create test temp dir: %s", err) } + defer os.RemoveAll(tempDir) certBundle := CertificateBundle{} cfg := base.Config{ @@ -69,12 +203,6 @@ func TestDummyInitializeNodeFromBundle(t *testing.T) { if err != nil { t.Fatalf("expected err=nil, got: %s", err) } - - // Remove temp directory now that we are done with it. - err = os.RemoveAll(tempDir) - if err != nil { - t.Fatalf("failed to remove test temp dir: %s", err) - } } // TestDummyCertLoader is a placeholder for actual testing functions. @@ -87,3 +215,161 @@ func TestDummyCertLoader(t *testing.T) { _ = scb.loadCACertAndKey("", "") _ = scb.rotateServiceCert("", "", time.Minute, "", []string{""}) } + +// TestNodeCertRotation tests that the rotation function will overwrite the +// expected certificates and fail if they are not there. +// TODO(aaron-crl): correct this +func TestRotationOnUnintializedNode(t *testing.T) { + // Create a temp dir for all certificate tests. + tempDir, err := ioutil.TempDir("", "auto_tls_init_test") + if err != nil { + t.Fatalf("failed to create test temp dir: %s", err) + } + defer os.RemoveAll(tempDir) + + cfg := base.Config{ + SSLCertsDir: tempDir, + } + + // Check the empty case. + // Check to see that the only file in dir is the EOF. + dir, err := os.Open(cfg.SSLCertsDir) + if err != nil { + t.Fatalf("failed to open cfg.SSLCertsDir: %q", cfg.SSLCertsDir) + } + _, err = dir.Readdir(1) + if err != io.EOF { + // Directory is not empty to start with, this is an error. + t.Fatal("files added to cfg.SSLCertsDir when they shouldn't have been") + } + dir.Close() + + err = rotateGeneratedCerts(cfg) + if err != nil { + t.Fatalf("expected nil error generating no certs, got: %q", err) + } + +} + +func TestRotationOnIntializedNode(t *testing.T) { + // Create a temp dir for all certificate tests. + tempDir, err := ioutil.TempDir("", "auto_tls_init_test") + if err != nil { + t.Fatalf("failed to create test temp dir: %s", err) + } + defer os.RemoveAll(tempDir) + + cfg := base.Config{ + SSLCertsDir: tempDir, + } + + // Test in the fully provisioned case. + certBundle := CertificateBundle{} + err = certBundle.InitializeFromConfig(cfg) + if err != nil { + t.Fatalf("expected err=nil, got: %q", err) + } + + err = rotateGeneratedCerts(cfg) + if err != nil { + t.Fatalf("rotation failed; expected err=nil, got: %q", err) + } + + // Verify that any existing certs have changed on disk for services + diskBundle, err := loadAllCertsFromDisk(cfg) + if err != nil { + t.Fatalf("failed loading certs from disk, got: %q", err) + } + compareBundleServiceCerts(t, certBundle, diskBundle, false) +} + +func TestRotationOnPartialIntializedNode(t *testing.T) { + // Create a temp dir for all certificate tests. + tempDir, err := ioutil.TempDir("", "auto_tls_init_test") + if err != nil { + t.Fatalf("failed to create test temp dir: %s", err) + } + defer os.RemoveAll(tempDir) + + cfg := base.Config{ + SSLCertsDir: tempDir, + } + // Test in the partially provisioned case (remove the Client and UI CAs). + certBundle := CertificateBundle{} + err = certBundle.InitializeFromConfig(cfg) + if err != nil { + t.Fatalf("expected err=nil, got: %q", err) + } + + cl := security.MakeCertsLocator(cfg.SSLCertsDir) + if err = os.Remove(cl.ClientCACertPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + if err = os.Remove(cl.ClientCAKeyPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + if err = os.Remove(cl.UICACertPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + if err = os.Remove(cl.UICAKeyPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + + // This should rotate all service certs except client and UI. + err = rotateGeneratedCerts(cfg) + if err != nil { + t.Fatalf("rotation failed; expected err=nil, got: %q", err) + } + + // Verify that client and UI service host certs are unchanged. + diskBundle, err := loadAllCertsFromDisk(cfg) + if err != nil { + t.Fatalf("failed loading certs from disk, got: %q", err) + } + cmp := certCompareHelper(t, true) + cmp( + certBundle.UserAuth.HostCertificate, + diskBundle.UserAuth.HostCertificate, "UserAuth host cert") + cmp( + certBundle.UserAuth.HostKey, + diskBundle.UserAuth.HostKey, "UserAuth host key") + cmp( + certBundle.AdminUIService.HostCertificate, + diskBundle.AdminUIService.HostCertificate, "AdminUIService host cert") + cmp( + certBundle.AdminUIService.HostKey, + diskBundle.AdminUIService.HostKey, "AdminUIService host key") +} + +// TestRotationOnBrokenIntializedNode in the partially provisioned case (remove the Client and UI CAs). +func TestRotationOnBrokenIntializedNode(t *testing.T) { + // Create a temp dir for all certificate tests. + tempDir, err := ioutil.TempDir("", "auto_tls_init_test") + if err != nil { + t.Fatalf("failed to create test temp dir: %s", err) + } + defer os.RemoveAll(tempDir) + + cfg := base.Config{ + SSLCertsDir: tempDir, + } + cl := security.MakeCertsLocator(cfg.SSLCertsDir) + certBundle := CertificateBundle{} + err = certBundle.InitializeFromConfig(cfg) + if err != nil { + t.Fatalf("expected err=nil, got: %q", err) + } + // Test in the case where a leaf certificate has been removed but a CA is + // still present with key. This should fail. Removing SQL certificate. + if err = os.Remove(cl.SQLServiceCertPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + if err = os.Remove(cl.SQLServiceKeyPath()); err != nil { + t.Fatalf("failed to remove test cert: %q", err) + } + + err = rotateGeneratedCerts(cfg) + if err == nil { + t.Fatalf("rotation succeeded but should have failed with missing leaf certs for SQLService") + } +}