Skip to content

Commit

Permalink
Implements HTTPS graceful shutdown
Browse files Browse the repository at this point in the history
Fixes #1865

Signed-off-by: Alexander Yastrebov <[email protected]>
  • Loading branch information
AlexanderYastrebov committed Sep 24, 2021
1 parent 27f4849 commit a9466fa
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 164 deletions.
92 changes: 52 additions & 40 deletions skipper.go
Original file line number Diff line number Diff line change
Expand Up @@ -944,8 +944,35 @@ func initLog(o Options) error {
return nil
}

func (o *Options) isHTTPS() bool {
return (o.ProxyTLS != nil) || (o.CertPathTLS != "" && o.KeyPathTLS != "")
func (o *Options) tlsConfig() (*tls.Config, error) {
if o.ProxyTLS != nil {
return o.ProxyTLS, nil
}

if o.CertPathTLS == "" && o.KeyPathTLS == "" {
return nil, nil
}

crts := strings.Split(o.CertPathTLS, ",")
keys := strings.Split(o.KeyPathTLS, ",")

if len(crts) != len(keys) {
return nil, fmt.Errorf("number of certificates does not match number of keys")
}

config := &tls.Config{
MinVersion: o.TLSMinVersion,
}

for i := 0; i < len(crts); i++ {
crt, key := crts[i], keys[i]
keypair, err := tls.LoadX509KeyPair(crt, key)
if err != nil {
return nil, fmt.Errorf("failed to load X509 keypair from %s and %s: %w", crt, key, err)
}
config.Certificates = append(config.Certificates, keypair)
}
return config, nil
}

func listen(o *Options, mtr metrics.Metrics) (net.Listener, error) {
Expand Down Expand Up @@ -1005,11 +1032,14 @@ func listenAndServeQuit(
idleConnsCH chan struct{},
mtr metrics.Metrics,
) error {
// create the access log handler
log.Infof("proxy listener on %v", o.Address)
tlsConfig, err := o.tlsConfig()
if err != nil {
return err
}

srv := &http.Server{
Addr: o.Address,
TLSConfig: tlsConfig,
Handler: proxy,
ReadTimeout: o.ReadTimeoutServer,
ReadHeaderTimeout: o.ReadHeaderTimeoutServer,
Expand All @@ -1025,35 +1055,6 @@ func listenAndServeQuit(
}
}

if o.isHTTPS() {
if o.ProxyTLS != nil {
srv.TLSConfig = o.ProxyTLS
o.CertPathTLS = ""
o.KeyPathTLS = ""
} else if strings.Index(o.CertPathTLS, ",") > 0 && strings.Index(o.KeyPathTLS, ",") > 0 {
tlsCfg := &tls.Config{
MinVersion: o.TLSMinVersion,
}
crts := strings.Split(o.CertPathTLS, ",")
keys := strings.Split(o.KeyPathTLS, ",")
if len(crts) != len(keys) {
log.Fatalf("number of certs does not match number of keys")
}
for i, crt := range crts {
kp, err := tls.LoadX509KeyPair(crt, keys[i])
if err != nil {
log.Fatalf("Failed to load X509 keypair from %s/%s: %v", crt, keys[i], err)
}
tlsCfg.Certificates = append(tlsCfg.Certificates, kp)
}
o.CertPathTLS = ""
o.KeyPathTLS = ""
srv.TLSConfig = tlsCfg
}
return srv.ListenAndServeTLS(o.CertPathTLS, o.KeyPathTLS)
}
log.Infof("TLS settings not found, defaulting to HTTP")

// making idleConnsCH and sigs optional parameters is required to be able to tear down a server
// from the tests
if idleConnsCH == nil {
Expand All @@ -1079,14 +1080,25 @@ func listenAndServeQuit(
close(idleConnsCH)
}()

l, err := listen(o, mtr)
if err != nil {
return err
}
log.Infof("proxy listener on %v", o.Address)

if err := srv.Serve(l); err != nil && err != http.ErrServerClosed {
log.Errorf("Failed to start to ListenAndServe: %v", err)
return err
if srv.TLSConfig != nil {
if err := srv.ListenAndServeTLS("", ""); err != http.ErrServerClosed {
log.Errorf("ListenAndServeTLS failed: %v", err)
return err
}
} else {
log.Infof("TLS settings not found, defaulting to HTTP")

l, err := listen(o, mtr)
if err != nil {
return err
}

if err := srv.Serve(l); err != http.ErrServerClosed {
log.Errorf("Serve failed: %v", err)
return err
}
}

<-idleConnsCH
Expand Down
215 changes: 91 additions & 124 deletions skipper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"net"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"testing"
"time"
Expand All @@ -20,6 +18,8 @@ import (
"github.com/zalando/skipper/proxy"
"github.com/zalando/skipper/ratelimit"
"github.com/zalando/skipper/routing"

"github.com/stretchr/testify/require"
)

const (
Expand Down Expand Up @@ -73,66 +73,58 @@ func findAddress() (string, error) {
return l.Addr().String(), nil
}

func TestOptionsDefaultsToHTTP(t *testing.T) {
o := Options{}
if o.isHTTPS() {
t.FailNow()
}
}

func TestOptionsWithCertUsesHTTPS(t *testing.T) {
o := Options{CertPathTLS: "foo", KeyPathTLS: "bar"}
if !o.isHTTPS() {
t.FailNow()
}
func TestOptionsTLSConfig(t *testing.T) {
cert, err := tls.LoadX509KeyPair("fixtures/test.crt", "fixtures/test.key")
require.NoError(t, err)

// empty
o := &Options{}
c, err := o.tlsConfig()
require.NoError(t, err)
require.Nil(t, c)

// proxy tls
o = &Options{ProxyTLS: &tls.Config{}}
c, err = o.tlsConfig()
require.NoError(t, err)
require.Equal(t, &tls.Config{}, c)

// proxy tls prio
o = &Options{ProxyTLS: &tls.Config{}, CertPathTLS: "fixtures/test.crt", KeyPathTLS: "fixtures/test.key"}
c, err = o.tlsConfig()
require.NoError(t, err)
require.Equal(t, &tls.Config{}, c)

// cert key path
o = &Options{TLSMinVersion: tls.VersionTLS12, CertPathTLS: "fixtures/test.crt", KeyPathTLS: "fixtures/test.key"}
c, err = o.tlsConfig()
require.NoError(t, err)
require.Equal(t, uint16(tls.VersionTLS12), c.MinVersion)
require.Equal(t, []tls.Certificate{cert}, c.Certificates)

// multiple cert key paths
o = &Options{TLSMinVersion: tls.VersionTLS13, CertPathTLS: "fixtures/test.crt,fixtures/test.crt", KeyPathTLS: "fixtures/test.key,fixtures/test.key"}
c, err = o.tlsConfig()
require.NoError(t, err)
require.Equal(t, uint16(tls.VersionTLS13), c.MinVersion)
require.Equal(t, []tls.Certificate{cert, cert}, c.Certificates)
}

func TestWithWrongCertPathFails(t *testing.T) {
a, err := findAddress()
if err != nil {
t.Fatal(err)
}

o := Options{Address: a,
CertPathTLS: "fixtures/notFound.crt",
KeyPathTLS: "fixtures/test.key",
}

rt := routing.New(routing.Options{
FilterRegistry: builtin.MakeRegistry(),
DataClients: []routing.DataClient{}})
defer rt.Close()

proxy := proxy.New(rt, proxy.OptionsNone)
defer proxy.Close()

err = listenAndServe(proxy, &o)
if err == nil {
t.Fatal(err)
}
}

func TestWithWrongKeyPathFails(t *testing.T) {
a, err := findAddress()
if err != nil {
t.Fatal(err)
}

o := Options{Address: a,
CertPathTLS: "fixtures/test.crt",
KeyPathTLS: "fixtures/notFound.key",
}

rt := routing.New(routing.Options{
FilterRegistry: builtin.MakeRegistry(),
DataClients: []routing.DataClient{}})
defer rt.Close()

proxy := proxy.New(rt, proxy.OptionsNone)
defer proxy.Close()
err = listenAndServe(proxy, &o)
if err == nil {
t.Fatal(err)
func TestOptionsTLSConfigInvalidPaths(t *testing.T) {
for _, tt := range []struct {
name string
options *Options
}{
{"missing cert path", &Options{KeyPathTLS: "fixtures/test.key"}},
{"missing key path", &Options{CertPathTLS: "fixtures/test.crt"}},
{"wrong cert path", &Options{CertPathTLS: "fixtures/notFound.crt", KeyPathTLS: "fixtures/test.key"}},
{"wrong key path", &Options{CertPathTLS: "fixtures/test.crt", KeyPathTLS: "fixtures/notFound.key"}},
{"multiple cert key path mismatch", &Options{CertPathTLS: "fixtures/test.crt,fixtures/test.crt", KeyPathTLS: "fixtures/test.key"}},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := tt.options.tlsConfig()
require.Error(t, err)
})
}
}

Expand Down Expand Up @@ -218,92 +210,67 @@ func TestHTTPServer(t *testing.T) {
}

func TestHTTPServerShutdown(t *testing.T) {
d := 1 * time.Second
o := &Options{}
testServerShutdown(t, o, "http")
}

o := Options{
Address: ":19999",
WaitForHealthcheckInterval: d,
func TestHTTPSServerShutdown(t *testing.T) {
o := &Options{
CertPathTLS: "fixtures/test.crt",
KeyPathTLS: "fixtures/test.key",
}
testServerShutdown(t, o, "https")
}

func testServerShutdown(t *testing.T, o *Options, scheme string) {
const shutdownDelay = 1 * time.Second

address, err := findAddress()
require.NoError(t, err)

o.Address, o.WaitForHealthcheckInterval = address, shutdownDelay
testUrl := scheme + "://" + address

// simulate a backend that got a request and should be handled correctly
dc, err := routestring.New(`r0: * -> latency("3s") -> inlineContent("OK") -> status(200) -> <shunt>`)
if err != nil {
t.Errorf("Failed to create dataclient: %v", err)
}
require.NoError(t, err)

rt := routing.New(routing.Options{
FilterRegistry: builtin.MakeRegistry(),
DataClients: []routing.DataClient{
dc,
},
DataClients: []routing.DataClient{dc},
})
defer rt.Close()

proxy := proxy.New(rt, proxy.OptionsNone)
defer proxy.Close()

sigs := make(chan os.Signal, 1)
go func() {
if errLas := listenAndServe(proxy, &o); errLas != nil {
t.Logf("Failed to liste and serve: %v", errLas)
}
err := listenAndServeQuit(proxy, o, sigs, nil, nil)
require.NoError(t, err)
}()

pid := syscall.Getpid()
p, err := os.FindProcess(pid)
if err != nil {
t.Errorf("Failed to find current process: %v", err)
}

var wg sync.WaitGroup
installSigHandler := make(chan struct{}, 1)
wg.Add(1)
go func() {
defer wg.Done()
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM)
// initiate shutdown
sigs <- syscall.SIGTERM

installSigHandler <- struct{}{}
time.Sleep(shutdownDelay / 2)

<-sigs
t.Logf("ongoing request passing in before shutdown")
r, err := waitConnGet(testUrl)
require.NoError(t, err)
require.Equal(t, 200, r.StatusCode)

// ongoing requests passing in before shutdown
time.Sleep(d / 2)
r, err2 := waitConnGet("http://" + o.Address)
if r != nil {
defer r.Body.Close()
}
if err2 != nil {
t.Errorf("Cannot connect to the local server for testing: %v ", err2)
}
if r.StatusCode != 200 {
t.Errorf("Status code should be 200, instead got: %d\n", r.StatusCode)
}
body, err2 := io.ReadAll(r.Body)
if err2 != nil {
t.Errorf("Failed to stream response body: %v", err2)
}
if s := string(body); s != "OK" {
t.Errorf("Failed to get the right content: %s", s)
}
defer r.Body.Close()

// requests on closed listener should fail
time.Sleep(d / 2)
r2, err2 := waitConnGet("http://" + o.Address)
if r2 != nil {
defer r2.Body.Close()
}
if err2 == nil {
t.Error("Can connect to a closed server for testing")
}
}()
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
require.Equal(t, "OK", string(body))

<-installSigHandler
time.Sleep(d / 2)
time.Sleep(shutdownDelay / 2)

if err = p.Signal(syscall.SIGTERM); err != nil {
t.Errorf("Failed to signal process: %v", err)
}
wg.Wait()
time.Sleep(d)
t.Logf("request after shutdown should fail")
r, err = waitConnGet(testUrl)
require.Error(t, err)
}

type (
Expand Down

0 comments on commit a9466fa

Please sign in to comment.