diff --git a/wrap.go b/wrap.go index 2462d3d..3e71058 100644 --- a/wrap.go +++ b/wrap.go @@ -79,14 +79,15 @@ func As(err error, target interface{}) bool { if target == nil { panic("errors: target cannot be nil") } - typ := reflect.TypeOf(target) - if typ.Kind() != reflect.Ptr { - panic("errors: target must be a pointer") + val := reflect.ValueOf(target) + typ := val.Type() + if typ.Kind() != reflect.Ptr || val.IsNil() { + panic("errors: target must be a non-nil pointer") } targetType := typ.Elem() for { if reflect.TypeOf(err).AssignableTo(targetType) { - reflect.ValueOf(target).Elem().Set(reflect.ValueOf(err)) + val.Elem().Set(reflect.ValueOf(err)) return true } if x, ok := err.(interface{ As(interface{}) bool }); ok && x.As(target) { diff --git a/wrap_test.go b/wrap_test.go index 2380190..ef6a9ed 100644 --- a/wrap_test.go +++ b/wrap_test.go @@ -150,6 +150,27 @@ func TestAs(t *testing.T) { } } +func TestAsValidation(t *testing.T) { + testCases := []interface{}{ + nil, + (*int)(nil), + "error", + } + err := xerrors.New("error") + for _, tc := range testCases { + t.Run(fmt.Sprintf("%T(%v)", tc, tc), func(t *testing.T) { + defer func() { + recover() + }() + if xerrors.As(err, tc) { + t.Errorf("As(err, %T(%v)) = true, want false", tc, tc) + return + } + t.Errorf("As(err, %T(%v)) did not panic", tc, tc) + }) + } +} + func TestUnwrap(t *testing.T) { err1 := xerrors.New("1") erra := xerrors.Errorf("wrap 2: %w", err1)