Skip to content

Commit

Permalink
Stow copy error shouldn't wrap nil err (flyteorg#47)
Browse files Browse the repository at this point in the history
* Stow copy error shouldn't wrap nil err

* Add unit tests
  • Loading branch information
EngHabu authored Oct 23, 2019
1 parent f7e4d7a commit 51f5fd1
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
30 changes: 19 additions & 11 deletions errors/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
package errors

import (
"errors"
"fmt"

"github.com/pkg/errors"
)

// A generic error code type.
Expand All @@ -25,12 +24,12 @@ func (e *err) Code() ErrorCode {

// Overrides Is to check for error code only. This enables the default package's errors.Is().
func (e *err) Is(target error) bool {
t, ok := target.(*err)
if !ok {
eCode, found := GetErrorCode(target)
if !found {
return false
}

return e.Code() == t.Code()
return e.Code() == eCode
}

type errorWithCause struct {
Expand All @@ -39,7 +38,7 @@ type errorWithCause struct {
}

func (e *errorWithCause) Error() string {
return fmt.Sprintf("%v, caused by: %v", e.err.Error(), errors.Cause(e))
return fmt.Sprintf("%v, caused by: %v", e.err.Error(), e.Cause())
}

func (e *errorWithCause) Cause() error {
Expand Down Expand Up @@ -90,17 +89,26 @@ func IsCausedBy(e error, errCode ErrorCode) bool {
Cause() error
}

type wrapped interface {
Unwrap() error
}

for e != nil {
if code, found := GetErrorCode(e); found && code == errCode {
return true
}

cause, ok := e.(causer)
if !ok {
break
if ok {
e = cause.Cause()
} else {
cause, ok := e.(wrapped)
if !ok {
break
}

e = cause.Unwrap()
}

e = cause.Cause()
}

return false
Expand All @@ -112,7 +120,7 @@ func IsCausedByError(e, e2 error) bool {
}

for e != nil {
if e == e2 {
if errors.Is(e, e2) {
return true
}

Expand Down
12 changes: 12 additions & 0 deletions errors/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,28 @@ func TestIsCausedBy(t *testing.T) {
e = Wrapf("Code2", e, "msg")
assert.True(t, IsCausedBy(e, "Code1"))
assert.True(t, IsCausedBy(e, "Code2"))

e = fmt.Errorf("new err caused by: %w", e)
assert.True(t, IsCausedBy(e, "Code1"))

e = fmt.Errorf("not sharing code err")
assert.False(t, IsCausedBy(e, "Code1"))
}

func TestIsCausedByError(t *testing.T) {
eRoot := Errorf("Code1", "msg")
assert.NotNil(t, eRoot)

e1 := Wrapf("Code2", eRoot, "msg")
assert.True(t, IsCausedByError(e1, eRoot))

e2 := Wrapf("Code3", e1, "msg")
assert.True(t, IsCausedByError(e2, eRoot))
assert.True(t, IsCausedByError(e2, e1))

e3 := fmt.Errorf("default errors. caused by: %w", e2)
assert.True(t, IsCausedByError(e3, eRoot))
assert.True(t, IsCausedByError(e3, e1))
}

func TestErrorsIs(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions storage/stow_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ func (s *StowStore) ReadRaw(ctx context.Context, reference DataReference) (io.Re
return nil, err
}

if sizeBytes/MiB > GetConfig().Limits.GetLimitMegabytes {
return nil, errors.Wrapf(ErrExceedsLimit, err, "limit exceeded")
if sizeMbs := sizeBytes / MiB; sizeMbs > GetConfig().Limits.GetLimitMegabytes {
return nil, errors.Errorf(ErrExceedsLimit, "limit exceeded. %vmb > %vmb.", sizeMbs, GetConfig().Limits.GetLimitMegabytes)
}

return item.Open()
Expand Down
3 changes: 3 additions & 0 deletions storage/stow_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"testing"
"time"

"github.com/pkg/errors"

"github.com/lyft/flytestdlib/promutils"

"github.com/graymeta/stow"
Expand Down Expand Up @@ -131,5 +133,6 @@ func TestStowStore_ReadRaw(t *testing.T) {
_, err = s.ReadRaw(context.TODO(), DataReference("s3://container/path"))
assert.Error(t, err)
assert.True(t, IsExceedsLimit(err))
assert.NotNil(t, errors.Cause(err))
})
}

0 comments on commit 51f5fd1

Please sign in to comment.