Skip to content

Commit

Permalink
Remove capturing the cancel callstack in the context package (#1595)
Browse files Browse the repository at this point in the history
* Fix race condition in context package

* Remove capturing the cancel callstack
  • Loading branch information
mcastorina authored Aug 2, 2023
1 parent 0ad4638 commit 160fd83
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 86 deletions.
46 changes: 4 additions & 42 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ package context

import (
"context"
"fmt"
"os"
"runtime/debug"
"sync"
"time"

"github.com/go-logr/logr"
Expand Down Expand Up @@ -37,23 +34,14 @@ type CancelFunc = context.CancelFunc
type logCtx struct {
// Embed context.Context to get all methods for free.
context.Context
log logr.Logger
err *error
errLock *sync.Mutex
log logr.Logger
}

// Logger returns a structured logger.
func (l logCtx) Logger() logr.Logger {
return l.log
}

func (l logCtx) Err() error {
if l.err != nil && *l.err != nil {
return *l.err
}
return l.Context.Err()
}

// Background returns context.Background with a default logger.
func Background() Context {
return logCtx{
Expand All @@ -77,7 +65,7 @@ func WithCancel(parent Context) (Context, context.CancelFunc) {
log: parent.Logger(),
Context: ctx,
}
return captureCancelCallstack(lCtx, cancel)
return lCtx, cancel
}

// WithDeadline returns context.WithDeadline with the log object propagated and
Expand All @@ -88,7 +76,7 @@ func WithDeadline(parent Context, d time.Time) (Context, context.CancelFunc) {
log: parent.Logger().WithValues("deadline", d),
Context: ctx,
}
return captureCancelCallstack(lCtx, cancel)
return lCtx, cancel
}

// WithTimeout returns context.WithTimeout with the log object propagated and
Expand All @@ -99,7 +87,7 @@ func WithTimeout(parent Context, timeout time.Duration) (Context, context.Cancel
log: parent.Logger().WithValues("timeout", timeout),
Context: ctx,
}
return captureCancelCallstack(lCtx, cancel)
return lCtx, cancel
}

// WithValue returns context.WithValue with the log object propagated and
Expand Down Expand Up @@ -150,29 +138,3 @@ func AddLogger(parent context.Context) Context {
func SetDefaultLogger(l logr.Logger) {
defaultLogger = l
}

// captureCancelCallstack is a helper function to capture the callstack where
// the cancel function was first called.
func captureCancelCallstack(ctx logCtx, f context.CancelFunc) (Context, context.CancelFunc) {
if ctx.err == nil {
var err error
ctx.err = &err
ctx.errLock = &sync.Mutex{}
}
return ctx, func() {
ctx.errLock.Lock()
defer ctx.errLock.Unlock()
// We must check Err() before calling f() since f() sets the error.
// If there's already an error, do nothing special.
if ctx.Err() != nil {
f()
return
}
f()
// Set the error with the stacktrace if the err pointer is non-nil.
*ctx.err = fmt.Errorf(
"%w (canceled at %v\n%s)",
ctx.Err(), time.Now(), string(debug.Stack()),
)
}
}
47 changes: 3 additions & 44 deletions pkg/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,51 +170,10 @@ func TestDefaultLogger(t *testing.T) {
ctx.Logger().Info("this shouldn't panic")
}

func TestErrCallstack(t *testing.T) {
c, cancel := WithCancel(Background())
ctx := c.(logCtx)
cancel()
select {
case <-ctx.Done():
assert.Contains(t, ctx.Err().Error(), "TestErrCallstack")
case <-time.After(1 * time.Second):
assert.Fail(t, "context should be done")
}
}

func TestErrCallstackTimeout(t *testing.T) {
ctx, cancel := WithTimeout(Background(), 10*time.Millisecond)
defer cancel()

select {
case <-ctx.Done():
// Deadline exceeded errors will not have a callstack from the cancel
// function.
assert.NotContains(t, ctx.Err().Error(), "TestErrCallstackTimeout")
case <-time.After(1 * time.Second):
assert.Fail(t, "context should be done")
}
}

func TestErrCallstackTimeoutCancel(t *testing.T) {
ctx, cancel := WithTimeout(Background(), 10*time.Millisecond)

var err error
select {
case <-ctx.Done():
err = ctx.Err()
case <-time.After(1 * time.Second):
assert.Fail(t, "context should be done")
}

// Calling cancel after deadline exceeded should not overwrite the original
// error.
cancel()
assert.Equal(t, err, ctx.Err())
}

func TestRace(t *testing.T) {
_, cancel := WithCancel(Background())
ctx, cancel := WithCancel(Background())
go cancel()
go func() { _ = ctx.Err() }()
cancel()
_ = ctx.Err()
}

0 comments on commit 160fd83

Please sign in to comment.