diff --git a/skipper.go b/skipper.go index a083a1fe22..ebecb8d243 100644 --- a/skipper.go +++ b/skipper.go @@ -944,8 +944,35 @@ func initLog(o Options) error { return nil } -func (o *Options) isHTTPS() bool { - return (o.ProxyTLS != nil) || (o.CertPathTLS != "" && o.KeyPathTLS != "") +func (o *Options) tlsConfig() (*tls.Config, error) { + if o.ProxyTLS != nil { + return o.ProxyTLS, nil + } + + if o.CertPathTLS == "" && o.KeyPathTLS == "" { + return nil, nil + } + + crts := strings.Split(o.CertPathTLS, ",") + keys := strings.Split(o.KeyPathTLS, ",") + + if len(crts) != len(keys) { + return nil, fmt.Errorf("number of certificates does not match number of keys") + } + + config := &tls.Config{ + MinVersion: o.TLSMinVersion, + } + + for i := 0; i < len(crts); i++ { + crt, key := crts[i], keys[i] + keypair, err := tls.LoadX509KeyPair(crt, key) + if err != nil { + return nil, fmt.Errorf("Failed to load X509 keypair from %s and %s: %w", crt, key, err) + } + config.Certificates = append(config.Certificates, keypair) + } + return config, nil } func listen(o *Options, mtr metrics.Metrics) (net.Listener, error) { @@ -1005,11 +1032,14 @@ func listenAndServeQuit( idleConnsCH chan struct{}, mtr metrics.Metrics, ) error { - // create the access log handler - log.Infof("proxy listener on %v", o.Address) + tlsConfig, err := o.tlsConfig() + if err != nil { + return err + } srv := &http.Server{ Addr: o.Address, + TLSConfig: tlsConfig, Handler: proxy, ReadTimeout: o.ReadTimeoutServer, ReadHeaderTimeout: o.ReadHeaderTimeoutServer, @@ -1025,35 +1055,6 @@ func listenAndServeQuit( } } - if o.isHTTPS() { - if o.ProxyTLS != nil { - srv.TLSConfig = o.ProxyTLS - o.CertPathTLS = "" - o.KeyPathTLS = "" - } else if strings.Index(o.CertPathTLS, ",") > 0 && strings.Index(o.KeyPathTLS, ",") > 0 { - tlsCfg := &tls.Config{ - MinVersion: o.TLSMinVersion, - } - crts := strings.Split(o.CertPathTLS, ",") - keys := strings.Split(o.KeyPathTLS, ",") - if len(crts) != len(keys) { - log.Fatalf("number of certs does not match number of keys") - } - for i, crt := range crts { - kp, err := tls.LoadX509KeyPair(crt, keys[i]) - if err != nil { - log.Fatalf("Failed to load X509 keypair from %s/%s: %v", crt, keys[i], err) - } - tlsCfg.Certificates = append(tlsCfg.Certificates, kp) - } - o.CertPathTLS = "" - o.KeyPathTLS = "" - srv.TLSConfig = tlsCfg - } - return srv.ListenAndServeTLS(o.CertPathTLS, o.KeyPathTLS) - } - log.Infof("TLS settings not found, defaulting to HTTP") - // making idleConnsCH and sigs optional parameters is required to be able to tear down a server // from the tests if idleConnsCH == nil { @@ -1079,14 +1080,25 @@ func listenAndServeQuit( close(idleConnsCH) }() - l, err := listen(o, mtr) - if err != nil { - return err - } + log.Infof("proxy listener on %v", o.Address) - if err := srv.Serve(l); err != nil && err != http.ErrServerClosed { - log.Errorf("Failed to start to ListenAndServe: %v", err) - return err + if srv.TLSConfig != nil { + if err := srv.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + log.Errorf("ListenAndServeTLS failed: %v", err) + return err + } + } else { + log.Infof("TLS settings not found, defaulting to HTTP") + + l, err := listen(o, mtr) + if err != nil { + return err + } + + if err := srv.Serve(l); err != http.ErrServerClosed { + log.Errorf("Serve failed: %v", err) + return err + } } <-idleConnsCH diff --git a/skipper_test.go b/skipper_test.go index a815e7c861..3574b68137 100644 --- a/skipper_test.go +++ b/skipper_test.go @@ -20,6 +20,8 @@ import ( "github.com/zalando/skipper/proxy" "github.com/zalando/skipper/ratelimit" "github.com/zalando/skipper/routing" + + "github.com/stretchr/testify/require" ) const ( @@ -73,66 +75,52 @@ func findAddress() (string, error) { return l.Addr().String(), nil } -func TestOptionsDefaultsToHTTP(t *testing.T) { - o := Options{} - if o.isHTTPS() { - t.FailNow() - } -} - -func TestOptionsWithCertUsesHTTPS(t *testing.T) { - o := Options{CertPathTLS: "foo", KeyPathTLS: "bar"} - if !o.isHTTPS() { - t.FailNow() - } +func TestOptionsTLSConfig(t *testing.T) { + cert, err := tls.LoadX509KeyPair("fixtures/test.crt", "fixtures/test.key") + require.NoError(t, err) + + // empty + o := &Options{} + c, err := o.tlsConfig() + require.NoError(t, err) + require.Nil(t, c) + + // proxy prio + o = &Options{ProxyTLS: &tls.Config{}, CertPathTLS: "fixtures/test.crt", KeyPathTLS: "fixtures/test.key"} + c, err = o.tlsConfig() + require.NoError(t, err) + require.Equal(t, &tls.Config{}, c) + + // cert key path + o = &Options{TLSMinVersion: tls.VersionTLS12, CertPathTLS: "fixtures/test.crt", KeyPathTLS: "fixtures/test.key"} + c, err = o.tlsConfig() + require.NoError(t, err) + require.Equal(t, uint16(tls.VersionTLS12), c.MinVersion) + require.Equal(t, []tls.Certificate{cert}, c.Certificates) + + // multiple cert key paths + o = &Options{TLSMinVersion: tls.VersionTLS13, CertPathTLS: "fixtures/test.crt,fixtures/test.crt", KeyPathTLS: "fixtures/test.key,fixtures/test.key"} + c, err = o.tlsConfig() + require.NoError(t, err) + require.Equal(t, uint16(tls.VersionTLS13), c.MinVersion) + require.Equal(t, []tls.Certificate{cert, cert}, c.Certificates) } -func TestWithWrongCertPathFails(t *testing.T) { - a, err := findAddress() - if err != nil { - t.Fatal(err) - } - - o := Options{Address: a, - CertPathTLS: "fixtures/notFound.crt", - KeyPathTLS: "fixtures/test.key", - } - - rt := routing.New(routing.Options{ - FilterRegistry: builtin.MakeRegistry(), - DataClients: []routing.DataClient{}}) - defer rt.Close() - - proxy := proxy.New(rt, proxy.OptionsNone) - defer proxy.Close() - - err = listenAndServe(proxy, &o) - if err == nil { - t.Fatal(err) - } -} - -func TestWithWrongKeyPathFails(t *testing.T) { - a, err := findAddress() - if err != nil { - t.Fatal(err) - } - - o := Options{Address: a, - CertPathTLS: "fixtures/test.crt", - KeyPathTLS: "fixtures/notFound.key", - } - - rt := routing.New(routing.Options{ - FilterRegistry: builtin.MakeRegistry(), - DataClients: []routing.DataClient{}}) - defer rt.Close() - - proxy := proxy.New(rt, proxy.OptionsNone) - defer proxy.Close() - err = listenAndServe(proxy, &o) - if err == nil { - t.Fatal(err) +func TestOptionsTLSConfigInvalidPaths(t *testing.T) { + for _, tt := range []struct { + name string + options *Options + }{ + {"missing cert path", &Options{KeyPathTLS: "fixtures/test.key"}}, + {"missing key path", &Options{CertPathTLS: "fixtures/test.crt"}}, + {"wrong cert path", &Options{CertPathTLS: "fixtures/notFound.crt", KeyPathTLS: "fixtures/test.key"}}, + {"wrong key path", &Options{CertPathTLS: "fixtures/test.crt", KeyPathTLS: "fixtures/notFound.key"}}, + {"multiple cert key path mismatch", &Options{CertPathTLS: "fixtures/test.crt,fixtures/test.crt", KeyPathTLS: "fixtures/test.key"}}, + } { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.options.tlsConfig() + require.Error(t, err) + }) } }