Skip to content

Commit

Permalink
feat: add ShutdownWithContext (#1383)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-jin-gou authored Nov 20, 2022
1 parent 7b3bf58 commit 4995135
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 4 deletions.
28 changes: 24 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1829,6 +1829,17 @@ func (s *Server) Serve(ln net.Listener) error {
//
// Shutdown does not close keepalive connections so its recommended to set ReadTimeout and IdleTimeout to something else than 0.
func (s *Server) Shutdown() error {
return s.ShutdownWithContext(context.Background())
}

// ShutdownWithContext gracefully shuts down the server without interrupting any active connections.
// ShutdownWithContext works by first closing all open listeners and then waiting for all connections to return to idle or context timeout and then shut down.
//
// When ShutdownWithContext is called, Serve, ListenAndServe, and ListenAndServeTLS immediately return nil.
// Make sure the program doesn't exit and waits instead for Shutdown to return.
//
// ShutdownWithContext does not close keepalive connections so its recommended to set ReadTimeout and IdleTimeout to something else than 0.
func (s *Server) ShutdownWithContext(ctx context.Context) (err error) {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -1840,7 +1851,7 @@ func (s *Server) Shutdown() error {
}

for _, ln := range s.ln {
if err := ln.Close(); err != nil {
if err = ln.Close(); err != nil {
return err
}
}
Expand All @@ -1851,7 +1862,10 @@ func (s *Server) Shutdown() error {

// Closing the listener will make Serve() call Stop on the worker pool.
// Setting .stop to 1 will make serveConn() break out of its loop.
// Now we just have to wait until all workers are done.
// Now we just have to wait until all workers are done or timeout.
ticker := time.NewTicker(time.Millisecond * 100)
defer ticker.Stop()
END:
for {
s.closeIdleConns()

Expand All @@ -1861,12 +1875,18 @@ func (s *Server) Shutdown() error {
// This is not an optimal solution but using a sync.WaitGroup
// here causes data races as it's hard to prevent Add() to be called
// while Wait() is waiting.
time.Sleep(time.Millisecond * 100)
select {
case <-ctx.Done():
err = ctx.Err()
break END
case <-ticker.C:
continue
}
}

s.done = nil
s.ln = nil
return nil
return err
}

func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
Expand Down
52 changes: 52 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -3594,6 +3595,57 @@ func TestShutdownCloseIdleConns(t *testing.T) {
}
}

func TestShutdownWithContext(t *testing.T) {
t.Parallel()

ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
time.Sleep(5 * time.Second)
ctx.Success("aaa/bbb", []byte("real response"))
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexepcted error: %v", err)
}
}()
time.Sleep(1 * time.Second)
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexepcted error: %v", err)
}

if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}()

time.Sleep(1 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
shutdownErr := make(chan error)
go func() {
shutdownErr <- s.ShutdownWithContext(ctx)
}()

timer := time.NewTimer(time.Second)
select {
case <-timer.C:
t.Fatal("idle connections not closed on shutdown")
case err := <-shutdownErr:
if err == nil || err != context.DeadlineExceeded {
t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded)
}
}
if atomic.LoadInt32(&s.open) != 1 {
t.Fatalf("unexpected open connection num: %#v. Expecting %#v", atomic.LoadInt32(&s.open), 1)
}
}

func TestMultipleServe(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 4995135

Please sign in to comment.