Skip to content

Commit

Permalink
Unbind transaction from session on commit/rollback. (googleapis#3669)
Browse files Browse the repository at this point in the history
  • Loading branch information
tseaver authored and landrito committed Aug 21, 2017
1 parent 5db3a82 commit 0d2ea02
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
2 changes: 0 additions & 2 deletions spanner/google/cloud/spanner/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ def run_in_transaction(self, func, *args, **kw):
continue
except Exception:
txn.rollback()
del self._transaction
raise

try:
Expand All @@ -312,7 +311,6 @@ def run_in_transaction(self, func, *args, **kw):
del self._transaction
else:
committed = txn.committed
del self._transaction
return committed


Expand Down
2 changes: 2 additions & 0 deletions spanner/google/cloud/spanner/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def rollback(self):
options = _options_with_prefix(database.name)
api.rollback(self._session.name, self._id, options=options)
self._rolled_back = True
del self._session._transaction

def commit(self):
"""Commit mutations to the database.
Expand All @@ -114,6 +115,7 @@ def commit(self):
transaction_id=self._id, options=options)
self.committed = _pb_timestamp_to_datetime(
response.commit_timestamp)
del self._session._transaction
return self.committed

def __enter__(self):
Expand Down
10 changes: 8 additions & 2 deletions spanner/tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def _getTargetClass(self):

return Transaction

def _make_one(self, *args, **kwargs):
return self._getTargetClass()(*args, **kwargs)
def _make_one(self, session, *args, **kwargs):
transaction = self._getTargetClass()(session, *args, **kwargs)
session._transaction = transaction
return transaction

def test_ctor_defaults(self):
session = _Session()
Expand Down Expand Up @@ -208,6 +210,7 @@ def test_rollback_ok(self):
transaction.rollback()

self.assertTrue(transaction._rolled_back)
self.assertIsNone(session._transaction)

session_id, txn_id, options = api._rolled_back
self.assertEqual(session_id, session.name)
Expand Down Expand Up @@ -290,6 +293,7 @@ def test_commit_ok(self):
transaction.commit()

self.assertEqual(transaction.committed, now)
self.assertIsNone(session._transaction)

session_id, mutations, txn_id, options = api._committed
self.assertEqual(session_id, session.name)
Expand Down Expand Up @@ -368,6 +372,8 @@ class _Database(object):

class _Session(object):

_transaction = None

def __init__(self, database=None, name=TestTransaction.SESSION_NAME):
self._database = database
self.name = name
Expand Down

0 comments on commit 0d2ea02

Please sign in to comment.