diff --git a/memoize.go b/memoize.go index 57b9591..87692ee 100644 --- a/memoize.go +++ b/memoize.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "runtime" "runtime/debug" "sync" "time" @@ -125,6 +126,12 @@ func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(ctx context.Context select { case ret := <-ch: + if e, ok := ret.Err.(*panicError); ok { + panic(e) + } + if ret.Err == errGoexit { + runtime.Goexit() + } return ret.Val, ret.ExpiresAt, ret.Err case <-ctx.Done(): e.mu.Lock() @@ -223,16 +230,10 @@ func do[K comparable, V any](g *Group[K, V], e *entry[V], c *call[V], key K, fn chans := c.chans e.mu.Unlock() - if e, ok := ret.Err.(*panicError); ok { - panic(e) - } else if ret.Err == errGoexit { - // Already in the process of goexit, no need to call again - } else { - // Normal return - // notify the result to the callers - for _, ch := range chans { - ch <- ret - } + // Normal return + // notify the result to the callers + for _, ch := range chans { + ch <- ret } }() diff --git a/memoize_test.go b/memoize_test.go index 7a33caa..74b5e7d 100644 --- a/memoize_test.go +++ b/memoize_test.go @@ -1,13 +1,10 @@ package memoize import ( - "bytes" "context" "errors" - "os" - "os/exec" "runtime" - "strings" + "runtime/debug" "sync" "sync/atomic" "testing" @@ -250,40 +247,69 @@ func TestDoContext(t *testing.T) { } func TestPanicDo(t *testing.T) { - if runtime.GOOS == "js" { - t.Skipf("js does not support exec") + var g Group[string, int] + fn := func(ctx context.Context, _ string) (int, time.Time, error) { + panic("something wrong!!") } - if os.Getenv("TEST_PANIC_DO") != "" { - var g Group[string, int] - fn := func(ctx context.Context, _ string) (int, time.Time, error) { - panic("Panicking in Do") - } - g.Do(context.Background(), "key", fn) - t.Fatalf("Do unexpectedly returned") + const n = 5 + waited := int32(n) + panicCount := int32(0) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + defer func() { + if err := recover(); err != nil { + t.Logf("Got panic: %v\n%s", err, debug.Stack()) + atomic.AddInt32(&panicCount, 1) + } + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + g.Do(context.Background(), "key", fn) + }() } - t.Parallel() - - cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v") - cmd.Env = append(os.Environ(), "TEST_PANIC_DO=1") - out := new(bytes.Buffer) - cmd.Stdout = out - cmd.Stderr = out - if err := cmd.Start(); err != nil { - t.Fatal(err) + select { + case <-done: + if panicCount != n { + t.Errorf("panic count = %d; want %d", panicCount, n) + } + case <-time.After(time.Second): + t.Errorf("Do hangs") } +} - err := cmd.Wait() - t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out) - if err == nil { - t.Errorf("Test subprocess passed; want a crash due to panic in DoChan") +func TestGoexitDo(t *testing.T) { + var g Group[string, int] + fn := func(ctx context.Context, _ string) (int, time.Time, error) { + runtime.Goexit() + return 0, time.Time{}, nil } - if bytes.Contains(out.Bytes(), []byte("Do unexpectedly returned")) { - t.Errorf("Test subprocess failed with an unexpected failure mode.") + + const n = 5 + waited := int32(n) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + var err error + defer func() { + if err != nil { + t.Errorf("Error should be nil, but got: %v", err) + } + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + _, _, err = g.Do(context.Background(), "key", fn) + }() } - if !bytes.Contains(out.Bytes(), []byte("Panicking in Do")) { - t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do") + + select { + case <-done: + case <-time.After(time.Second): + t.Errorf("Do hangs") } }