Skip to content

Commit

Permalink
Merge pull request #6 from shogo82148/add-panic-error-unwrap-method
Browse files Browse the repository at this point in the history
add panicError.Unwrap method
  • Loading branch information
shogo82148 authored Oct 6, 2023
2 parents 7590b42 + fe37c61 commit 1ac0642
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
9 changes: 9 additions & 0 deletions memoize.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ func (p *panicError) Error() string {
return fmt.Sprintf("%v\n\n%s", p.value, p.stack)
}

func (p *panicError) Unwrap() error {
err, ok := p.value.(error)
if !ok {
return nil
}

return err
}

func newPanicError(v interface{}) error {
stack := debug.Stack()

Expand Down
58 changes: 58 additions & 0 deletions memoize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,64 @@ func TestGoexitDo(t *testing.T) {
}
}

type errValue struct{}

func (err *errValue) Error() string {
return "error value"
}

func TestPanicErrorUnwrap(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
panicValue interface{}
wrappedErrorType bool
}{
{
name: "panicError wraps non-error type",
panicValue: &panicError{value: "string value"},
wrappedErrorType: false,
},
{
name: "panicError wraps error type",
panicValue: &panicError{value: new(errValue)},
wrappedErrorType: true,
},
}

for _, tc := range testCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
var recovered interface{}
group := new(Group[string, int])
func() {
defer func() {
recovered = recover()
t.Logf("after panic(%#v) in group.Do, recovered %#v", tc.panicValue, recovered)
}()
_, _, _ = group.Do(context.Background(), "key", func(ctx context.Context, _ string) (int, time.Time, error) {
panic(tc.panicValue)
})
}()

if recovered == nil {
t.Fatal("expected a non-nil panic value")
}

err, ok := recovered.(error)
if !ok {
t.Fatalf("expected panic value to be an error, got %T", recovered)
}

if !errors.Is(err, new(errValue)) && tc.wrappedErrorType {
t.Errorf("unexpected wrapped error type %T; want %T", err, new(errValue))
}
})
}
}

func benchmarkDo(parallelism int) func(b *testing.B) {
return func(b *testing.B) {
b.SetParallelism(parallelism)
Expand Down

0 comments on commit 1ac0642

Please sign in to comment.