Skip to content

Commit

Permalink
embed: gracefully close peer handlers on shutdown
Browse files Browse the repository at this point in the history
Signed-off-by: Gyu-Ho Lee <[email protected]>
  • Loading branch information
gyuho committed May 5, 2017
1 parent db6f45e commit 823f9b3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 44 deletions.
97 changes: 65 additions & 32 deletions embed/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
package embed

import (
"context"
"crypto/tls"
"fmt"
"io/ioutil"
defaultLog "log"
"net"
"net/http"
"path/filepath"
"sync"
"time"

"github.com/coreos/etcd/etcdserver"
"github.com/coreos/etcd/etcdserver/api/v2http"
Expand Down Expand Up @@ -51,7 +55,7 @@ const (

// Etcd contains a running etcd server and its listeners.
type Etcd struct {
Peers []net.Listener
Peers []*peerListener
Clients []net.Listener
Server *etcdserver.EtcdServer

Expand All @@ -63,6 +67,12 @@ type Etcd struct {
closeOnce sync.Once
}

type peerListener struct {
net.Listener
serve func() error
close func(context.Context) error
}

// StartEtcd launches the etcd server and HTTP handlers for client/server communication.
// The returned Etcd.Server is not guaranteed to have joined the cluster. Wait
// on the Etcd.Server.ReadyNotify() channel to know when it completes and is ready for use.
Expand All @@ -87,21 +97,10 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
e = nil
}()

if e.Peers, err = startPeerListeners(cfg); err != nil {
return
}
if e.sctxs, err = startClientListeners(cfg); err != nil {
return
}
for _, sctx := range e.sctxs {
e.Clients = append(e.Clients, sctx.l)
}

var (
urlsmap types.URLsMap
token string
)

if !isMemberInitialized(cfg) {
urlsmap, token, err = cfg.PeerURLsMapAndToken("etcd")
if err != nil {
Expand Down Expand Up @@ -138,6 +137,16 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
return
}

if e.Peers, err = startPeerListeners(cfg, e.Server); err != nil {
return
}
if e.sctxs, err = startClientListeners(cfg); err != nil {
return
}
for _, sctx := range e.sctxs {
e.Clients = append(e.Clients, sctx.l)
}

// buffer channel so goroutines on closed connections won't wait forever
e.errc = make(chan error, len(e.Peers)+len(e.Clients)+2*len(e.sctxs))

Expand All @@ -164,28 +173,33 @@ func (e *Etcd) Close() {
gs.GracefulStop()
}
}

for _, sctx := range e.sctxs {
sctx.cancel()
}
for i := range e.Peers {
if e.Peers[i] != nil {
e.Peers[i].Close()
}
}
for i := range e.Clients {
if e.Clients[i] != nil {
e.Clients[i].Close()
}
}

// close rafthttp transports
if e.Server != nil {
e.Server.Stop()
}

// close all idle connections in peer handler (wait up to 1-second)
for i := range e.Peers {
if e.Peers[i] != nil && e.Peers[i].close != nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
e.Peers[i].close(ctx)
cancel()
}
}
}

func (e *Etcd) Err() <-chan error { return e.errc }

func startPeerListeners(cfg *Config) (plns []net.Listener, err error) {
func startPeerListeners(cfg *Config, srv *etcdserver.EtcdServer) (peers []*peerListener, err error) {
if cfg.PeerAutoTLS && cfg.PeerTLSInfo.Empty() {
phosts := make([]string, len(cfg.LPUrls))
for i, u := range cfg.LPUrls {
Expand All @@ -203,20 +217,24 @@ func startPeerListeners(cfg *Config) (plns []net.Listener, err error) {
plog.Infof("peerTLS: %s", cfg.PeerTLSInfo)
}

plns = make([]net.Listener, len(cfg.LPUrls))
peers = make([]*peerListener, len(cfg.LPUrls))
defer func() {
if err == nil {
return
}
for i := range plns {
if plns[i] == nil {
continue
for i := range peers {
if peers[i] != nil && peers[i].close != nil {
plog.Info("stopping listening for peers on ", cfg.LPUrls[i].String())
// pass canceled context to close idle connections only once
canceled, cancel := context.WithCancel(context.Background())
cancel()
peers[i].close(canceled)
}
plns[i].Close()
plog.Info("stopping listening for peers on ", cfg.LPUrls[i].String())
}
}()

ph := v2http.NewPeerHandler(srv)

for i, u := range cfg.LPUrls {
if u.Scheme == "http" {
if !cfg.PeerTLSInfo.Empty() {
Expand All @@ -226,12 +244,28 @@ func startPeerListeners(cfg *Config) (plns []net.Listener, err error) {
plog.Warningf("The scheme of peer url %s is HTTP while client cert auth (--peer-client-cert-auth) is enabled. Ignored client cert auth for this url.", u.String())
}
}
if plns[i], err = rafthttp.NewListener(u, &cfg.PeerTLSInfo); err != nil {
peers[i] = &peerListener{}
peers[i].Listener, err = rafthttp.NewListener(u, &cfg.PeerTLSInfo)
if err != nil {
return nil, err
}
plog.Info("listening for peers on ", u.String())

srv := &http.Server{
Handler: ph,
ReadTimeout: 5 * time.Minute,
ErrorLog: defaultLog.New(ioutil.Discard, "", 0), // do not log user error
}
peers[i].serve = func() error { return srv.Serve(peers[i].Listener) }
peers[i].close = func(ctx context.Context) error {
// gracefully shutdown http.Server
// close open listeners, idle connections
// until context cancel or time-out
return srv.Shutdown(ctx)
}
}
return plns, nil

return peers, nil
}

func startClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err error) {
Expand Down Expand Up @@ -336,11 +370,10 @@ func (e *Etcd) serve() (err error) {
}

// Start the peer server in a goroutine
ph := v2http.NewPeerHandler(e.Server)
for _, l := range e.Peers {
go func(l net.Listener) {
e.errHandler(servePeerHTTP(l, ph))
}(l)
for _, pl := range e.Peers {
go func(l *peerListener) {
e.errHandler(l.serve())
}(pl)
}

// Start a client server goroutine for each listen address
Expand Down
12 changes: 0 additions & 12 deletions embed/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"net"
"net/http"
"strings"
"time"

"github.com/coreos/etcd/etcdserver"
"github.com/coreos/etcd/etcdserver/api/v3client"
Expand Down Expand Up @@ -161,17 +160,6 @@ func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Ha
})
}

func servePeerHTTP(l net.Listener, handler http.Handler) error {
logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
// TODO: add debug flag; enable logging when debug flag is set
srv := &http.Server{
Handler: handler,
ReadTimeout: 5 * time.Minute,
ErrorLog: logger, // do not log user error
}
return srv.Serve(l)
}

type registerHandlerFunc func(context.Context, *gw.ServeMux, string, []grpc.DialOption) error

func (sctx *serveCtx) registerGateway(opts []grpc.DialOption) (*gw.ServeMux, error) {
Expand Down

0 comments on commit 823f9b3

Please sign in to comment.