diff --git a/src/context/context.go b/src/context/context.go index b561968f31c764..b3fdb8277afc37 100644 --- a/src/context/context.go +++ b/src/context/context.go @@ -230,6 +230,9 @@ type CancelFunc func() // Canceling this context releases resources associated with it, so code should // call cancel as soon as the operations running in this Context complete. func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + if parent == nil { + panic("cannot create context from nil parent") + } c := newCancelCtx(parent) propagateCancel(parent, &c) return &c, func() { c.cancel(true, Canceled) } @@ -425,6 +428,9 @@ func (c *cancelCtx) cancel(removeFromParent bool, err error) { // Canceling this context releases resources associated with it, so code should // call cancel as soon as the operations running in this Context complete. func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) { + if parent == nil { + panic("cannot create context from nil parent") + } if cur, ok := parent.Deadline(); ok && cur.Before(d) { // The current deadline is already sooner than the new one. return WithCancel(parent) @@ -511,6 +517,9 @@ func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { // struct{}. Alternatively, exported context key variables' static // type should be a pointer or interface. func WithValue(parent Context, key, val interface{}) Context { + if parent == nil { + panic("cannot create context from nil parent") + } if key == nil { panic("nil key") } diff --git a/src/context/context_test.go b/src/context/context_test.go index da29ed0c2b903e..98c6683335370f 100644 --- a/src/context/context_test.go +++ b/src/context/context_test.go @@ -667,6 +667,21 @@ func XTestWithValueChecksKey(t testingT) { } } +func XTestInvalidDerivedFail(t testingT) { + panicVal := recoveredValue(func() { WithCancel(nil) }) + if panicVal == nil { + t.Error("expected panic") + } + panicVal = recoveredValue(func() { WithDeadline(nil, time.Now().Add(shortDuration)) }) + if panicVal == nil { + t.Error("expected panic") + } + panicVal = recoveredValue(func() { WithValue(nil, "foo", "bar") }) + if panicVal == nil { + t.Error("expected panic") + } +} + func recoveredValue(fn func()) (v interface{}) { defer func() { v = recover() }() fn() diff --git a/src/context/x_test.go b/src/context/x_test.go index e85ef2d50e5fd8..00eca72d5aff0a 100644 --- a/src/context/x_test.go +++ b/src/context/x_test.go @@ -26,5 +26,6 @@ func TestLayersTimeout(t *testing.T) { XTestLayersTimeout(t) } func TestCancelRemoves(t *testing.T) { XTestCancelRemoves(t) } func TestWithCancelCanceledParent(t *testing.T) { XTestWithCancelCanceledParent(t) } func TestWithValueChecksKey(t *testing.T) { XTestWithValueChecksKey(t) } +func TestInvalidDerivedFail(t *testing.T) { XTestInvalidDerivedFail(t) } func TestDeadlineExceededSupportsTimeout(t *testing.T) { XTestDeadlineExceededSupportsTimeout(t) } func TestCustomContextGoroutines(t *testing.T) { XTestCustomContextGoroutines(t) }