Skip to content

Commit

Permalink
os/signal: add NotifyContext to cancel context using system signals
Browse files Browse the repository at this point in the history
Fixes #37255

Change-Id: Ic0fde3498afefed6e4447f8476e4da7c1faa7145
Reviewed-on: https://go-review.googlesource.com/c/go/+/219640
Run-TryBot: Ian Lance Taylor <[email protected]>
TryBot-Result: Go Bot <[email protected]>
Trust: Giovanni Bajo <[email protected]>
Reviewed-by: Ian Lance Taylor <[email protected]>
  • Loading branch information
henvic authored and ianlancetaylor committed Sep 15, 2020
1 parent 8248b57 commit b6dbaef
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 0 deletions.
47 changes: 47 additions & 0 deletions src/os/signal/example_unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris

package signal_test

import (
"context"
"fmt"
"log"
"os"
"os/signal"
"time"
)

// This example passes a context with a signal to tell a blocking function that
// it should abandon its work after a signal is received.
func ExampleNotifyContext() {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
defer stop()

p, err := os.FindProcess(os.Getpid())
if err != nil {
log.Fatal(err)
}

// On a Unix-like system, pressing Ctrl+C on a keyboard sends a
// SIGINT signal to the process of the program in execution.
//
// This example simulates that by sending a SIGINT signal to itself.
if err := p.Signal(os.Interrupt); err != nil {
log.Fatal(err)
}

select {
case <-time.After(time.Second):
fmt.Println("missed signal")
case <-ctx.Done():
fmt.Println(ctx.Err()) // prints "context canceled"
stop() // stop receiving signal notifications as soon as possible.
}

// Output:
// context canceled
}
75 changes: 75 additions & 0 deletions src/os/signal/signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package signal

import (
"context"
"os"
"sync"
)
Expand Down Expand Up @@ -257,3 +258,77 @@ func process(sig os.Signal) {
}
}
}

// NotifyContext returns a copy of the parent context that is marked done
// (its Done channel is closed) when one of the listed signals arrives,
// when the returned stop function is called, or when the parent context's
// Done channel is closed, whichever happens first.
//
// The stop function unregisters the signal behavior, which, like signal.Reset,
// may restore the default behavior for a given signal. For example, the default
// behavior of a Go program receiving os.Interrupt is to exit. Calling
// NotifyContext(parent, os.Interrupt) will change the behavior to cancel
// the returned context. Future interrupts received will not trigger the default
// (exit) behavior until the returned stop function is called.
//
// The stop function releases resources associated with it, so code should
// call stop as soon as the operations running in this Context complete and
// signals no longer need to be diverted to the context.
func NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
ctx, cancel := context.WithCancel(parent)
c := &signalCtx{
Context: ctx,
cancel: cancel,
signals: signals,
}
c.ch = make(chan os.Signal, 1)
Notify(c.ch, c.signals...)
if ctx.Err() == nil {
go func() {
select {
case <-c.ch:
c.cancel()
case <-c.Done():
}
}()
}
return c, c.stop
}

type signalCtx struct {
context.Context

cancel context.CancelFunc
signals []os.Signal
ch chan os.Signal
}

func (c *signalCtx) stop() {
c.cancel()
Stop(c.ch)
}

type stringer interface {
String() string
}

func (c *signalCtx) String() string {
var buf []byte
// We know that the type of c.Context is context.cancelCtx, and we know that the
// String method of cancelCtx returns a string that ends with ".WithCancel".
name := c.Context.(stringer).String()
name = name[:len(name)-len(".WithCancel")]
buf = append(buf, "signal.NotifyContext("+name...)
if len(c.signals) != 0 {
buf = append(buf, ", ["...)
for i, s := range c.signals {
buf = append(buf, s.String()...)
if i != len(c.signals)-1 {
buf = append(buf, ' ')
}
}
buf = append(buf, ']')
}
buf = append(buf, ')')
return string(buf)
}
162 changes: 162 additions & 0 deletions src/os/signal/signal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package signal

import (
"bytes"
"context"
"flag"
"fmt"
"internal/testenv"
Expand Down Expand Up @@ -674,3 +675,164 @@ func TestTime(t *testing.T) {
close(stop)
<-done
}

func TestNotifyContext(t *testing.T) {
c, stop := NotifyContext(context.Background(), syscall.SIGINT)
defer stop()

if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
t.Errorf("c.String() = %q, want %q", got, want)
}

syscall.Kill(syscall.Getpid(), syscall.SIGINT)
select {
case <-c.Done():
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
case <-time.After(time.Second):
t.Errorf("timed out waiting for context to be done after SIGINT")
}
}

func TestNotifyContextStop(t *testing.T) {
Ignore(syscall.SIGHUP)
if !Ignored(syscall.SIGHUP) {
t.Errorf("expected SIGHUP to be ignored when explicitly ignoring it.")
}

parent, cancelParent := context.WithCancel(context.Background())
defer cancelParent()
c, stop := NotifyContext(parent, syscall.SIGHUP)
defer stop()

// If we're being notified, then the signal should not be ignored.
if Ignored(syscall.SIGHUP) {
t.Errorf("expected SIGHUP to not be ignored.")
}

if want, got := "signal.NotifyContext(context.Background.WithCancel, [hangup])", fmt.Sprint(c); want != got {
t.Errorf("c.String() = %q, wanted %q", got, want)
}

stop()
select {
case <-c.Done():
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
case <-time.After(time.Second):
t.Errorf("timed out waiting for context to be done after calling stop")
}
}

func TestNotifyContextCancelParent(t *testing.T) {
parent, cancelParent := context.WithCancel(context.Background())
defer cancelParent()
c, stop := NotifyContext(parent, syscall.SIGINT)
defer stop()

if want, got := "signal.NotifyContext(context.Background.WithCancel, [interrupt])", fmt.Sprint(c); want != got {
t.Errorf("c.String() = %q, want %q", got, want)
}

cancelParent()
select {
case <-c.Done():
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
case <-time.After(time.Second):
t.Errorf("timed out waiting for parent context to be canceled")
}
}

func TestNotifyContextPrematureCancelParent(t *testing.T) {
parent, cancelParent := context.WithCancel(context.Background())
defer cancelParent()

cancelParent() // Prematurely cancel context before calling NotifyContext.
c, stop := NotifyContext(parent, syscall.SIGINT)
defer stop()

if want, got := "signal.NotifyContext(context.Background.WithCancel, [interrupt])", fmt.Sprint(c); want != got {
t.Errorf("c.String() = %q, want %q", got, want)
}

select {
case <-c.Done():
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
case <-time.After(time.Second):
t.Errorf("timed out waiting for parent context to be canceled")
}
}

func TestNotifyContextSimultaneousNotifications(t *testing.T) {
c, stop := NotifyContext(context.Background(), syscall.SIGINT)
defer stop()

if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
t.Errorf("c.String() = %q, want %q", got, want)
}

var wg sync.WaitGroup
n := 10
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
wg.Done()
}()
}
wg.Wait()
select {
case <-c.Done():
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
case <-time.After(time.Second):
t.Errorf("expected context to be canceled")
}
}

func TestNotifyContextSimultaneousStop(t *testing.T) {
c, stop := NotifyContext(context.Background(), syscall.SIGINT)
defer stop()

if want, got := "signal.NotifyContext(context.Background, [interrupt])", fmt.Sprint(c); want != got {
t.Errorf("c.String() = %q, want %q", got, want)
}

var wg sync.WaitGroup
n := 10
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
stop()
wg.Done()
}()
}
wg.Wait()
select {
case <-c.Done():
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
case <-time.After(time.Second):
t.Errorf("expected context to be canceled")
}
}

func TestNotifyContextStringer(t *testing.T) {
parent, cancelParent := context.WithCancel(context.Background())
defer cancelParent()
c, stop := NotifyContext(parent, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
defer stop()

want := `signal.NotifyContext(context.Background.WithCancel, [hangup interrupt terminated])`
if got := fmt.Sprint(c); got != want {
t.Errorf("c.String() = %q, want %q", got, want)
}
}

0 comments on commit b6dbaef

Please sign in to comment.