Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementing client side statements in dbapi (starting with commit) #1037

Merged
merged 8 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
ClientSideStatementType,
)


def execute(connection, parsed_statement: ParsedStatement):
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
"""Executes the client side statements by calling the relevant method.

It is an internal method that can make backwards-incompatible changes.

:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
return connection.commit()
42 changes: 42 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2023 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
ClientSideStatementType,
)

RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
ankiaga marked this conversation as resolved.
Show resolved Hide resolved


def parse_stmt(query):
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
"""Parses the sql query to check if it matches with any of the client side
statement regex.

It is an internal method that can make backwards-incompatible changes.

:type query: str
:param query: sql query

:rtype: ParsedStatement
:returns: ParsedStatement object.
"""
if RE_COMMIT.match(query):
return ParsedStatement(
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
)
return None
36 changes: 27 additions & 9 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
from google.cloud.spanner_dbapi.exceptions import OperationalError
from google.cloud.spanner_dbapi.exceptions import ProgrammingError

from google.cloud.spanner_dbapi import _helpers
from google.cloud.spanner_dbapi import _helpers, client_side_statement_executor
from google.cloud.spanner_dbapi._helpers import ColumnInfo
from google.cloud.spanner_dbapi._helpers import CODE_TO_DISPLAY_SIZE

from google.cloud.spanner_dbapi import parse_utils
from google.cloud.spanner_dbapi.parse_utils import get_param_types
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_dbapi.parsed_statement import StatementType
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

Expand Down Expand Up @@ -210,7 +211,10 @@ def _batch_DDLs(self, sql):
for ddl in sqlparse.split(sql):
if ddl:
ddl = ddl.rstrip(";")
if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL:
if (
parse_utils.classify_statement(ddl).statement_type
!= StatementType.DDL
):
raise ValueError("Only DDL statements may be batched.")

statements.append(ddl)
Expand Down Expand Up @@ -239,8 +243,12 @@ def execute(self, sql, args=None):
self._handle_DQL(sql, args or None)
return

class_ = parse_utils.classify_stmt(sql)
if class_ == parse_utils.STMT_DDL:
parsed_statement = parse_utils.classify_statement(sql)
if parsed_statement.statement_type == StatementType.CLIENT_SIDE:
return client_side_statement_executor.execute(
self.connection, parsed_statement
)
if parsed_statement.statement_type == StatementType.DDL:
self._batch_DDLs(sql)
if self.connection.autocommit:
self.connection.run_prior_DDL_statements()
Expand All @@ -251,7 +259,7 @@ def execute(self, sql, args=None):
# self._run_prior_DDL_statements()
self.connection.run_prior_DDL_statements()

if class_ == parse_utils.STMT_UPDATING:
if parsed_statement.statement_type == StatementType.UPDATE:
sql = parse_utils.ensure_where_clause(sql)

sql, args = sql_pyformat_args_to_spanner(sql, args or None)
Expand All @@ -276,7 +284,7 @@ def execute(self, sql, args=None):
self.connection.retry_transaction()
return

if class_ == parse_utils.STMT_NON_UPDATING:
if parsed_statement.statement_type == StatementType.QUERY:
self._handle_DQL(sql, args or None)
else:
self.connection.database.run_in_transaction(
Expand Down Expand Up @@ -309,19 +317,29 @@ def executemany(self, operation, seq_of_params):
self._result_set = None
self._row_count = _UNSET_COUNT

class_ = parse_utils.classify_stmt(operation)
if class_ == parse_utils.STMT_DDL:
parsed_statement = parse_utils.classify_statement(operation)
if parsed_statement.statement_type == StatementType.DDL:
raise ProgrammingError(
"Executing DDL statements with executemany() method is not allowed."
)

if parsed_statement.statement_type == StatementType.CLIENT_SIDE:
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
raise ProgrammingError(
"Executing the following operation: "
+ operation
+ ", with executemany() method is not allowed."
)

# For every operation, we've got to ensure that any prior DDL
# statements were run.
self.connection.run_prior_DDL_statements()

many_result_set = StreamedManyResultSets()

if class_ in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING):
if parsed_statement.statement_type in (
StatementType.INSERT,
StatementType.UPDATE,
):
statements = []

for params in seq_of_params:
Expand Down
39 changes: 37 additions & 2 deletions google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import sqlparse
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_v1 import JsonObject
from . import client_side_statement_parser
from deprecated import deprecated

from .exceptions import Error
from .parsed_statement import ParsedStatement, StatementType
from .types import DateStr, TimestampStr
from .utils import sanitize_literals_for_upload

Expand Down Expand Up @@ -174,12 +177,11 @@
RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL)


@deprecated(reason="This method is deprecated. Use _classify_stmt method")
Copy link
Contributor

@asottile-sentry asottile-sentry Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this pulls in two new dependencies (deprecated, wrapt) into this library -- this can be replaced entirely with this inside the function body:

warnings.warn("This method is deprecated.  Use classify_statement method instead", stacklevel=2)

this uses warnings directly avoiding the two new dependencies (including one which has a C extension!)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would one be open to a PR to do this instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ankiaga Would you mind taking a look what would be the best solution here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

threw one together: #1120

def classify_stmt(query):
"""Determine SQL query type.

:type query: str
:param query: A SQL query.

:rtype: str
:returns: The query type name.
"""
Expand All @@ -203,6 +205,39 @@ def classify_stmt(query):
return STMT_UPDATING


def classify_statement(query):
"""Determine SQL query type.

It is an internal method that can make backwards-incompatible changes.

:type query: str
:param query: A SQL query.

:rtype: ParsedStatement
:returns: parsed statement attributes.
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
"""
# sqlparse will strip Cloud Spanner comments,
# still, special commenting styles, like
# PostgreSQL dollar quoted comments are not
# supported and will not be stripped.
query = sqlparse.format(query, strip_comments=True).strip()
parsed_statement = client_side_statement_parser.parse_stmt(query)
if parsed_statement is not None:
return parsed_statement
if RE_DDL.match(query):
return ParsedStatement(StatementType.DDL, query)

if RE_IS_INSERT.match(query):
return ParsedStatement(StatementType.INSERT, query)

if RE_NON_UPDATE.match(query) or RE_WITH.match(query):
# As of 13-March-2020, Cloud Spanner only supports WITH for DQL
# statements and doesn't yet support WITH for DML statements.
return ParsedStatement(StatementType.QUERY, query)

return ParsedStatement(StatementType.UPDATE, query)


def sql_pyformat_args_to_spanner(sql, params):
"""
Transform pyformat set SQL to named arguments for Cloud Spanner.
Expand Down
36 changes: 36 additions & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 20203 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from enum import Enum


class StatementType(Enum):
CLIENT_SIDE = 1
DDL = 2
QUERY = 3
UPDATE = 4
INSERT = 5


class ClientSideStatementType(Enum):
COMMIT = 1
BEGIN = 2


@dataclass
class ParsedStatement:
statement_type: StatementType
query: str
client_side_statement_type: ClientSideStatementType = None
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"proto-plus >= 1.22.0, <2.0.0dev",
"sqlparse >= 0.4.4",
"protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5",
"deprecated >= 1.2.14",
]
extras = {
"tracing": [
Expand Down
79 changes: 52 additions & 27 deletions tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from google.cloud import spanner_v1
from google.cloud._helpers import UTC

from google.cloud.spanner_dbapi import Cursor
from google.cloud.spanner_dbapi.connection import connect
from google.cloud.spanner_dbapi.connection import Connection
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
Expand Down Expand Up @@ -72,37 +74,11 @@ def dbapi_database(raw_database):

def test_commit(shared_instance, dbapi_database):
"""Test committing a transaction with several statements."""
want_row = (
1,
"updated-first-name",
"last-name",
"[email protected]",
)
# connect to the test database
conn = Connection(shared_instance, dbapi_database)
cursor = conn.cursor()

# execute several DML statements within one transaction
cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', '[email protected]')
"""
)
cursor.execute(
"""
UPDATE contacts
SET first_name = 'updated-first-name'
WHERE first_name = 'first-name'
"""
)
cursor.execute(
"""
UPDATE contacts
SET email = '[email protected]'
WHERE email = '[email protected]'
"""
)
want_row = _execute_common_precommit_statements(cursor)
conn.commit()

# read the resulting data from the database
Expand All @@ -116,6 +92,25 @@ def test_commit(shared_instance, dbapi_database):
conn.close()


def test_commit_client_side(shared_instance, dbapi_database):
"""Test committing a transaction with several statements."""
# connect to the test database
conn = Connection(shared_instance, dbapi_database)
cursor = conn.cursor()

want_row = _execute_common_precommit_statements(cursor)
cursor.execute("""COMMIT""")

# read the resulting data from the database
cursor.execute("SELECT * FROM contacts")
got_rows = cursor.fetchall()
conn.commit()
cursor.close()
conn.close()

assert got_rows == [want_row]


def test_rollback(shared_instance, dbapi_database):
"""Test rollbacking a transaction with several statements."""
want_row = (2, "first-name", "last-name", "[email protected]")
Expand Down Expand Up @@ -810,3 +805,33 @@ def test_dml_returning_delete(shared_instance, dbapi_database, autocommit):
assert cur.fetchone() == (1, "first-name")
assert cur.rowcount == 1
conn.commit()


def _execute_common_precommit_statements(cursor: Cursor):
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
# execute several DML statements within one transaction
cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
VALUES (1, 'first-name', 'last-name', '[email protected]')
"""
)
cursor.execute(
"""
UPDATE contacts
SET first_name = 'updated-first-name'
WHERE first_name = 'first-name'
"""
)
cursor.execute(
"""
UPDATE contacts
SET email = '[email protected]'
WHERE email = '[email protected]'
"""
)
return (
1,
"updated-first-name",
"last-name",
"[email protected]",
)
Loading