diff --git a/wrap.go b/wrap.go index eee732b..2462d3d 100644 --- a/wrap.go +++ b/wrap.go @@ -68,9 +68,10 @@ func Is(err, target error) bool { } // As finds the first error in err's chain that matches the type to which target -// points, and if so, sets the target to its value and and returns true. An error -// matches a type if it is of the same type, or if it has a method As(interface{}) bool -// such that As(target) returns true. As will panic if target is nil or not a pointer. +// points, and if so, sets the target to its value and returns true. An error +// matches a type if it is assignable to the target type, or if it has a method +// As(interface{}) bool such that As(target) returns true. As will panic if target +// is nil or not a pointer. // // The As method should set the target to its value and return true if err // matches the type to which target points. @@ -84,7 +85,7 @@ func As(err error, target interface{}) bool { } targetType := typ.Elem() for { - if reflect.TypeOf(err) == targetType { + if reflect.TypeOf(err).AssignableTo(targetType) { reflect.ValueOf(target).Elem().Set(reflect.ValueOf(err)) return true } diff --git a/wrap_test.go b/wrap_test.go index 677b209..2380190 100644 --- a/wrap_test.go +++ b/wrap_test.go @@ -80,6 +80,7 @@ func (p *poser) As(err interface{}) bool { func TestAs(t *testing.T) { var errT errorT var errP *os.PathError + var timeout interface{ Timeout() bool } var p *poser _, errF := os.Open("non-existing") @@ -120,16 +121,24 @@ func TestAs(t *testing.T) { &p, true, }, { - &poser{"oo", nil}, - &errF, + xerrors.New("err"), + &timeout, false, + }, { + errF, + &timeout, + true, + }, { + xerrors.Errorf("path error: %w", errF), + &timeout, + true, }} - for _, tc := range testCases { - name := fmt.Sprintf("As(Errorf(..., %v), %v)", tc.err, tc.target) + for i, tc := range testCases { + name := fmt.Sprintf("%d:As(Errorf(..., %v), %v)", i, tc.err, tc.target) t.Run(name, func(t *testing.T) { match := xerrors.As(tc.err, tc.target) if match != tc.match { - t.Fatalf("match: got %v; want %v", match, tc.match) + t.Fatalf("xerrors.As(%T, %T): got %v; want %v", tc.err, tc.target, match, tc.match) } if !match { return