diff --git a/spanner/google/cloud/spanner/database.py b/spanner/google/cloud/spanner/database.py index b098f7684b7c..40dcc471d1c4 100644 --- a/spanner/google/cloud/spanner/database.py +++ b/spanner/google/cloud/spanner/database.py @@ -15,6 +15,7 @@ """User friendly container for Cloud Spanner Database.""" import re +import threading import google.auth.credentials from google.gax.errors import GaxError @@ -79,6 +80,7 @@ def __init__(self, database_id, instance, ddl_statements=(), pool=None): self.database_id = database_id self._instance = instance self._ddl_statements = _check_ddl_statements(ddl_statements) + self._local = threading.local() if pool is None: pool = BurstyPool() @@ -332,8 +334,20 @@ def run_in_transaction(self, func, *args, **kw): :rtype: :class:`datetime.datetime` :returns: timestamp of committed transaction """ - with SessionCheckout(self._pool) as session: - return session.run_in_transaction(func, *args, **kw) + # Sanity check: Is there a transaction already running? + # If there is, then raise a red flag. Otherwise, mark that this one + # is running. + if getattr(self._local, 'transaction_running', False): + raise RuntimeError('Spanner does not support nested transactions.') + self._local.transaction_running = True + + # Check out a session and run the function in a transaction; once + # done, flip the sanity check bit back. + try: + with SessionCheckout(self._pool) as session: + return session.run_in_transaction(func, *args, **kw) + finally: + self._local.transaction_running = False def batch(self): """Return an object which wraps a batch. diff --git a/spanner/tests/unit/test_database.py b/spanner/tests/unit/test_database.py index c1218599b3b3..c812176499dd 100644 --- a/spanner/tests/unit/test_database.py +++ b/spanner/tests/unit/test_database.py @@ -223,7 +223,7 @@ def __init__(self, scopes=(), source=None): self._scopes = scopes self._source = source - def requires_scopes(self): # pragma: NO COVER + def requires_scopes(self): # pragma: NO COVER return True def with_scopes(self, scopes): @@ -663,6 +663,29 @@ def test_run_in_transaction_w_args(self): self.assertEqual(session._retried, (_unit_of_work, (SINCE,), {'until': UNTIL})) + def test_run_in_transaction_nested(self): + from datetime import datetime + + # Perform the various setup tasks. + instance = _Instance(self.INSTANCE_NAME, client=_Client()) + pool = _Pool() + session = _Session(run_transaction_function=True) + session._committed = datetime.now() + pool.put(session) + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + # Define the inner function. + inner = mock.Mock(spec=()) + + # Define the nested transaction. + def nested_unit_of_work(): + return database.run_in_transaction(inner) + + # Attempting to run this transaction should raise RuntimeError. + with self.assertRaises(RuntimeError): + database.run_in_transaction(nested_unit_of_work) + self.assertEqual(inner.call_count, 0) + def test_batch(self): from google.cloud.spanner.database import BatchCheckout @@ -900,11 +923,15 @@ class _Session(object): _rows = () - def __init__(self, database=None, name=_BaseTest.SESSION_NAME): + def __init__(self, database=None, name=_BaseTest.SESSION_NAME, + run_transaction_function=False): self._database = database self.name = name + self._run_transaction_function = run_transaction_function def run_in_transaction(self, func, *args, **kw): + if self._run_transaction_function: + func(*args, **kw) self._retried = (func, args, kw) return self._committed