Skip to content

Commit

Permalink
PYTHON-4660 Fix AttributeError when MongoClient.bulk_write batch fail…
Browse files Browse the repository at this point in the history
…s with InvalidBSON (#1792)
  • Loading branch information
shruti-sridhar authored Aug 15, 2024
1 parent adf8817 commit 297dfe6
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 6 deletions.
3 changes: 2 additions & 1 deletion pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,8 @@ async def _execute_command(
if result.get("error"):
error = result["error"]
retryable_top_level_error = (
isinstance(error.details, dict)
hasattr(error, "details")
and isinstance(error.details, dict)
and error.details.get("code", 0) in _RETRYABLE_ERROR_CODES
)
retryable_network_error = isinstance(
Expand Down
4 changes: 3 additions & 1 deletion pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2562,7 +2562,9 @@ async def run(self) -> T:
if not self._retryable:
raise
if isinstance(exc, ClientBulkWriteException) and exc.error:
retryable_write_error_exc = exc.error.has_error_label("RetryableWriteError")
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
Expand Down
3 changes: 2 additions & 1 deletion pymongo/synchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ def _execute_command(
if result.get("error"):
error = result["error"]
retryable_top_level_error = (
isinstance(error.details, dict)
hasattr(error, "details")
and isinstance(error.details, dict)
and error.details.get("code", 0) in _RETRYABLE_ERROR_CODES
)
retryable_network_error = isinstance(
Expand Down
4 changes: 3 additions & 1 deletion pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2549,7 +2549,9 @@ def run(self) -> T:
if not self._retryable:
raise
if isinstance(exc, ClientBulkWriteException) and exc.error:
retryable_write_error_exc = exc.error.has_error_label("RetryableWriteError")
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
Expand Down
22 changes: 21 additions & 1 deletion test/asynchronous/test_client_bulk_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@

sys.path[0:0] = [""]

from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous import (
AsyncIntegrationTest,
async_client_context,
unittest,
)
from test.utils import (
OvertCommandListener,
async_rs_or_single_client,
)
from unittest.mock import patch

from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
from pymongo.errors import (
ClientBulkWriteException,
Expand Down Expand Up @@ -53,6 +59,20 @@ async def test_returns_error_if_no_namespace_provided(self):
context.exception._message,
)

@async_client_context.require_version_min(8, 0, 0, -24)
async def test_handles_non_pymongo_error(self):
with patch.object(
_AsyncClientBulk, "write_command", return_value={"error": TypeError("mock type error")}
):
client = await async_rs_or_single_client()
self.addAsyncCleanup(client.close)

models = [InsertOne(namespace="db.coll", document={"a": "b"})]
with self.assertRaises(ClientBulkWriteException) as context:
await client.bulk_write(models=models)
self.assertIsInstance(context.exception.error, TypeError)
self.assertFalse(hasattr(context.exception.error, "details"))


# https://github.com/mongodb/specifications/tree/master/source/crud/tests
class TestClientBulkWriteCRUD(AsyncIntegrationTest):
Expand Down
22 changes: 21 additions & 1 deletion test/test_client_bulk_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@

sys.path[0:0] = [""]

from test import IntegrationTest, client_context, unittest
from test import (
IntegrationTest,
client_context,
unittest,
)
from test.utils import (
OvertCommandListener,
rs_or_single_client,
)
from unittest.mock import patch

from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
from pymongo.errors import (
Expand All @@ -34,6 +39,7 @@
)
from pymongo.monitoring import *
from pymongo.operations import *
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.write_concern import WriteConcern

_IS_SYNC = True
Expand All @@ -53,6 +59,20 @@ def test_returns_error_if_no_namespace_provided(self):
context.exception._message,
)

@client_context.require_version_min(8, 0, 0, -24)
def test_handles_non_pymongo_error(self):
with patch.object(
_ClientBulk, "write_command", return_value={"error": TypeError("mock type error")}
):
client = rs_or_single_client()
self.addCleanup(client.close)

models = [InsertOne(namespace="db.coll", document={"a": "b"})]
with self.assertRaises(ClientBulkWriteException) as context:
client.bulk_write(models=models)
self.assertIsInstance(context.exception.error, TypeError)
self.assertFalse(hasattr(context.exception.error, "details"))


# https://github.com/mongodb/specifications/tree/master/source/crud/tests
class TestClientBulkWriteCRUD(IntegrationTest):
Expand Down

0 comments on commit 297dfe6

Please sign in to comment.