Skip to content

Commit

Permalink
feat: Implementing client side statements in dbapi (starting with com…
Browse files Browse the repository at this point in the history
…mit) (#1037)

* Implementing client side statement in dbapi starting with commit

* Fixing comments

* Adding dependency on "deprecated" package

* Fix in setup.py

* Fixing tests

* Lint issue fix

* Resolving comments

* Fixing formatting issue
  • Loading branch information
ankiaga authored Nov 23, 2023
1 parent 07fbc45 commit eb41b0d
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 91 deletions.
29 changes: 29 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
ClientSideStatementType,
)


def execute(connection, parsed_statement: ParsedStatement):
"""Executes the client side statements by calling the relevant method.
It is an internal method that can make backwards-incompatible changes.
:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
return connection.commit()
42 changes: 42 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
ClientSideStatementType,
)

RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)


def parse_stmt(query):
"""Parses the sql query to check if it matches with any of the client side
statement regex.
It is an internal method that can make backwards-incompatible changes.
:type query: str
:param query: sql query
:rtype: ParsedStatement
:returns: ParsedStatement object.
"""
if RE_COMMIT.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
)
return None
36 changes: 27 additions & 9 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
from google.cloud.spanner_dbapi.exceptions import OperationalError
from google.cloud.spanner_dbapi.exceptions import ProgrammingError

from google.cloud.spanner_dbapi import _helpers
from google.cloud.spanner_dbapi import _helpers, client_side_statement_executor
from google.cloud.spanner_dbapi._helpers import ColumnInfo
from google.cloud.spanner_dbapi._helpers import CODE_TO_DISPLAY_SIZE

from google.cloud.spanner_dbapi import parse_utils
from google.cloud.spanner_dbapi.parse_utils import get_param_types
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_dbapi.parsed_statement import StatementType
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

Expand Down Expand Up @@ -210,7 +211,10 @@ def _batch_DDLs(self, sql):
for ddl in sqlparse.split(sql):
if ddl:
ddl = ddl.rstrip(";")
if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL:
if (
parse_utils.classify_statement(ddl).statement_type
!= StatementType.DDL
):
raise ValueError("Only DDL statements may be batched.")

statements.append(ddl)
Expand Down Expand Up @@ -239,8 +243,12 @@ def execute(self, sql, args=None):
self._handle_DQL(sql, args or None)
return

class_ = parse_utils.classify_stmt(sql)
if class_ == parse_utils.STMT_DDL:
parsed_statement = parse_utils.classify_statement(sql)
if parsed_statement.statement_type == StatementType.CLIENT_SIDE:
return client_side_statement_executor.execute(
self.connection, parsed_statement
)
if parsed_statement.statement_type == StatementType.DDL:
self._batch_DDLs(sql)
if self.connection.autocommit:
self.connection.run_prior_DDL_statements()
Expand All @@ -251,7 +259,7 @@ def execute(self, sql, args=None):
# self._run_prior_DDL_statements()
self.connection.run_prior_DDL_statements()

if class_ == parse_utils.STMT_UPDATING:
if parsed_statement.statement_type == StatementType.UPDATE:
sql = parse_utils.ensure_where_clause(sql)

sql, args = sql_pyformat_args_to_spanner(sql, args or None)
Expand All @@ -276,7 +284,7 @@ def execute(self, sql, args=None):
self.connection.retry_transaction()
return

if class_ == parse_utils.STMT_NON_UPDATING:
if parsed_statement.statement_type == StatementType.QUERY:
self._handle_DQL(sql, args or None)
else:
self.connection.database.run_in_transaction(
Expand Down Expand Up @@ -309,19 +317,29 @@ def executemany(self, operation, seq_of_params):
self._result_set = None
self._row_count = _UNSET_COUNT

class_ = parse_utils.classify_stmt(operation)
if class_ == parse_utils.STMT_DDL:
parsed_statement = parse_utils.classify_statement(operation)
if parsed_statement.statement_type == StatementType.DDL:
raise ProgrammingError(
"Executing DDL statements with executemany() method is not allowed."
)

if parsed_statement.statement_type == StatementType.CLIENT_SIDE:
raise ProgrammingError(
"Executing the following operation: "
+ operation
+ ", with executemany() method is not allowed."
)

# For every operation, we've got to ensure that any prior DDL
# statements were run.
self.connection.run_prior_DDL_statements()

many_result_set = StreamedManyResultSets()

if class_ in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING):
if parsed_statement.statement_type in (
StatementType.INSERT,
StatementType.UPDATE,
):
statements = []

for params in seq_of_params:
Expand Down
39 changes: 37 additions & 2 deletions google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import sqlparse
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_v1 import JsonObject
from . import client_side_statement_parser
from deprecated import deprecated

from .exceptions import Error
from .parsed_statement import ParsedStatement, StatementType
from .types import DateStr, TimestampStr
from .utils import sanitize_literals_for_upload

Expand Down Expand Up @@ -174,12 +177,11 @@
RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL)


@deprecated(reason="This method is deprecated. Use _classify_stmt method")
def classify_stmt(query):
"""Determine SQL query type.
:type query: str
:param query: A SQL query.
:rtype: str
:returns: The query type name.
"""
Expand All @@ -203,6 +205,39 @@ def classify_stmt(query):
return STMT_UPDATING


def classify_statement(query):
"""Determine SQL query type.
It is an internal method that can make backwards-incompatible changes.
:type query: str
:param query: A SQL query.
:rtype: ParsedStatement
:returns: parsed statement attributes.
"""
# sqlparse will strip Cloud Spanner comments,
# still, special commenting styles, like
# PostgreSQL dollar quoted comments are not
# supported and will not be stripped.
query = sqlparse.format(query, strip_comments=True).strip()
parsed_statement = client_side_statement_parser.parse_stmt(query)
if parsed_statement is not None:
return parsed_statement
if RE_DDL.match(query):
return ParsedStatement(StatementType.DDL, query)

if RE_IS_INSERT.match(query):
return ParsedStatement(StatementType.INSERT, query)

if RE_NON_UPDATE.match(query) or RE_WITH.match(query):
# As of 13-March-2020, Cloud Spanner only supports WITH for DQL
# statements and doesn't yet support WITH for DML statements.
return ParsedStatement(StatementType.QUERY, query)

return ParsedStatement(StatementType.UPDATE, query)


def sql_pyformat_args_to_spanner(sql, params):
"""
Transform pyformat set SQL to named arguments for Cloud Spanner.
Expand Down
36 changes: 36 additions & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 20203 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from enum import Enum


class StatementType(Enum):
CLIENT_SIDE = 1
DDL = 2
QUERY = 3
UPDATE = 4
INSERT = 5


class ClientSideStatementType(Enum):
COMMIT = 1
BEGIN = 2


@dataclass
class ParsedStatement:
statement_type: StatementType
query: str
client_side_statement_type: ClientSideStatementType = None
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"proto-plus >= 1.22.0, <2.0.0dev",
"sqlparse >= 0.4.4",
"protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5",
"deprecated >= 1.2.14",
]
extras = {
"tracing": [
Expand Down
79 changes: 52 additions & 27 deletions tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from google.cloud import spanner_v1
from google.cloud._helpers import UTC

from google.cloud.spanner_dbapi import Cursor
from google.cloud.spanner_dbapi.connection import connect
from google.cloud.spanner_dbapi.connection import Connection
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
Expand Down Expand Up @@ -72,37 +74,11 @@ def dbapi_database(raw_database):

def test_commit(shared_instance, dbapi_database):
"""Test committing a transaction with several statements."""
want_row = (
1,
"updated-first-name",
"last-name",
"[email protected]",
)
# connect to the test database
conn = Connection(shared_instance, dbapi_database)
cursor = conn.cursor()

# execute several DML statements within one transaction
cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', '[email protected]')
"""
)
cursor.execute(
"""
UPDATE contacts
SET first_name = 'updated-first-name'
WHERE first_name = 'first-name'
"""
)
cursor.execute(
"""
UPDATE contacts
SET email = '[email protected]'
WHERE email = '[email protected]'
"""
)
want_row = _execute_common_precommit_statements(cursor)
conn.commit()

# read the resulting data from the database
Expand All @@ -116,6 +92,25 @@ def test_commit(shared_instance, dbapi_database):
conn.close()


def test_commit_client_side(shared_instance, dbapi_database):
"""Test committing a transaction with several statements."""
# connect to the test database
conn = Connection(shared_instance, dbapi_database)
cursor = conn.cursor()

want_row = _execute_common_precommit_statements(cursor)
cursor.execute("""COMMIT""")

# read the resulting data from the database
cursor.execute("SELECT * FROM contacts")
got_rows = cursor.fetchall()
conn.commit()
cursor.close()
conn.close()

assert got_rows == [want_row]


def test_rollback(shared_instance, dbapi_database):
"""Test rollbacking a transaction with several statements."""
want_row = (2, "first-name", "last-name", "[email protected]")
Expand Down Expand Up @@ -810,3 +805,33 @@ def test_dml_returning_delete(shared_instance, dbapi_database, autocommit):
assert cur.fetchone() == (1, "first-name")
assert cur.rowcount == 1
conn.commit()


def _execute_common_precommit_statements(cursor: Cursor):
# execute several DML statements within one transaction
cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', '[email protected]')
"""
)
cursor.execute(
"""
UPDATE contacts
SET first_name = 'updated-first-name'
WHERE first_name = 'first-name'
"""
)
cursor.execute(
"""
UPDATE contacts
SET email = '[email protected]'
WHERE email = '[email protected]'
"""
)
return (
1,
"updated-first-name",
"last-name",
"[email protected]",
)
Loading

0 comments on commit eb41b0d

Please sign in to comment.