Skip to content

Commit

Permalink
Spanner: add support for session / pool labels (#5734)
Browse files Browse the repository at this point in the history
  • Loading branch information
tseaver authored Aug 6, 2018
1 parent 06e860a commit e36e986
Show file tree
Hide file tree
Showing 7 changed files with 609 additions and 342 deletions.
7 changes: 5 additions & 2 deletions spanner/google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,16 @@ def drop(self):
metadata = _metadata_with_prefix(self.name)
api.drop_database(self.name, metadata=metadata)

def session(self):
def session(self, labels=None):
"""Factory to create a session for this database.
:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for the session.
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
:returns: a session bound to this database.
"""
return Session(self)
return Session(self, labels=labels)

def snapshot(self, **kw):
"""Return an object which wraps a snapshot.
Expand Down
74 changes: 62 additions & 12 deletions spanner/google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,28 @@


class AbstractSessionPool(object):
"""Specifies required API for concrete session pool implementations."""
"""Specifies required API for concrete session pool implementations.
:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for sessions created
by the pool.
"""
_database = None

def __init__(self, labels=None):
if labels is None:
labels = {}
self._labels = labels

@property
def labels(self):
"""User-assigned labels for sesions created by the pool.
:rtype: dict (str -> str)
:returns: labels assigned by the user
"""
return self._labels

def bind(self, database):
"""Associate the pool with a database.
Expand Down Expand Up @@ -80,6 +98,16 @@ def clear(self):
"""
raise NotImplementedError()

def _new_session(self):
"""Helper for concrete methods creating session instances.
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
:returns: new session instance.
"""
if self.labels:
return self._database.session(labels=self.labels)
return self._database.session()

def session(self, **kwargs):
"""Check out a session from the pool.
Expand Down Expand Up @@ -115,11 +143,17 @@ class FixedSizePool(AbstractSessionPool):
:type default_timeout: int
:param default_timeout: default timeout, in seconds, to wait for
a returned session.
:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for sessions created
by the pool.
"""
DEFAULT_SIZE = 10
DEFAULT_TIMEOUT = 10

def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT):
def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT,
labels=None):
super(FixedSizePool, self).__init__(labels=labels)
self.size = size
self.default_timeout = default_timeout
self._sessions = queue.Queue(size)
Expand All @@ -134,7 +168,7 @@ def bind(self, database):
self._database = database

while not self._sessions.full():
session = database.session()
session = self._new_session()
session.create()
self._sessions.put(session)

Expand Down Expand Up @@ -198,9 +232,14 @@ class BurstyPool(AbstractSessionPool):
:type target_size: int
:param target_size: max pool size
:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for sessions created
by the pool.
"""

def __init__(self, target_size=10):
def __init__(self, target_size=10, labels=None):
super(BurstyPool, self).__init__(labels=labels)
self.target_size = target_size
self._database = None
self._sessions = queue.Queue(target_size)
Expand All @@ -224,11 +263,11 @@ def get(self):
try:
session = self._sessions.get_nowait()
except queue.Empty:
session = self._database.session()
session = self._new_session()
session.create()
else:
if not session.exists():
session = self._database.session()
session = self._new_session()
session.create()
return session

Expand Down Expand Up @@ -290,9 +329,15 @@ class PingingPool(AbstractSessionPool):
:type ping_interval: int
:param ping_interval: interval at which to ping sessions.
:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for sessions created
by the pool.
"""

def __init__(self, size=10, default_timeout=10, ping_interval=3000):
def __init__(self, size=10, default_timeout=10, ping_interval=3000,
labels=None):
super(PingingPool, self).__init__(labels=labels)
self.size = size
self.default_timeout = default_timeout
self._delta = datetime.timedelta(seconds=ping_interval)
Expand All @@ -308,7 +353,7 @@ def bind(self, database):
self._database = database

for _ in xrange(self.size):
session = database.session()
session = self._new_session()
session.create()
self.put(session)

Expand All @@ -330,7 +375,7 @@ def get(self, timeout=None): # pylint: disable=arguments-differ

if _NOW() > ping_after:
if not session.exists():
session = self._database.session()
session = self._new_session()
session.create()

return session
Expand Down Expand Up @@ -373,7 +418,7 @@ def ping(self):
self._sessions.put((ping_after, session))
break
if not session.exists(): # stale
session = self._database.session()
session = self._new_session()
session.create()
# Re-add to queue with new expiration
self.put(session)
Expand All @@ -400,13 +445,18 @@ class TransactionPingingPool(PingingPool):
:type ping_interval: int
:param ping_interval: interval at which to ping sessions.
:type labels: dict (str -> str) or None
:param labels: (Optional) user-assigned labels for sessions created
by the pool.
"""

def __init__(self, size=10, default_timeout=10, ping_interval=3000):
def __init__(self, size=10, default_timeout=10, ping_interval=3000,
labels=None):
self._pending_sessions = queue.Queue()

super(TransactionPingingPool, self).__init__(
size, default_timeout, ping_interval)
size, default_timeout, ping_interval, labels=labels)

self.begin_pending_transactions()

Expand Down
26 changes: 24 additions & 2 deletions spanner/google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,19 @@ class Session(object):
:type database: :class:`~google.cloud.spanner_v1.database.Database`
:param database: The database to which the session is bound.
:type labels: dict (str -> str)
:param labels: (Optional) User-assigned labels for the session.
"""

_session_id = None
_transaction = None

def __init__(self, database):
def __init__(self, database, labels=None):
self._database = database
if labels is None:
labels = {}
self._labels = labels

def __lt__(self, other):
return self._session_id < other._session_id
Expand All @@ -60,6 +66,15 @@ def session_id(self):
"""Read-only ID, set by the back-end during :meth:`create`."""
return self._session_id

@property
def labels(self):
"""User-assigned labels for the session.
:rtype: dict (str -> str)
:returns: the labels dict (empty if no labels were assigned.
"""
return self._labels

@property
def name(self):
"""Session name used in requests.
Expand Down Expand Up @@ -93,7 +108,14 @@ def create(self):
raise ValueError('Session ID already set by back-end')
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
session_pb = api.create_session(self._database.name, metadata=metadata)
kw = {}
if self._labels:
kw = {'session': {'labels': self._labels}}
session_pb = api.create_session(
self._database.name,
metadata=metadata,
**kw
)
self._session_id = session_pb.name.split('/')[-1]

def exists(self):
Expand Down
10 changes: 5 additions & 5 deletions spanner/tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class TestDatabaseAPI(unittest.TestCase, _TestData):

@classmethod
def setUpClass(cls):
pool = BurstyPool()
pool = BurstyPool(labels={'testcase': 'database_api'})
cls._db = Config.INSTANCE.database(
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool)
operation = cls._db.create()
Expand All @@ -264,7 +264,7 @@ def test_list_databases(self):
self.assertTrue(self._db.name in database_names)

def test_create_database(self):
pool = BurstyPool()
pool = BurstyPool(labels={'testcase': 'create_database'})
temp_db_id = 'temp_db' + unique_resource_id('_')
temp_db = Config.INSTANCE.database(temp_db_id, pool=pool)
operation = temp_db.create()
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_table_not_found(self):
'https://github.com/GoogleCloudPlatform/google-cloud-python/issues/'
'5629'))
def test_update_database_ddl(self):
pool = BurstyPool()
pool = BurstyPool(labels={'testcase': 'update_database_ddl'})
temp_db_id = 'temp_db' + unique_resource_id('_')
temp_db = Config.INSTANCE.database(temp_db_id, pool=pool)
create_op = temp_db.create()
Expand Down Expand Up @@ -434,7 +434,7 @@ class TestSessionAPI(unittest.TestCase, _TestData):

@classmethod
def setUpClass(cls):
pool = BurstyPool()
pool = BurstyPool(labels={'testcase': 'session_api'})
cls._db = Config.INSTANCE.database(
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool)
operation = cls._db.create()
Expand Down Expand Up @@ -902,7 +902,7 @@ def test_read_w_index(self):
EXTRA_DDL = [
'CREATE INDEX contacts_by_last_name ON contacts(last_name)',
]
pool = BurstyPool()
pool = BurstyPool(labels={'testcase': 'read_w_index'})
temp_db = Config.INSTANCE.database(
'test_read' + unique_resource_id('_'),
ddl_statements=DDL_STATEMENTS + EXTRA_DDL,
Expand Down
19 changes: 18 additions & 1 deletion spanner/tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def test_drop_success(self):
self.assertEqual(
metadata, [('google-cloud-resource-prefix', database.name)])

def test_session_factory(self):
def test_session_factory_defaults(self):
from google.cloud.spanner_v1.session import Session

client = _Client()
Expand All @@ -609,6 +609,23 @@ def test_session_factory(self):
self.assertTrue(isinstance(session, Session))
self.assertIs(session.session_id, None)
self.assertIs(session._database, database)
self.assertEqual(session.labels, {})

def test_session_factory_w_labels(self):
from google.cloud.spanner_v1.session import Session

client = _Client()
instance = _Instance(self.INSTANCE_NAME, client=client)
pool = _Pool()
labels = {'foo': 'bar'}
database = self._make_one(self.DATABASE_ID, instance, pool=pool)

session = database.session(labels=labels)

self.assertTrue(isinstance(session, Session))
self.assertIs(session.session_id, None)
self.assertIs(session._database, database)
self.assertEqual(session.labels, labels)

def test_snapshot_defaults(self):
from google.cloud.spanner_v1.database import SnapshotCheckout
Expand Down
Loading

0 comments on commit e36e986

Please sign in to comment.