diff --git a/embed/etcd.go b/embed/etcd.go index 7ebbcf201cb..f77dee0e065 100644 --- a/embed/etcd.go +++ b/embed/etcd.go @@ -15,12 +15,16 @@ package embed import ( + "context" "crypto/tls" "fmt" + "io/ioutil" + defaultLog "log" "net" "net/http" "path/filepath" "sync" + "time" "github.com/coreos/etcd/etcdserver" "github.com/coreos/etcd/etcdserver/api/v2http" @@ -51,7 +55,7 @@ const ( // Etcd contains a running etcd server and its listeners. type Etcd struct { - Peers []net.Listener + Peers []*peerListener Clients []net.Listener Server *etcdserver.EtcdServer @@ -63,6 +67,12 @@ type Etcd struct { closeOnce sync.Once } +type peerListener struct { + net.Listener + serve func() error + close func(context.Context) error +} + // StartEtcd launches the etcd server and HTTP handlers for client/server communication. // The returned Etcd.Server is not guaranteed to have joined the cluster. Wait // on the Etcd.Server.ReadyNotify() channel to know when it completes and is ready for use. @@ -138,6 +148,25 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) { return } + // configure peer handlers after rafthttp.Transport started + ph := v2http.NewPeerHandler(e.Server) + for i := range e.Peers { + srv := &http.Server{ + Handler: ph, + ReadTimeout: 5 * time.Minute, + ErrorLog: defaultLog.New(ioutil.Discard, "", 0), // do not log user error + } + e.Peers[i].serve = func() error { + return srv.Serve(e.Peers[i].Listener) + } + e.Peers[i].close = func(ctx context.Context) error { + // gracefully shutdown http.Server + // close open listeners, idle connections + // until context cancel or time-out + return srv.Shutdown(ctx) + } + } + // buffer channel so goroutines on closed connections won't wait forever e.errc = make(chan error, len(e.Peers)+len(e.Clients)+2*len(e.sctxs)) @@ -168,24 +197,30 @@ func (e *Etcd) Close() { for _, sctx := range e.sctxs { sctx.cancel() } - for i := range e.Peers { - if e.Peers[i] != nil { - e.Peers[i].Close() - } - } for i := range e.Clients { if e.Clients[i] != nil { e.Clients[i].Close() } } + + // close rafthttp transports if e.Server != nil { e.Server.Stop() } + + // close all idle connections in peer handler (wait up to 1-second) + for i := range e.Peers { + if e.Peers[i] != nil && e.Peers[i].close != nil { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + e.Peers[i].close(ctx) + cancel() + } + } } func (e *Etcd) Err() <-chan error { return e.errc } -func startPeerListeners(cfg *Config) (plns []net.Listener, err error) { +func startPeerListeners(cfg *Config) (peers []*peerListener, err error) { if cfg.PeerAutoTLS && cfg.PeerTLSInfo.Empty() { phosts := make([]string, len(cfg.LPUrls)) for i, u := range cfg.LPUrls { @@ -203,17 +238,16 @@ func startPeerListeners(cfg *Config) (plns []net.Listener, err error) { plog.Infof("peerTLS: %s", cfg.PeerTLSInfo) } - plns = make([]net.Listener, len(cfg.LPUrls)) + peers = make([]*peerListener, len(cfg.LPUrls)) defer func() { if err == nil { return } - for i := range plns { - if plns[i] == nil { - continue + for i := range peers { + if peers[i] != nil && peers[i].close != nil { + plog.Info("stopping listening for peers on ", cfg.LPUrls[i].String()) + peers[i].close(context.Background()) } - plns[i].Close() - plog.Info("stopping listening for peers on ", cfg.LPUrls[i].String()) } }() @@ -226,12 +260,18 @@ func startPeerListeners(cfg *Config) (plns []net.Listener, err error) { plog.Warningf("The scheme of peer url %s is HTTP while client cert auth (--peer-client-cert-auth) is enabled. Ignored client cert auth for this url.", u.String()) } } - if plns[i], err = rafthttp.NewListener(u, &cfg.PeerTLSInfo); err != nil { + peers[i] = &peerListener{close: func(context.Context) error { return nil }} + peers[i].Listener, err = rafthttp.NewListener(u, &cfg.PeerTLSInfo) + if err != nil { return nil, err } + // once serve, overwrite with 'http.Server.Shutdown' + peers[i].close = func(context.Context) error { + return peers[i].Listener.Close() + } plog.Info("listening for peers on ", u.String()) } - return plns, nil + return peers, nil } func startClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err error) { @@ -336,11 +376,10 @@ func (e *Etcd) serve() (err error) { } // Start the peer server in a goroutine - ph := v2http.NewPeerHandler(e.Server) - for _, l := range e.Peers { - go func(l net.Listener) { - e.errHandler(servePeerHTTP(l, ph)) - }(l) + for _, pl := range e.Peers { + go func(l *peerListener) { + e.errHandler(l.serve()) + }(pl) } // Start a client server goroutine for each listen address diff --git a/embed/serve.go b/embed/serve.go index e862c4897a4..e43611f77a3 100644 --- a/embed/serve.go +++ b/embed/serve.go @@ -21,7 +21,6 @@ import ( "net" "net/http" "strings" - "time" "github.com/coreos/etcd/etcdserver" "github.com/coreos/etcd/etcdserver/api/v3client" @@ -161,17 +160,6 @@ func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Ha }) } -func servePeerHTTP(l net.Listener, handler http.Handler) error { - logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0) - // TODO: add debug flag; enable logging when debug flag is set - srv := &http.Server{ - Handler: handler, - ReadTimeout: 5 * time.Minute, - ErrorLog: logger, // do not log user error - } - return srv.Serve(l) -} - type registerHandlerFunc func(context.Context, *gw.ServeMux, string, []grpc.DialOption) error func (sctx *serveCtx) registerGateway(opts []grpc.DialOption) (*gw.ServeMux, error) {