diff --git a/api/errors.go b/api/errors.go index 93f69cf5d..bd4c27884 100644 --- a/api/errors.go +++ b/api/errors.go @@ -77,6 +77,13 @@ func NewErrorDeadlineExceeded(msg string) error { // 1. Failed to put the blob in the disperser's queue (disperser is down) // 2. Timed out before getting confirmed onchain (batcher is down) // 3. Insufficient signatures (eigenda network is down) +// +// One can check if an error is an ErrorFailover by using errors.Is: +// +// failoverErr := NewErrorFailover(someOtherErr) +// if !errors.Is(wrappedFailoverErr, &ErrorFailover{}) { +// // do something... +// } type ErrorFailover struct { Err error } @@ -90,9 +97,28 @@ func NewErrorFailover(err error) *ErrorFailover { } func (e *ErrorFailover) Error() string { + if e.Err == nil { + return "Failover" + } return fmt.Sprintf("Failover: %s", e.Err.Error()) } func (e *ErrorFailover) Unwrap() error { return e.Err } + +// Is only checks the type of the error, not the underlying error. +// This is because we want to be able to check that an error is an ErrorFailover, +// even when wrapped. This can now be done with errors.Is. +// +// baseErr := fmt.Errorf("some error") +// failoverErr := NewErrorFailover(baseErr) +// wrappedFailoverErr := fmt.Errorf("some extra context: %w", failoverErr) +// +// if !errors.Is(wrappedFailoverErr, &ErrorFailover{}) { +// // do something... +// } +func (e *ErrorFailover) Is(target error) bool { + _, ok := target.(*ErrorFailover) + return ok +} diff --git a/api/errors_test.go b/api/errors_test.go new file mode 100644 index 000000000..f630bdc39 --- /dev/null +++ b/api/errors_test.go @@ -0,0 +1,46 @@ +package api + +import ( + "errors" + "fmt" + "testing" +) + +func TestErrorFailoverErrorsIs(t *testing.T) { + baseErr := fmt.Errorf("base error") + failoverErr := NewErrorFailover(baseErr) + otherFailoverErr := NewErrorFailover(fmt.Errorf("some other error")) + wrappedFailoverErr := fmt.Errorf("wrapped: %w", failoverErr) + + if !errors.Is(failoverErr, failoverErr) { + t.Error("should match itself") + } + + if !errors.Is(failoverErr, baseErr) { + t.Error("should match base error") + } + + if errors.Is(failoverErr, fmt.Errorf("some other error")) { + t.Error("should not match other errors") + } + + if !errors.Is(failoverErr, otherFailoverErr) { + t.Error("should match other failover error") + } + + if !errors.Is(failoverErr, &ErrorFailover{}) { + t.Error("should match ErrorFailover type") + } + + if !errors.Is(wrappedFailoverErr, &ErrorFailover{}) { + t.Error("should match ErrorFailover type even when wrapped") + } + +} + +func TestErrorFailoverZeroValue(t *testing.T) { + var failoverErr ErrorFailover + if failoverErr.Error() != "Failover" { + t.Error("should return 'Failover' for zero value") + } +}