diff --git a/common/util.go b/common/util.go index e8075a3c3b..92ca0a2504 100644 --- a/common/util.go +++ b/common/util.go @@ -213,44 +213,65 @@ func GetTLSConfig(config *apicommon.TLSConfig) (*tls.Config, error) { var caCertPath, clientCertPath, clientKeyPath string var err error - switch { - case config.CACertSecret != nil && config.ClientCertSecret != nil && config.ClientKeySecret != nil: + if config.CACertSecret != nil { caCertPath, err = GetSecretVolumePath(config.CACertSecret) if err != nil { return nil, err } + } else if config.DeprecatedCACertPath != "" { + // DEPRECATED. + caCertPath = config.DeprecatedCACertPath + } + + if config.ClientCertSecret != nil { clientCertPath, err = GetSecretVolumePath(config.ClientCertSecret) if err != nil { return nil, err } + } else if config.DeprecatedClientCertPath != "" { + // DEPRECATED. + clientCertPath = config.DeprecatedClientCertPath + } + + if config.ClientKeySecret != nil { clientKeyPath, err = GetSecretVolumePath(config.ClientKeySecret) if err != nil { return nil, err } - case config.DeprecatedCACertPath != "" && config.DeprecatedClientCertPath != "" && config.DeprecatedClientKeyPath != "": + } else if config.DeprecatedClientKeyPath != "" { // DEPRECATED. - caCertPath = config.DeprecatedCACertPath - clientCertPath = config.DeprecatedClientCertPath clientKeyPath = config.DeprecatedClientKeyPath - default: - return nil, errors.New("invalid tls config, please configure caCertSecret, clientCertSecret and clientKeySecret") } - caCert, err := ioutil.ReadFile(caCertPath) - if err != nil { - return nil, errors.Wrapf(err, "failed to read ca cert file %s", caCertPath) + if len(caCertPath)+len(clientCertPath)+len(clientKeyPath) == 0 { + // None of 3 is configured + return nil, errors.New("invalid tls config, neither of caCertSecret, clientCertSecret and clientKeySecret is configured") } - pool := x509.NewCertPool() - pool.AppendCertsFromPEM(caCert) - clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) - if err != nil { - return nil, errors.Wrapf(err, "failed to load client cert key pair %s", caCertPath) + if len(clientCertPath)+len(clientKeyPath) > 0 && len(clientCertPath)*len(clientKeyPath) == 0 { + // Only one of clientCertSecret and clientKeySecret is configured + return nil, errors.New("invalid tls config, both of clientCertSecret and clientKeySecret need to be configured") + } + + c := &tls.Config{} + if len(caCertPath) > 0 { + caCert, err := ioutil.ReadFile(caCertPath) + if err != nil { + return nil, errors.Wrapf(err, "failed to read ca cert file %s", caCertPath) + } + pool := x509.NewCertPool() + pool.AppendCertsFromPEM(caCert) + c.RootCAs = pool + } + + if len(clientCertPath) > 0 && len(clientKeyPath) > 0 { + clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) + if err != nil { + return nil, errors.Wrapf(err, "failed to load client cert key pair %s", caCertPath) + } + c.Certificates = []tls.Certificate{clientCert} } - return &tls.Config{ - RootCAs: pool, - Certificates: []tls.Certificate{clientCert}, - }, nil + return c, nil } // VolumesFromSecretsOrConfigMaps builds volumes and volumeMounts spec based on diff --git a/common/util_test.go b/common/util_test.go index 34aa3e1be1..b599525101 100644 --- a/common/util_test.go +++ b/common/util_test.go @@ -18,8 +18,10 @@ package common import ( "net/http" + "strings" "testing" + apicommon "github.com/argoproj/argo-events/pkg/apis/common" "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" ) @@ -171,3 +173,69 @@ func TestVolumesFromSecretsOrConfigMaps(t *testing.T) { assert.Equal(t, len(mounts), 6) }) } + +func fakeTLSConfig(t *testing.T) *apicommon.TLSConfig { + t.Helper() + return &apicommon.TLSConfig{ + CACertSecret: &corev1.SecretKeySelector{ + Key: "fake-key1", + LocalObjectReference: corev1.LocalObjectReference{ + Name: "fake-name1", + }, + }, + ClientCertSecret: &corev1.SecretKeySelector{ + Key: "fake-key2", + LocalObjectReference: corev1.LocalObjectReference{ + Name: "fake-name2", + }, + }, + ClientKeySecret: &corev1.SecretKeySelector{ + Key: "fake-key3", + LocalObjectReference: corev1.LocalObjectReference{ + Name: "fake-name3", + }, + }, + } +} + +func TestGetTLSConfig(t *testing.T) { + t.Run("test empty", func(t *testing.T) { + c := &apicommon.TLSConfig{} + _, err := GetTLSConfig(c) + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "neither of caCertSecret, clientCertSecret and clientKeySecret is configured")) + }) + + t.Run("test clientKeySecret is set, clientCertSecret is empty", func(t *testing.T) { + c := fakeTLSConfig(t) + c.CACertSecret = nil + c.ClientCertSecret = nil + _, err := GetTLSConfig(c) + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "both of clientCertSecret and clientKeySecret need to be configured")) + }) + + t.Run("test only caCertSecret is set", func(t *testing.T) { + c := fakeTLSConfig(t) + c.ClientCertSecret = nil + c.ClientKeySecret = nil + _, err := GetTLSConfig(c) + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "failed to read ca cert file")) + }) + + t.Run("test clientCertSecret and clientKeySecret are set", func(t *testing.T) { + c := fakeTLSConfig(t) + c.CACertSecret = nil + _, err := GetTLSConfig(c) + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "failed to load client cert key pair")) + }) + + t.Run("test all of 3 are set", func(t *testing.T) { + c := fakeTLSConfig(t) + _, err := GetTLSConfig(c) + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "failed to read ca cert file")) + }) +}