diff --git a/provider_errors.go b/provider_errors.go index 79a2cda2b..718bcbec8 100644 --- a/provider_errors.go +++ b/provider_errors.go @@ -42,3 +42,7 @@ func (e *PartialError) Error() string { e.Failed.Source.Type, e.Failed.Source.Version, e.Err, ) } + +func (e *PartialError) Unwrap() error { + return e.Err +} diff --git a/provider_test.go b/provider_test.go index 1b201b413..82676043e 100644 --- a/provider_test.go +++ b/provider_test.go @@ -35,7 +35,6 @@ func TestProvider(t *testing.T) { check.Equal(t, len(sources), 2) check.Equal(t, sources[0], newSource(goose.TypeSQL, "001_foo.sql", 1)) check.Equal(t, sources[1], newSource(goose.TypeSQL, "002_bar.sql", 2)) - } var ( @@ -76,3 +75,10 @@ ALTER TABLE my_foo DROP COLUMN timestamp; ALTER TABLE my_foo RENAME TO foo; ` ) + +func TestPartialErrorUnwrap(t *testing.T) { + err := &goose.PartialError{Err: goose.ErrNoCurrentVersion} + + got := errors.Is(err, goose.ErrNoCurrentVersion) + check.Bool(t, got, true) +}