diff --git a/spanner/google/cloud/spanner_v1/database.py b/spanner/google/cloud/spanner_v1/database.py index b55fc9a24690..d3494eb63902 100644 --- a/spanner/google/cloud/spanner_v1/database.py +++ b/spanner/google/cloud/spanner_v1/database.py @@ -272,13 +272,16 @@ def drop(self): metadata = _metadata_with_prefix(self.name) api.drop_database(self.name, metadata=metadata) - def session(self): + def session(self, labels=None): """Factory to create a session for this database. + :type labels: dict (str -> str) or None + :param labels: (Optional) user-assigned labels for the session. + :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: a session bound to this database. """ - return Session(self) + return Session(self, labels=labels) def snapshot(self, **kw): """Return an object which wraps a snapshot. diff --git a/spanner/google/cloud/spanner_v1/pool.py b/spanner/google/cloud/spanner_v1/pool.py index c11b295025a4..34ccd76ee8f0 100644 --- a/spanner/google/cloud/spanner_v1/pool.py +++ b/spanner/google/cloud/spanner_v1/pool.py @@ -26,10 +26,28 @@ class AbstractSessionPool(object): - """Specifies required API for concrete session pool implementations.""" + """Specifies required API for concrete session pool implementations. + :type labels: dict (str -> str) or None + :param labels: (Optional) user-assigned labels for sessions created + by the pool. + """ _database = None + def __init__(self, labels=None): + if labels is None: + labels = {} + self._labels = labels + + @property + def labels(self): + """User-assigned labels for sesions created by the pool. + + :rtype: dict (str -> str) + :returns: labels assigned by the user + """ + return self._labels + def bind(self, database): """Associate the pool with a database. @@ -80,6 +98,16 @@ def clear(self): """ raise NotImplementedError() + def _new_session(self): + """Helper for concrete methods creating session instances. + + :rtype: :class:`~google.cloud.spanner_v1.session.Session` + :returns: new session instance. + """ + if self.labels: + return self._database.session(labels=self.labels) + return self._database.session() + def session(self, **kwargs): """Check out a session from the pool. @@ -115,11 +143,17 @@ class FixedSizePool(AbstractSessionPool): :type default_timeout: int :param default_timeout: default timeout, in seconds, to wait for a returned session. + + :type labels: dict (str -> str) or None + :param labels: (Optional) user-assigned labels for sessions created + by the pool. """ DEFAULT_SIZE = 10 DEFAULT_TIMEOUT = 10 - def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT): + def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT, + labels=None): + super(FixedSizePool, self).__init__(labels=labels) self.size = size self.default_timeout = default_timeout self._sessions = queue.Queue(size) @@ -134,7 +168,7 @@ def bind(self, database): self._database = database while not self._sessions.full(): - session = database.session() + session = self._new_session() session.create() self._sessions.put(session) @@ -198,9 +232,14 @@ class BurstyPool(AbstractSessionPool): :type target_size: int :param target_size: max pool size + + :type labels: dict (str -> str) or None + :param labels: (Optional) user-assigned labels for sessions created + by the pool. """ - def __init__(self, target_size=10): + def __init__(self, target_size=10, labels=None): + super(BurstyPool, self).__init__(labels=labels) self.target_size = target_size self._database = None self._sessions = queue.Queue(target_size) @@ -224,11 +263,11 @@ def get(self): try: session = self._sessions.get_nowait() except queue.Empty: - session = self._database.session() + session = self._new_session() session.create() else: if not session.exists(): - session = self._database.session() + session = self._new_session() session.create() return session @@ -290,9 +329,15 @@ class PingingPool(AbstractSessionPool): :type ping_interval: int :param ping_interval: interval at which to ping sessions. + + :type labels: dict (str -> str) or None + :param labels: (Optional) user-assigned labels for sessions created + by the pool. """ - def __init__(self, size=10, default_timeout=10, ping_interval=3000): + def __init__(self, size=10, default_timeout=10, ping_interval=3000, + labels=None): + super(PingingPool, self).__init__(labels=labels) self.size = size self.default_timeout = default_timeout self._delta = datetime.timedelta(seconds=ping_interval) @@ -308,7 +353,7 @@ def bind(self, database): self._database = database for _ in xrange(self.size): - session = database.session() + session = self._new_session() session.create() self.put(session) @@ -330,7 +375,7 @@ def get(self, timeout=None): # pylint: disable=arguments-differ if _NOW() > ping_after: if not session.exists(): - session = self._database.session() + session = self._new_session() session.create() return session @@ -373,7 +418,7 @@ def ping(self): self._sessions.put((ping_after, session)) break if not session.exists(): # stale - session = self._database.session() + session = self._new_session() session.create() # Re-add to queue with new expiration self.put(session) @@ -400,13 +445,18 @@ class TransactionPingingPool(PingingPool): :type ping_interval: int :param ping_interval: interval at which to ping sessions. + + :type labels: dict (str -> str) or None + :param labels: (Optional) user-assigned labels for sessions created + by the pool. """ - def __init__(self, size=10, default_timeout=10, ping_interval=3000): + def __init__(self, size=10, default_timeout=10, ping_interval=3000, + labels=None): self._pending_sessions = queue.Queue() super(TransactionPingingPool, self).__init__( - size, default_timeout, ping_interval) + size, default_timeout, ping_interval, labels=labels) self.begin_pending_transactions() diff --git a/spanner/google/cloud/spanner_v1/session.py b/spanner/google/cloud/spanner_v1/session.py index 1f7a9dd16b56..60512f025496 100644 --- a/spanner/google/cloud/spanner_v1/session.py +++ b/spanner/google/cloud/spanner_v1/session.py @@ -44,13 +44,19 @@ class Session(object): :type database: :class:`~google.cloud.spanner_v1.database.Database` :param database: The database to which the session is bound. + + :type labels: dict (str -> str) + :param labels: (Optional) User-assigned labels for the session. """ _session_id = None _transaction = None - def __init__(self, database): + def __init__(self, database, labels=None): self._database = database + if labels is None: + labels = {} + self._labels = labels def __lt__(self, other): return self._session_id < other._session_id @@ -60,6 +66,15 @@ def session_id(self): """Read-only ID, set by the back-end during :meth:`create`.""" return self._session_id + @property + def labels(self): + """User-assigned labels for the session. + + :rtype: dict (str -> str) + :returns: the labels dict (empty if no labels were assigned. + """ + return self._labels + @property def name(self): """Session name used in requests. @@ -93,7 +108,14 @@ def create(self): raise ValueError('Session ID already set by back-end') api = self._database.spanner_api metadata = _metadata_with_prefix(self._database.name) - session_pb = api.create_session(self._database.name, metadata=metadata) + kw = {} + if self._labels: + kw = {'session': {'labels': self._labels}} + session_pb = api.create_session( + self._database.name, + metadata=metadata, + **kw + ) self._session_id = session_pb.name.split('/')[-1] def exists(self): diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index d2c756ec632b..2d85a99531b6 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -239,7 +239,7 @@ class TestDatabaseAPI(unittest.TestCase, _TestData): @classmethod def setUpClass(cls): - pool = BurstyPool() + pool = BurstyPool(labels={'testcase': 'database_api'}) cls._db = Config.INSTANCE.database( cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool) operation = cls._db.create() @@ -264,7 +264,7 @@ def test_list_databases(self): self.assertTrue(self._db.name in database_names) def test_create_database(self): - pool = BurstyPool() + pool = BurstyPool(labels={'testcase': 'create_database'}) temp_db_id = 'temp_db' + unique_resource_id('_') temp_db = Config.INSTANCE.database(temp_db_id, pool=pool) operation = temp_db.create() @@ -311,7 +311,7 @@ def test_table_not_found(self): 'https://github.com/GoogleCloudPlatform/google-cloud-python/issues/' '5629')) def test_update_database_ddl(self): - pool = BurstyPool() + pool = BurstyPool(labels={'testcase': 'update_database_ddl'}) temp_db_id = 'temp_db' + unique_resource_id('_') temp_db = Config.INSTANCE.database(temp_db_id, pool=pool) create_op = temp_db.create() @@ -434,7 +434,7 @@ class TestSessionAPI(unittest.TestCase, _TestData): @classmethod def setUpClass(cls): - pool = BurstyPool() + pool = BurstyPool(labels={'testcase': 'session_api'}) cls._db = Config.INSTANCE.database( cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool) operation = cls._db.create() @@ -902,7 +902,7 @@ def test_read_w_index(self): EXTRA_DDL = [ 'CREATE INDEX contacts_by_last_name ON contacts(last_name)', ] - pool = BurstyPool() + pool = BurstyPool(labels={'testcase': 'read_w_index'}) temp_db = Config.INSTANCE.database( 'test_read' + unique_resource_id('_'), ddl_statements=DDL_STATEMENTS + EXTRA_DDL, diff --git a/spanner/tests/unit/test_database.py b/spanner/tests/unit/test_database.py index 458fdd4bb5c1..34b30deb2022 100644 --- a/spanner/tests/unit/test_database.py +++ b/spanner/tests/unit/test_database.py @@ -596,7 +596,7 @@ def test_drop_success(self): self.assertEqual( metadata, [('google-cloud-resource-prefix', database.name)]) - def test_session_factory(self): + def test_session_factory_defaults(self): from google.cloud.spanner_v1.session import Session client = _Client() @@ -609,6 +609,23 @@ def test_session_factory(self): self.assertTrue(isinstance(session, Session)) self.assertIs(session.session_id, None) self.assertIs(session._database, database) + self.assertEqual(session.labels, {}) + + def test_session_factory_w_labels(self): + from google.cloud.spanner_v1.session import Session + + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + labels = {'foo': 'bar'} + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + session = database.session(labels=labels) + + self.assertTrue(isinstance(session, Session)) + self.assertIs(session.session_id, None) + self.assertIs(session._database, database) + self.assertEqual(session.labels, labels) def test_snapshot_defaults(self): from google.cloud.spanner_v1.database import SnapshotCheckout diff --git a/spanner/tests/unit/test_pool.py b/spanner/tests/unit/test_pool.py index 5eecdef9b9ee..03c776a55fed 100644 --- a/spanner/tests/unit/test_pool.py +++ b/spanner/tests/unit/test_pool.py @@ -16,6 +16,20 @@ from functools import total_ordering import unittest +import mock + + +def _make_database(name='name'): + from google.cloud.spanner_v1.database import Database + + return mock.create_autospec(Database, instance=True) + + +def _make_session(): + from google.cloud.spanner_v1.database import Session + + return mock.create_autospec(Session, instance=True) + class TestAbstractSessionPool(unittest.TestCase): @@ -30,10 +44,17 @@ def _make_one(self, *args, **kwargs): def test_ctor_defaults(self): pool = self._make_one() self.assertIsNone(pool._database) + self.assertEqual(pool.labels, {}) + + def test_ctor_explicit(self): + labels = {'foo': 'bar'} + pool = self._make_one(labels=labels) + self.assertIsNone(pool._database) + self.assertEqual(pool.labels, labels) def test_bind_abstract(self): pool = self._make_one() - database = _Database('name') + database = _make_database('name') with self.assertRaises(NotImplementedError): pool.bind(database) @@ -53,6 +74,31 @@ def test_clear_abstract(self): with self.assertRaises(NotImplementedError): pool.clear() + def test__new_session_wo_labels(self): + pool = self._make_one() + database = pool._database = _make_database('name') + session = _make_session() + database.session.return_value = session + + new_session = pool._new_session() + + self.assertIs(new_session, session) + database.session.assert_called_once_with() + + def test__new_session_w_labels(self): + labels = {'foo': 'bar'} + pool = self._make_one(labels=labels) + database = pool._database = _make_database('name') + session = _make_session() + database.session.return_value = session + + new_session = pool._new_session() + + self.assertIs(new_session, session) + database.session.assert_called_once_with( + labels=labels, + ) + def test_session_wo_kwargs(self): from google.cloud.spanner_v1.pool import SessionCheckout @@ -90,13 +136,16 @@ def test_ctor_defaults(self): self.assertEqual(pool.size, 10) self.assertEqual(pool.default_timeout, 10) self.assertTrue(pool._sessions.empty()) + self.assertEqual(pool.labels, {}) def test_ctor_explicit(self): - pool = self._make_one(size=4, default_timeout=30) + labels = {'foo': 'bar'} + pool = self._make_one(size=4, default_timeout=30, labels=labels) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) self.assertEqual(pool.default_timeout, 30) self.assertTrue(pool._sessions.empty()) + self.assertEqual(pool.labels, labels) def test_bind(self): pool = self._make_one() @@ -222,12 +271,15 @@ def test_ctor_defaults(self): self.assertIsNone(pool._database) self.assertEqual(pool.target_size, 10) self.assertTrue(pool._sessions.empty()) + self.assertEqual(pool.labels, {}) def test_ctor_explicit(self): - pool = self._make_one(target_size=4) + labels = {'foo': 'bar'} + pool = self._make_one(target_size=4, labels=labels) self.assertIsNone(pool._database) self.assertEqual(pool.target_size, 4) self.assertTrue(pool._sessions.empty()) + self.assertEqual(pool.labels, labels) def test_get_empty(self): pool = self._make_one() @@ -340,14 +392,18 @@ def test_ctor_defaults(self): self.assertEqual(pool.default_timeout, 10) self.assertEqual(pool._delta.seconds, 3000) self.assertTrue(pool._sessions.empty()) + self.assertEqual(pool.labels, {}) def test_ctor_explicit(self): - pool = self._make_one(size=4, default_timeout=30, ping_interval=1800) + labels = {'foo': 'bar'} + pool = self._make_one( + size=4, default_timeout=30, ping_interval=1800, labels=labels) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) self.assertEqual(pool.default_timeout, 30) self.assertEqual(pool._delta.seconds, 1800) self.assertTrue(pool._sessions.empty()) + self.assertEqual(pool.labels, labels) def test_bind(self): pool = self._make_one() @@ -567,15 +623,19 @@ def test_ctor_defaults(self): self.assertEqual(pool._delta.seconds, 3000) self.assertTrue(pool._sessions.empty()) self.assertTrue(pool._pending_sessions.empty()) + self.assertEqual(pool.labels, {}) def test_ctor_explicit(self): - pool = self._make_one(size=4, default_timeout=30, ping_interval=1800) + labels = {'foo': 'bar'} + pool = self._make_one( + size=4, default_timeout=30, ping_interval=1800, labels=labels) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) self.assertEqual(pool.default_timeout, 30) self.assertEqual(pool._delta.seconds, 1800) self.assertTrue(pool._sessions.empty()) self.assertTrue(pool._pending_sessions.empty()) + self.assertEqual(pool.labels, labels) def test_bind(self): pool = self._make_one() diff --git a/spanner/tests/unit/test_session.py b/spanner/tests/unit/test_session.py index 5c1d2e82bef4..b165f3dda85e 100644 --- a/spanner/tests/unit/test_session.py +++ b/spanner/tests/unit/test_session.py @@ -23,7 +23,7 @@ def _make_rpc_error(error_cls, trailing_metadata=None): grpc_error = mock.create_autospec(grpc.Call, instance=True) grpc_error.trailing_metadata.return_value = trailing_metadata - raise error_cls('error', errors=(grpc_error,)) + return error_cls('error', errors=(grpc_error,)) class TestSession(unittest.TestCase): @@ -44,14 +44,42 @@ def _getTargetClass(self): def _make_one(self, *args, **kwargs): return self._getTargetClass()(*args, **kwargs) - def test_constructor(self): - database = _Database(self.DATABASE_NAME) + @staticmethod + def _make_database(name=DATABASE_NAME): + from google.cloud.spanner_v1.database import Database + + database = mock.create_autospec(Database, instance=True) + database.name = name + return database + + @staticmethod + def _make_session_pb(name, labels=None): + from google.cloud.spanner_v1.proto.spanner_pb2 import Session + + return Session(name=name, labels=labels) + + def _make_spanner_api(self): + from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient + + return mock.Mock(autospec=SpannerClient, instance=True) + + def test_constructor_wo_labels(self): + database = self._make_database() session = self._make_one(database) self.assertIs(session.session_id, None) self.assertIs(session._database, database) + self.assertEqual(session.labels, {}) + + def test_constructor_w_labels(self): + database = self._make_database() + labels = {'foo': 'bar'} + session = self._make_one(database, labels=labels) + self.assertIs(session.session_id, None) + self.assertIs(session._database, database) + self.assertEqual(session.labels, labels) def test___lt___(self): - database = _Database(self.DATABASE_NAME) + database = self._make_database() lhs = self._make_one(database) lhs._session_id = b'123' rhs = self._make_one(database) @@ -59,28 +87,31 @@ def test___lt___(self): self.assertTrue(lhs < rhs) def test_name_property_wo_session_id(self): - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) + with self.assertRaises(ValueError): (session.name) def test_name_property_w_session_id(self): - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) session._session_id = self.SESSION_ID self.assertEqual(session.name, self.SESSION_NAME) def test_create_w_session_id(self): - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) session._session_id = self.SESSION_ID + with self.assertRaises(ValueError): session.create() def test_create_ok(self): - session_pb = _SessionPB(self.SESSION_NAME) - gax_api = _SpannerApi(_create_session_response=session_pb) - database = _Database(self.DATABASE_NAME) + session_pb = self._make_session_pb(self.SESSION_NAME) + gax_api = self._make_spanner_api() + gax_api.create_session.return_value = session_pb + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) @@ -88,16 +119,36 @@ def test_create_ok(self): self.assertEqual(session.session_id, self.SESSION_ID) - database_name, metadata = gax_api._create_session_called_with - self.assertEqual(database_name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + gax_api.create_session.assert_called_once_with( + database.name, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + def test_create_w_labels(self): + labels = {'foo': 'bar'} + session_pb = self._make_session_pb(self.SESSION_NAME, labels=labels) + gax_api = self._make_spanner_api() + gax_api.create_session.return_value = session_pb + database = self._make_database() + database.spanner_api = gax_api + session = self._make_one(database, labels=labels) + + session.create() + + self.assertEqual(session.session_id, self.SESSION_ID) + + gax_api.create_session.assert_called_once_with( + database.name, + session={'labels': labels}, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_create_error(self): from google.api_core.exceptions import Unknown - gax_api = _SpannerApi(_rpc_error=Unknown('error')) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.create_session.side_effect = Unknown('error') + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) @@ -105,44 +156,49 @@ def test_create_error(self): session.create() def test_exists_wo_session_id(self): - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) self.assertFalse(session.exists()) def test_exists_hit(self): - session_pb = _SessionPB(self.SESSION_NAME) - gax_api = _SpannerApi(_get_session_response=session_pb) - database = _Database(self.DATABASE_NAME) + session_pb = self._make_session_pb(self.SESSION_NAME) + gax_api = self._make_spanner_api() + gax_api.get_session.return_value = session_pb + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) session._session_id = self.SESSION_ID self.assertTrue(session.exists()) - session_name, metadata = gax_api._get_session_called_with - self.assertEqual(session_name, self.SESSION_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + gax_api.get_session.assert_called_once_with( + self.SESSION_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_exists_miss(self): - gax_api = _SpannerApi() - database = _Database(self.DATABASE_NAME) + from google.api_core.exceptions import NotFound + + gax_api = self._make_spanner_api() + gax_api.get_session.side_effect = NotFound('testing') + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) session._session_id = self.SESSION_ID self.assertFalse(session.exists()) - session_name, metadata = gax_api._get_session_called_with - self.assertEqual(session_name, self.SESSION_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + gax_api.get_session.assert_called_once_with( + self.SESSION_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_exists_error(self): from google.api_core.exceptions import Unknown - gax_api = _SpannerApi(_rpc_error=Unknown('error')) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.get_session.side_effect = Unknown('testing') + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) session._session_id = self.SESSION_ID @@ -150,31 +206,39 @@ def test_exists_error(self): with self.assertRaises(Unknown): session.exists() + gax_api.get_session.assert_called_once_with( + self.SESSION_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_delete_wo_session_id(self): - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) + with self.assertRaises(ValueError): session.delete() def test_delete_hit(self): - gax_api = _SpannerApi(_delete_session_ok=True) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.delete_session.return_value = None + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) session._session_id = self.SESSION_ID session.delete() - session_name, metadata = gax_api._delete_session_called_with - self.assertEqual(session_name, self.SESSION_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + gax_api.delete_session.assert_called_once_with( + self.SESSION_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_delete_miss(self): from google.cloud.exceptions import NotFound - gax_api = _SpannerApi(_delete_session_ok=False) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.delete_session.side_effect = NotFound('testing') + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) session._session_id = self.SESSION_ID @@ -182,16 +246,17 @@ def test_delete_miss(self): with self.assertRaises(NotFound): session.delete() - session_name, metadata = gax_api._delete_session_called_with - self.assertEqual(session_name, self.SESSION_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + gax_api.delete_session.assert_called_once_with( + self.SESSION_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_delete_error(self): from google.api_core.exceptions import Unknown - gax_api = _SpannerApi(_rpc_error=Unknown('error')) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.delete_session.side_effect = Unknown('testing') + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) session._session_id = self.SESSION_ID @@ -199,8 +264,13 @@ def test_delete_error(self): with self.assertRaises(Unknown): session.delete() + gax_api.delete_session.assert_called_once_with( + self.SESSION_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_snapshot_not_created(self): - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) with self.assertRaises(ValueError): @@ -209,7 +279,7 @@ def test_snapshot_not_created(self): def test_snapshot_created(self): from google.cloud.spanner_v1.snapshot import Snapshot - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) session._session_id = 'DEADBEEF' # emulate 'session.create()' @@ -223,7 +293,7 @@ def test_snapshot_created(self): def test_snapshot_created_w_multi_use(self): from google.cloud.spanner_v1.snapshot import Snapshot - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) session._session_id = 'DEADBEEF' # emulate 'session.create()' @@ -241,15 +311,13 @@ def test_read_not_created(self): COLUMNS = ['email', 'first_name', 'last_name', 'age'] KEYS = ['bharney@example.com', 'phred@example.com'] KEYSET = KeySet(keys=KEYS) - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) with self.assertRaises(ValueError): session.read(TABLE_NAME, COLUMNS, KEYSET) def test_read(self): - from google.cloud.spanner_v1 import session as MUT - from google.cloud._testing import _Monkey from google.cloud.spanner_v1.keyset import KeySet TABLE_NAME = 'citizens' @@ -258,87 +326,81 @@ def test_read(self): KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) session._session_id = 'DEADBEEF' - _read_with = [] - expected = object() - - class _Snapshot(object): - - def __init__(self, session, **kwargs): - self._session = session - self._kwargs = kwargs.copy() - - def read(self, table, columns, keyset, index='', limit=0): - _read_with.append( - (table, columns, keyset, index, limit)) - return expected - - with _Monkey(MUT, Snapshot=_Snapshot): + with mock.patch( + 'google.cloud.spanner_v1.session.Snapshot') as snapshot: found = session.read( TABLE_NAME, COLUMNS, KEYSET, index=INDEX, limit=LIMIT) - self.assertIs(found, expected) + self.assertIs(found, snapshot().read.return_value) - self.assertEqual(len(_read_with), 1) - (table, columns, key_set, index, limit) = _read_with[0] - - self.assertEqual(table, TABLE_NAME) - self.assertEqual(columns, COLUMNS) - self.assertEqual(key_set, KEYSET) - self.assertEqual(index, INDEX) - self.assertEqual(limit, LIMIT) + snapshot().read.assert_called_once_with( + TABLE_NAME, + COLUMNS, + KEYSET, + INDEX, + LIMIT, + ) def test_execute_sql_not_created(self): SQL = 'SELECT first_name, age FROM citizens' - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) with self.assertRaises(ValueError): session.execute_sql(SQL) def test_execute_sql_defaults(self): - from google.cloud.spanner_v1 import session as MUT - from google.cloud._testing import _Monkey - SQL = 'SELECT first_name, age FROM citizens' - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) session._session_id = 'DEADBEEF' - _executed_sql_with = [] - expected = object() + with mock.patch( + 'google.cloud.spanner_v1.session.Snapshot') as snapshot: + found = session.execute_sql(SQL) + + self.assertIs(found, snapshot().execute_sql.return_value) - class _Snapshot(object): + snapshot().execute_sql.assert_called_once_with( + SQL, + None, + None, + None, + ) - def __init__(self, session, **kwargs): - self._session = session - self._kwargs = kwargs.copy() + def test_execute_sql_explicit(self): + from google.protobuf.struct_pb2 import Struct, Value + from google.cloud.spanner_v1.proto.type_pb2 import STRING - def execute_sql( - self, sql, params=None, param_types=None, query_mode=None): - _executed_sql_with.append( - (sql, params, param_types, query_mode)) - return expected + SQL = 'SELECT first_name, age FROM citizens' + database = self._make_database() + session = self._make_one(database) + session._session_id = 'DEADBEEF' - with _Monkey(MUT, Snapshot=_Snapshot): - found = session.execute_sql(SQL) + params = Struct(fields={'foo': Value(string_value='bar')}) + param_types = {'foo': STRING} - self.assertIs(found, expected) + with mock.patch( + 'google.cloud.spanner_v1.session.Snapshot') as snapshot: + found = session.execute_sql( + SQL, params, param_types, 'PLAN') - self.assertEqual(len(_executed_sql_with), 1) - sql, params, param_types, query_mode = _executed_sql_with[0] + self.assertIs(found, snapshot().execute_sql.return_value) - self.assertEqual(sql, SQL) - self.assertEqual(params, None) - self.assertEqual(param_types, None) - self.assertEqual(query_mode, None) + snapshot().execute_sql.assert_called_once_with( + SQL, + params, + param_types, + 'PLAN', + ) def test_batch_not_created(self): - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) with self.assertRaises(ValueError): @@ -347,7 +409,7 @@ def test_batch_not_created(self): def test_batch_created(self): from google.cloud.spanner_v1.batch import Batch - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) session._session_id = 'DEADBEEF' @@ -357,7 +419,7 @@ def test_batch_created(self): self.assertIs(batch._session, session) def test_transaction_not_created(self): - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) with self.assertRaises(ValueError): @@ -366,7 +428,7 @@ def test_transaction_not_created(self): def test_transaction_created(self): from google.cloud.spanner_v1.transaction import Transaction - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) session._session_id = 'DEADBEEF' @@ -377,7 +439,7 @@ def test_transaction_created(self): self.assertIs(session._transaction, transaction) def test_transaction_w_existing_txn(self): - database = _Database(self.DATABASE_NAME) + database = self._make_database() session = self._make_one(database) session._session_id = 'DEADBEEF' @@ -389,7 +451,7 @@ def test_transaction_w_existing_txn(self): def test_run_in_transaction_callback_raises_non_gax_error(self): from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) from google.cloud.spanner_v1.transaction import Transaction TABLE_NAME = 'citizens' @@ -400,14 +462,13 @@ def test_run_in_transaction_callback_raises_non_gax_error(self): ] TRANSACTION_ID = b'FACEDACE' transaction_pb = TransactionPB(id=TRANSACTION_ID) - gax_api = _SpannerApi( - _begin_transaction_response=transaction_pb, - _rollback_response=None, - ) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.begin_transaction.return_value = transaction_pb + gax_api.rollback.return_value = None + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) - session._session_id = 'DEADBEEF' + session._session_id = self.SESSION_ID called_with = [] @@ -431,10 +492,24 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ()) self.assertEqual(kw, {}) + expected_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite(), + ) + gax_api.begin_transaction.assert_called_once_with( + self.SESSION_NAME, + expected_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + gax_api.rollback.assert_called_once_with( + self.SESSION_NAME, + TRANSACTION_ID, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_run_in_transaction_callback_raises_non_abort_rpc_error(self): from google.api_core.exceptions import Cancelled from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) from google.cloud.spanner_v1.transaction import Transaction TABLE_NAME = 'citizens' @@ -445,14 +520,13 @@ def test_run_in_transaction_callback_raises_non_abort_rpc_error(self): ] TRANSACTION_ID = b'FACEDACE' transaction_pb = TransactionPB(id=TRANSACTION_ID) - gax_api = _SpannerApi( - _begin_transaction_response=transaction_pb, - _rollback_response=None, - ) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.begin_transaction.return_value = transaction_pb + gax_api.rollback.return_value = None + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) - session._session_id = 'DEADBEEF' + session._session_id = self.SESSION_ID called_with = [] @@ -473,11 +547,21 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ()) self.assertEqual(kw, {}) + expected_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite(), + ) + gax_api.begin_transaction.assert_called_once_with( + self.SESSION_NAME, + expected_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + gax_api.rollback.assert_not_called() + def test_run_in_transaction_w_args_w_kwargs_wo_abort(self): import datetime from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) from google.cloud._helpers import UTC from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud.spanner_v1.transaction import Transaction @@ -493,14 +577,13 @@ def test_run_in_transaction_w_args_w_kwargs_wo_abort(self): now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) response = CommitResponse(commit_timestamp=now_pb) - gax_api = _SpannerApi( - _begin_transaction_response=transaction_pb, - _commit_response=response, - ) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.begin_transaction.return_value = transaction_pb + gax_api.commit.return_value = response + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) - session._session_id = 'DEADBEEF' + session._session_id = self.SESSION_ID called_with = [] @@ -520,6 +603,21 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ('abc',)) self.assertEqual(kw, {'some_arg': 'def'}) + expected_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite(), + ) + gax_api.begin_transaction.assert_called_once_with( + self.SESSION_NAME, + expected_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + gax_api.commit.assert_called_once_with( + self.SESSION_NAME, + txn._mutations, + transaction_id=TRANSACTION_ID, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_run_in_transaction_w_commit_error(self): from google.api_core.exceptions import Unknown from google.cloud.spanner_v1.transaction import Transaction @@ -530,14 +628,15 @@ def test_run_in_transaction_w_commit_error(self): ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], ['bharney@example.com', 'Bharney', 'Rhubble', 31], ] - gax_api = _SpannerApi( - _commit_error=True) - database = _Database(self.DATABASE_NAME) + TRANSACTION_ID = b'FACEDACE' + gax_api = self._make_spanner_api() + gax_api.commit.side_effect = Unknown('error') + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) - session._session_id = 'DEADBEEF' + session._session_id = self.SESSION_ID begun_txn = session._transaction = Transaction(session) - begun_txn._transaction_id = b'FACEDACE' + begun_txn._transaction_id = TRANSACTION_ID assert session._transaction._transaction_id @@ -558,11 +657,20 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ()) self.assertEqual(kw, {}) + gax_api.begin_transaction.assert_not_called() + gax_api.commit.assert_called_once_with( + self.SESSION_NAME, + txn._mutations, + transaction_id=TRANSACTION_ID, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_run_in_transaction_w_abort_no_retry_metadata(self): import datetime + from google.api_core.exceptions import Aborted from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) from google.cloud._helpers import UTC from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud.spanner_v1.transaction import Transaction @@ -577,16 +685,15 @@ def test_run_in_transaction_w_abort_no_retry_metadata(self): transaction_pb = TransactionPB(id=TRANSACTION_ID) now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) + aborted = _make_rpc_error(Aborted, trailing_metadata=[]) response = CommitResponse(commit_timestamp=now_pb) - gax_api = _SpannerApi( - _begin_transaction_response=transaction_pb, - _commit_abort_count=1, - _commit_response=response, - ) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.begin_transaction.return_value = transaction_pb + gax_api.commit.side_effect = [aborted, response] + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) - session._session_id = 'DEADBEEF' + session._session_id = self.SESSION_ID called_with = [] @@ -605,16 +712,36 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ('abc',)) self.assertEqual(kw, {'some_arg': 'def'}) + expected_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite(), + ) + self.assertEqual( + gax_api.begin_transaction.call_args_list, + [mock.call( + self.SESSION_NAME, + expected_options, + metadata=[('google-cloud-resource-prefix', database.name)], + )] * 2) + self.assertEqual( + gax_api.commit.call_args_list, + [mock.call( + self.SESSION_NAME, + txn._mutations, + transaction_id=TRANSACTION_ID, + metadata=[('google-cloud-resource-prefix', database.name)], + )] * 2) + def test_run_in_transaction_w_abort_w_retry_metadata(self): import datetime + from google.api_core.exceptions import Aborted + from google.protobuf.duration_pb2 import Duration + from google.rpc.error_details_pb2 import RetryInfo from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) from google.cloud._helpers import UTC from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud.spanner_v1.transaction import Transaction - from google.cloud.spanner_v1 import session as MUT - from google.cloud._testing import _Monkey TABLE_NAME = 'citizens' COLUMNS = ['email', 'first_name', 'last_name', 'age'] @@ -625,21 +752,28 @@ def test_run_in_transaction_w_abort_w_retry_metadata(self): TRANSACTION_ID = b'FACEDACE' RETRY_SECONDS = 12 RETRY_NANOS = 3456 + retry_info = RetryInfo( + retry_delay=Duration( + seconds=RETRY_SECONDS, + nanos=RETRY_NANOS)) + trailing_metadata = [ + ('google.rpc.retryinfo-bin', retry_info.SerializeToString()), + ] + aborted = _make_rpc_error( + Aborted, + trailing_metadata=trailing_metadata, + ) transaction_pb = TransactionPB(id=TRANSACTION_ID) now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) response = CommitResponse(commit_timestamp=now_pb) - gax_api = _SpannerApi( - _begin_transaction_response=transaction_pb, - _commit_abort_count=1, - _commit_abort_retry_seconds=RETRY_SECONDS, - _commit_abort_retry_nanos=RETRY_NANOS, - _commit_response=response, - ) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.begin_transaction.return_value = transaction_pb + gax_api.commit.side_effect = [aborted, response] + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) - session._session_id = 'DEADBEEF' + session._session_id = self.SESSION_ID called_with = [] @@ -647,14 +781,12 @@ def unit_of_work(txn, *args, **kw): called_with.append((txn, args, kw)) txn.insert(TABLE_NAME, COLUMNS, VALUES) - time_module = _FauxTimeModule() - - with _Monkey(MUT, time=time_module): + with mock.patch('time.sleep') as sleep_mock: session.run_in_transaction(unit_of_work, 'abc', some_arg='def') - self.assertEqual(time_module._slept, - RETRY_SECONDS + RETRY_NANOS / 1.0e9) + sleep_mock.assert_called_once_with(RETRY_SECONDS + RETRY_NANOS / 1.0e9) self.assertEqual(len(called_with), 2) + for index, (txn, args, kw) in enumerate(called_with): self.assertIsInstance(txn, Transaction) if index == 1: @@ -664,17 +796,36 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ('abc',)) self.assertEqual(kw, {'some_arg': 'def'}) + expected_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite(), + ) + self.assertEqual( + gax_api.begin_transaction.call_args_list, + [mock.call( + self.SESSION_NAME, + expected_options, + metadata=[('google-cloud-resource-prefix', database.name)], + )] * 2) + self.assertEqual( + gax_api.commit.call_args_list, + [mock.call( + self.SESSION_NAME, + txn._mutations, + transaction_id=TRANSACTION_ID, + metadata=[('google-cloud-resource-prefix', database.name)], + )] * 2) + def test_run_in_transaction_w_callback_raises_abort_wo_metadata(self): import datetime from google.api_core.exceptions import Aborted + from google.protobuf.duration_pb2 import Duration + from google.rpc.error_details_pb2 import RetryInfo from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) from google.cloud._helpers import UTC from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud.spanner_v1.transaction import Transaction - from google.cloud.spanner_v1 import session as MUT - from google.cloud._testing import _Monkey TABLE_NAME = 'citizens' COLUMNS = ['email', 'first_name', 'last_name', 'age'] @@ -689,33 +840,33 @@ def test_run_in_transaction_w_callback_raises_abort_wo_metadata(self): now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) response = CommitResponse(commit_timestamp=now_pb) - gax_api = _SpannerApi( - _begin_transaction_response=transaction_pb, - _commit_abort_retry_seconds=RETRY_SECONDS, - _commit_abort_retry_nanos=RETRY_NANOS, - _commit_response=response, - ) - database = _Database(self.DATABASE_NAME) + retry_info = RetryInfo( + retry_delay=Duration( + seconds=RETRY_SECONDS, + nanos=RETRY_NANOS)) + trailing_metadata = [ + ('google.rpc.retryinfo-bin', retry_info.SerializeToString()), + ] + gax_api = self._make_spanner_api() + gax_api.begin_transaction.return_value = transaction_pb + gax_api.commit.side_effect = [response] + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) - session._session_id = 'DEADBEEF' + session._session_id = self.SESSION_ID called_with = [] def unit_of_work(txn, *args, **kw): called_with.append((txn, args, kw)) if len(called_with) < 2: - raise _make_rpc_error( - Aborted, gax_api._trailing_metadata()) + raise _make_rpc_error(Aborted, trailing_metadata) txn.insert(TABLE_NAME, COLUMNS, VALUES) - time_module = _FauxTimeModule() - - with _Monkey(MUT, time=time_module): + with mock.patch('time.sleep') as sleep_mock: session.run_in_transaction(unit_of_work) - self.assertEqual(time_module._slept, - RETRY_SECONDS + RETRY_NANOS / 1.0e9) + sleep_mock.assert_called_once_with(RETRY_SECONDS + RETRY_NANOS / 1.0e9) self.assertEqual(len(called_with), 2) for index, (txn, args, kw) in enumerate(called_with): self.assertIsInstance(txn, Transaction) @@ -726,16 +877,34 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ()) self.assertEqual(kw, {}) + expected_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite(), + ) + self.assertEqual( + gax_api.begin_transaction.call_args_list, + [mock.call( + self.SESSION_NAME, + expected_options, + metadata=[('google-cloud-resource-prefix', database.name)], + )] * 2) + gax_api.commit.assert_called_once_with( + self.SESSION_NAME, + txn._mutations, + transaction_id=TRANSACTION_ID, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_run_in_transaction_w_abort_w_retry_metadata_deadline(self): import datetime from google.api_core.exceptions import Aborted + from google.protobuf.duration_pb2 import Duration + from google.rpc.error_details_pb2 import RetryInfo from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) + from google.cloud.spanner_v1.transaction import Transaction from google.cloud._helpers import UTC from google.cloud._helpers import _datetime_to_pb_timestamp - from google.cloud.spanner_v1 import session as MUT - from google.cloud._testing import _Monkey TABLE_NAME = 'citizens' COLUMNS = ['email', 'first_name', 'last_name', 'age'] @@ -750,17 +919,24 @@ def test_run_in_transaction_w_abort_w_retry_metadata_deadline(self): now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) response = CommitResponse(commit_timestamp=now_pb) - gax_api = _SpannerApi( - _begin_transaction_response=transaction_pb, - _commit_abort_count=1, - _commit_abort_retry_seconds=RETRY_SECONDS, - _commit_abort_retry_nanos=RETRY_NANOS, - _commit_response=response, + retry_info = RetryInfo( + retry_delay=Duration( + seconds=RETRY_SECONDS, + nanos=RETRY_NANOS)) + trailing_metadata = [ + ('google.rpc.retryinfo-bin', retry_info.SerializeToString()), + ] + aborted = _make_rpc_error( + Aborted, + trailing_metadata=trailing_metadata, ) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.begin_transaction.return_value = transaction_pb + gax_api.commit.side_effect = [aborted, response] + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) - session._session_id = 'DEADBEEF' + session._session_id = self.SESSION_ID called_with = [] @@ -768,23 +944,44 @@ def unit_of_work(txn, *args, **kw): called_with.append((txn, args, kw)) txn.insert(TABLE_NAME, COLUMNS, VALUES) - time_module = _FauxTimeModule() - time_module._times = [1, 1.5] + # retry once w/ timeout_secs=1 + def _time(_results=[1, 1.5]): + return _results.pop(0) - with _Monkey(MUT, time=time_module): - with self.assertRaises(Aborted): - session.run_in_transaction( - unit_of_work, 'abc', timeout_secs=1) + with mock.patch('time.time', _time): + with mock.patch('time.sleep') as sleep_mock: + with self.assertRaises(Aborted): + session.run_in_transaction( + unit_of_work, 'abc', timeout_secs=1) + + sleep_mock.assert_not_called() - self.assertIsNone(time_module._slept) self.assertEqual(len(called_with), 1) + txn, args, kw = called_with[0] + self.assertIsInstance(txn, Transaction) + self.assertIsNone(txn.committed) + self.assertEqual(args, ('abc',)) + self.assertEqual(kw, {}) + + expected_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite(), + ) + gax_api.begin_transaction.assert_called_once_with( + self.SESSION_NAME, + expected_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + gax_api.commit.assert_called_once_with( + self.SESSION_NAME, + txn._mutations, + transaction_id=TRANSACTION_ID, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_run_in_transaction_w_timeout(self): from google.api_core.exceptions import Aborted - from google.cloud.spanner_v1 import session as MUT - from google.cloud._testing import _Monkey from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) from google.cloud.spanner_v1.transaction import Transaction TABLE_NAME = 'citizens' @@ -795,14 +992,17 @@ def test_run_in_transaction_w_timeout(self): ] TRANSACTION_ID = b'FACEDACE' transaction_pb = TransactionPB(id=TRANSACTION_ID) - gax_api = _SpannerApi( - _begin_transaction_response=transaction_pb, - _commit_abort_count=1e6, + aborted = _make_rpc_error( + Aborted, + trailing_metadata=[], ) - database = _Database(self.DATABASE_NAME) + gax_api = self._make_spanner_api() + gax_api.begin_transaction.return_value = transaction_pb + gax_api.commit.side_effect = aborted + database = self._make_database() database.spanner_api = gax_api session = self._make_one(database) - session._session_id = 'DEADBEEF' + session._session_id = self.SESSION_ID called_with = [] @@ -810,14 +1010,17 @@ def unit_of_work(txn, *args, **kw): called_with.append((txn, args, kw)) txn.insert(TABLE_NAME, COLUMNS, VALUES) - time_module = _FauxTimeModule() - time_module._times = [1, 1.5, 2.5] # retry once w/ timeout_secs=1 + # retry once w/ timeout_secs=1 + def _time(_results=[1, 1.5, 2.5]): + return _results.pop(0) + + with mock.patch('time.time', _time): + with mock.patch('time.sleep') as sleep_mock: + with self.assertRaises(Aborted): + session.run_in_transaction(unit_of_work, timeout_secs=1) - with _Monkey(MUT, time=time_module): - with self.assertRaises(Aborted): - session.run_in_transaction(unit_of_work, timeout_secs=1) + sleep_mock.assert_not_called() - self.assertEqual(time_module._slept, None) self.assertEqual(len(called_with), 2) for txn, args, kw in called_with: self.assertIsInstance(txn, Transaction) @@ -825,109 +1028,21 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ()) self.assertEqual(kw, {}) - -class _Database(object): - - def __init__(self, name): - self.name = name - - -class _SpannerApi(object): - - _commit_abort_count = 0 - _commit_abort_retry_seconds = None - _commit_abort_retry_nanos = None - _commit_error = False - _rpc_error = None - - def __init__(self, **kwargs): - self.__dict__.update(**kwargs) - - def create_session(self, database, metadata=None): - if self._rpc_error is not None: - raise self._rpc_error - - self._create_session_called_with = database, metadata - return self._create_session_response - - def get_session(self, name, metadata=None): - from google.api_core.exceptions import NotFound - - if self._rpc_error is not None: - raise self._rpc_error - - self._get_session_called_with = name, metadata - try: - return self._get_session_response - except AttributeError: - raise NotFound('miss') - - def delete_session(self, name, metadata=None): - from google.api_core.exceptions import NotFound - - if self._rpc_error is not None: - raise self._rpc_error - - self._delete_session_called_with = name, metadata - if not self._delete_session_ok: - raise NotFound('miss') - - def begin_transaction(self, session, options_, metadata=None): - self._begun = (session, options_, metadata) - return self._begin_transaction_response - - def _trailing_metadata(self): - from google.protobuf.duration_pb2 import Duration - from google.rpc.error_details_pb2 import RetryInfo - - if self._commit_abort_retry_nanos is None: - return [] - - retry_info = RetryInfo( - retry_delay=Duration( - seconds=self._commit_abort_retry_seconds, - nanos=self._commit_abort_retry_nanos)) - return [ - ('google.rpc.retryinfo-bin', retry_info.SerializeToString()), - ] - - def commit(self, session, mutations, - transaction_id='', single_use_transaction=None, metadata=None): - from google.api_core.exceptions import Unknown, Aborted - - assert single_use_transaction is None - self._committed = (session, mutations, transaction_id, metadata) - if self._commit_error: - raise Unknown('error') - if self._commit_abort_count > 0: - self._commit_abort_count -= 1 - raise _make_rpc_error( - Aborted, trailing_metadata=self._trailing_metadata()) - return self._commit_response - - def rollback(self, session, transaction_id, metadata=None): - self._rolled_back = (session, transaction_id, metadata) - return self._rollback_response - - -class _SessionPB(object): - - def __init__(self, name): - self.name = name - - -class _FauxTimeModule(object): - - _slept = None - _times = () - - def time(self): - import time - - if len(self._times) > 0: - return self._times.pop(0) - - return time.time() - - def sleep(self, seconds): - self._slept = seconds + expected_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite(), + ) + self.assertEqual( + gax_api.begin_transaction.call_args_list, + [mock.call( + self.SESSION_NAME, + expected_options, + metadata=[('google-cloud-resource-prefix', database.name)], + )] * 2) + self.assertEqual( + gax_api.commit.call_args_list, + [mock.call( + self.SESSION_NAME, + txn._mutations, + transaction_id=TRANSACTION_ID, + metadata=[('google-cloud-resource-prefix', database.name)], + )] * 2)