Skip to content

Commit

Permalink
Pass pool's 'labels' through to created sessions.
Browse files Browse the repository at this point in the history
  • Loading branch information
tseaver committed Aug 2, 2018
1 parent 0721de9 commit 69c00b5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
2 changes: 2 additions & 0 deletions spanner/google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def _new_session(self):
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
:returns: new session instance.
"""
if self.labels:
return self._database.session(labels=self.labels)
return self._database.session()

def session(self, **kwargs):
Expand Down
45 changes: 38 additions & 7 deletions spanner/tests/unit/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@
from functools import total_ordering
import unittest

import mock


def _make_database(name='name'):
from google.cloud.spanner_v1.database import Database

return mock.create_autospec(Database, instance=True)


def _make_session():
from google.cloud.spanner_v1.database import Session

return mock.create_autospec(Session, instance=True)


class TestAbstractSessionPool(unittest.TestCase):

Expand All @@ -40,7 +54,7 @@ def test_ctor_explicit(self):

def test_bind_abstract(self):
pool = self._make_one()
database = _Database('name')
database = _make_database('name')
with self.assertRaises(NotImplementedError):
pool.bind(database)

Expand All @@ -60,13 +74,30 @@ def test_clear_abstract(self):
with self.assertRaises(NotImplementedError):
pool.clear()

def test__new_session(self):
def test__new_session_wo_labels(self):
pool = self._make_one()
database = pool._database = _Database('name')
sessions = [_Session(database)]
database._sessions.extend(sessions)
session = pool._new_session()
self.assertIsInstance(session, _Session)
database = pool._database = _make_database('name')
session = _make_session()
database.session.return_value = session

new_session = pool._new_session()

self.assertIs(new_session, session)
database.session.assert_called_once_with()

def test__new_session_w_labels(self):
labels = {'foo': 'bar'}
pool = self._make_one(labels=labels)
database = pool._database = _make_database('name')
session = _make_session()
database.session.return_value = session

new_session = pool._new_session()

self.assertIs(new_session, session)
database.session.assert_called_once_with(
labels=labels,
)

def test_session_wo_kwargs(self):
from google.cloud.spanner_v1.pool import SessionCheckout
Expand Down

0 comments on commit 69c00b5

Please sign in to comment.