diff --git a/errors/error.go b/errors/error.go index 62f092f..fb646de 100644 --- a/errors/error.go +++ b/errors/error.go @@ -2,9 +2,8 @@ package errors import ( + "errors" "fmt" - - "github.com/pkg/errors" ) // A generic error code type. @@ -25,12 +24,12 @@ func (e *err) Code() ErrorCode { // Overrides Is to check for error code only. This enables the default package's errors.Is(). func (e *err) Is(target error) bool { - t, ok := target.(*err) - if !ok { + eCode, found := GetErrorCode(target) + if !found { return false } - return e.Code() == t.Code() + return e.Code() == eCode } type errorWithCause struct { @@ -39,7 +38,7 @@ type errorWithCause struct { } func (e *errorWithCause) Error() string { - return fmt.Sprintf("%v, caused by: %v", e.err.Error(), errors.Cause(e)) + return fmt.Sprintf("%v, caused by: %v", e.err.Error(), e.Cause()) } func (e *errorWithCause) Cause() error { @@ -90,17 +89,26 @@ func IsCausedBy(e error, errCode ErrorCode) bool { Cause() error } + type wrapped interface { + Unwrap() error + } + for e != nil { if code, found := GetErrorCode(e); found && code == errCode { return true } cause, ok := e.(causer) - if !ok { - break + if ok { + e = cause.Cause() + } else { + cause, ok := e.(wrapped) + if !ok { + break + } + + e = cause.Unwrap() } - - e = cause.Cause() } return false @@ -112,7 +120,7 @@ func IsCausedByError(e, e2 error) bool { } for e != nil { - if e == e2 { + if errors.Is(e, e2) { return true } diff --git a/errors/error_test.go b/errors/error_test.go index c7b8bd9..e22ce2f 100644 --- a/errors/error_test.go +++ b/errors/error_test.go @@ -35,16 +35,28 @@ func TestIsCausedBy(t *testing.T) { e = Wrapf("Code2", e, "msg") assert.True(t, IsCausedBy(e, "Code1")) assert.True(t, IsCausedBy(e, "Code2")) + + e = fmt.Errorf("new err caused by: %w", e) + assert.True(t, IsCausedBy(e, "Code1")) + + e = fmt.Errorf("not sharing code err") + assert.False(t, IsCausedBy(e, "Code1")) } func TestIsCausedByError(t *testing.T) { eRoot := Errorf("Code1", "msg") assert.NotNil(t, eRoot) + e1 := Wrapf("Code2", eRoot, "msg") assert.True(t, IsCausedByError(e1, eRoot)) + e2 := Wrapf("Code3", e1, "msg") assert.True(t, IsCausedByError(e2, eRoot)) assert.True(t, IsCausedByError(e2, e1)) + + e3 := fmt.Errorf("default errors. caused by: %w", e2) + assert.True(t, IsCausedByError(e3, eRoot)) + assert.True(t, IsCausedByError(e3, e1)) } func TestErrorsIs(t *testing.T) { diff --git a/storage/stow_store.go b/storage/stow_store.go index 998fc8a..2d7dc27 100644 --- a/storage/stow_store.go +++ b/storage/stow_store.go @@ -117,8 +117,8 @@ func (s *StowStore) ReadRaw(ctx context.Context, reference DataReference) (io.Re return nil, err } - if sizeBytes/MiB > GetConfig().Limits.GetLimitMegabytes { - return nil, errors.Wrapf(ErrExceedsLimit, err, "limit exceeded") + if sizeMbs := sizeBytes / MiB; sizeMbs > GetConfig().Limits.GetLimitMegabytes { + return nil, errors.Errorf(ErrExceedsLimit, "limit exceeded. %vmb > %vmb.", sizeMbs, GetConfig().Limits.GetLimitMegabytes) } return item.Open() diff --git a/storage/stow_store_test.go b/storage/stow_store_test.go index bb36827..710e410 100644 --- a/storage/stow_store_test.go +++ b/storage/stow_store_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/pkg/errors" + "github.com/lyft/flytestdlib/promutils" "github.com/graymeta/stow" @@ -131,5 +133,6 @@ func TestStowStore_ReadRaw(t *testing.T) { _, err = s.ReadRaw(context.TODO(), DataReference("s3://container/path")) assert.Error(t, err) assert.True(t, IsExceedsLimit(err)) + assert.NotNil(t, errors.Cause(err)) }) }