From 90da9087cc5507d47857562d36767220f52594f7 Mon Sep 17 00:00:00 2001 From: Ruben Vargas Date: Wed, 24 Jul 2024 10:17:24 -0600 Subject: [PATCH] Use GetClientCertificate to allow client cert to be reloaded (#537) * Use GetClientCertificate to allow client cert to be reloaded --------- Signed-off-by: Ruben Vargas Co-authored-by: Arve Knudsen --- CHANGELOG.md | 1 + crypto/tls/test/tls_integration_test.go | 6 ++++- crypto/tls/tls.go | 31 +++++++++++++++++-------- crypto/tls/tls_test.go | 5 +++- 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b8865269..f21aaf6fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -218,6 +218,7 @@ * [ENHANCEMENT] memberlist: use separate queue for broadcast messages that are result of local updates, and prioritize locally-generated messages when sending broadcasts. On stopping, only wait for queue with locally-generated messages to be empty. #539 * [ENHANCEMENT] memberlist: Added `-memberlist.broadcast-timeout-for-local-updates-on-shutdown` option to set timeout for sending locally-generated updates on shutdown, instead of previously hardcoded 10s (which is still the default). #539 * [ENHANCEMENT] tracing: add ExtractTraceSpanID function. +* [EHNANCEMENT] crypto/tls: Support reloading client certificates #537 * [CHANGE] Backoff: added `Backoff.ErrCause()` which is like `Backoff.Err()` but returns the context cause if backoff is terminated because the context has been canceled. #538 * [BUGFIX] spanlogger: Support multiple tenant IDs. #59 * [BUGFIX] Memberlist: fixed corrupted packets when sending compound messages with more than 255 messages or messages bigger than 64KB. #85 diff --git a/crypto/tls/test/tls_integration_test.go b/crypto/tls/test/tls_integration_test.go index 57b08d8b2..9a8730b1a 100644 --- a/crypto/tls/test/tls_integration_test.go +++ b/crypto/tls/test/tls_integration_test.go @@ -32,6 +32,8 @@ import ( "github.com/grafana/dskit/crypto/tls" ) +const mismatchCAAndCerts = "remote error: tls: unknown certificate authority" + type tcIntegrationClientServer struct { name string tlsGrpcEnabled bool @@ -363,6 +365,8 @@ func TestTLSServerWithLocalhostCertWithClientCertificateEnforcementUsingClientCA // bad certificate from the server side and just see connection // closed/reset instead badCertErr := errorContainsString(badCertificateErrorMessage) + mismatchCAAndCertsErr := errorContainsString(mismatchCAAndCerts) + newIntegrationClientServer( t, cfg, @@ -411,7 +415,7 @@ func TestTLSServerWithLocalhostCertWithClientCertificateEnforcementUsingClientCA CertPath: certs.client2CertFile, KeyPath: certs.client2KeyFile, }, - httpExpectError: badCertErr, + httpExpectError: mismatchCAAndCertsErr, grpcExpectError: unavailableDescErr, }, }, diff --git a/crypto/tls/tls.go b/crypto/tls/tls.go index 7ed818f39..a5b3805b7 100644 --- a/crypto/tls/tls.go +++ b/crypto/tls/tls.go @@ -109,15 +109,7 @@ func (cfg *ClientConfig) GetTLSConfig() (*tls.Config, error) { config.RootCAs = caCertPool } - // Read Client Certificate - if cfg.CertPath != "" || cfg.KeyPath != "" { - if cfg.CertPath == "" { - return nil, errCertMissing - } - if cfg.KeyPath == "" { - return nil, errKeyMissing - } - + loadCert := func() (*tls.Certificate, error) { cert, err := reader.ReadSecret(cfg.CertPath) if err != nil { return nil, errors.Wrapf(err, "error loading client cert: %s", cfg.CertPath) @@ -131,7 +123,26 @@ func (cfg *ClientConfig) GetTLSConfig() (*tls.Config, error) { if err != nil { return nil, errors.Wrapf(err, "failed to load TLS certificate %s,%s", cfg.CertPath, cfg.KeyPath) } - config.Certificates = []tls.Certificate{clientCert} + return &clientCert, nil + + } + + // Read Client Certificate + if cfg.CertPath != "" || cfg.KeyPath != "" { + if cfg.CertPath == "" { + return nil, errCertMissing + } + if cfg.KeyPath == "" { + return nil, errKeyMissing + } + // Confirm that certificate and key paths are valid. + if _, err := loadCert(); err != nil { + return nil, err + } + + config.GetClientCertificate = func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return loadCert() + } } if cfg.MinVersion != "" { diff --git a/crypto/tls/tls_test.go b/crypto/tls/tls_test.go index 292307ddf..8d2a8c185 100644 --- a/crypto/tls/tls_test.go +++ b/crypto/tls/tls_test.go @@ -109,7 +109,10 @@ func TestGetTLSConfig_ClientCerts(t *testing.T) { tlsConfig, err := c.GetTLSConfig() assert.NoError(t, err) assert.Equal(t, false, tlsConfig.InsecureSkipVerify, "make sure we default to not skip verification") - assert.Equal(t, 1, len(tlsConfig.Certificates), "ensure a certificate is returned") + require.NotNil(t, tlsConfig.GetClientCertificate, "ensure GetClientCertificate is set") + cert, err := tlsConfig.GetClientCertificate(nil) + require.NoError(t, err) + assert.NotNil(t, cert, "ensure GetClientCertificate returns a certificate") // expect error with key and cert swapped passed along c = &ClientConfig{