diff --git a/pkg/bundle/source.go b/pkg/bundle/source.go index f6ea3a8f..1de112ca 100644 --- a/pkg/bundle/source.go +++ b/pkg/bundle/source.go @@ -21,7 +21,6 @@ import ( "context" "crypto/sha256" "encoding/hex" - "encoding/pem" "fmt" "strings" @@ -64,7 +63,7 @@ type bundleData struct { // is each bundle is concatenated together with a new line character. func (b *bundle) buildSourceBundle(ctx context.Context, sources []trustapi.BundleSource, formats *trustapi.AdditionalFormats) (bundleData, error) { var resolvedBundle bundleData - var bundles []string + certPool := util.NewCertPool(util.WithFilteredExpiredCerts(b.FilterExpiredCerts)) for _, source := range sources { var ( @@ -99,27 +98,19 @@ func (b *bundle) buildSourceBundle(ctx context.Context, sources []trustapi.Bundl return bundleData{}, fmt.Errorf("failed to retrieve bundle from source: %w", err) } - opts := util.ValidateAndSanitizeOptions{FilterExpired: b.Options.FilterExpiredCerts} - sanitizedBundle, err := util.ValidateAndSanitizePEMBundleWithOptions([]byte(sourceData), opts) + err = util.ValidateAndSplitPEMBundle(certPool, []byte(sourceData)) if err != nil { return bundleData{}, fmt.Errorf("invalid PEM data in source: %w", err) } - bundles = append(bundles, string(sanitizedBundle)) } // NB: empty bundles are not valid so check and return an error if one somehow snuck through. - - if len(bundles) == 0 { + if util.GetCertificatesQuantity(certPool) == 0 { return bundleData{}, fmt.Errorf("couldn't find any valid certificates in bundle") } - deduplicatedBundles, err := deduplicateBundles(bundles) - if err != nil { - return bundleData{}, err - } - - if err := resolvedBundle.populateData(deduplicatedBundles, formats); err != nil { + if err := resolvedBundle.populateData(util.AsPEMBundleStrings(certPool), formats); err != nil { return bundleData{}, err } @@ -342,39 +333,3 @@ func (b *bundleData) populateData(bundles []string, formats *trustapi.Additional } return nil } - -// remove duplicate certificates from bundles -func deduplicateBundles(bundles []string) ([]string, error) { - var block *pem.Block - - var certificatesHashes = make(map[[32]byte]struct{}) - var dedupCerts []string - - for _, cert := range bundles { - certBytes := []byte(cert) - - LOOP: - for { - block, certBytes = pem.Decode(certBytes) - if block == nil { - break LOOP - } - - if block.Type != "CERTIFICATE" { - return nil, fmt.Errorf("couldn't decode PEM block containing certificate") - } - - // calculate hash sum of the given certificate - hash := sha256.Sum256(block.Bytes) - // check existence of the hash - if _, ok := certificatesHashes[hash]; !ok { - // neew to trim a newline which is added by Encoder - dedupCerts = append(dedupCerts, string(bytes.Trim(pem.EncodeToMemory(block), "\n"))) - certificatesHashes[hash] = struct{}{} - } - } - - } - - return dedupCerts, nil -} diff --git a/pkg/bundle/source_test.go b/pkg/bundle/source_test.go index 5a91ad29..3092f262 100644 --- a/pkg/bundle/source_test.go +++ b/pkg/bundle/source_test.go @@ -431,80 +431,3 @@ func Test_certAlias(t *testing.T) { t.Fatalf("expected alias to be %q but got %q", expectedAlias, alias) } } - -func TestBundlesDeduplication(t *testing.T) { - tests := map[string]struct { - name string - bundle []string - testBundle []string - }{ - "single, different cert per source": { - bundle: []string{ - dummy.TestCertificate1, - dummy.TestCertificate2, - }, - testBundle: []string{ - dummy.TestCertificate1, - dummy.TestCertificate2, - }, - }, - "no certs in sources": { - bundle: []string{}, - testBundle: []string{}, - }, - "single cert in the first source, joined certs in the second source": { - bundle: []string{ - dummy.TestCertificate1, - dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate3), - }, - testBundle: []string{ - dummy.TestCertificate1, - dummy.TestCertificate3, - }, - }, - "joined certs in the first source, single cert in the second source": { - bundle: []string{ - dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate3), - dummy.TestCertificate1, - }, - testBundle: []string{ - dummy.TestCertificate3, - dummy.TestCertificate1, - }, - }, - "joined, different certs in the first source; joined,different certs in the second source": { - bundle: []string{ - dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate2), - dummy.JoinCerts(dummy.TestCertificate4, dummy.TestCertificate5), - }, - testBundle: []string{ - dummy.TestCertificate1, - dummy.TestCertificate2, - dummy.TestCertificate4, - dummy.TestCertificate5, - }, - }, - "all certs are joined ones and equal ones in all sources": { - bundle: []string{ - dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate1), - dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate1), - }, - testBundle: []string{ - dummy.TestCertificate1, - }, - }, - } - for name, test := range tests { - test := test - t.Run(name, func(t *testing.T) { - t.Parallel() - - resultBundle, err := deduplicateBundles(test.bundle) - - assert.Nil(t, err) - - // check certificates bundle for duplicated certificates - assert.ElementsMatch(t, test.testBundle, resultBundle) - }) - } -} diff --git a/pkg/fspkg/package.go b/pkg/fspkg/package.go index 253998e7..3699b5aa 100644 --- a/pkg/fspkg/package.go +++ b/pkg/fspkg/package.go @@ -64,7 +64,10 @@ func (p *Package) Clone() *Package { func (p *Package) Validate() error { // Ignore the sanitized bundle here and preserve the bundle as-is. // We'll sanitize later, when building a bundle on a reconcile. - _, err := util.ValidateAndSanitizePEMBundle([]byte(p.Bundle)) + + certPool := util.NewCertPool(util.WithFilteredExpiredCerts(false)) + + err := util.ValidateAndSplitPEMBundle(certPool, []byte(p.Bundle)) if err != nil { return fmt.Errorf("package bundle failed validation: %w", err) } diff --git a/pkg/util/cert_pool.go b/pkg/util/cert_pool.go index 5d3394d7..ec298b98 100644 --- a/pkg/util/cert_pool.go +++ b/pkg/util/cert_pool.go @@ -17,6 +17,7 @@ limitations under the License. package util import ( + "crypto/sha256" "crypto/x509" "encoding/pem" "fmt" @@ -24,21 +25,36 @@ import ( ) // CertPool is a set of certificates. -type certPool struct { - certificates []*x509.Certificate - filterExpired bool +type CertPool struct { + certificatesHashes map[[32]byte]struct{} + certificates []*x509.Certificate + filterExpired bool +} + +type Option func(*CertPool) + +func WithFilteredExpiredCerts(filterExpired bool) Option { + return func(cp *CertPool) { + cp.filterExpired = filterExpired + } } // newCertPool returns a new, empty CertPool. -func newCertPool(filterExpired bool) *certPool { - return &certPool{ - certificates: make([]*x509.Certificate, 0), - filterExpired: filterExpired, +func NewCertPool(options ...Option) *CertPool { + certPool := &CertPool{ + certificates: make([]*x509.Certificate, 0), + certificatesHashes: make(map[[32]byte]struct{}), } + + for _, option := range options { + option(certPool) + } + + return certPool } // Append certificate to a pool -func (cp *certPool) appendCertFromPEM(pemData []byte) error { +func (cp *CertPool) appendCertFromPEM(pemData []byte) error { if pemData == nil { return fmt.Errorf("certificate data can't be nil") } @@ -75,6 +91,10 @@ func (cp *certPool) appendCertFromPEM(pemData []byte) error { continue } + if cp.isDuplicate(certificate) { + continue + } + cp.certificates = append(cp.certificates, certificate) } @@ -82,7 +102,7 @@ func (cp *certPool) appendCertFromPEM(pemData []byte) error { } // Get PEM certificates from pool -func (cp *certPool) getCertsPEM() [][]byte { +func (cp *CertPool) getCertsPEM() [][]byte { var certsData [][]byte = make([][]byte, len(cp.certificates)) for i, cert := range cp.certificates { @@ -91,3 +111,25 @@ func (cp *certPool) getCertsPEM() [][]byte { return certsData } + +// Get certificates quantity in the certificates pool +func (cp *CertPool) size() int { + return len(cp.certificates) +} + +// Check deplicates of certificate in the certificates pool +func (cp *CertPool) isDuplicate(cert *x509.Certificate) bool { + hash := sha256.Sum256(cert.Raw) + // check existence of the hash + if _, ok := cp.certificatesHashes[hash]; !ok { + cp.certificatesHashes[hash] = struct{}{} + return false + } + + return true +} + +// Get the full list of x509 Certificates from the certificates pool +func (cp *CertPool) getCertsList() []*x509.Certificate { + return cp.certificates +} diff --git a/pkg/util/cert_pool_test.go b/pkg/util/cert_pool_test.go index 21e376a1..c77eb3aa 100644 --- a/pkg/util/cert_pool_test.go +++ b/pkg/util/cert_pool_test.go @@ -23,7 +23,7 @@ import ( ) func TestNewCertPool(t *testing.T) { - certPool := newCertPool(false) + certPool := NewCertPool(WithFilteredExpiredCerts(false)) if certPool == nil { t.Fatal("pool is nil") @@ -76,7 +76,7 @@ func TestAppendCertFromPEM(t *testing.T) { // populate certificates bundle for _, crt := range certificateList { - certPool := newCertPool(false) + certPool := NewCertPool(WithFilteredExpiredCerts(false)) if err := certPool.appendCertFromPEM([]byte(crt.certificate)); err != nil { t.Fatalf("error adding PEM certificate into pool %s", err) diff --git a/pkg/util/pem.go b/pkg/util/pem.go index 4a0fde0a..95558489 100644 --- a/pkg/util/pem.go +++ b/pkg/util/pem.go @@ -21,6 +21,7 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "strings" ) // ValidateAndSanitizePEMBundle strictly validates a given input PEM bundle to confirm it contains @@ -42,56 +43,19 @@ import ( // contain (accidental) private information. They're also non-standard according to // https://www.rfc-editor.org/rfc/rfc7468 -type ValidateAndSanitizeOptions struct { - FilterExpired bool // If true, expired certificates will be filtered out -} - -// ValidateAndSanitizePEMBundle keeps the original function signature for backward compatibility -func ValidateAndSanitizePEMBundle(data []byte) ([]byte, error) { - opts := ValidateAndSanitizeOptions{ - FilterExpired: false, - } - return ValidateAndSanitizePEMBundleWithOptions(data, opts) -} - -// ValidateAndSplitPEMBundle keeps the original function signature for backward compatibility -func ValidateAndSplitPEMBundle(data []byte) ([][]byte, error) { - opts := ValidateAndSanitizeOptions{ - FilterExpired: false, - } - return ValidateAndSplitPEMBundleWithOptions(data, opts) -} - // See also https://github.com/golang/go/blob/5d5ed57b134b7a02259ff070864f753c9e601a18/src/crypto/x509/cert_pool.go#L201-L239 // An option to enable filtering of expired certificates is available. -func ValidateAndSanitizePEMBundleWithOptions(data []byte, opts ValidateAndSanitizeOptions) ([]byte, error) { - certificates, err := ValidateAndSplitPEMBundleWithOptions(data, opts) +func ValidateAndSplitPEMBundle(certPool *CertPool, data []byte) error { + err := certPool.appendCertFromPEM(data) if err != nil { - return nil, err - } - - if len(certificates) == 0 { - return nil, fmt.Errorf("bundle contains no PEM certificates") + return err } - return bytes.TrimSpace(bytes.Join(certificates, nil)), nil -} - -// ValidateAndSplitPEMBundleWithOptions takes a PEM bundle as input, validates it and -// returns the list of certificates as a slice, allowing them to be iterated over. -// This process involves performs deduplication of certificates to ensure -// no duplicated certificates in the bundle. -// For details of the validation performed, see the comment for ValidateAndSanitizePEMBundle -// An option to enable filtering of expired certificates is available. -func ValidateAndSplitPEMBundleWithOptions(data []byte, opts ValidateAndSanitizeOptions) ([][]byte, error) { - var certPool *certPool = newCertPool(opts.FilterExpired) // put PEM encoded certificate into a pool - - err := certPool.appendCertFromPEM(data) - if err != nil { - return nil, fmt.Errorf("invalid PEM block in bundle; invalid PEM certificate: %w", err) + if certPool.size() == 0 { + return fmt.Errorf("bundle contains no PEM certificates") } - return certPool.getCertsPEM(), nil + return nil } // DecodeX509CertificateChainBytes will decode a PEM encoded x509 Certificate chain. @@ -121,3 +85,33 @@ func DecodeX509CertificateChainBytes(certBytes []byte) ([]*x509.Certificate, err return certs, nil } + +// Get the split bundle of all certificates in the certificates pool as representation of [][]byte +func AsSplitPEMBundle(certPool *CertPool) [][]byte { + return certPool.getCertsPEM() +} + +// Get the split bundle of all certificates in the certificates pool as representation of []byte +func AsPEMBundleBytes(certPool *CertPool) []byte { + return bytes.TrimSpace(bytes.Join(certPool.getCertsPEM(), nil)) +} + +// Get the split bundle of all certificates in the certificates pool as representation of []string +func AsPEMBundleStrings(certPool *CertPool) []string { + var certList = make([]string, 0) + + for _, cert := range certPool.getCertsPEM() { + certList = append(certList, strings.TrimSpace(string(cert))) + } + + return certList +} + +// Get the list of all x509 Certificates in the certificates pool +func AsCertificateList(certPool *CertPool) []*x509.Certificate { + return certPool.getCertsList() +} + +func GetCertificatesQuantity(certPool *CertPool) int { + return certPool.size() +} diff --git a/pkg/util/pem_test.go b/pkg/util/pem_test.go index 34fda82d..c8fafb26 100644 --- a/pkg/util/pem_test.go +++ b/pkg/util/pem_test.go @@ -18,6 +18,7 @@ package util import ( "bytes" + "crypto/sha256" "crypto/x509" "strings" "testing" @@ -35,10 +36,12 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { } cases := map[string]struct { - parts []string - filterExpiredCerts bool - expectExpiredCerts bool - expectErr bool + parts []string + filterDuplicateCerts bool + filterExpiredCerts bool + expectExpiredCerts bool + expectErr bool + expectDuplicatesCerts bool }{ "valid bundle with all types of cert and no comments succeeds": { parts: []string{dummy.TestCertificate1, dummy.TestCertificate2, dummy.TestCertificate3}, @@ -90,15 +93,22 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { filterExpiredCerts: true, expectErr: true, }, + "duplicate certificate should be removed": { + parts: []string{dummy.TestCertificate1, dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate1), dummy.TestCertificate2, dummy.TestCertificate2}, + filterExpiredCerts: true, + expectErr: false, + expectDuplicatesCerts: true, + }, } for name, test := range cases { t.Run(name, func(t *testing.T) { - validateOpts := ValidateAndSanitizeOptions{FilterExpired: test.filterExpiredCerts} + _ = name + var certPool = NewCertPool(WithFilteredExpiredCerts(test.filterExpiredCerts)) inputBundle := []byte(strings.Join(test.parts, "\n")) - sanitizedBundleBytes, err := ValidateAndSanitizePEMBundleWithOptions(inputBundle, validateOpts) + err := ValidateAndSplitPEMBundle(certPool, inputBundle) if test.expectErr != (err != nil) { t.Fatalf("ValidateAndSanitizePEMBundle: expectErr: %v | err: %v", test.expectErr, err) @@ -108,22 +118,22 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { return } - if sanitizedBundleBytes == nil { + if GetCertificatesQuantity(certPool) == 0 { t.Fatalf("got no error from ValidateAndSanitizePEMBundle but sanitizedBundle was nil") } for _, strippable := range strippableText { - if bytes.Contains(sanitizedBundleBytes, strippable) { + if bytes.Contains(AsPEMBundleBytes(certPool), strippable) { // can't print the comment since it could be an invalid string t.Errorf("expected sanitizedBundle to not contain a comment but it did") } } - if !utf8.Valid(sanitizedBundleBytes) { + if !utf8.Valid(AsPEMBundleBytes(certPool)) { t.Error("expected sanitizedBundle to be valid UTF-8 but it wasn't") } - sanitizedBundle := string(sanitizedBundleBytes) + sanitizedBundle := string(AsPEMBundleBytes(certPool)) if strings.HasSuffix(sanitizedBundle, "\n") { t.Errorf("expected sanitizedBundle not to end with a newline") @@ -141,7 +151,7 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { } } - certs, err := ValidateAndSplitPEMBundleWithOptions(sanitizedBundleBytes, validateOpts) + certs := AsCertificateList(certPool) if err != nil { t.Errorf("failed to split already-validated bundle: %s", err) return @@ -150,28 +160,24 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { var expiredCerts []*x509.Certificate for _, cert := range certs { - parsedCerts, err := DecodeX509CertificateChainBytes(cert) - if err != nil { - t.Errorf("failed to decode split PEM cert: %s", err) - continue - } - - if len(parsedCerts) != 1 { - // shouldn't ever happen since we're decoding a single PEM cert - t.Errorf("got more than one parsed cert after splitting a PEM bundle") - continue - } - - parsedCert := parsedCerts[0] - - if parsedCert.NotAfter.Before(dummy.DummyInstant()) { - expiredCerts = append(expiredCerts, parsedCert) + if cert.NotAfter.Before(dummy.DummyInstant()) { + expiredCerts = append(expiredCerts, cert) } } if test.expectExpiredCerts != (len(expiredCerts) > 0) { t.Errorf("expectExpiredCerts=%v but got %d expired certs", test.expectExpiredCerts, len(expiredCerts)) } + + if test.expectDuplicatesCerts { + var hashes = make(map[[32]byte]struct{}) + for _, cert := range certs { + hash := sha256.Sum256(cert.Raw) + if _, ok := hashes[hash]; ok { + t.Errorf("expectDuplicatesCerts=%v but got duplicate certs", test.expectDuplicatesCerts) + } + } + } }) } } diff --git a/test/env/data.go b/test/env/data.go index c83fe5a4..3dfa3252 100644 --- a/test/env/data.go +++ b/test/env/data.go @@ -219,9 +219,9 @@ func CheckBundleSyncedStartsWith(ctx context.Context, cl client.Client, name str remaining := strings.TrimPrefix(got, startingData) - // check that there are a nonzero number of valid certs remaining + certPool := util.NewCertPool(util.WithFilteredExpiredCerts(false)) - _, err := util.ValidateAndSanitizePEMBundle([]byte(remaining)) + err := util.ValidateAndSplitPEMBundle(certPool, []byte(remaining)) if err != nil { return fmt.Errorf("received data didn't have any valid certs after valid starting data: %w", err) } @@ -317,6 +317,7 @@ func EventuallyBundleHasSyncedAllNamespacesStartsWith(ctx context.Context, cl cl // CheckJKSFileSynced ensures that the given JKS data func CheckJKSFileSynced(jksData []byte, expectedPassword string, expectedCertPEMData string) error { reader := bytes.NewReader(jksData) + var certPool = util.NewCertPool(util.WithFilteredExpiredCerts(false)) ks := jks.New() @@ -325,7 +326,7 @@ func CheckJKSFileSynced(jksData []byte, expectedPassword string, expectedCertPEM return err } - expectedCertList, err := util.ValidateAndSplitPEMBundle([]byte(expectedCertPEMData)) + err = util.ValidateAndSplitPEMBundle(certPool, []byte(expectedCertPEMData)) if err != nil { return fmt.Errorf("invalid PEM data passed to CheckJKSFileSynced: %s", err) } @@ -334,7 +335,7 @@ func CheckJKSFileSynced(jksData []byte, expectedPassword string, expectedCertPEM // that the count is the same aliasCount := len(ks.Aliases()) - expectedPEMCount := len(expectedCertList) + expectedPEMCount := util.GetCertificatesQuantity(certPool) if aliasCount != expectedPEMCount { return fmt.Errorf("expected %d certificates in JKS but found %d", expectedPEMCount, aliasCount)