Skip to content

Commit

Permalink
feat: Implementation of client side statements that return (#1046)
Browse files Browse the repository at this point in the history
* Implementation of client side statements that return

* Small fix

* Incorporated comments

* Added tests for exception in commit and rollback

* Fix in tests

* Skipping few tests from running in emulator

* Few fixes

* Refactoring

* Incorporated comments

* Incorporating comments
  • Loading branch information
ankiaga authored Dec 12, 2023
1 parent 95b8e74 commit bb5fa1f
Show file tree
Hide file tree
Showing 11 changed files with 581 additions and 259 deletions.
66 changes: 60 additions & 6 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,27 @@

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi import ProgrammingError

from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
ClientSideStatementType,
)
from google.cloud.spanner_v1 import (
Type,
StructType,
TypeCode,
ResultSetMetadata,
PartialResultSet,
)

from google.cloud.spanner_v1._helpers import _make_value_pb
from google.cloud.spanner_v1.streamed import StreamedResultSet

CONNECTION_CLOSED_ERROR = "This connection is closed"
TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as a transaction has not been started."
)


def execute(connection: "Connection", parsed_statement: ParsedStatement):
Expand All @@ -32,9 +49,46 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement):
:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
return connection.commit()
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
return connection.begin()
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
return connection.rollback()
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
statement_type = parsed_statement.client_side_statement_type
if statement_type == ClientSideStatementType.COMMIT:
connection.commit()
return None
if statement_type == ClientSideStatementType.BEGIN:
connection.begin()
return None
if statement_type == ClientSideStatementType.ROLLBACK:
connection.rollback()
return None
if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP:
if connection._transaction is None:
committed_timestamp = None
else:
committed_timestamp = connection._transaction.committed
return _get_streamed_result_set(
ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name,
TypeCode.TIMESTAMP,
committed_timestamp,
)
if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP:
if connection._snapshot is None:
read_timestamp = None
else:
read_timestamp = connection._snapshot._transaction_read_timestamp
return _get_streamed_result_set(
ClientSideStatementType.SHOW_READ_TIMESTAMP.name,
TypeCode.TIMESTAMP,
read_timestamp,
)


def _get_streamed_result_set(column_name, type_code, column_value):
struct_type_pb = StructType(
fields=[StructType.Field(name=column_name, type_=Type(code=type_code))]
)

result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb))
if column_value is not None:
result_set.values.extend([_make_value_pb(column_value)])
return StreamedResultSet(iter([result_set]))
23 changes: 16 additions & 7 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)
RE_SHOW_COMMIT_TIMESTAMP = re.compile(
r"^\s*(SHOW)\s+(VARIABLE)\s+(COMMIT_TIMESTAMP)", re.IGNORECASE
)
RE_SHOW_READ_TIMESTAMP = re.compile(
r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE
)


def parse_stmt(query):
Expand All @@ -37,16 +43,19 @@ def parse_stmt(query):
:rtype: ParsedStatement
:returns: ParsedStatement object.
"""
client_side_statement_type = None
if RE_COMMIT.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
)
client_side_statement_type = ClientSideStatementType.COMMIT
if RE_BEGIN.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN
)
client_side_statement_type = ClientSideStatementType.BEGIN
if RE_ROLLBACK.match(query):
client_side_statement_type = ClientSideStatementType.ROLLBACK
if RE_SHOW_COMMIT_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
if RE_SHOW_READ_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK
StatementType.CLIENT_SIDE, query, client_side_statement_type
)
return None
81 changes: 38 additions & 43 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
from deprecated import deprecated

from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
Expand All @@ -35,7 +36,7 @@


CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as transaction has not started"
"This method is non-operational as a transaction has not been started."
)
MAX_INTERNAL_RETRIES = 50

Expand Down Expand Up @@ -107,6 +108,9 @@ def __init__(self, instance, database=None, read_only=False):
self._staleness = None
self.request_priority = None
self._transaction_begin_marked = False
# whether transaction started at Spanner. This means that we had
# made atleast one call to Spanner.
self._spanner_transaction_started = False

@property
def autocommit(self):
Expand Down Expand Up @@ -140,26 +144,15 @@ def database(self):
return self._database

@property
def _spanner_transaction_started(self):
"""Flag: whether transaction started at Spanner. This means that we had
made atleast one call to Spanner. Property client_transaction_started
would always be true if this is true as transaction has to start first
at clientside than at Spanner
Returns:
bool: True if Spanner transaction started, False otherwise.
"""
@deprecated(
reason="This method is deprecated. Use _spanner_transaction_started field"
)
def inside_transaction(self):
return (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
) or (self._snapshot is not None)

@property
def inside_transaction(self):
"""Deprecated property which won't be supported in future versions.
Please use spanner_transaction_started property instead."""
return self._spanner_transaction_started
)

@property
def _client_transaction_started(self):
Expand Down Expand Up @@ -277,7 +270,8 @@ def _release_session(self):
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
self.database._pool.put(self._session)
if self._session is not None:
self.database._pool.put(self._session)
self._session = None

def retry_transaction(self):
Expand All @@ -293,7 +287,7 @@ def retry_transaction(self):
"""
attempt = 0
while True:
self._transaction = None
self._spanner_transaction_started = False
attempt += 1
if attempt > MAX_INTERNAL_RETRIES:
raise
Expand All @@ -319,7 +313,6 @@ def _rerun_previous_statements(self):
status, res = transaction.batch_update(statements)

if status.code == ABORTED:
self.connection._transaction = None
raise Aborted(status.details)

retried_checksum = ResultsChecksum()
Expand Down Expand Up @@ -363,6 +356,8 @@ def transaction_checkout(self):
if not self.read_only and self._client_transaction_started:
if not self._spanner_transaction_started:
self._transaction = self._session_checkout().transaction()
self._snapshot = None
self._spanner_transaction_started = True
self._transaction.begin()

return self._transaction
Expand All @@ -377,11 +372,13 @@ def snapshot_checkout(self):
:returns: A Cloud Spanner snapshot object, ready to use.
"""
if self.read_only and self._client_transaction_started:
if not self._snapshot:
if not self._spanner_transaction_started:
self._snapshot = Snapshot(
self._session_checkout(), multi_use=True, **self.staleness
)
self._transaction = None
self._snapshot.begin()
self._spanner_transaction_started = True

return self._snapshot

Expand All @@ -391,7 +388,7 @@ def close(self):
The connection will be unusable from this point forward. If the
connection has an active transaction, it will be rolled back.
"""
if self._spanner_transaction_started and not self.read_only:
if self._spanner_transaction_started and not self._read_only:
self._transaction.rollback()

if self._own_pool and self.database:
Expand All @@ -405,13 +402,15 @@ def begin(self):
Marks the transaction as started.
:raises: :class:`InterfaceError`: if this connection is closed.
:raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running
:raises: :class:`OperationalError`: if there is an existing transaction
that has been started
"""
if self._transaction_begin_marked:
raise OperationalError("A transaction has already started")
if self._spanner_transaction_started:
raise OperationalError(
"Beginning a new transaction is not allowed when a transaction is already running"
"Beginning a new transaction is not allowed when a transaction "
"is already running"
)
self._transaction_begin_marked = True

Expand All @@ -430,41 +429,37 @@ def commit(self):
return

self.run_prior_DDL_statements()
if self._spanner_transaction_started:
try:
if self.read_only:
self._snapshot = None
else:
self._transaction.commit()

self._release_session()
self._statements = []
self._transaction_begin_marked = False
except Aborted:
self.retry_transaction()
self.commit()
try:
if self._spanner_transaction_started and not self._read_only:
self._transaction.commit()
except Aborted:
self.retry_transaction()
self.commit()
finally:
self._release_session()
self._statements = []
self._transaction_begin_marked = False
self._spanner_transaction_started = False

def rollback(self):
"""Rolls back any pending transaction.
This is a no-op if there is no active client transaction.
"""

if not self._client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
return

if self._spanner_transaction_started:
if self.read_only:
self._snapshot = None
else:
try:
if self._spanner_transaction_started and not self._read_only:
self._transaction.rollback()

finally:
self._release_session()
self._statements = []
self._transaction_begin_marked = False
self._spanner_transaction_started = False

@check_not_closed
def cursor(self):
Expand Down
Loading

0 comments on commit bb5fa1f

Please sign in to comment.