From 297dfe6aa3d31aa23e081a79d56e11ffd45dd952 Mon Sep 17 00:00:00 2001 From: Shruti Sridhar <77828382+shruti-sridhar@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:13:00 -0700 Subject: [PATCH] PYTHON-4660 Fix AttributeError when MongoClient.bulk_write batch fails with InvalidBSON (#1792) --- pymongo/asynchronous/client_bulk.py | 3 ++- pymongo/asynchronous/mongo_client.py | 4 +++- pymongo/synchronous/client_bulk.py | 3 ++- pymongo/synchronous/mongo_client.py | 4 +++- test/asynchronous/test_client_bulk_write.py | 22 ++++++++++++++++++++- test/test_client_bulk_write.py | 22 ++++++++++++++++++++- 6 files changed, 52 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 671d989c25..1f3cca2f6c 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -550,7 +550,8 @@ async def _execute_command( if result.get("error"): error = result["error"] retryable_top_level_error = ( - isinstance(error.details, dict) + hasattr(error, "details") + and isinstance(error.details, dict) and error.details.get("code", 0) in _RETRYABLE_ERROR_CODES ) retryable_network_error = isinstance( diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index fbbd9a4eed..8848fa4fd5 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2562,7 +2562,9 @@ async def run(self) -> T: if not self._retryable: raise if isinstance(exc, ClientBulkWriteException) and exc.error: - retryable_write_error_exc = exc.error.has_error_label("RetryableWriteError") + retryable_write_error_exc = isinstance( + exc.error, PyMongoError + ) and exc.error.has_error_label("RetryableWriteError") else: retryable_write_error_exc = exc.has_error_label("RetryableWriteError") if retryable_write_error_exc: diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 229abd4330..5f969804d5 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -548,7 +548,8 @@ def _execute_command( if result.get("error"): error = result["error"] retryable_top_level_error = ( - isinstance(error.details, dict) + hasattr(error, "details") + and isinstance(error.details, dict) and error.details.get("code", 0) in _RETRYABLE_ERROR_CODES ) retryable_network_error = isinstance( diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 1863165625..4aff3b5eed 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2549,7 +2549,9 @@ def run(self) -> T: if not self._retryable: raise if isinstance(exc, ClientBulkWriteException) and exc.error: - retryable_write_error_exc = exc.error.has_error_label("RetryableWriteError") + retryable_write_error_exc = isinstance( + exc.error, PyMongoError + ) and exc.error.has_error_label("RetryableWriteError") else: retryable_write_error_exc = exc.has_error_label("RetryableWriteError") if retryable_write_error_exc: diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index eea0b4e8e0..20e6ab7c95 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -19,12 +19,18 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + unittest, +) from test.utils import ( OvertCommandListener, async_rs_or_single_client, ) +from unittest.mock import patch +from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( ClientBulkWriteException, @@ -53,6 +59,20 @@ async def test_returns_error_if_no_namespace_provided(self): context.exception._message, ) + @async_client_context.require_version_min(8, 0, 0, -24) + async def test_handles_non_pymongo_error(self): + with patch.object( + _AsyncClientBulk, "write_command", return_value={"error": TypeError("mock type error")} + ): + client = await async_rs_or_single_client() + self.addAsyncCleanup(client.close) + + models = [InsertOne(namespace="db.coll", document={"a": "b"})] + with self.assertRaises(ClientBulkWriteException) as context: + await client.bulk_write(models=models) + self.assertIsInstance(context.exception.error, TypeError) + self.assertFalse(hasattr(context.exception.error, "details")) + # https://github.com/mongodb/specifications/tree/master/source/crud/tests class TestClientBulkWriteCRUD(AsyncIntegrationTest): diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index 8f6aad0cfa..686b60642a 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -19,11 +19,16 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import ( + IntegrationTest, + client_context, + unittest, +) from test.utils import ( OvertCommandListener, rs_or_single_client, ) +from unittest.mock import patch from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( @@ -34,6 +39,7 @@ ) from pymongo.monitoring import * from pymongo.operations import * +from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -53,6 +59,20 @@ def test_returns_error_if_no_namespace_provided(self): context.exception._message, ) + @client_context.require_version_min(8, 0, 0, -24) + def test_handles_non_pymongo_error(self): + with patch.object( + _ClientBulk, "write_command", return_value={"error": TypeError("mock type error")} + ): + client = rs_or_single_client() + self.addCleanup(client.close) + + models = [InsertOne(namespace="db.coll", document={"a": "b"})] + with self.assertRaises(ClientBulkWriteException) as context: + client.bulk_write(models=models) + self.assertIsInstance(context.exception.error, TypeError) + self.assertFalse(hasattr(context.exception.error, "details")) + # https://github.com/mongodb/specifications/tree/master/source/crud/tests class TestClientBulkWriteCRUD(IntegrationTest):