From a9466faf953d7c61bbb49a7764331cc5915f3ca3 Mon Sep 17 00:00:00 2001 From: Alexander Yastrebov Date: Fri, 24 Sep 2021 01:29:14 +0200 Subject: [PATCH] Implements HTTPS graceful shutdown Fixes #1865 Signed-off-by: Alexander Yastrebov --- skipper.go | 92 ++++++++++++--------- skipper_test.go | 215 ++++++++++++++++++++---------------------------- 2 files changed, 143 insertions(+), 164 deletions(-) diff --git a/skipper.go b/skipper.go index a083a1fe22..e032898dd1 100644 --- a/skipper.go +++ b/skipper.go @@ -944,8 +944,35 @@ func initLog(o Options) error { return nil } -func (o *Options) isHTTPS() bool { - return (o.ProxyTLS != nil) || (o.CertPathTLS != "" && o.KeyPathTLS != "") +func (o *Options) tlsConfig() (*tls.Config, error) { + if o.ProxyTLS != nil { + return o.ProxyTLS, nil + } + + if o.CertPathTLS == "" && o.KeyPathTLS == "" { + return nil, nil + } + + crts := strings.Split(o.CertPathTLS, ",") + keys := strings.Split(o.KeyPathTLS, ",") + + if len(crts) != len(keys) { + return nil, fmt.Errorf("number of certificates does not match number of keys") + } + + config := &tls.Config{ + MinVersion: o.TLSMinVersion, + } + + for i := 0; i < len(crts); i++ { + crt, key := crts[i], keys[i] + keypair, err := tls.LoadX509KeyPair(crt, key) + if err != nil { + return nil, fmt.Errorf("failed to load X509 keypair from %s and %s: %w", crt, key, err) + } + config.Certificates = append(config.Certificates, keypair) + } + return config, nil } func listen(o *Options, mtr metrics.Metrics) (net.Listener, error) { @@ -1005,11 +1032,14 @@ func listenAndServeQuit( idleConnsCH chan struct{}, mtr metrics.Metrics, ) error { - // create the access log handler - log.Infof("proxy listener on %v", o.Address) + tlsConfig, err := o.tlsConfig() + if err != nil { + return err + } srv := &http.Server{ Addr: o.Address, + TLSConfig: tlsConfig, Handler: proxy, ReadTimeout: o.ReadTimeoutServer, ReadHeaderTimeout: o.ReadHeaderTimeoutServer, @@ -1025,35 +1055,6 @@ func listenAndServeQuit( } } - if o.isHTTPS() { - if o.ProxyTLS != nil { - srv.TLSConfig = o.ProxyTLS - o.CertPathTLS = "" - o.KeyPathTLS = "" - } else if strings.Index(o.CertPathTLS, ",") > 0 && strings.Index(o.KeyPathTLS, ",") > 0 { - tlsCfg := &tls.Config{ - MinVersion: o.TLSMinVersion, - } - crts := strings.Split(o.CertPathTLS, ",") - keys := strings.Split(o.KeyPathTLS, ",") - if len(crts) != len(keys) { - log.Fatalf("number of certs does not match number of keys") - } - for i, crt := range crts { - kp, err := tls.LoadX509KeyPair(crt, keys[i]) - if err != nil { - log.Fatalf("Failed to load X509 keypair from %s/%s: %v", crt, keys[i], err) - } - tlsCfg.Certificates = append(tlsCfg.Certificates, kp) - } - o.CertPathTLS = "" - o.KeyPathTLS = "" - srv.TLSConfig = tlsCfg - } - return srv.ListenAndServeTLS(o.CertPathTLS, o.KeyPathTLS) - } - log.Infof("TLS settings not found, defaulting to HTTP") - // making idleConnsCH and sigs optional parameters is required to be able to tear down a server // from the tests if idleConnsCH == nil { @@ -1079,14 +1080,25 @@ func listenAndServeQuit( close(idleConnsCH) }() - l, err := listen(o, mtr) - if err != nil { - return err - } + log.Infof("proxy listener on %v", o.Address) - if err := srv.Serve(l); err != nil && err != http.ErrServerClosed { - log.Errorf("Failed to start to ListenAndServe: %v", err) - return err + if srv.TLSConfig != nil { + if err := srv.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + log.Errorf("ListenAndServeTLS failed: %v", err) + return err + } + } else { + log.Infof("TLS settings not found, defaulting to HTTP") + + l, err := listen(o, mtr) + if err != nil { + return err + } + + if err := srv.Serve(l); err != http.ErrServerClosed { + log.Errorf("Serve failed: %v", err) + return err + } } <-idleConnsCH diff --git a/skipper_test.go b/skipper_test.go index a815e7c861..cc7409aa6b 100644 --- a/skipper_test.go +++ b/skipper_test.go @@ -6,8 +6,6 @@ import ( "net" "net/http" "os" - "os/signal" - "sync" "syscall" "testing" "time" @@ -20,6 +18,8 @@ import ( "github.com/zalando/skipper/proxy" "github.com/zalando/skipper/ratelimit" "github.com/zalando/skipper/routing" + + "github.com/stretchr/testify/require" ) const ( @@ -73,66 +73,58 @@ func findAddress() (string, error) { return l.Addr().String(), nil } -func TestOptionsDefaultsToHTTP(t *testing.T) { - o := Options{} - if o.isHTTPS() { - t.FailNow() - } -} - -func TestOptionsWithCertUsesHTTPS(t *testing.T) { - o := Options{CertPathTLS: "foo", KeyPathTLS: "bar"} - if !o.isHTTPS() { - t.FailNow() - } +func TestOptionsTLSConfig(t *testing.T) { + cert, err := tls.LoadX509KeyPair("fixtures/test.crt", "fixtures/test.key") + require.NoError(t, err) + + // empty + o := &Options{} + c, err := o.tlsConfig() + require.NoError(t, err) + require.Nil(t, c) + + // proxy tls + o = &Options{ProxyTLS: &tls.Config{}} + c, err = o.tlsConfig() + require.NoError(t, err) + require.Equal(t, &tls.Config{}, c) + + // proxy tls prio + o = &Options{ProxyTLS: &tls.Config{}, CertPathTLS: "fixtures/test.crt", KeyPathTLS: "fixtures/test.key"} + c, err = o.tlsConfig() + require.NoError(t, err) + require.Equal(t, &tls.Config{}, c) + + // cert key path + o = &Options{TLSMinVersion: tls.VersionTLS12, CertPathTLS: "fixtures/test.crt", KeyPathTLS: "fixtures/test.key"} + c, err = o.tlsConfig() + require.NoError(t, err) + require.Equal(t, uint16(tls.VersionTLS12), c.MinVersion) + require.Equal(t, []tls.Certificate{cert}, c.Certificates) + + // multiple cert key paths + o = &Options{TLSMinVersion: tls.VersionTLS13, CertPathTLS: "fixtures/test.crt,fixtures/test.crt", KeyPathTLS: "fixtures/test.key,fixtures/test.key"} + c, err = o.tlsConfig() + require.NoError(t, err) + require.Equal(t, uint16(tls.VersionTLS13), c.MinVersion) + require.Equal(t, []tls.Certificate{cert, cert}, c.Certificates) } -func TestWithWrongCertPathFails(t *testing.T) { - a, err := findAddress() - if err != nil { - t.Fatal(err) - } - - o := Options{Address: a, - CertPathTLS: "fixtures/notFound.crt", - KeyPathTLS: "fixtures/test.key", - } - - rt := routing.New(routing.Options{ - FilterRegistry: builtin.MakeRegistry(), - DataClients: []routing.DataClient{}}) - defer rt.Close() - - proxy := proxy.New(rt, proxy.OptionsNone) - defer proxy.Close() - - err = listenAndServe(proxy, &o) - if err == nil { - t.Fatal(err) - } -} - -func TestWithWrongKeyPathFails(t *testing.T) { - a, err := findAddress() - if err != nil { - t.Fatal(err) - } - - o := Options{Address: a, - CertPathTLS: "fixtures/test.crt", - KeyPathTLS: "fixtures/notFound.key", - } - - rt := routing.New(routing.Options{ - FilterRegistry: builtin.MakeRegistry(), - DataClients: []routing.DataClient{}}) - defer rt.Close() - - proxy := proxy.New(rt, proxy.OptionsNone) - defer proxy.Close() - err = listenAndServe(proxy, &o) - if err == nil { - t.Fatal(err) +func TestOptionsTLSConfigInvalidPaths(t *testing.T) { + for _, tt := range []struct { + name string + options *Options + }{ + {"missing cert path", &Options{KeyPathTLS: "fixtures/test.key"}}, + {"missing key path", &Options{CertPathTLS: "fixtures/test.crt"}}, + {"wrong cert path", &Options{CertPathTLS: "fixtures/notFound.crt", KeyPathTLS: "fixtures/test.key"}}, + {"wrong key path", &Options{CertPathTLS: "fixtures/test.crt", KeyPathTLS: "fixtures/notFound.key"}}, + {"multiple cert key path mismatch", &Options{CertPathTLS: "fixtures/test.crt,fixtures/test.crt", KeyPathTLS: "fixtures/test.key"}}, + } { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.options.tlsConfig() + require.Error(t, err) + }) } } @@ -218,92 +210,67 @@ func TestHTTPServer(t *testing.T) { } func TestHTTPServerShutdown(t *testing.T) { - d := 1 * time.Second + o := &Options{} + testServerShutdown(t, o, "http") +} - o := Options{ - Address: ":19999", - WaitForHealthcheckInterval: d, +func TestHTTPSServerShutdown(t *testing.T) { + o := &Options{ + CertPathTLS: "fixtures/test.crt", + KeyPathTLS: "fixtures/test.key", } + testServerShutdown(t, o, "https") +} + +func testServerShutdown(t *testing.T, o *Options, scheme string) { + const shutdownDelay = 1 * time.Second + + address, err := findAddress() + require.NoError(t, err) + + o.Address, o.WaitForHealthcheckInterval = address, shutdownDelay + testUrl := scheme + "://" + address // simulate a backend that got a request and should be handled correctly dc, err := routestring.New(`r0: * -> latency("3s") -> inlineContent("OK") -> status(200) -> `) - if err != nil { - t.Errorf("Failed to create dataclient: %v", err) - } + require.NoError(t, err) rt := routing.New(routing.Options{ FilterRegistry: builtin.MakeRegistry(), - DataClients: []routing.DataClient{ - dc, - }, + DataClients: []routing.DataClient{dc}, }) defer rt.Close() proxy := proxy.New(rt, proxy.OptionsNone) defer proxy.Close() + + sigs := make(chan os.Signal, 1) go func() { - if errLas := listenAndServe(proxy, &o); errLas != nil { - t.Logf("Failed to liste and serve: %v", errLas) - } + err := listenAndServeQuit(proxy, o, sigs, nil, nil) + require.NoError(t, err) }() - pid := syscall.Getpid() - p, err := os.FindProcess(pid) - if err != nil { - t.Errorf("Failed to find current process: %v", err) - } - - var wg sync.WaitGroup - installSigHandler := make(chan struct{}, 1) - wg.Add(1) - go func() { - defer wg.Done() - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM) + // initiate shutdown + sigs <- syscall.SIGTERM - installSigHandler <- struct{}{} + time.Sleep(shutdownDelay / 2) - <-sigs + t.Logf("ongoing request passing in before shutdown") + r, err := waitConnGet(testUrl) + require.NoError(t, err) + require.Equal(t, 200, r.StatusCode) - // ongoing requests passing in before shutdown - time.Sleep(d / 2) - r, err2 := waitConnGet("http://" + o.Address) - if r != nil { - defer r.Body.Close() - } - if err2 != nil { - t.Errorf("Cannot connect to the local server for testing: %v ", err2) - } - if r.StatusCode != 200 { - t.Errorf("Status code should be 200, instead got: %d\n", r.StatusCode) - } - body, err2 := io.ReadAll(r.Body) - if err2 != nil { - t.Errorf("Failed to stream response body: %v", err2) - } - if s := string(body); s != "OK" { - t.Errorf("Failed to get the right content: %s", s) - } + defer r.Body.Close() - // requests on closed listener should fail - time.Sleep(d / 2) - r2, err2 := waitConnGet("http://" + o.Address) - if r2 != nil { - defer r2.Body.Close() - } - if err2 == nil { - t.Error("Can connect to a closed server for testing") - } - }() + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.Equal(t, "OK", string(body)) - <-installSigHandler - time.Sleep(d / 2) + time.Sleep(shutdownDelay / 2) - if err = p.Signal(syscall.SIGTERM); err != nil { - t.Errorf("Failed to signal process: %v", err) - } - wg.Wait() - time.Sleep(d) + t.Logf("request after shutdown should fail") + r, err = waitConnGet(testUrl) + require.Error(t, err) } type (