diff --git a/broadcast.go b/broadcast.go new file mode 100644 index 000000000..da2a207ae --- /dev/null +++ b/broadcast.go @@ -0,0 +1,179 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package fx + +import ( + "fmt" + "os" + "sync" +) + +// broadcaster broadcasts signals to registered signal listeners. +// All methods on the broadcaster are concurrency-safe. +type broadcaster struct { + // This lock is used to protect all fields of broadcaster. + // Methods on broadcaster should protect all concurrent access + // by taking this lock when accessing its fields. + // Conversely, this lock should NOT be taken outside of broadcaster. + m sync.Mutex + + // last will contain a pointer to the last ShutdownSignal received, or + // nil if none, if a new channel is created by Wait or Done, this last + // signal will be immediately written to, this allows Wait or Done state + // to be read after application stop + last *ShutdownSignal + + // contains channels created by Done + done []chan os.Signal + + // contains channels created by Wait + wait []chan ShutdownSignal +} + +func (b *broadcaster) reset() { + b.m.Lock() + defer b.m.Unlock() + b.last = nil +} + +// Done creates a new channel that will receive signals being broadcast +// via the broadcaster. +// +// If a signal has been received prior to the call of Done, +// the signal will be sent to the new channel. +func (b *broadcaster) Done() <-chan os.Signal { + b.m.Lock() + defer b.m.Unlock() + + ch := make(chan os.Signal, 1) + // If we had received a signal prior to the call of done, send it's + // os.Signal to the new channel. + // However we still want to have the operating system notify signals to this + // channel should the application receive another. + if b.last != nil { + ch <- b.last.Signal + } + b.done = append(b.done, ch) + return ch +} + +// Wait creates a new channel that will receive signals being broadcast +// via the broadcaster. +// +// If a signal has been received prior to the call of Wait, +// the signal will be sent to the new channel. +func (b *broadcaster) Wait() <-chan ShutdownSignal { + b.m.Lock() + defer b.m.Unlock() + + ch := make(chan ShutdownSignal, 1) + + if b.last != nil { + ch <- *b.last + } + + b.wait = append(b.wait, ch) + return ch +} + +// Broadcast sends the given signal to all channels that have been created +// via Done or Wait. It does not block on sending, and returns an unsentSignalError +// if any send did not go through. +func (b *broadcaster) Broadcast(signal ShutdownSignal) error { + b.m.Lock() + defer b.m.Unlock() + + b.last = &signal + + channels, unsent := b.broadcast( + signal, + b.broadcastDone, + b.broadcastWait, + ) + + if unsent != 0 { + return &unsentSignalError{ + Signal: signal, + Total: channels, + Unsent: unsent, + } + } + + return nil +} + +func (b *broadcaster) broadcast( + signal ShutdownSignal, + anchors ...func(ShutdownSignal) (int, int), +) (int, int) { + var channels, unsent int + + for _, anchor := range anchors { + c, u := anchor(signal) + channels += c + unsent += u + } + + return channels, unsent +} + +func (b *broadcaster) broadcastDone(signal ShutdownSignal) (int, int) { + var unsent int + + for _, reader := range b.done { + select { + case reader <- signal.Signal: + default: + unsent++ + } + } + + return len(b.done), unsent +} + +func (b *broadcaster) broadcastWait(signal ShutdownSignal) (int, int) { + var unsent int + + for _, reader := range b.wait { + select { + case reader <- signal: + default: + unsent++ + } + } + + return len(b.wait), unsent +} + +type unsentSignalError struct { + Signal ShutdownSignal + Unsent int + Total int +} + +func (err *unsentSignalError) Error() string { + return fmt.Sprintf( + "send %v signal: %v/%v channels are blocked", + err.Signal, + err.Unsent, + err.Total, + ) +} diff --git a/shutdown.go b/shutdown.go index b5fda5cbd..525d2c78b 100644 --- a/shutdown.go +++ b/shutdown.go @@ -83,7 +83,7 @@ func (s *shutdowner) Shutdown(opts ...ShutdownOption) error { opt.apply(s) } - return s.app.receivers.Broadcast(ShutdownSignal{ + return s.app.receivers.b.Broadcast(ShutdownSignal{ Signal: _sigTERM, ExitCode: s.exitCode, }) diff --git a/signal.go b/signal.go index 595a847bc..249a35810 100644 --- a/signal.go +++ b/signal.go @@ -49,9 +49,12 @@ func newSignalReceivers() signalReceivers { notify: signal.Notify, stopNotify: signal.Stop, signals: make(chan os.Signal, 1), + b: &broadcaster{}, } } +// signalReceivers listens to OS signals and shutdown signals, +// and relays them to registered listeners when started. type signalReceivers struct { // this mutex protects writes and reads of this struct to prevent // race conditions in a parallel execution pattern @@ -68,17 +71,9 @@ type signalReceivers struct { notify func(c chan<- os.Signal, sig ...os.Signal) stopNotify func(c chan<- os.Signal) - // last will contain a pointer to the last ShutdownSignal received, or - // nil if none, if a new channel is created by Wait or Done, this last - // signal will be immediately written to, this allows Wait or Done state - // to be read after application stop - last *ShutdownSignal - - // contains channels created by Done - done []chan os.Signal - - // contains channels created by Wait - wait []chan ShutdownSignal + // used to register and broadcast to signal listeners + // created via Done and Wait + b *broadcaster } func (recv *signalReceivers) relayer() { @@ -90,7 +85,7 @@ func (recv *signalReceivers) relayer() { case <-recv.shutdown: return case signal := <-recv.signals: - recv.Broadcast(ShutdownSignal{ + recv.b.Broadcast(ShutdownSignal{ Signal: signal, }) } @@ -137,120 +132,15 @@ func (recv *signalReceivers) Stop(ctx context.Context) error { close(recv.finished) recv.shutdown = nil recv.finished = nil - recv.last = nil + recv.b.reset() return nil } } func (recv *signalReceivers) Done() <-chan os.Signal { - recv.m.Lock() - defer recv.m.Unlock() - - ch := make(chan os.Signal, 1) - - // If we had received a signal prior to the call of done, send it's - // os.Signal to the new channel. - // However we still want to have the operating system notify signals to this - // channel should the application receive another. - if recv.last != nil { - ch <- recv.last.Signal - } - - recv.done = append(recv.done, ch) - return ch + return recv.b.Done() } func (recv *signalReceivers) Wait() <-chan ShutdownSignal { - recv.m.Lock() - defer recv.m.Unlock() - - ch := make(chan ShutdownSignal, 1) - - if recv.last != nil { - ch <- *recv.last - } - - recv.wait = append(recv.wait, ch) - return ch -} - -func (recv *signalReceivers) Broadcast(signal ShutdownSignal) error { - recv.m.Lock() - defer recv.m.Unlock() - - recv.last = &signal - - channels, unsent := recv.broadcast( - signal, - recv.broadcastDone, - recv.broadcastWait, - ) - - if unsent != 0 { - return &unsentSignalError{ - Signal: signal, - Total: channels, - Unsent: unsent, - } - } - - return nil -} - -func (recv *signalReceivers) broadcast( - signal ShutdownSignal, - anchors ...func(ShutdownSignal) (int, int), -) (int, int) { - var channels, unsent int - - for _, anchor := range anchors { - c, u := anchor(signal) - channels += c - unsent += u - } - - return channels, unsent -} - -func (recv *signalReceivers) broadcastDone(signal ShutdownSignal) (int, int) { - var unsent int - - for _, reader := range recv.done { - select { - case reader <- signal.Signal: - default: - unsent++ - } - } - - return len(recv.done), unsent -} - -func (recv *signalReceivers) broadcastWait(signal ShutdownSignal) (int, int) { - var unsent int - - for _, reader := range recv.wait { - select { - case reader <- signal: - default: - unsent++ - } - } - - return len(recv.wait), unsent -} - -type unsentSignalError struct { - Signal ShutdownSignal - Unsent int - Total int -} - -func (err *unsentSignalError) Error() string { - return fmt.Sprintf( - "send %v signal: %v/%v channels are blocked", - err.Signal, - err.Unsent, - err.Total, - ) + return recv.b.Wait() } diff --git a/signal_test.go b/signal_test.go index 18d96f479..cd85b54e0 100644 --- a/signal_test.go +++ b/signal_test.go @@ -25,6 +25,7 @@ import ( "os" "syscall" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -56,9 +57,9 @@ func TestSignal(t *testing.T) { Signal: syscall.SIGTERM, } - require.NoError(t, recv.Broadcast(expected), "first broadcast should succeed") + require.NoError(t, recv.b.Broadcast(expected), "first broadcast should succeed") - assertUnsentSignalError(t, recv.Broadcast(expected), &unsentSignalError{ + assertUnsentSignalError(t, recv.b.Broadcast(expected), &unsentSignalError{ Signal: expected, Total: 2, Unsent: 2, @@ -117,4 +118,26 @@ func TestSignal(t *testing.T) { }) }) }) + + t.Run("stop deadlock", func(t *testing.T) { + recv := newSignalReceivers() + + var notify chan<- os.Signal + recv.notify = func(ch chan<- os.Signal, _ ...os.Signal) { + notify = ch + } + recv.Start() + + // Artificially create a race where the relayer receives an OS signal + // while Stop() holds the lock. If this leads to deadlock, + // we will receive a context timeout error. + gotErr := make(chan error, 1) + notify <- syscall.SIGTERM + go func() { + stopCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + gotErr <- recv.Stop(stopCtx) + }() + assert.NoError(t, <-gotErr) + }) }