diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 3ef61eed69..886e28d7f7 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -193,14 +193,14 @@ def bind(self, database): metadata = _metadata_with_prefix(database.name) self._database_role = self._database_role or self._database.database_role request = BatchCreateSessionsRequest( + database=database.database_id, + session_count=self.size - self._sessions.qsize(), session_template=Session(creator_role=self.database_role), ) while not self._sessions.full(): resp = api.batch_create_sessions( request=request, - database=database.name, - session_count=self.size - self._sessions.qsize(), metadata=metadata, ) for session_pb in resp.session: @@ -406,14 +406,14 @@ def bind(self, database): self._database_role = self._database_role or self._database.database_role request = BatchCreateSessionsRequest( + database=database.database_id, + session_count=self.size - created_session_count, session_template=Session(creator_role=self.database_role), ) while created_session_count < self.size: resp = api.batch_create_sessions( request=request, - database=database.name, - session_count=self.size - created_session_count, metadata=metadata, ) for session_pb in resp.session: diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 9fac10ed4d..699b3f4a69 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -20,6 +20,7 @@ from google.api_core import exceptions from google.iam.v1 import policy_pb2 from google.cloud import spanner_v1 +from google.cloud.spanner_v1.pool import FixedSizePool, PingingPool from google.type import expr_pb2 from . import _helpers from . import _sample_data @@ -73,6 +74,61 @@ def test_create_database(shared_instance, databases_to_delete, database_dialect) assert temp_db.name in database_ids +def test_database_binding_of_fixed_size_pool( + not_emulator, shared_instance, databases_to_delete, not_postgres +): + temp_db_id = _helpers.unique_id("fixed_size_db", separator="_") + temp_db = shared_instance.database(temp_db_id) + + create_op = temp_db.create() + databases_to_delete.append(temp_db) + create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # Create role and grant select permission on table contacts for parent role. + ddl_statements = _helpers.DDL_STATEMENTS + [ + "CREATE ROLE parent", + "GRANT SELECT ON TABLE contacts TO ROLE parent", + ] + operation = temp_db.update_ddl(ddl_statements) + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + pool = FixedSizePool( + size=1, + default_timeout=500, + database_role="parent", + ) + database = shared_instance.database(temp_db.name, pool=pool) + assert database._pool.database_role == "parent" + + +def test_database_binding_of_pinging_pool( + not_emulator, shared_instance, databases_to_delete, not_postgres +): + temp_db_id = _helpers.unique_id("binding_db", separator="_") + temp_db = shared_instance.database(temp_db_id) + + create_op = temp_db.create() + databases_to_delete.append(temp_db) + create_op.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + # Create role and grant select permission on table contacts for parent role. + ddl_statements = _helpers.DDL_STATEMENTS + [ + "CREATE ROLE parent", + "GRANT SELECT ON TABLE contacts TO ROLE parent", + ] + operation = temp_db.update_ddl(ddl_statements) + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + pool = PingingPool( + size=1, + default_timeout=500, + ping_interval=100, + database_role="parent", + ) + database = shared_instance.database(temp_db.name, pool=pool) + assert database._pool.database_role == "parent" + + def test_create_database_pitr_invalid_retention_period( not_emulator, # PITR-lite features are not supported by the emulator not_postgres, diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 48cc1434ef..3a9d35bc92 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -956,11 +956,10 @@ def __init__(self, name): self.name = name self._sessions = [] self._database_role = None + self.database_id = name def mock_batch_create_sessions( request=None, - database=None, - session_count=10, timeout=10, metadata=[], labels={}, @@ -969,7 +968,7 @@ def mock_batch_create_sessions( from google.cloud.spanner_v1 import Session database_role = request.session_template.creator_role if request else None - if session_count < 2: + if request.session_count < 2: response = BatchCreateSessionsResponse( session=[Session(creator_role=database_role, labels=labels)] )