Skip to content

Commit

Permalink
More plumbing for Database DDL methods
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Dec 20, 2024
1 parent 65757b5 commit a06b00c
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 22 deletions.
53 changes: 42 additions & 11 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,10 @@ def create(self):
database_dialect=self._database_dialect,
proto_descriptors=self._proto_descriptors,
)
future = api.create_database(request=request, metadata=metadata)
future = api.create_database(
request=request,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
return future

def exists(self):
Expand All @@ -531,7 +534,12 @@ def exists(self):
metadata = _metadata_with_prefix(self.name)

try:
api.get_database_ddl(database=self.name, metadata=metadata)
api.get_database_ddl(
database=self.name,
metadata=self.metadata_with_request_id(
self._next_nth_request, 1, metadata
),
)
except NotFound:
return False
return True
Expand All @@ -548,10 +556,16 @@ def reload(self):
"""
api = self._instance._client.database_admin_api
metadata = _metadata_with_prefix(self.name)
response = api.get_database_ddl(database=self.name, metadata=metadata)
response = api.get_database_ddl(
database=self.name,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
self._ddl_statements = tuple(response.statements)
self._proto_descriptors = response.proto_descriptors
response = api.get_database(name=self.name, metadata=metadata)
response = api.get_database(
name=self.name,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
self._state = DatabasePB.State(response.state)
self._create_time = response.create_time
self._restore_info = response.restore_info
Expand Down Expand Up @@ -596,7 +610,10 @@ def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None):
proto_descriptors=proto_descriptors,
)

future = api.update_database_ddl(request=request, metadata=metadata)
future = api.update_database_ddl(
request=request,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
return future

def update(self, fields):
Expand Down Expand Up @@ -634,7 +651,9 @@ def update(self, fields):
metadata = _metadata_with_prefix(self.name)

future = api.update_database(
database=database_pb, update_mask=field_mask, metadata=metadata
database=database_pb,
update_mask=field_mask,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)

return future
Expand All @@ -647,7 +666,10 @@ def drop(self):
"""
api = self._instance._client.database_admin_api
metadata = _metadata_with_prefix(self.name)
api.drop_database(database=self.name, metadata=metadata)
api.drop_database(
database=self.name,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)

def execute_partitioned_dml(
self,
Expand Down Expand Up @@ -995,7 +1017,7 @@ def restore(self, source):
)
future = api.restore_database(
request=request,
metadata=metadata,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
return future

Expand Down Expand Up @@ -1064,7 +1086,10 @@ def list_database_roles(self, page_size=None):
parent=self.name,
page_size=page_size,
)
return api.list_database_roles(request=request, metadata=metadata)
return api.list_database_roles(
request=request,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)

def table(self, table_id):
"""Factory to create a table object within this database.
Expand Down Expand Up @@ -1148,7 +1173,10 @@ def get_iam_policy(self, policy_version=None):
requested_policy_version=policy_version
),
)
response = api.get_iam_policy(request=request, metadata=metadata)
response = api.get_iam_policy(
request=request,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
return response

def set_iam_policy(self, policy):
Expand All @@ -1170,7 +1198,10 @@ def set_iam_policy(self, policy):
resource=self.name,
policy=policy,
)
response = api.set_iam_policy(request=request, metadata=metadata)
response = api.set_iam_policy(
request=request,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
return response

@property
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,9 @@ def bind(self, database):
while created_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
metadata=database.metadata_with_request_id(
database._next_nth_request, 1, metadata
),
)
for session_pb in resp.session:
session = self._new_session()
Expand Down
26 changes: 19 additions & 7 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def create(self):
):
session_pb = api.create_session(
request=request,
metadata=metadata,
metadata=self._database.metadata_with_request_id(
self._database._next_nth_request, 1, metadata
),
)
self._session_id = session_pb.name.split("/")[-1]

Expand Down Expand Up @@ -257,7 +259,12 @@ def delete(self):
},
observability_options=observability_options,
):
api.delete_session(name=self.name, metadata=metadata)
api.delete_session(
name=self.name,
metadata=database.metadata_with_request_id(
database._next_nth_request, 1, metadata
),
)

def ping(self):
"""Ping the session to keep it alive by executing "SELECT 1".
Expand All @@ -266,13 +273,18 @@ def ping(self):
"""
if self._session_id is None:
raise ValueError("Session ID not set by back-end")
api = self._database.spanner_api
database = self._database
metadata = database.metadata_with_request_id(
database._next_nth_request, 1, _metadata_with_prefix(database.name)
)
api = database.spanner_api
database = self._database
request = ExecuteSqlRequest(session=self.name, sql="SELECT 1")
api.execute_sql(request=request, metadata=metadata)
api.execute_sql(
request=request,
metadata=database.metadata_with_request_id(
database._next_nth_request,
1,
_metadata_with_prefix(database.name),
),
)
self._last_use_time = datetime.now()

def snapshot(self, **kw):
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
from datetime import datetime, timedelta

import mock
from google.cloud.spanner_v1._helpers import (
_metadata_with_request_id,
AtomicCounter,
)

from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from tests._helpers import (
OpenTelemetryBase,
Expand Down Expand Up @@ -1179,6 +1184,9 @@ def session_id(self):


class _Database(object):
NTH_REQUEST = AtomicCounter()
NTH_CLIENT_ID = AtomicCounter()

def __init__(self, name):
self.name = name
self._sessions = []
Expand Down Expand Up @@ -1233,6 +1241,28 @@ def session(self, **kwargs):
def observability_options(self):
return dict(db_name=self.name)

@property
def _next_nth_request(self):
return self.NTH_REQUEST.increment()

@property
def _nth_client_id(self):
return self.NTH_CLIENT_ID.increment()

def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
client_id = self._nth_client_id
return _metadata_with_request_id(
self._nth_client_id,
self._channel_id,
nth_request,
nth_attempt,
prior_metadata,
)

@property
def _channel_id(self):
return 1


class _Queue(object):
_size = 1
Expand Down
Loading

0 comments on commit a06b00c

Please sign in to comment.