From 7d9fd08fa8452b056c52281eadec44cb7269dae2 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Sat, 13 Aug 2022 21:30:54 +0900 Subject: [PATCH] handle panics --- memoize.go | 107 +++++++++++++++++++++++++++++++++++++++--------- memoize_test.go | 43 +++++++++++++++++++ 2 files changed, 131 insertions(+), 19 deletions(-) diff --git a/memoize.go b/memoize.go index b5dbd1d..a23a7b6 100644 --- a/memoize.go +++ b/memoize.go @@ -1,7 +1,11 @@ package memoize import ( + "bytes" "context" + "errors" + "fmt" + "runtime/debug" "sync" "time" ) @@ -9,6 +13,34 @@ import ( // for testing var nowFunc = time.Now +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value interface{} + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func newPanicError(v interface{}) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + // Group memoizes the calls of Func with expiration. type Group[K comparable, V any] struct { mu sync.Mutex // protects m @@ -93,28 +125,65 @@ func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(ctx context.Context } func do[K comparable, V any](g *Group[K, V], e *entry[V], c *call[V], key K, fn func(ctx context.Context, key K) (V, time.Time, error)) { - defer c.cancel() + var ret result[V] - v, expiresAt, err := fn(c.ctx, key) - ret := result[V]{ - val: v, - expiresAt: expiresAt, - err: err, - } + normalReturn := false + recovered := false - // save to the cache - e.mu.Lock() - e.call = nil // to avoid adding new channels to c.chans - chans := c.chans - if err == nil { - e.val = v - e.expiresAt = expiresAt - } - e.mu.Unlock() + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + ret.err = errGoexit + } + + // save to the cache + e.mu.Lock() + e.call = nil // to avoid adding new channels to c.chans + chans := c.chans + if ret.err == nil { + e.val = ret.val + e.expiresAt = ret.expiresAt + } + e.mu.Unlock() + + if e, ok := ret.err.(*panicError); ok { + panic(e) + } else if ret.err == errGoexit { + // Already in the process of goexit, no need to call again + } else { + // Normal return + // notify the result to the callers + for _, ch := range chans { + ch <- ret + } + } + }() + + func() { + defer func() { + c.cancel() + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + ret.err = newPanicError(r) + } + } + }() + + ret.val, ret.expiresAt, ret.err = fn(c.ctx, key) + normalReturn = true + }() - // notify the result to the callers - for _, ch := range chans { - ch <- ret + if !normalReturn { + recovered = true } } diff --git a/memoize_test.go b/memoize_test.go index 1450c84..4568862 100644 --- a/memoize_test.go +++ b/memoize_test.go @@ -1,8 +1,13 @@ package memoize import ( + "bytes" "context" "errors" + "os" + "os/exec" + "runtime" + "strings" "sync" "sync/atomic" "testing" @@ -244,6 +249,44 @@ func TestDoContext(t *testing.T) { } } +func TestPanicDo(t *testing.T) { + if runtime.GOOS == "js" { + t.Skipf("js does not support exec") + } + + if os.Getenv("TEST_PANIC_DO") != "" { + var g Group[string, int] + fn := func(ctx context.Context, _ string) (int, time.Time, error) { + panic("Panicking in Do") + } + g.Do(context.Background(), "key", fn) + t.Fatalf("Do unexpectedly returned") + } + + t.Parallel() + + cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v") + cmd.Env = append(os.Environ(), "TEST_PANIC_DO=1") + out := new(bytes.Buffer) + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + err := cmd.Wait() + t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out) + if err == nil { + t.Errorf("Test subprocess passed; want a crash due to panic in DoChan") + } + if bytes.Contains(out.Bytes(), []byte("Do unexpectedly returned")) { + t.Errorf("Test subprocess failed with an unexpected failure mode.") + } + if !bytes.Contains(out.Bytes(), []byte("Panicking in Do")) { + t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do") + } +} + func benchmarkDo(parallelism int) func(b *testing.B) { return func(b *testing.B) { b.SetParallelism(parallelism)