Skip to content

Commit

Permalink
Comments incorporated
Browse files Browse the repository at this point in the history
  • Loading branch information
ankiaga committed Nov 30, 2023
1 parent ba1e60d commit 792c0dc
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 44 deletions.
6 changes: 5 additions & 1 deletion google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
ClientSideStatementType,
)


def execute(connection, parsed_statement: ParsedStatement):
def execute(connection: "Connection", parsed_statement: ParsedStatement):
"""Executes the client side statements by calling the relevant method.
It is an internal method that can make backwards-incompatible changes.
Expand Down
72 changes: 42 additions & 30 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as transaction has not begun"
"This method is non-operational as transaction has not started"
)
MAX_INTERNAL_RETRIES = 50

Expand Down Expand Up @@ -125,7 +125,7 @@ def autocommit(self, value):
:type value: bool
:param value: New autocommit mode state.
"""
if value and not self._autocommit and self.spanner_transaction_started:
if value and not self._autocommit and self._spanner_transaction_started:
self.commit()

self._autocommit = value
Expand All @@ -140,27 +140,33 @@ def database(self):
return self._database

@property
def spanner_transaction_started(self):
"""Flag: whether transaction started at SpanFE. This means that we had
made atleast one call to SpanFE. Property client_transaction_started
def _spanner_transaction_started(self):
"""Flag: whether transaction started at Spanner. This means that we had
made atleast one call to Spanner. Property client_transaction_started
would always be true if this is true as transaction has to start first
at clientside than at Spanner (SpanFE)
at clientside than at Spanner
Returns:
bool: True if SpanFE transaction started, False otherwise.
bool: True if Spanner transaction started, False otherwise.
"""
return (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
)
) or (self._snapshot is not None)

@property
def inside_transaction(self):
"""Deprecated property which won't be supported in future versions.
Please use spanner_transaction_started property instead."""
return self._spanner_transaction_started

@property
def client_transaction_started(self):
def _client_transaction_started(self):
"""Flag: whether transaction started at client side.
Returns:
bool: True if transaction begun, False otherwise.
bool: True if transaction started, False otherwise.
"""
return (not self._autocommit) or self._transaction_begin_marked

Expand Down Expand Up @@ -190,7 +196,7 @@ def read_only(self, value):
Args:
value (bool): True for ReadOnly mode, False for ReadWrite.
"""
if self.spanner_transaction_started:
if self._spanner_transaction_started:
raise ValueError(
"Connection read/write mode can't be changed while a transaction is in progress. "
"Commit or rollback the current transaction and try again."
Expand Down Expand Up @@ -228,7 +234,7 @@ def staleness(self, value):
Args:
value (dict): Staleness type and value.
"""
if self.spanner_transaction_started:
if self._spanner_transaction_started:
raise ValueError(
"`staleness` option can't be changed while a transaction is in progress. "
"Commit or rollback the current transaction and try again."
Expand Down Expand Up @@ -346,13 +352,16 @@ def transaction_checkout(self):
"""Get a Cloud Spanner transaction.
Begin a new transaction, if there is no transaction in
this connection yet. Return the begun one otherwise.
this connection yet. Return the started one otherwise.
This method is a no-op if the connection is in autocommit mode and no
explicit transaction has been started
:rtype: :class:`google.cloud.spanner_v1.transaction.Transaction`
:returns: A Cloud Spanner transaction object, ready to use.
"""
if self.client_transaction_started:
if not self.spanner_transaction_started:
if not self.read_only and self._client_transaction_started:
if not self._spanner_transaction_started:
self._transaction = self._session_checkout().transaction()
self._transaction.begin()

Expand All @@ -367,7 +376,7 @@ def snapshot_checkout(self):
:rtype: :class:`google.cloud.spanner_v1.snapshot.Snapshot`
:returns: A Cloud Spanner snapshot object, ready to use.
"""
if self.read_only and self.client_transaction_started:
if self.read_only and self._client_transaction_started:
if not self._snapshot:
self._snapshot = Snapshot(
self._session_checkout(), multi_use=True, **self.staleness
Expand All @@ -382,7 +391,7 @@ def close(self):
The connection will be unusable from this point forward. If the
connection has an active transaction, it will be rolled back.
"""
if self.spanner_transaction_started:
if self._spanner_transaction_started and not self.read_only:
self._transaction.rollback()

if self._own_pool and self.database:
Expand All @@ -399,8 +408,8 @@ def begin(self):
:raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running
"""
if self._transaction_begin_marked:
raise OperationalError("A transaction has already begun")
if self.spanner_transaction_started:
raise OperationalError("A transaction has already started")
if self._spanner_transaction_started:
raise OperationalError(
"Beginning a new transaction is not allowed when a transaction is already running"
)
Expand All @@ -409,22 +418,23 @@ def begin(self):
def commit(self):
"""Commits any pending transaction to the database.
This method is non-operational in autocommit mode.
This is a no-op if there is no active client transaction.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
self._snapshot = None

if not self.client_transaction_started:
if not self._client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
return

self.run_prior_DDL_statements()
if self.spanner_transaction_started:
if self._spanner_transaction_started:
try:
if not self.read_only:
if self.read_only:
self._snapshot = None
else:
self._transaction.commit()

self._release_session()
Expand All @@ -437,17 +447,19 @@ def commit(self):
def rollback(self):
"""Rolls back any pending transaction.
This is a no-op if there is no active transaction or if the connection
is in autocommit mode.
This is a no-op if there is no active client transaction.
"""
self._snapshot = None

if not self.client_transaction_started:
if not self._client_transaction_started:
warnings.warn(
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
)
elif self._transaction:
if not self.read_only:
return

if self._spanner_transaction_started:
if self.read_only:
self._snapshot = None
else:
self._transaction.rollback()

self._release_session()
Expand Down
14 changes: 7 additions & 7 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def execute(self, sql, args=None):
)
if parsed_statement.statement_type == StatementType.DDL:
self._batch_DDLs(sql)
if not self.connection.client_transaction_started:
if not self.connection._client_transaction_started:
self.connection.run_prior_DDL_statements()
return

Expand All @@ -264,7 +264,7 @@ def execute(self, sql, args=None):

sql, args = sql_pyformat_args_to_spanner(sql, args or None)

if self.connection.client_transaction_started:
if self.connection._client_transaction_started:
statement = Statement(
sql,
args,
Expand Down Expand Up @@ -348,7 +348,7 @@ def executemany(self, operation, seq_of_params):
)
statements.append((sql, params, get_param_types(params)))

if not self.connection.client_transaction_started:
if not self.connection._client_transaction_started:
self.connection.database.run_in_transaction(
self._do_batch_update, statements, many_result_set
)
Expand Down Expand Up @@ -397,7 +397,7 @@ def fetchone(self):
try:
res = next(self)
if (
self.connection.client_transaction_started
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(res)
Expand All @@ -418,7 +418,7 @@ def fetchall(self):
try:
for row in self:
if (
self.connection.client_transaction_started
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(row)
Expand Down Expand Up @@ -450,7 +450,7 @@ def fetchmany(self, size=None):
try:
res = next(self)
if (
self.connection.client_transaction_started
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(res)
Expand Down Expand Up @@ -482,7 +482,7 @@ def _handle_DQL(self, sql, params):
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
if self.connection.read_only and self.connection.client_transaction_started:
if self.connection.read_only and self.connection._client_transaction_started:
# initiate or use the existing multi-use snapshot
self._handle_DQL_with_snapshot(
self.connection.snapshot_checkout(), sql, params
Expand Down
2 changes: 1 addition & 1 deletion tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def test_rowcount(self, dbapi_database, autocommit):
# execute with INSERT
self._cursor.execute(
"INSERT INTO Singers (SingerId, Name) VALUES (%s, %s), (%s, %s)",
[x for row in rows[98:] for x in row],
(x for row in rows[98:] for x in row),
)
assert self._cursor.rowcount == 2

Expand Down
9 changes: 4 additions & 5 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,21 +394,20 @@ def test_as_context_manager(self):
self.assertTrue(connection.is_closed)

def test_begin_cursor_closed(self):
connection = self._make_connection()
connection.close()
self._under_test.close()

with self.assertRaises(InterfaceError):
connection.begin()
self._under_test.begin()

self.assertEqual(connection._transaction_begin_marked, False)
self.assertEqual(self._under_test._transaction_begin_marked, False)

def test_begin_transaction_begin_marked(self):
self._under_test._transaction_begin_marked = True

with self.assertRaises(OperationalError):
self._under_test.begin()

def test_begin_inside_transaction(self):
def test_begin_transaction_started(self):
mock_transaction = mock.MagicMock()
mock_transaction.committed = mock_transaction.rolled_back = False
self._under_test._transaction = mock_transaction
Expand Down

0 comments on commit 792c0dc

Please sign in to comment.