Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spanner: add support for session / pool labels #5734

Merged
merged 10 commits into from
Aug 6, 2018
7 changes: 5 additions & 2 deletions spanner/google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,16 @@ def drop(self):
metadata = _metadata_with_prefix(self.name)
api.drop_database(self.name, metadata=metadata)

def session(self):
def session(self, labels=None):
"""Factory to create a session for this database.

:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for the session.

:rtype: :class:`~google.cloud.spanner_v1.session.Session`
:returns: a session bound to this database.
"""
return Session(self)
return Session(self, labels=labels)

def snapshot(self, **kw):
"""Return an object which wraps a snapshot.
Expand Down
74 changes: 62 additions & 12 deletions spanner/google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,28 @@


class AbstractSessionPool(object):
"""Specifies required API for concrete session pool implementations."""
"""Specifies required API for concrete session pool implementations.

:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for sessions created
by the pool.
"""
_database = None

def __init__(self, labels=None):
if labels is None:
labels = {}
self._labels = labels

@property

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

def labels(self):
"""User-assigned labels for sesions created by the pool.

:rtype: dict (str -> str)
:returns: labels assigned by the user
"""
return self._labels

def bind(self, database):
"""Associate the pool with a database.

Expand Down Expand Up @@ -80,6 +98,16 @@ def clear(self):
"""
raise NotImplementedError()

def _new_session(self):
"""Helper for concrete methods creating session instances.

:rtype: :class:`~google.cloud.spanner_v1.session.Session`
:returns: new session instance.
"""
if self.labels:
return self._database.session(labels=self.labels)
return self._database.session()

def session(self, **kwargs):
"""Check out a session from the pool.

Expand Down Expand Up @@ -115,11 +143,17 @@ class FixedSizePool(AbstractSessionPool):
:type default_timeout: int
:param default_timeout: default timeout, in seconds, to wait for
a returned session.

:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for sessions created
by the pool.
"""
DEFAULT_SIZE = 10
DEFAULT_TIMEOUT = 10

def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT):
def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT,
labels=None):
super(FixedSizePool, self).__init__(labels=labels)
self.size = size
self.default_timeout = default_timeout
self._sessions = queue.Queue(size)
Expand All @@ -134,7 +168,7 @@ def bind(self, database):
self._database = database

while not self._sessions.full():
session = database.session()
session = self._new_session()
session.create()
self._sessions.put(session)

Expand Down Expand Up @@ -198,9 +232,14 @@ class BurstyPool(AbstractSessionPool):

:type target_size: int
:param target_size: max pool size

:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for sessions created
by the pool.
"""

def __init__(self, target_size=10):
def __init__(self, target_size=10, labels=None):
super(BurstyPool, self).__init__(labels=labels)
self.target_size = target_size
self._database = None
self._sessions = queue.Queue(target_size)
Expand All @@ -224,11 +263,11 @@ def get(self):
try:
session = self._sessions.get_nowait()
except queue.Empty:
session = self._database.session()
session = self._new_session()
session.create()
else:
if not session.exists():
session = self._database.session()
session = self._new_session()
session.create()
return session

Expand Down Expand Up @@ -290,9 +329,15 @@ class PingingPool(AbstractSessionPool):

:type ping_interval: int
:param ping_interval: interval at which to ping sessions.

:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for sessions created
by the pool.
"""

def __init__(self, size=10, default_timeout=10, ping_interval=3000):
def __init__(self, size=10, default_timeout=10, ping_interval=3000,
labels=None):
super(PingingPool, self).__init__(labels=labels)
self.size = size
self.default_timeout = default_timeout
self._delta = datetime.timedelta(seconds=ping_interval)
Expand All @@ -308,7 +353,7 @@ def bind(self, database):
self._database = database

for _ in xrange(self.size):
session = database.session()
session = self._new_session()
session.create()
self.put(session)

Expand All @@ -330,7 +375,7 @@ def get(self, timeout=None): # pylint: disable=arguments-differ

if _NOW() > ping_after:
if not session.exists():
session = self._database.session()
session = self._new_session()
session.create()

return session
Expand Down Expand Up @@ -373,7 +418,7 @@ def ping(self):
self._sessions.put((ping_after, session))
break
if not session.exists(): # stale
session = self._database.session()
session = self._new_session()
session.create()
# Re-add to queue with new expiration
self.put(session)
Expand All @@ -400,13 +445,18 @@ class TransactionPingingPool(PingingPool):

:type ping_interval: int
:param ping_interval: interval at which to ping sessions.

:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for sessions created
by the pool.
"""

def __init__(self, size=10, default_timeout=10, ping_interval=3000):
def __init__(self, size=10, default_timeout=10, ping_interval=3000,
labels=None):
self._pending_sessions = queue.Queue()

super(TransactionPingingPool, self).__init__(
size, default_timeout, ping_interval)
size, default_timeout, ping_interval, labels=labels)

self.begin_pending_transactions()

Expand Down
26 changes: 24 additions & 2 deletions spanner/google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,19 @@ class Session(object):

:type database: :class:`~google.cloud.spanner_v1.database.Database`
:param database: The database to which the session is bound.

:type labels: dict (str -> str)
:param labels: (Optional) User-assigned labels for the session.
"""

_session_id = None
_transaction = None

def __init__(self, database):
def __init__(self, database, labels=None):
self._database = database
if labels is None:
labels = {}
self._labels = labels

def __lt__(self, other):
return self._session_id < other._session_id
Expand All @@ -60,6 +66,15 @@ def session_id(self):
"""Read-only ID, set by the back-end during :meth:`create`."""
return self._session_id

@property
def labels(self):
"""User-assigned labels for the session.

:rtype: dict (str -> str)
:returns: the labels dict (empty if no labels were assigned.
"""
return self._labels

@property
def name(self):
"""Session name used in requests.
Expand Down Expand Up @@ -93,7 +108,14 @@ def create(self):
raise ValueError('Session ID already set by back-end')
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
session_pb = api.create_session(self._database.name, metadata=metadata)
kw = {}
if self._labels:
kw = {'session': {'labels': self._labels}}
session_pb = api.create_session(
self._database.name,
metadata=metadata,
**kw
)
self._session_id = session_pb.name.split('/')[-1]

def exists(self):
Expand Down
10 changes: 5 additions & 5 deletions spanner/tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class TestDatabaseAPI(unittest.TestCase, _TestData):

@classmethod
def setUpClass(cls):
pool = BurstyPool()
pool = BurstyPool(labels={'testcase': 'database_api'})
cls._db = Config.INSTANCE.database(
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool)
operation = cls._db.create()
Expand All @@ -264,7 +264,7 @@ def test_list_databases(self):
self.assertTrue(self._db.name in database_names)

def test_create_database(self):
pool = BurstyPool()
pool = BurstyPool(labels={'testcase': 'create_database'})
temp_db_id = 'temp_db' + unique_resource_id('_')
temp_db = Config.INSTANCE.database(temp_db_id, pool=pool)
operation = temp_db.create()
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_table_not_found(self):
'https://github.com/GoogleCloudPlatform/google-cloud-python/issues/'
'5629'))
def test_update_database_ddl(self):
pool = BurstyPool()
pool = BurstyPool(labels={'testcase': 'update_database_ddl'})
temp_db_id = 'temp_db' + unique_resource_id('_')
temp_db = Config.INSTANCE.database(temp_db_id, pool=pool)
create_op = temp_db.create()
Expand Down Expand Up @@ -434,7 +434,7 @@ class TestSessionAPI(unittest.TestCase, _TestData):

@classmethod
def setUpClass(cls):
pool = BurstyPool()
pool = BurstyPool(labels={'testcase': 'session_api'})
cls._db = Config.INSTANCE.database(
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool)
operation = cls._db.create()
Expand Down Expand Up @@ -901,7 +901,7 @@ def test_read_w_index(self):
EXTRA_DDL = [
'CREATE INDEX contacts_by_last_name ON contacts(last_name)',
]
pool = BurstyPool()
pool = BurstyPool(labels={'testcase': 'read_w_index'})
temp_db = Config.INSTANCE.database(
'test_read' + unique_resource_id('_'),
ddl_statements=DDL_STATEMENTS + EXTRA_DDL,
Expand Down
19 changes: 18 additions & 1 deletion spanner/tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def test_drop_success(self):
self.assertEqual(
metadata, [('google-cloud-resource-prefix', database.name)])

def test_session_factory(self):
def test_session_factory_defaults(self):
from google.cloud.spanner_v1.session import Session

client = _Client()
Expand All @@ -609,6 +609,23 @@ def test_session_factory(self):
self.assertTrue(isinstance(session, Session))
self.assertIs(session.session_id, None)
self.assertIs(session._database, database)
self.assertEqual(session.labels, {})

def test_session_factory_w_labels(self):
from google.cloud.spanner_v1.session import Session

client = _Client()
instance = _Instance(self.INSTANCE_NAME, client=client)
pool = _Pool()
labels = {'foo': 'bar'}
database = self._make_one(self.DATABASE_ID, instance, pool=pool)

session = database.session(labels=labels)

self.assertTrue(isinstance(session, Session))
self.assertIs(session.session_id, None)
self.assertIs(session._database, database)
self.assertEqual(session.labels, labels)

def test_snapshot_defaults(self):
from google.cloud.spanner_v1.database import SnapshotCheckout
Expand Down
Loading