From b5b0c98354954326e6dce8b9ddf2c4d7ac009217 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Tue, 2 Jan 2024 18:31:20 +0530 Subject: [PATCH 1/6] feat: Implementation for partitioned query in dbapi --- .../client_side_statement_executor.py | 27 ++++++++-- .../client_side_statement_parser.py | 21 +++++++- google/cloud/spanner_dbapi/connection.py | 54 ++++++++++++++++++- google/cloud/spanner_dbapi/parse_utils.py | 16 +++--- .../cloud/spanner_dbapi/parsed_statement.py | 8 +++ .../cloud/spanner_dbapi/partition_helper.py | 32 +++++++++++ google/cloud/spanner_v1/database.py | 52 ++++++++++++++++-- google/cloud/spanner_v1/snapshot.py | 2 + tests/system/test_dbapi.py | 18 +++++++ tests/unit/spanner_dbapi/test_parse_utils.py | 39 ++++++++++++-- tests/unit/test_database.py | 15 ++++-- 11 files changed, 260 insertions(+), 24 deletions(-) create mode 100644 google/cloud/spanner_dbapi/partition_helper.py diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py index 06d0d25948..1246c1bdd0 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_executor.py +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -20,6 +20,7 @@ from google.cloud.spanner_dbapi.parsed_statement import ( ParsedStatement, ClientSideStatementType, + ClientSideStatementParamKey, ) from google.cloud.spanner_v1 import ( Type, @@ -66,7 +67,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): if connection._transaction is None: committed_timestamp = None else: - committed_timestamp = connection._transaction.committed + committed_timestamp = list(connection._transaction.committed) return _get_streamed_result_set( ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name, TypeCode.TIMESTAMP, @@ -76,7 +77,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): if connection._snapshot is None: read_timestamp = None else: - read_timestamp = connection._snapshot._transaction_read_timestamp + read_timestamp = list(connection._snapshot._transaction_read_timestamp) return _get_streamed_result_set( ClientSideStatementType.SHOW_READ_TIMESTAMP.name, TypeCode.TIMESTAMP, @@ -89,14 +90,30 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): return connection.run_batch() if statement_type == ClientSideStatementType.ABORT_BATCH: return connection.abort_batch() + if statement_type == ClientSideStatementType.PARTITION_QUERY: + partition_ids = connection.partition_query(parsed_statement) + return _get_streamed_result_set( + "PARTITION", + TypeCode.STRING, + partition_ids, + ) + if statement_type == ClientSideStatementType.RUN_PARTITION: + return connection.run_partition( + parsed_statement.client_side_statement_params[ + ClientSideStatementParamKey.PARTITION_ID + ] + ) -def _get_streamed_result_set(column_name, type_code, column_value): +def _get_streamed_result_set(column_name, type_code, column_values): struct_type_pb = StructType( fields=[StructType.Field(name=column_name, type_=Type(code=type_code))] ) result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb)) - if column_value is not None: - result_set.values.extend([_make_value_pb(column_value)]) + column_values_pb = [] + if column_values is not None: + for column_value in column_values: + column_values_pb.append(_make_value_pb(column_value)) + result_set.values.extend(column_values_pb) return StreamedResultSet(iter([result_set])) diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py index 39970259b2..85dbca4eb4 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_parser.py +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -19,6 +19,7 @@ StatementType, ClientSideStatementType, Statement, + ClientSideStatementParamKey, ) RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE) @@ -33,6 +34,8 @@ RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE) RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE) RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE) +RE_PARTITION_QUERY = re.compile(r"^\s*(PARTITION)\s+(.+)", re.IGNORECASE) +RE_RUN_PARTITION = re.compile(r"^\s*(RUN)\s+(PARTITION)\s+(.+)", re.IGNORECASE) def parse_stmt(query): @@ -48,6 +51,7 @@ def parse_stmt(query): :returns: ParsedStatement object. """ client_side_statement_type = None + client_side_statement_params = {} if RE_COMMIT.match(query): client_side_statement_type = ClientSideStatementType.COMMIT if RE_BEGIN.match(query): @@ -64,8 +68,23 @@ def parse_stmt(query): client_side_statement_type = ClientSideStatementType.RUN_BATCH if RE_ABORT_BATCH.match(query): client_side_statement_type = ClientSideStatementType.ABORT_BATCH + if RE_PARTITION_QUERY.match(query): + match = re.search(RE_PARTITION_QUERY, query) + client_side_statement_params[ + ClientSideStatementParamKey.PARTITIONED_SQL_QUERY + ] = match.group(2) + client_side_statement_type = ClientSideStatementType.PARTITION_QUERY + if RE_RUN_PARTITION.match(query): + match = re.search(RE_RUN_PARTITION, query) + client_side_statement_params[ + ClientSideStatementParamKey.PARTITION_ID + ] = match.group(3) + client_side_statement_type = ClientSideStatementType.RUN_PARTITION if client_side_statement_type is not None: return ParsedStatement( - StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type + StatementType.CLIENT_SIDE, + Statement(query), + client_side_statement_type, + client_side_statement_params, ) return None diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index e635563587..17daf51e84 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -19,8 +19,16 @@ from google.api_core.exceptions import Aborted from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner +from google.cloud.spanner_dbapi import partition_helper from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor -from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement +from google.cloud.spanner_dbapi.parse_utils import _get_statement_type +from google.cloud.spanner_dbapi.parsed_statement import ( + ParsedStatement, + Statement, + StatementType, + ClientSideStatementParamKey, +) +from google.cloud.spanner_dbapi.partition_helper import PartitionId from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.session import _get_retry_delay from google.cloud.spanner_v1.snapshot import Snapshot @@ -585,6 +593,50 @@ def abort_batch(self): self._batch_dml_executor = None self._batch_mode = BatchMode.NONE + @check_not_closed + def partition_query( + self, + parsed_statement: ParsedStatement, + query_options=None, + ): + statement = parsed_statement.statement + partitioned_query = parsed_statement.client_side_statement_params[ + ClientSideStatementParamKey.PARTITIONED_SQL_QUERY + ] + if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY: + raise ProgrammingError( + "Only queries can be partitioned. Invalid statement: " + statement.sql + ) + batch_snapshot = self._database.batch_snapshot() + partition_ids = [] + partitions = list( + batch_snapshot.generate_query_batches( + partitioned_query, + statement.params, + statement.param_types, + query_options=query_options, + ) + ) + for partition in partitions: + batch_transaction_id = batch_snapshot.get_batch_transaction_id() + partition_ids.append( + partition_helper.encode_to_string(batch_transaction_id, partition) + ) + return partition_ids + + @check_not_closed + def run_partition(self, batch_transaction_id): + partition_id: PartitionId = partition_helper.decode_from_string( + batch_transaction_id + ) + batch_transaction_id = partition_id.batch_transaction_id + batch_snapshot = self._database.batch_snapshot( + read_timestamp=batch_transaction_id.read_timestamp, + session_id=batch_transaction_id.session_id, + transaction_id=batch_transaction_id.transaction_id, + ) + return batch_snapshot.process(partition_id.batch_result) + def __enter__(self): return self diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 76ac951e0c..008f21bf93 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -232,19 +232,23 @@ def classify_statement(query, args=None): get_param_types(args or None), ResultsChecksum(), ) - if RE_DDL.match(query): - return ParsedStatement(StatementType.DDL, statement) + statement_type = _get_statement_type(statement) + return ParsedStatement(statement_type, statement) - if RE_IS_INSERT.match(query): - return ParsedStatement(StatementType.INSERT, statement) +def _get_statement_type(statement): + query = statement.sql + if RE_DDL.match(query): + return StatementType.DDL + if RE_IS_INSERT.match(query): + return StatementType.INSERT if RE_NON_UPDATE.match(query) or RE_WITH.match(query): # As of 13-March-2020, Cloud Spanner only supports WITH for DQL # statements and doesn't yet support WITH for DML statements. - return ParsedStatement(StatementType.QUERY, statement) + return StatementType.QUERY statement.sql = ensure_where_clause(query) - return ParsedStatement(StatementType.UPDATE, statement) + return StatementType.UPDATE def sql_pyformat_args_to_spanner(sql, params): diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py index 4f633c7b10..2d62a5306a 100644 --- a/google/cloud/spanner_dbapi/parsed_statement.py +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -35,6 +35,13 @@ class ClientSideStatementType(Enum): START_BATCH_DML = 6 RUN_BATCH = 7 ABORT_BATCH = 8 + PARTITION_QUERY = 9 + RUN_PARTITION = 10 + + +class ClientSideStatementParamKey(Enum): + PARTITIONED_SQL_QUERY = 1 + PARTITION_ID = 2 @dataclass @@ -53,3 +60,4 @@ class ParsedStatement: statement_type: StatementType statement: Statement client_side_statement_type: ClientSideStatementType = None + client_side_statement_params: dict[ClientSideStatementParamKey, Any] = None diff --git a/google/cloud/spanner_dbapi/partition_helper.py b/google/cloud/spanner_dbapi/partition_helper.py new file mode 100644 index 0000000000..f7ac4e2db6 --- /dev/null +++ b/google/cloud/spanner_dbapi/partition_helper.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import Any + +import gzip +import pickle +import base64 + + +def decode_from_string(encoded_partition_id): + gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8")) + partition_id_bytes = gzip.decompress(gzip_bytes) + return pickle.loads(partition_id_bytes) + + +def encode_to_string(batch_transaction_id, batch_result): + partition_id = PartitionId(batch_transaction_id, batch_result) + partition_id_bytes = pickle.dumps(partition_id) + gzip_bytes = gzip.compress(partition_id_bytes) + return str(base64.b64encode(gzip_bytes), "utf-8") + + +@dataclass +class BatchTransactionId: + transaction_id: str + session_id: str + read_timestamp: Any + + +@dataclass +class PartitionId: + batch_transaction_id: BatchTransactionId + batch_result: Any diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 758547cf86..e148921e40 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -16,6 +16,7 @@ import copy import functools + import grpc import logging import re @@ -39,6 +40,7 @@ from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest from google.cloud.spanner_admin_database_v1.types import DatabaseDialect +from google.cloud.spanner_dbapi.partition_helper import BatchTransactionId from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import TransactionSelector from google.cloud.spanner_v1 import TransactionOptions @@ -746,7 +748,13 @@ def mutation_groups(self): """ return MutationGroupsCheckout(self) - def batch_snapshot(self, read_timestamp=None, exact_staleness=None): + def batch_snapshot( + self, + read_timestamp=None, + exact_staleness=None, + session_id=None, + transaction_id=None, + ): """Return an object which wraps a batch read / query. :type read_timestamp: :class:`datetime.datetime` @@ -756,11 +764,21 @@ def batch_snapshot(self, read_timestamp=None, exact_staleness=None): :param exact_staleness: Execute all reads at a timestamp that is ``exact_staleness`` old. + :type session_id: str + :param session_id: id of the session used in transaction + + :type transaction_id: str + :param transaction_id: id of the transaction + :rtype: :class:`~google.cloud.spanner_v1.database.BatchSnapshot` :returns: new wrapper """ return BatchSnapshot( - self, read_timestamp=read_timestamp, exact_staleness=exact_staleness + self, + read_timestamp=read_timestamp, + exact_staleness=exact_staleness, + session_id=session_id, + transaction_id=transaction_id, ) def run_in_transaction(self, func, *args, **kw): @@ -1138,10 +1156,19 @@ class BatchSnapshot(object): ``exact_staleness`` old. """ - def __init__(self, database, read_timestamp=None, exact_staleness=None): + def __init__( + self, + database, + read_timestamp=None, + exact_staleness=None, + session_id=None, + transaction_id=None, + ): self._database = database + self._session_id = session_id self._session = None self._snapshot = None + self._transaction_id = transaction_id self._read_timestamp = read_timestamp self._exact_staleness = exact_staleness @@ -1189,7 +1216,10 @@ def _get_session(self): """ if self._session is None: session = self._session = self._database.session() - session.create() + if self._session_id is None: + session.create() + else: + session._session_id = self._session_id return self._session def _get_snapshot(self): @@ -1199,10 +1229,22 @@ def _get_snapshot(self): read_timestamp=self._read_timestamp, exact_staleness=self._exact_staleness, multi_use=True, + transaction_id=self._transaction_id, ) - self._snapshot.begin() + if self._transaction_id is None: + self._snapshot.begin() return self._snapshot + def get_batch_transaction_id(self): + snapshot = self._snapshot + if snapshot is None: + raise ValueError("Read-only transaction not begun") + return BatchTransactionId( + snapshot._transaction_id, + snapshot._session.session_id, + snapshot._read_timestamp, + ) + def read(self, *args, **kw): """Convenience method: perform read operation via snapshot. diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 1e515bd8e6..642413ff1f 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -712,6 +712,7 @@ def __init__( max_staleness=None, exact_staleness=None, multi_use=False, + transaction_id=None, ): super(Snapshot, self).__init__(session) opts = [read_timestamp, min_read_timestamp, max_staleness, exact_staleness] @@ -734,6 +735,7 @@ def __init__( self._max_staleness = max_staleness self._exact_staleness = exact_staleness self._multi_use = multi_use + self._transaction_id = transaction_id def _make_txn_selector(self): """Helper for :meth:`read`.""" diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index fdea0b0d17..82679e049e 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -536,6 +536,24 @@ def test_batch_dml_invalid_statements(self): with pytest.raises(OperationalError): self._cursor.execute("run batch") + def test_partitioned_query(self): + self._cursor.execute("start batch dml") + for i in range(1, 11): + self._insert_row(i) + self._cursor.execute("run batch") + self._conn.commit() + + self._conn.read_only = True + self._cursor.execute("PARTITION SELECT * FROM contacts") + partition_id_rows = self._cursor.fetchall() + assert len(partition_id_rows) == 1 + + partition_id_row = partition_id_rows[0] + self._cursor.execute("RUN PARTITION " + partition_id_row[0]) + rows = self._cursor.fetchall() + assert len(rows) == 10 + self._conn.commit() + def _insert_row(self, i): self._cursor.execute( f""" diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 7f179d6d31..1e9ebc4ab9 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -15,9 +15,16 @@ import sys import unittest -from google.cloud.spanner_dbapi.parsed_statement import StatementType +from google.cloud.spanner_dbapi.parsed_statement import ( + StatementType, + ParsedStatement, + Statement, + ClientSideStatementType, + ClientSideStatementParamKey, +) from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1 import JsonObject +from google.cloud.spanner_dbapi.parse_utils import classify_statement class TestParseUtils(unittest.TestCase): @@ -25,8 +32,6 @@ class TestParseUtils(unittest.TestCase): skip_message = "Subtests are not supported in Python 2" def test_classify_stmt(self): - from google.cloud.spanner_dbapi.parse_utils import classify_statement - cases = ( ("SELECT 1", StatementType.QUERY), ("SELECT s.SongName FROM Songs AS s", StatementType.QUERY), @@ -71,6 +76,34 @@ def test_classify_stmt(self): for query, want_class in cases: self.assertEqual(classify_statement(query).statement_type, want_class) + def test_partition_query_classify_stmt(self): + parsed_statement = classify_statement( + " PARTITION SELECT s.SongName FROM Songs AS s " + ) + self.assertEqual( + parsed_statement, + ParsedStatement( + StatementType.CLIENT_SIDE, + Statement("PARTITION SELECT s.SongName FROM Songs AS s"), + ClientSideStatementType.PARTITION_QUERY, + { + ClientSideStatementParamKey.PARTITIONED_SQL_QUERY: "SELECT s.SongName FROM Songs AS s" + }, + ), + ) + + def test_run_partition_classify_stmt(self): + parsed_statement = classify_statement(" RUN PARTITION bj2bjb2j2bj2ebbh ") + self.assertEqual( + parsed_statement, + ParsedStatement( + StatementType.CLIENT_SIDE, + Statement("RUN PARTITION bj2bjb2j2bj2ebbh"), + ClientSideStatementType.RUN_PARTITION, + {ClientSideStatementParamKey.PARTITION_ID: "bj2bjb2j2bj2ebbh"}, + ), + ) + @unittest.skipIf(skip_condition, skip_message) def test_sql_pyformat_args_to_spanner(self): from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index cac45a26ac..87c6d8602d 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -2117,7 +2117,10 @@ def test__get_snapshot_new_wo_staleness(self): snapshot = session.snapshot.return_value = self._make_snapshot() self.assertIs(batch_txn._get_snapshot(), snapshot) session.snapshot.assert_called_once_with( - read_timestamp=None, exact_staleness=None, multi_use=True + read_timestamp=None, + exact_staleness=None, + multi_use=True, + transaction_id=None, ) snapshot.begin.assert_called_once_with() @@ -2129,7 +2132,10 @@ def test__get_snapshot_w_read_timestamp(self): snapshot = session.snapshot.return_value = self._make_snapshot() self.assertIs(batch_txn._get_snapshot(), snapshot) session.snapshot.assert_called_once_with( - read_timestamp=timestamp, exact_staleness=None, multi_use=True + read_timestamp=timestamp, + exact_staleness=None, + multi_use=True, + transaction_id=None, ) snapshot.begin.assert_called_once_with() @@ -2141,7 +2147,10 @@ def test__get_snapshot_w_exact_staleness(self): snapshot = session.snapshot.return_value = self._make_snapshot() self.assertIs(batch_txn._get_snapshot(), snapshot) session.snapshot.assert_called_once_with( - read_timestamp=None, exact_staleness=duration, multi_use=True + read_timestamp=None, + exact_staleness=duration, + multi_use=True, + transaction_id=None, ) snapshot.begin.assert_called_once_with() From f0e47aa43f45931b44616dd39a77468015bcfdfb Mon Sep 17 00:00:00 2001 From: ankiaga Date: Wed, 3 Jan 2024 11:54:07 +0530 Subject: [PATCH 2/6] Comments incorporated and added more tests --- google/cloud/spanner_dbapi/connection.py | 8 ++- .../cloud/spanner_dbapi/parsed_statement.py | 6 +- .../cloud/spanner_dbapi/partition_helper.py | 20 +++++- tests/system/test_dbapi.py | 63 +++++++++++++++++++ 4 files changed, 90 insertions(+), 7 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 17daf51e84..7b1966244e 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -607,6 +607,12 @@ def partition_query( raise ProgrammingError( "Only queries can be partitioned. Invalid statement: " + statement.sql ) + if self.read_only is not True and self._client_transaction_started is True: + raise ProgrammingError( + "Partitioned query not supported as the connection is not in " + "read only mode or ReadWrite transaction started" + ) + batch_snapshot = self._database.batch_snapshot() partition_ids = [] partitions = list( @@ -635,7 +641,7 @@ def run_partition(self, batch_transaction_id): session_id=batch_transaction_id.session_id, transaction_id=batch_transaction_id.transaction_id, ) - return batch_snapshot.process(partition_id.batch_result) + return batch_snapshot.process(partition_id.partition_result) def __enter__(self): return self diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py index 2d62a5306a..02dd2676cf 100644 --- a/google/cloud/spanner_dbapi/parsed_statement.py +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -1,4 +1,4 @@ -# Copyright 20203 Google LLC All rights reserved. +# Copyright 2023 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass from enum import Enum -from typing import Any +from typing import Any, Dict from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -60,4 +60,4 @@ class ParsedStatement: statement_type: StatementType statement: Statement client_side_statement_type: ClientSideStatementType = None - client_side_statement_params: dict[ClientSideStatementParamKey, Any] = None + client_side_statement_params: Dict[ClientSideStatementParamKey, Any] = None diff --git a/google/cloud/spanner_dbapi/partition_helper.py b/google/cloud/spanner_dbapi/partition_helper.py index f7ac4e2db6..94b396c801 100644 --- a/google/cloud/spanner_dbapi/partition_helper.py +++ b/google/cloud/spanner_dbapi/partition_helper.py @@ -1,3 +1,17 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Any @@ -12,8 +26,8 @@ def decode_from_string(encoded_partition_id): return pickle.loads(partition_id_bytes) -def encode_to_string(batch_transaction_id, batch_result): - partition_id = PartitionId(batch_transaction_id, batch_result) +def encode_to_string(batch_transaction_id, partition_result): + partition_id = PartitionId(batch_transaction_id, partition_result) partition_id_bytes = pickle.dumps(partition_id) gzip_bytes = gzip.compress(partition_id_bytes) return str(base64.b64encode(gzip_bytes), "utf-8") @@ -29,4 +43,4 @@ class BatchTransactionId: @dataclass class PartitionId: batch_transaction_id: BatchTransactionId - batch_result: Any + partition_result: Any diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 82679e049e..5dac502648 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -537,6 +537,7 @@ def test_batch_dml_invalid_statements(self): self._cursor.execute("run batch") def test_partitioned_query(self): + """Test partition query works in read-only mode.""" self._cursor.execute("start batch dml") for i in range(1, 11): self._insert_row(i) @@ -554,6 +555,68 @@ def test_partitioned_query(self): assert len(rows) == 10 self._conn.commit() + def test_partitioned_query_in_rw_transaction(self): + """Test partition query throws exception when connection is not in + read-only mode and neither in auto-commit mode.""" + self._cursor.execute("start batch dml") + for i in range(1, 11): + self._insert_row(i) + self._cursor.execute("run batch") + self._conn.commit() + + with pytest.raises(ProgrammingError): + self._cursor.execute("PARTITION SELECT * FROM contacts") + + def test_partitioned_query_with_dml_query(self): + """Test partition query throws exception when sql query is a DML query.""" + self._cursor.execute("start batch dml") + for i in range(1, 11): + self._insert_row(i) + self._cursor.execute("run batch") + self._conn.commit() + + self._conn.read_only = True + with pytest.raises(ProgrammingError): + self._cursor.execute( + f""" + PARTITION INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (1111, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + + def test_partitioned_query_in_autocommit_mode(self): + """Test partition query works when connection is not in read-only mode + but is in auto-commit mode.""" + self._cursor.execute("start batch dml") + for i in range(1, 11): + self._insert_row(i) + self._cursor.execute("run batch") + self._conn.commit() + + self._conn.autocommit = True + self._cursor.execute("PARTITION SELECT * FROM contacts") + partition_id_rows = self._cursor.fetchall() + assert len(partition_id_rows) == 1 + + partition_id_row = partition_id_rows[0] + self._cursor.execute("RUN PARTITION " + partition_id_row[0]) + rows = self._cursor.fetchall() + assert len(rows) == 10 + + def test_partitioned_query_with_client_transaction_started(self): + """Test partition query throws exception when connection is in + auto-commit mode but transaction started using client side statement.""" + self._cursor.execute("start batch dml") + for i in range(1, 11): + self._insert_row(i) + self._cursor.execute("run batch") + self._conn.commit() + + self._conn.autocommit = True + self._cursor.execute("begin transaction") + with pytest.raises(ProgrammingError): + self._cursor.execute("PARTITION SELECT * FROM contacts") + def _insert_row(self, i): self._cursor.execute( f""" From a3cbdc7c719ba829e21ab860235e0b2f86bb65d6 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Wed, 3 Jan 2024 12:14:43 +0530 Subject: [PATCH 3/6] Small fix --- .../client_side_statement_executor.py | 27 ++++++++++--------- tests/system/test_dbapi.py | 8 +++--- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py index 1246c1bdd0..01590486f0 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_executor.py +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -51,6 +51,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): :param parsed_statement: parsed_statement based on the sql query """ connection = cursor.connection + column_values = [] if connection.is_closed: raise ProgrammingError(CONNECTION_CLOSED_ERROR) statement_type = parsed_statement.client_side_statement_type @@ -64,24 +65,26 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): connection.rollback() return None if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP: - if connection._transaction is None: - committed_timestamp = None - else: - committed_timestamp = list(connection._transaction.committed) + if ( + connection._transaction is not None + and connection._transaction.committed is not None + ): + column_values.append(connection._transaction.committed) return _get_streamed_result_set( ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name, TypeCode.TIMESTAMP, - committed_timestamp, + column_values, ) if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP: - if connection._snapshot is None: - read_timestamp = None - else: - read_timestamp = list(connection._snapshot._transaction_read_timestamp) + if ( + connection._snapshot is not None + and connection._snapshot._transaction_read_timestamp is not None + ): + column_values.append(connection._snapshot._transaction_read_timestamp) return _get_streamed_result_set( ClientSideStatementType.SHOW_READ_TIMESTAMP.name, TypeCode.TIMESTAMP, - read_timestamp, + column_values, ) if statement_type == ClientSideStatementType.START_BATCH_DML: connection.start_batch_dml(cursor) @@ -111,8 +114,8 @@ def _get_streamed_result_set(column_name, type_code, column_values): ) result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb)) - column_values_pb = [] - if column_values is not None: + if len(column_values) > 0: + column_values_pb = [] for column_value in column_values: column_values_pb.append(_make_value_pb(column_value)) result_set.values.extend(column_values_pb) diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 5dac502648..4cd1ec2330 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -578,10 +578,10 @@ def test_partitioned_query_with_dml_query(self): self._conn.read_only = True with pytest.raises(ProgrammingError): self._cursor.execute( - f""" - PARTITION INSERT INTO contacts (contact_id, first_name, last_name, email) - VALUES (1111, 'first-name', 'last-name', 'test.email@domen.ru') - """ + """ + PARTITION INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (1111, 'first-name', 'last-name', 'test.email@domen.ru') + """ ) def test_partitioned_query_in_autocommit_mode(self): From 8ec04ffbd1db9ea0a296da0650faa4f1efc05f55 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Thu, 4 Jan 2024 14:56:23 +0530 Subject: [PATCH 4/6] Test fix --- tests/system/test_dbapi.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 4cd1ec2330..1d5f3f5e78 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -547,11 +547,12 @@ def test_partitioned_query(self): self._conn.read_only = True self._cursor.execute("PARTITION SELECT * FROM contacts") partition_id_rows = self._cursor.fetchall() - assert len(partition_id_rows) == 1 + assert len(partition_id_rows) > 0 - partition_id_row = partition_id_rows[0] - self._cursor.execute("RUN PARTITION " + partition_id_row[0]) - rows = self._cursor.fetchall() + rows = [] + for partition_id_row in partition_id_rows: + self._cursor.execute("RUN PARTITION " + partition_id_row[0]) + rows = rows + self._cursor.fetchall() assert len(rows) == 10 self._conn.commit() @@ -596,12 +597,14 @@ def test_partitioned_query_in_autocommit_mode(self): self._conn.autocommit = True self._cursor.execute("PARTITION SELECT * FROM contacts") partition_id_rows = self._cursor.fetchall() - assert len(partition_id_rows) == 1 + assert len(partition_id_rows) > 0 - partition_id_row = partition_id_rows[0] - self._cursor.execute("RUN PARTITION " + partition_id_row[0]) - rows = self._cursor.fetchall() + rows = [] + for partition_id_row in partition_id_rows: + self._cursor.execute("RUN PARTITION " + partition_id_row[0]) + rows = rows + self._cursor.fetchall() assert len(rows) == 10 + self._conn.commit() def test_partitioned_query_with_client_transaction_started(self): """Test partition query throws exception when connection is in From fd8db526b8bc679da37b47998751513854b0905a Mon Sep 17 00:00:00 2001 From: ankiaga Date: Thu, 4 Jan 2024 19:34:51 +0530 Subject: [PATCH 5/6] Removing ClientSideStatementParamKey enum --- .../spanner_dbapi/client_side_statement_executor.py | 5 +---- .../spanner_dbapi/client_side_statement_parser.py | 11 +++-------- google/cloud/spanner_dbapi/connection.py | 5 +---- google/cloud/spanner_dbapi/parsed_statement.py | 9 ++------- tests/unit/spanner_dbapi/test_parse_utils.py | 7 ++----- 5 files changed, 9 insertions(+), 28 deletions(-) diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py index 01590486f0..4d3408218c 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_executor.py +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -20,7 +20,6 @@ from google.cloud.spanner_dbapi.parsed_statement import ( ParsedStatement, ClientSideStatementType, - ClientSideStatementParamKey, ) from google.cloud.spanner_v1 import ( Type, @@ -102,9 +101,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): ) if statement_type == ClientSideStatementType.RUN_PARTITION: return connection.run_partition( - parsed_statement.client_side_statement_params[ - ClientSideStatementParamKey.PARTITION_ID - ] + parsed_statement.client_side_statement_params[0] ) diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py index 85dbca4eb4..04a3cc523c 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_parser.py +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -19,7 +19,6 @@ StatementType, ClientSideStatementType, Statement, - ClientSideStatementParamKey, ) RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE) @@ -51,7 +50,7 @@ def parse_stmt(query): :returns: ParsedStatement object. """ client_side_statement_type = None - client_side_statement_params = {} + client_side_statement_params = [] if RE_COMMIT.match(query): client_side_statement_type = ClientSideStatementType.COMMIT if RE_BEGIN.match(query): @@ -70,15 +69,11 @@ def parse_stmt(query): client_side_statement_type = ClientSideStatementType.ABORT_BATCH if RE_PARTITION_QUERY.match(query): match = re.search(RE_PARTITION_QUERY, query) - client_side_statement_params[ - ClientSideStatementParamKey.PARTITIONED_SQL_QUERY - ] = match.group(2) + client_side_statement_params.append(match.group(2)) client_side_statement_type = ClientSideStatementType.PARTITION_QUERY if RE_RUN_PARTITION.match(query): match = re.search(RE_RUN_PARTITION, query) - client_side_statement_params[ - ClientSideStatementParamKey.PARTITION_ID - ] = match.group(3) + client_side_statement_params.append(match.group(3)) client_side_statement_type = ClientSideStatementType.RUN_PARTITION if client_side_statement_type is not None: return ParsedStatement( diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 7b1966244e..47680fd550 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -26,7 +26,6 @@ ParsedStatement, Statement, StatementType, - ClientSideStatementParamKey, ) from google.cloud.spanner_dbapi.partition_helper import PartitionId from google.cloud.spanner_v1 import RequestOptions @@ -600,9 +599,7 @@ def partition_query( query_options=None, ): statement = parsed_statement.statement - partitioned_query = parsed_statement.client_side_statement_params[ - ClientSideStatementParamKey.PARTITIONED_SQL_QUERY - ] + partitioned_query = parsed_statement.client_side_statement_params[0] if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY: raise ProgrammingError( "Only queries can be partitioned. Invalid statement: " + statement.sql diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py index 02dd2676cf..798f5126c3 100644 --- a/google/cloud/spanner_dbapi/parsed_statement.py +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass from enum import Enum -from typing import Any, Dict +from typing import Any, List from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -39,11 +39,6 @@ class ClientSideStatementType(Enum): RUN_PARTITION = 10 -class ClientSideStatementParamKey(Enum): - PARTITIONED_SQL_QUERY = 1 - PARTITION_ID = 2 - - @dataclass class Statement: sql: str @@ -60,4 +55,4 @@ class ParsedStatement: statement_type: StatementType statement: Statement client_side_statement_type: ClientSideStatementType = None - client_side_statement_params: Dict[ClientSideStatementParamKey, Any] = None + client_side_statement_params: List[Any] = None diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 1e9ebc4ab9..de7b9a6dce 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -20,7 +20,6 @@ ParsedStatement, Statement, ClientSideStatementType, - ClientSideStatementParamKey, ) from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1 import JsonObject @@ -86,9 +85,7 @@ def test_partition_query_classify_stmt(self): StatementType.CLIENT_SIDE, Statement("PARTITION SELECT s.SongName FROM Songs AS s"), ClientSideStatementType.PARTITION_QUERY, - { - ClientSideStatementParamKey.PARTITIONED_SQL_QUERY: "SELECT s.SongName FROM Songs AS s" - }, + ["SELECT s.SongName FROM Songs AS s"], ), ) @@ -100,7 +97,7 @@ def test_run_partition_classify_stmt(self): StatementType.CLIENT_SIDE, Statement("RUN PARTITION bj2bjb2j2bj2ebbh"), ClientSideStatementType.RUN_PARTITION, - {ClientSideStatementParamKey.PARTITION_ID: "bj2bjb2j2bj2ebbh"}, + ["bj2bjb2j2bj2ebbh"], ), ) From 5eb91f158b63020424e32dcd03b82cad5070bfdf Mon Sep 17 00:00:00 2001 From: ankiaga Date: Wed, 10 Jan 2024 11:55:37 +0530 Subject: [PATCH 6/6] Comments incorporated --- tests/system/test_dbapi.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 1d5f3f5e78..18bde6c94d 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -559,22 +559,12 @@ def test_partitioned_query(self): def test_partitioned_query_in_rw_transaction(self): """Test partition query throws exception when connection is not in read-only mode and neither in auto-commit mode.""" - self._cursor.execute("start batch dml") - for i in range(1, 11): - self._insert_row(i) - self._cursor.execute("run batch") - self._conn.commit() with pytest.raises(ProgrammingError): self._cursor.execute("PARTITION SELECT * FROM contacts") def test_partitioned_query_with_dml_query(self): """Test partition query throws exception when sql query is a DML query.""" - self._cursor.execute("start batch dml") - for i in range(1, 11): - self._insert_row(i) - self._cursor.execute("run batch") - self._conn.commit() self._conn.read_only = True with pytest.raises(ProgrammingError): @@ -604,16 +594,10 @@ def test_partitioned_query_in_autocommit_mode(self): self._cursor.execute("RUN PARTITION " + partition_id_row[0]) rows = rows + self._cursor.fetchall() assert len(rows) == 10 - self._conn.commit() def test_partitioned_query_with_client_transaction_started(self): - """Test partition query throws exception when connection is in - auto-commit mode but transaction started using client side statement.""" - self._cursor.execute("start batch dml") - for i in range(1, 11): - self._insert_row(i) - self._cursor.execute("run batch") - self._conn.commit() + """Test partition query throws exception when connection is not in + read-only mode and transaction started using client side statement.""" self._conn.autocommit = True self._cursor.execute("begin transaction")