diff --git a/mongo/errors.go b/mongo/errors.go index 5267621274..6123635f67 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -106,8 +106,8 @@ func IsTimeout(err error) bool { return ne.Timeout() } //timeout error labels - if se, ok := err.(ServerError); ok { - if se.HasErrorLabel("NetworkTimeoutError") || se.HasErrorLabel("ExceededTimeLimitError") { + if le, ok := err.(labeledError); ok { + if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") { return true } } @@ -130,8 +130,8 @@ func unwrap(err error) error { // errorHasLabel returns true if err contains the specified label func errorHasLabel(err error, label string) bool { for ; err != nil; err = unwrap(err) { - if e, ok := err.(ServerError); ok { - return e.HasErrorLabel(label) + if le, ok := err.(labeledError); ok && le.HasErrorLabel(label) { + return true } } return false @@ -184,6 +184,12 @@ func (e MongocryptdError) Unwrap() error { return e.Wrapped } +type labeledError interface { + error + // HasErrorLabel returns true if the error contains the specified label. + HasErrorLabel(string) bool +} + // ServerError is the interface implemented by errors returned from the server. Custom implementations of this // interface should not be used in production. type ServerError interface { diff --git a/mongo/integration/sdam_error_handling_test.go b/mongo/integration/sdam_error_handling_test.go index d282be6baf..3ae4ebf6f2 100644 --- a/mongo/integration/sdam_error_handling_test.go +++ b/mongo/integration/sdam_error_handling_test.go @@ -85,6 +85,7 @@ func TestSDAMErrorHandling(t *testing.T) { _, err := mt.Coll.InsertOne(timeoutCtx, bson.D{{"test", 1}}) assert.NotNil(mt, err, "expected InsertOne error, got nil") assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err) + assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err) assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not") }) mt.RunOpts("pool cleared on non-timeout network error", noClientOpts, func(mt *mtest.T) {