diff --git a/pkg/bundle/bundle_test.go b/pkg/bundle/bundle_test.go index b301faff..1000aadd 100644 --- a/pkg/bundle/bundle_test.go +++ b/pkg/bundle/bundle_test.go @@ -41,6 +41,7 @@ import ( trustapi "github.com/cert-manager/trust-manager/pkg/apis/trust/v1alpha1" "github.com/cert-manager/trust-manager/pkg/fspkg" + "github.com/cert-manager/trust-manager/pkg/util" "github.com/cert-manager/trust-manager/test/dummy" "github.com/cert-manager/trust-manager/test/gen" ) @@ -48,7 +49,12 @@ import ( func testEncodeJKS(t *testing.T, data string) []byte { t.Helper() - encoded, err := jksEncoder{password: trustapi.DefaultJKSPassword}.encode(data) + certPool := util.NewCertPool() + if err := certPool.AddCertsFromPEM([]byte(data)); err != nil { + t.Fatal(err) + } + + encoded, err := jksEncoder{password: trustapi.DefaultJKSPassword}.encode(certPool) if err != nil { t.Error(err) } diff --git a/pkg/bundle/source.go b/pkg/bundle/source.go index 6e53e274..3db5130d 100644 --- a/pkg/bundle/source.go +++ b/pkg/bundle/source.go @@ -21,9 +21,7 @@ import ( "context" "crypto/sha256" "encoding/hex" - "encoding/pem" "fmt" - "slices" "strings" jks "github.com/pavlo-v-chernykh/keystore-go/v4" @@ -54,7 +52,7 @@ type bundleData struct { // is each bundle is concatenated together with a new line character. func (b *bundle) buildSourceBundle(ctx context.Context, sources []trustapi.BundleSource, formats *trustapi.AdditionalFormats) (bundleData, error) { var resolvedBundle bundleData - var bundles []string + certPool := util.NewCertPool(util.WithFilteredExpiredCerts(b.FilterExpiredCerts)) for _, source := range sources { var ( @@ -89,27 +87,19 @@ func (b *bundle) buildSourceBundle(ctx context.Context, sources []trustapi.Bundl return bundleData{}, fmt.Errorf("failed to retrieve bundle from source: %w", err) } - opts := util.ValidateAndSanitizeOptions{FilterExpired: b.Options.FilterExpiredCerts} - sanitizedBundle, err := util.ValidateAndSanitizePEMBundleWithOptions([]byte(sourceData), opts) + err = certPool.AddCertsFromPEM([]byte(sourceData)) if err != nil { return bundleData{}, fmt.Errorf("invalid PEM data in source: %w", err) } - bundles = append(bundles, string(sanitizedBundle)) } // NB: empty bundles are not valid so check and return an error if one somehow snuck through. - - if len(bundles) == 0 { + if certPool.Size() == 0 { return bundleData{}, fmt.Errorf("couldn't find any valid certificates in bundle") } - deduplicatedBundles, err := deduplicateAndSortBundles(bundles) - if err != nil { - return bundleData{}, err - } - - if err := resolvedBundle.populateData(deduplicatedBundles, formats); err != nil { + if err := resolvedBundle.populateData(certPool, formats); err != nil { return bundleData{}, err } @@ -217,17 +207,12 @@ type jksEncoder struct { // encodeJKS creates a binary JKS file from the given PEM-encoded trust bundle and password. // Note that the password is not treated securely; JKS files generally seem to expect a password // to exist and so we have the option for one. -func (e jksEncoder) encode(trustBundle string) ([]byte, error) { - cas, err := util.DecodeX509CertificateChainBytes([]byte(trustBundle)) - if err != nil { - return nil, fmt.Errorf("failed to decode trust bundle: %w", err) - } - +func (e jksEncoder) encode(trustBundle *util.CertPool) ([]byte, error) { // WithOrderedAliases ensures that trusted certs are added to the JKS file in order, // which makes the files appear to be reliably deterministic. ks := jks.New(jks.WithOrderedAliases()) - for _, c := range cas { + for _, c := range trustBundle.Certificates() { alias := certAlias(c.Raw, c.Subject.String()) // Note on CreationTime: @@ -239,15 +224,13 @@ func (e jksEncoder) encode(trustBundle string) ([]byte, error) { // - Using a fixed time (i.e. unix epoch) // We use NotBefore here, arbitrarily. - err = ks.SetTrustedCertificateEntry(alias, jks.TrustedCertificateEntry{ + if err := ks.SetTrustedCertificateEntry(alias, jks.TrustedCertificateEntry{ CreationTime: c.NotBefore, Certificate: jks.Certificate{ Type: "X509", Content: c.Raw, }, - }) - - if err != nil { + }); err != nil { // this error should never happen if we set jks.Certificate correctly return nil, fmt.Errorf("failed to add cert with alias %q to trust store: %w", alias, err) } @@ -255,8 +238,7 @@ func (e jksEncoder) encode(trustBundle string) ([]byte, error) { buf := &bytes.Buffer{} - err = ks.Store(buf, []byte(e.password)) - if err != nil { + if err := ks.Store(buf, []byte(e.password)); err != nil { return nil, fmt.Errorf("failed to create JKS file: %w", err) } @@ -285,14 +267,9 @@ type pkcs12Encoder struct { password string } -func (e pkcs12Encoder) encode(trustBundle string) ([]byte, error) { - cas, err := util.DecodeX509CertificateChainBytes([]byte(trustBundle)) - if err != nil { - return nil, fmt.Errorf("failed to decode trust bundle: %w", err) - } - +func (e pkcs12Encoder) encode(trustBundle *util.CertPool) ([]byte, error) { var entries []pkcs12.TrustStoreEntry - for _, c := range cas { + for _, c := range trustBundle.Certificates() { entries = append(entries, pkcs12.TrustStoreEntry{ Cert: c, FriendlyName: certAlias(c.Raw, c.Subject.String()), @@ -308,14 +285,14 @@ func (e pkcs12Encoder) encode(trustBundle string) ([]byte, error) { return encoder.EncodeTrustStoreEntries(entries, e.password) } -func (b *bundleData) populateData(bundles []string, formats *trustapi.AdditionalFormats) error { - b.data = strings.Join(bundles, "\n") + "\n" +func (b *bundleData) populateData(pool *util.CertPool, formats *trustapi.AdditionalFormats) error { + b.data = pool.PEM() if formats != nil { b.binaryData = make(map[string][]byte) if formats.JKS != nil { - encoded, err := jksEncoder{password: *formats.JKS.Password}.encode(b.data) + encoded, err := jksEncoder{password: *formats.JKS.Password}.encode(pool) if err != nil { return fmt.Errorf("failed to encode JKS: %w", err) } @@ -323,7 +300,7 @@ func (b *bundleData) populateData(bundles []string, formats *trustapi.Additional } if formats.PKCS12 != nil { - encoded, err := pkcs12Encoder{password: *formats.PKCS12.Password}.encode(b.data) + encoded, err := pkcs12Encoder{password: *formats.PKCS12.Password}.encode(pool) if err != nil { return fmt.Errorf("failed to encode PKCS12: %w", err) } @@ -332,48 +309,3 @@ func (b *bundleData) populateData(bundles []string, formats *trustapi.Additional } return nil } - -// remove duplicate certificates from bundles and sort certificates by hash -func deduplicateAndSortBundles(bundles []string) ([]string, error) { - var block *pem.Block - - var certificatesHashes = make(map[[32]byte]string) - - for _, cert := range bundles { - certBytes := []byte(cert) - - for { - block, certBytes = pem.Decode(certBytes) - if block == nil { - break - } - - if block.Type != "CERTIFICATE" { - return nil, fmt.Errorf("couldn't decode PEM block containing certificate") - } - - // calculate hash sum of the given certificate - hash := sha256.Sum256(block.Bytes) - // check existence of the hash - if _, ok := certificatesHashes[hash]; !ok { - // neew to trim a newline which is added by Encoder - certificatesHashes[hash] = string(bytes.Trim(pem.EncodeToMemory(block), "\n")) - } - } - } - - var orderedKeys [][32]byte - for key := range certificatesHashes { - orderedKeys = append(orderedKeys, key) - } - slices.SortFunc(orderedKeys, func(a, b [32]byte) int { - return bytes.Compare(a[:], b[:]) - }) - - var sortedDeduplicatedCerts []string - for _, key := range orderedKeys { - sortedDeduplicatedCerts = append(sortedDeduplicatedCerts, certificatesHashes[key]) - } - - return sortedDeduplicatedCerts, nil -} diff --git a/pkg/bundle/source_test.go b/pkg/bundle/source_test.go index 7e6755a0..07057e13 100644 --- a/pkg/bundle/source_test.go +++ b/pkg/bundle/source_test.go @@ -22,6 +22,7 @@ import ( "crypto/x509" "encoding/pem" "errors" + "strings" "testing" jks "github.com/pavlo-v-chernykh/keystore-go/v4" @@ -35,6 +36,7 @@ import ( trustapi "github.com/cert-manager/trust-manager/pkg/apis/trust/v1alpha1" "github.com/cert-manager/trust-manager/pkg/fspkg" + "github.com/cert-manager/trust-manager/pkg/util" "github.com/cert-manager/trust-manager/test/dummy" ) @@ -398,7 +400,12 @@ func Test_encodeJKSAliases(t *testing.T) { // Using different dummy certs would allow this test to pass but wouldn't actually test anything useful! bundle := dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate2) - jksFile, err := jksEncoder{password: trustapi.DefaultJKSPassword}.encode(bundle) + certPool := util.NewCertPool() + if err := certPool.AddCertsFromPEM([]byte(bundle)); err != nil { + t.Fatalf("failed to add certs to pool: %s", err) + } + + jksFile, err := jksEncoder{password: trustapi.DefaultJKSPassword}.encode(certPool) if err != nil { t.Fatalf("didn't expect an error but got: %s", err) } @@ -449,6 +456,7 @@ func TestBundlesDeduplication(t *testing.T) { tests := map[string]struct { name string bundle []string + expError string testBundle []string }{ "single, different cert per source": { @@ -464,6 +472,7 @@ func TestBundlesDeduplication(t *testing.T) { "no certs in sources": { bundle: []string{}, testBundle: nil, + expError: "no non-expired certificates found in input bundle", }, "single cert in the first source, joined certs in the second source": { bundle: []string{ @@ -512,9 +521,17 @@ func TestBundlesDeduplication(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - resultBundle, err := deduplicateAndSortBundles(test.bundle) + certPool := util.NewCertPool() + err := certPool.AddCertsFromPEM([]byte(strings.Join(test.bundle, "\n"))) + if test.expError != "" { + assert.NotNil(t, err) + assert.Equal(t, err.Error(), test.expError) + return + } else { + assert.Nil(t, err) + } - assert.Nil(t, err) + resultBundle := certPool.PEMSplit() // check certificates bundle for duplicated certificates assert.Equal(t, test.testBundle, resultBundle) diff --git a/pkg/fspkg/package.go b/pkg/fspkg/package.go index 253998e7..47fa02fa 100644 --- a/pkg/fspkg/package.go +++ b/pkg/fspkg/package.go @@ -64,7 +64,10 @@ func (p *Package) Clone() *Package { func (p *Package) Validate() error { // Ignore the sanitized bundle here and preserve the bundle as-is. // We'll sanitize later, when building a bundle on a reconcile. - _, err := util.ValidateAndSanitizePEMBundle([]byte(p.Bundle)) + + certPool := util.NewCertPool(util.WithFilteredExpiredCerts(false)) + + err := certPool.AddCertsFromPEM([]byte(p.Bundle)) if err != nil { return fmt.Errorf("package bundle failed validation: %w", err) } diff --git a/pkg/util/cert_pool.go b/pkg/util/cert_pool.go index 5d3394d7..971ca5b2 100644 --- a/pkg/util/cert_pool.go +++ b/pkg/util/cert_pool.go @@ -17,32 +17,72 @@ limitations under the License. package util import ( + "bytes" + "crypto/sha256" "crypto/x509" "encoding/pem" "fmt" + "slices" "time" ) // CertPool is a set of certificates. -type certPool struct { - certificates []*x509.Certificate +type CertPool struct { + certificates map[[32]byte]*x509.Certificate + filterExpired bool } -// newCertPool returns a new, empty CertPool. -func newCertPool(filterExpired bool) *certPool { - return &certPool{ - certificates: make([]*x509.Certificate, 0), - filterExpired: filterExpired, +type Option func(*CertPool) + +func WithFilteredExpiredCerts(filterExpired bool) Option { + return func(cp *CertPool) { + cp.filterExpired = filterExpired + } +} + +// NewCertPool returns a new, empty CertPool. +// It will deduplicate certificates based on their SHA256 hash. +// Optionally, it can filter out expired certificates. +func NewCertPool(options ...Option) *CertPool { + certPool := &CertPool{ + certificates: make(map[[32]byte]*x509.Certificate), + } + + for _, option := range options { + option(certPool) } + + return certPool } -// Append certificate to a pool -func (cp *certPool) appendCertFromPEM(pemData []byte) error { +// AddCertsFromPEM strictly validates a given input PEM bundle to confirm it contains +// only valid CERTIFICATE PEM blocks. If successful, returns the validated PEM blocks with any +// comments or extra data stripped. +// +// This validation is broadly similar to the standard library function +// crypto/x509.CertPool.AppendCertsFromPEM - that is, we decode each PEM block at a time and parse +// it as a certificate. +// +// The difference here is that we want to ensure that the bundle _only_ contains certificates, and +// not just skip over things which aren't certificates. +// +// If, for example, someone accidentally used a combined cert + private key as an input to a trust +// bundle, we wouldn't want to then distribute the private key in the target. +// +// In addition, the standard library AppendCertsFromPEM also silently skips PEM blocks with +// non-empty Headers. We error on such PEM blocks, for the same reason as above; headers could +// contain (accidental) private information. They're also non-standard according to +// https://www.rfc-editor.org/rfc/rfc7468 +// +// Additionally, if the input PEM bundle contains no non-expired certificates, an error is returned. +// TODO: Reconsider what should happen if the input only contains expired certificates. +func (cp *CertPool) AddCertsFromPEM(pemData []byte) error { if pemData == nil { return fmt.Errorf("certificate data can't be nil") } + ok := false for { var block *pem.Block block, pemData = pem.Decode(pemData) @@ -75,19 +115,68 @@ func (cp *certPool) appendCertFromPEM(pemData []byte) error { continue } - cp.certificates = append(cp.certificates, certificate) + ok = true // at least one non-expired certificate was found in the input + + hash := sha256.Sum256(certificate.Raw) + cp.certificates[hash] = certificate + } + + if !ok { + return fmt.Errorf("no non-expired certificates found in input bundle") } return nil } -// Get PEM certificates from pool -func (cp *certPool) getCertsPEM() [][]byte { - var certsData [][]byte = make([][]byte, len(cp.certificates)) +// Get certificates quantity in the certificates pool +func (cp *CertPool) Size() int { + return len(cp.certificates) +} + +func (certPool *CertPool) PEM() string { + if certPool == nil || len(certPool.certificates) == 0 { + return "" + } + + buffer := bytes.Buffer{} + + for _, cert := range certPool.Certificates() { + if err := pem.Encode(&buffer, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}); err != nil { + return "" + } + } + + return string(bytes.TrimSpace(buffer.Bytes())) +} + +func (certPool *CertPool) PEMSplit() []string { + if certPool == nil || len(certPool.certificates) == 0 { + return nil + } + + pems := make([]string, 0, len(certPool.certificates)) + for _, cert := range certPool.Certificates() { + pems = append(pems, string(bytes.TrimSpace(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})))) + } + + return pems +} + +// Get the list of all x509 Certificates in the certificates pool +func (certPool *CertPool) Certificates() []*x509.Certificate { + hashes := make([][32]byte, 0, len(certPool.certificates)) + for hash := range certPool.certificates { + hashes = append(hashes, hash) + } + + slices.SortFunc(hashes, func(i, j [32]byte) int { + return bytes.Compare(i[:], j[:]) + }) - for i, cert := range cp.certificates { - certsData[i] = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) + orderedCertificates := make([]*x509.Certificate, 0, len(certPool.certificates)) + for _, hash := range hashes { + orderedCertificates = append(orderedCertificates, certPool.certificates[hash]) } - return certsData + return orderedCertificates } diff --git a/pkg/util/cert_pool_test.go b/pkg/util/cert_pool_test.go index 21e376a1..fca44b73 100644 --- a/pkg/util/cert_pool_test.go +++ b/pkg/util/cert_pool_test.go @@ -20,10 +20,11 @@ import ( "testing" "github.com/cert-manager/trust-manager/test/dummy" + "github.com/stretchr/testify/require" ) func TestNewCertPool(t *testing.T) { - certPool := newCertPool(false) + certPool := NewCertPool(WithFilteredExpiredCerts(false)) if certPool == nil { t.Fatal("pool is nil") @@ -35,22 +36,23 @@ func TestAppendCertFromPEM(t *testing.T) { certificateList := [...]struct { certificateName string certificate string + expectError string expectNil bool }{ { - "TestCertificate5", - dummy.TestCertificate5, - false, + certificateName: "TestCertificate5", + certificate: dummy.TestCertificate5, + expectNil: false, }, { - "TestCertificateChain6", - dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate2, dummy.TestCertificate3), - false, + certificateName: "TestCertificateChain6", + certificate: dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate2, dummy.TestCertificate3), + expectNil: false, }, { // invalid certificate - "TestCertificateInvalid7", - `-----BEGIN CERTIFICATE----- + certificateName: "TestCertificateInvalid7", + certificate: `-----BEGIN CERTIFICATE----- MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 @@ -60,29 +62,36 @@ func TestAppendCertFromPEM(t *testing.T) { h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW`, - true, + expectError: "no non-expired certificates found in input bundle", + expectNil: true, }, { - "TestCertificateInvalid8", - "qwerty", - true, + certificateName: "TestCertificateInvalid8", + certificate: "qwerty", + expectError: "no non-expired certificates found in input bundle", + expectNil: true, }, { - "TestExpiredCertificate", - dummy.TestExpiredCertificate, - false, + certificateName: "TestExpiredCertificate", + certificate: dummy.TestExpiredCertificate, + expectNil: false, }, } // populate certificates bundle for _, crt := range certificateList { - certPool := newCertPool(false) + certPool := NewCertPool(WithFilteredExpiredCerts(false)) - if err := certPool.appendCertFromPEM([]byte(crt.certificate)); err != nil { - t.Fatalf("error adding PEM certificate into pool %s", err) + err := certPool.AddCertsFromPEM([]byte(crt.certificate)) + if crt.expectError != "" { + require.Error(t, err) + require.Equal(t, crt.expectError, err.Error()) + continue + } else { + require.NoError(t, err) } - certPEM := certPool.getCertsPEM() + certPEM := certPool.PEM() if len(certPEM) != 0 == (crt.expectNil) { t.Fatalf("error getting PEM certificates from pool: certificate data is nil") } diff --git a/pkg/util/pem.go b/pkg/util/pem.go deleted file mode 100644 index 4a0fde0a..00000000 --- a/pkg/util/pem.go +++ /dev/null @@ -1,123 +0,0 @@ -/* -Copyright 2022 The cert-manager Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package util - -import ( - "bytes" - "crypto/x509" - "encoding/pem" - "fmt" -) - -// ValidateAndSanitizePEMBundle strictly validates a given input PEM bundle to confirm it contains -// only valid CERTIFICATE PEM blocks. If successful, returns the validated PEM blocks with any -// comments or extra data stripped. - -// This validation is broadly similar to the standard library function -// crypto/x509.CertPool.AppendCertsFromPEM - that is, we decode each PEM block at a time and parse -// it as a certificate. - -// The difference here is that we want to ensure that the bundle _only_ contains certificates, and -// not just skip over things which aren't certificates. - -// If, for example, someone accidentally used a combined cert + private key as an input to a trust -// bundle, we wouldn't want to then distribute the private key in the target. - -// In addition, the standard library AppendCertsFromPEM also silently skips PEM blocks with -// non-empty Headers. We error on such PEM blocks, for the same reason as above; headers could -// contain (accidental) private information. They're also non-standard according to -// https://www.rfc-editor.org/rfc/rfc7468 - -type ValidateAndSanitizeOptions struct { - FilterExpired bool // If true, expired certificates will be filtered out -} - -// ValidateAndSanitizePEMBundle keeps the original function signature for backward compatibility -func ValidateAndSanitizePEMBundle(data []byte) ([]byte, error) { - opts := ValidateAndSanitizeOptions{ - FilterExpired: false, - } - return ValidateAndSanitizePEMBundleWithOptions(data, opts) -} - -// ValidateAndSplitPEMBundle keeps the original function signature for backward compatibility -func ValidateAndSplitPEMBundle(data []byte) ([][]byte, error) { - opts := ValidateAndSanitizeOptions{ - FilterExpired: false, - } - return ValidateAndSplitPEMBundleWithOptions(data, opts) -} - -// See also https://github.com/golang/go/blob/5d5ed57b134b7a02259ff070864f753c9e601a18/src/crypto/x509/cert_pool.go#L201-L239 -// An option to enable filtering of expired certificates is available. -func ValidateAndSanitizePEMBundleWithOptions(data []byte, opts ValidateAndSanitizeOptions) ([]byte, error) { - certificates, err := ValidateAndSplitPEMBundleWithOptions(data, opts) - if err != nil { - return nil, err - } - - if len(certificates) == 0 { - return nil, fmt.Errorf("bundle contains no PEM certificates") - } - - return bytes.TrimSpace(bytes.Join(certificates, nil)), nil -} - -// ValidateAndSplitPEMBundleWithOptions takes a PEM bundle as input, validates it and -// returns the list of certificates as a slice, allowing them to be iterated over. -// This process involves performs deduplication of certificates to ensure -// no duplicated certificates in the bundle. -// For details of the validation performed, see the comment for ValidateAndSanitizePEMBundle -// An option to enable filtering of expired certificates is available. -func ValidateAndSplitPEMBundleWithOptions(data []byte, opts ValidateAndSanitizeOptions) ([][]byte, error) { - var certPool *certPool = newCertPool(opts.FilterExpired) // put PEM encoded certificate into a pool - - err := certPool.appendCertFromPEM(data) - if err != nil { - return nil, fmt.Errorf("invalid PEM block in bundle; invalid PEM certificate: %w", err) - } - - return certPool.getCertsPEM(), nil -} - -// DecodeX509CertificateChainBytes will decode a PEM encoded x509 Certificate chain. -func DecodeX509CertificateChainBytes(certBytes []byte) ([]*x509.Certificate, error) { - var certs []*x509.Certificate - - var block *pem.Block - - for { - // decode the tls certificate pem - block, certBytes = pem.Decode(certBytes) - if block == nil { - break - } - - // parse the tls certificate - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return nil, fmt.Errorf("error parsing TLS certificate: %s", err.Error()) - } - certs = append(certs, cert) - } - - if len(certs) == 0 { - return nil, fmt.Errorf("error decoding certificate PEM block") - } - - return certs, nil -} diff --git a/pkg/util/pem_test.go b/pkg/util/pem_test.go index 34fda82d..8732bd74 100644 --- a/pkg/util/pem_test.go +++ b/pkg/util/pem_test.go @@ -18,6 +18,7 @@ package util import ( "bytes" + "crypto/sha256" "crypto/x509" "strings" "testing" @@ -26,7 +27,7 @@ import ( "github.com/cert-manager/trust-manager/test/dummy" ) -func TestValidateAndSanitizePEMBundle(t *testing.T) { +func TestAddCertsFromPEM(t *testing.T) { poisonComment := []byte{0xFF} // strippableComments is a list of things which should not be present in the output strippableText := [][]byte{ @@ -35,10 +36,12 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { } cases := map[string]struct { - parts []string - filterExpiredCerts bool - expectExpiredCerts bool - expectErr bool + parts []string + filterDuplicateCerts bool + filterExpiredCerts bool + expectExpiredCerts bool + expectErr bool + expectDuplicatesCerts bool }{ "valid bundle with all types of cert and no comments succeeds": { parts: []string{dummy.TestCertificate1, dummy.TestCertificate2, dummy.TestCertificate3}, @@ -90,40 +93,47 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { filterExpiredCerts: true, expectErr: true, }, + "duplicate certificate should be removed": { + parts: []string{dummy.TestCertificate1, dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate1), dummy.TestCertificate2, dummy.TestCertificate2}, + filterExpiredCerts: true, + expectErr: false, + expectDuplicatesCerts: true, + }, } for name, test := range cases { t.Run(name, func(t *testing.T) { - validateOpts := ValidateAndSanitizeOptions{FilterExpired: test.filterExpiredCerts} + _ = name + certPool := NewCertPool(WithFilteredExpiredCerts(test.filterExpiredCerts)) inputBundle := []byte(strings.Join(test.parts, "\n")) - sanitizedBundleBytes, err := ValidateAndSanitizePEMBundleWithOptions(inputBundle, validateOpts) + err := certPool.AddCertsFromPEM(inputBundle) if test.expectErr != (err != nil) { - t.Fatalf("ValidateAndSanitizePEMBundle: expectErr: %v | err: %v", test.expectErr, err) + t.Fatalf("AddCertsFromPEM: expectErr: %v | err: %v", test.expectErr, err) } if test.expectErr { return } - if sanitizedBundleBytes == nil { - t.Fatalf("got no error from ValidateAndSanitizePEMBundle but sanitizedBundle was nil") + if certPool.Size() == 0 { + t.Fatalf("got no error from AddCertsFromPEM but sanitizedBundle was nil") } for _, strippable := range strippableText { - if bytes.Contains(sanitizedBundleBytes, strippable) { + if bytes.Contains([]byte(certPool.PEM()), strippable) { // can't print the comment since it could be an invalid string t.Errorf("expected sanitizedBundle to not contain a comment but it did") } } - if !utf8.Valid(sanitizedBundleBytes) { + if !utf8.Valid([]byte(certPool.PEM())) { t.Error("expected sanitizedBundle to be valid UTF-8 but it wasn't") } - sanitizedBundle := string(sanitizedBundleBytes) + sanitizedBundle := certPool.PEM() if strings.HasSuffix(sanitizedBundle, "\n") { t.Errorf("expected sanitizedBundle not to end with a newline") @@ -141,7 +151,7 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { } } - certs, err := ValidateAndSplitPEMBundleWithOptions(sanitizedBundleBytes, validateOpts) + certs := certPool.Certificates() if err != nil { t.Errorf("failed to split already-validated bundle: %s", err) return @@ -150,28 +160,24 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { var expiredCerts []*x509.Certificate for _, cert := range certs { - parsedCerts, err := DecodeX509CertificateChainBytes(cert) - if err != nil { - t.Errorf("failed to decode split PEM cert: %s", err) - continue - } - - if len(parsedCerts) != 1 { - // shouldn't ever happen since we're decoding a single PEM cert - t.Errorf("got more than one parsed cert after splitting a PEM bundle") - continue - } - - parsedCert := parsedCerts[0] - - if parsedCert.NotAfter.Before(dummy.DummyInstant()) { - expiredCerts = append(expiredCerts, parsedCert) + if cert.NotAfter.Before(dummy.DummyInstant()) { + expiredCerts = append(expiredCerts, cert) } } if test.expectExpiredCerts != (len(expiredCerts) > 0) { t.Errorf("expectExpiredCerts=%v but got %d expired certs", test.expectExpiredCerts, len(expiredCerts)) } + + if test.expectDuplicatesCerts { + var hashes = make(map[[32]byte]struct{}) + for _, cert := range certs { + hash := sha256.Sum256(cert.Raw) + if _, ok := hashes[hash]; ok { + t.Errorf("expectDuplicatesCerts=%v but got duplicate certs", test.expectDuplicatesCerts) + } + } + } }) } } diff --git a/test/dummy/certificates.go b/test/dummy/certificates.go index bff48af1..7cd6d01b 100644 --- a/test/dummy/certificates.go +++ b/test/dummy/certificates.go @@ -489,5 +489,5 @@ func DefaultJoinedCerts() string { } func JoinCerts(certs ...string) string { - return strings.Join(certs, "\n") + "\n" + return strings.Join(certs, "\n") } diff --git a/test/env/data.go b/test/env/data.go index 46f669cd..afe0590c 100644 --- a/test/env/data.go +++ b/test/env/data.go @@ -346,6 +346,7 @@ func EventuallyBundleHasSyncedAllNamespacesContains(ctx context.Context, cl clie // CheckJKSFileSynced ensures that the given JKS data func CheckJKSFileSynced(jksData []byte, expectedPassword string, expectedCertPEMData string) error { reader := bytes.NewReader(jksData) + certPool := util.NewCertPool(util.WithFilteredExpiredCerts(false)) ks := jks.New() @@ -354,7 +355,7 @@ func CheckJKSFileSynced(jksData []byte, expectedPassword string, expectedCertPEM return err } - expectedCertList, err := util.ValidateAndSplitPEMBundle([]byte(expectedCertPEMData)) + err = certPool.AddCertsFromPEM([]byte(expectedCertPEMData)) if err != nil { return fmt.Errorf("invalid PEM data passed to CheckJKSFileSynced: %s", err) } @@ -363,7 +364,7 @@ func CheckJKSFileSynced(jksData []byte, expectedPassword string, expectedCertPEM // that the count is the same aliasCount := len(ks.Aliases()) - expectedPEMCount := len(expectedCertList) + expectedPEMCount := certPool.Size() if aliasCount != expectedPEMCount { return fmt.Errorf("expected %d certificates in JKS but found %d", expectedPEMCount, aliasCount)