Skip to content

Commit

Permalink
fix: add NOT_FOUND error check in __exit__ method of SessionCheckout. (
Browse files Browse the repository at this point in the history
…#718)

* fix: Inside SnapshotCheckout __exit__ block check if NotFound exception was raised for the session and create new session if needed

* test: add test for SnapshotCheckout __exit__ checks

* refactor: lint fixes

* test: add test case for NotFound Error in SessionCheckout context but unrelated to Sessions
  • Loading branch information
vi3k6i5 authored Apr 20, 2022
1 parent 7642eba commit 265e207
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
6 changes: 6 additions & 0 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,12 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
if isinstance(exc_val, NotFound):
# If NotFound exception occurs inside the with block
# then we validate if the session still exists.
if not self._session.exists():
self._session = self._database._pool._new_session()
self._session.create()
self._database._pool.put(self._session)


Expand Down
61 changes: 60 additions & 1 deletion tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import mock
from google.api_core import gapic_v1

from google.cloud.spanner_v1.param_types import INT64
from google.api_core.retry import Retry

Expand Down Expand Up @@ -1792,6 +1791,66 @@ class Testing(Exception):

self.assertIs(pool._session, session)

def test_context_mgr_session_not_found_error(self):
from google.cloud.exceptions import NotFound

database = _Database(self.DATABASE_NAME)
session = _Session(database, name="session-1")
session.exists = mock.MagicMock(return_value=False)
pool = database._pool = _Pool()
new_session = _Session(database, name="session-2")
new_session.create = mock.MagicMock(return_value=[])
pool._new_session = mock.MagicMock(return_value=new_session)

pool.put(session)
checkout = self._make_one(database)

self.assertEqual(pool._session, session)
with self.assertRaises(NotFound):
with checkout as _:
raise NotFound("Session not found")
# Assert that session-1 was removed from pool and new session was added.
self.assertEqual(pool._session, new_session)

def test_context_mgr_table_not_found_error(self):
from google.cloud.exceptions import NotFound

database = _Database(self.DATABASE_NAME)
session = _Session(database, name="session-1")
session.exists = mock.MagicMock(return_value=True)
pool = database._pool = _Pool()
pool._new_session = mock.MagicMock(return_value=[])

pool.put(session)
checkout = self._make_one(database)

self.assertEqual(pool._session, session)
with self.assertRaises(NotFound):
with checkout as _:
raise NotFound("Table not found")
# Assert that session-1 was not removed from pool.
self.assertEqual(pool._session, session)
pool._new_session.assert_not_called()

def test_context_mgr_unknown_error(self):
database = _Database(self.DATABASE_NAME)
session = _Session(database)
pool = database._pool = _Pool()
pool._new_session = mock.MagicMock(return_value=[])
pool.put(session)
checkout = self._make_one(database)

class Testing(Exception):
pass

self.assertEqual(pool._session, session)
with self.assertRaises(Testing):
with checkout as _:
raise Testing("Unknown error.")
# Assert that session-1 was not removed from pool.
self.assertEqual(pool._session, session)
pool._new_session.assert_not_called()


class TestBatchSnapshot(_BaseTest):
TABLE = "table_name"
Expand Down

0 comments on commit 265e207

Please sign in to comment.