Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementation for partitioned query in dbapi #1067

Merged
merged 7 commits into from
Jan 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: Implementation for partitioned query in dbapi
ankiaga committed Jan 2, 2024
commit b5b0c98354954326e6dce8b9ddf2c4d7ac009217
27 changes: 22 additions & 5 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
ClientSideStatementType,
ClientSideStatementParamKey,
)
from google.cloud.spanner_v1 import (
Type,
@@ -66,7 +67,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
if connection._transaction is None:
committed_timestamp = None
else:
committed_timestamp = connection._transaction.committed
committed_timestamp = list(connection._transaction.committed)
return _get_streamed_result_set(
ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name,
TypeCode.TIMESTAMP,
@@ -76,7 +77,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
if connection._snapshot is None:
read_timestamp = None
else:
read_timestamp = connection._snapshot._transaction_read_timestamp
read_timestamp = list(connection._snapshot._transaction_read_timestamp)
return _get_streamed_result_set(
ClientSideStatementType.SHOW_READ_TIMESTAMP.name,
TypeCode.TIMESTAMP,
@@ -89,14 +90,30 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
return connection.run_batch()
if statement_type == ClientSideStatementType.ABORT_BATCH:
return connection.abort_batch()
if statement_type == ClientSideStatementType.PARTITION_QUERY:
partition_ids = connection.partition_query(parsed_statement)
return _get_streamed_result_set(
"PARTITION",
TypeCode.STRING,
partition_ids,
)
if statement_type == ClientSideStatementType.RUN_PARTITION:
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
return connection.run_partition(
parsed_statement.client_side_statement_params[
ClientSideStatementParamKey.PARTITION_ID
]
)


def _get_streamed_result_set(column_name, type_code, column_value):
def _get_streamed_result_set(column_name, type_code, column_values):
struct_type_pb = StructType(
fields=[StructType.Field(name=column_name, type_=Type(code=type_code))]
)

result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb))
if column_value is not None:
result_set.values.extend([_make_value_pb(column_value)])
column_values_pb = []
if column_values is not None:
for column_value in column_values:
column_values_pb.append(_make_value_pb(column_value))
result_set.values.extend(column_values_pb)
return StreamedResultSet(iter([result_set]))
21 changes: 20 additions & 1 deletion google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
StatementType,
ClientSideStatementType,
Statement,
ClientSideStatementParamKey,
)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
@@ -33,6 +34,8 @@
RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE)
RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE)
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)
RE_PARTITION_QUERY = re.compile(r"^\s*(PARTITION)\s+(.+)", re.IGNORECASE)
RE_RUN_PARTITION = re.compile(r"^\s*(RUN)\s+(PARTITION)\s+(.+)", re.IGNORECASE)


def parse_stmt(query):
@@ -48,6 +51,7 @@ def parse_stmt(query):
:returns: ParsedStatement object.
"""
client_side_statement_type = None
client_side_statement_params = {}
if RE_COMMIT.match(query):
client_side_statement_type = ClientSideStatementType.COMMIT
if RE_BEGIN.match(query):
@@ -64,8 +68,23 @@ def parse_stmt(query):
client_side_statement_type = ClientSideStatementType.RUN_BATCH
if RE_ABORT_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
if RE_PARTITION_QUERY.match(query):
match = re.search(RE_PARTITION_QUERY, query)
client_side_statement_params[
ClientSideStatementParamKey.PARTITIONED_SQL_QUERY
] = match.group(2)
client_side_statement_type = ClientSideStatementType.PARTITION_QUERY
if RE_RUN_PARTITION.match(query):
match = re.search(RE_RUN_PARTITION, query)
client_side_statement_params[
ClientSideStatementParamKey.PARTITION_ID
] = match.group(3)
client_side_statement_type = ClientSideStatementType.RUN_PARTITION
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type
StatementType.CLIENT_SIDE,
Statement(query),
client_side_statement_type,
client_side_statement_params,
)
return None
54 changes: 53 additions & 1 deletion google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
@@ -19,8 +19,16 @@
from google.api_core.exceptions import Aborted
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi import partition_helper
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
from google.cloud.spanner_dbapi.parse_utils import _get_statement_type
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
Statement,
StatementType,
ClientSideStatementParamKey,
)
from google.cloud.spanner_dbapi.partition_helper import PartitionId
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
@@ -585,6 +593,50 @@ def abort_batch(self):
self._batch_dml_executor = None
self._batch_mode = BatchMode.NONE

@check_not_closed
def partition_query(
self,
parsed_statement: ParsedStatement,
query_options=None,
):
statement = parsed_statement.statement
partitioned_query = parsed_statement.client_side_statement_params[
ClientSideStatementParamKey.PARTITIONED_SQL_QUERY
]
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
raise ProgrammingError(
"Only queries can be partitioned. Invalid statement: " + statement.sql
)
batch_snapshot = self._database.batch_snapshot()
partition_ids = []
partitions = list(
batch_snapshot.generate_query_batches(
partitioned_query,
statement.params,
statement.param_types,
query_options=query_options,
)
)
for partition in partitions:
batch_transaction_id = batch_snapshot.get_batch_transaction_id()
partition_ids.append(
partition_helper.encode_to_string(batch_transaction_id, partition)
)
return partition_ids

@check_not_closed
def run_partition(self, batch_transaction_id):
partition_id: PartitionId = partition_helper.decode_from_string(
batch_transaction_id
)
batch_transaction_id = partition_id.batch_transaction_id
batch_snapshot = self._database.batch_snapshot(
read_timestamp=batch_transaction_id.read_timestamp,
session_id=batch_transaction_id.session_id,
transaction_id=batch_transaction_id.transaction_id,
)
return batch_snapshot.process(partition_id.batch_result)

def __enter__(self):
return self

16 changes: 10 additions & 6 deletions google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
@@ -232,19 +232,23 @@ def classify_statement(query, args=None):
get_param_types(args or None),
ResultsChecksum(),
)
if RE_DDL.match(query):
return ParsedStatement(StatementType.DDL, statement)
statement_type = _get_statement_type(statement)
return ParsedStatement(statement_type, statement)

if RE_IS_INSERT.match(query):
return ParsedStatement(StatementType.INSERT, statement)

def _get_statement_type(statement):
query = statement.sql
if RE_DDL.match(query):
return StatementType.DDL
if RE_IS_INSERT.match(query):
return StatementType.INSERT
if RE_NON_UPDATE.match(query) or RE_WITH.match(query):
# As of 13-March-2020, Cloud Spanner only supports WITH for DQL
# statements and doesn't yet support WITH for DML statements.
return ParsedStatement(StatementType.QUERY, statement)
return StatementType.QUERY

statement.sql = ensure_where_clause(query)
return ParsedStatement(StatementType.UPDATE, statement)
return StatementType.UPDATE


def sql_pyformat_args_to_spanner(sql, params):
8 changes: 8 additions & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,13 @@ class ClientSideStatementType(Enum):
START_BATCH_DML = 6
RUN_BATCH = 7
ABORT_BATCH = 8
PARTITION_QUERY = 9
RUN_PARTITION = 10


class ClientSideStatementParamKey(Enum):
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
PARTITIONED_SQL_QUERY = 1
PARTITION_ID = 2


@dataclass
@@ -53,3 +60,4 @@ class ParsedStatement:
statement_type: StatementType
statement: Statement
client_side_statement_type: ClientSideStatementType = None
client_side_statement_params: dict[ClientSideStatementParamKey, Any] = None
32 changes: 32 additions & 0 deletions google/cloud/spanner_dbapi/partition_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass
olavloite marked this conversation as resolved.
Show resolved Hide resolved
from typing import Any

import gzip
import pickle
import base64


def decode_from_string(encoded_partition_id):
gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8"))
partition_id_bytes = gzip.decompress(gzip_bytes)
return pickle.loads(partition_id_bytes)


def encode_to_string(batch_transaction_id, batch_result):
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
partition_id = PartitionId(batch_transaction_id, batch_result)
partition_id_bytes = pickle.dumps(partition_id)
gzip_bytes = gzip.compress(partition_id_bytes)
return str(base64.b64encode(gzip_bytes), "utf-8")


@dataclass
class BatchTransactionId:
transaction_id: str
session_id: str
read_timestamp: Any


@dataclass
class PartitionId:
batch_transaction_id: BatchTransactionId
batch_result: Any
52 changes: 47 additions & 5 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@

import copy
import functools

import grpc
import logging
import re
@@ -39,6 +40,7 @@
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
from google.cloud.spanner_dbapi.partition_helper import BatchTransactionId
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import TransactionSelector
from google.cloud.spanner_v1 import TransactionOptions
@@ -746,7 +748,13 @@ def mutation_groups(self):
"""
return MutationGroupsCheckout(self)

def batch_snapshot(self, read_timestamp=None, exact_staleness=None):
def batch_snapshot(
self,
read_timestamp=None,
exact_staleness=None,
session_id=None,
transaction_id=None,
):
"""Return an object which wraps a batch read / query.

:type read_timestamp: :class:`datetime.datetime`
@@ -756,11 +764,21 @@ def batch_snapshot(self, read_timestamp=None, exact_staleness=None):
:param exact_staleness: Execute all reads at a timestamp that is
``exact_staleness`` old.

:type session_id: str
:param session_id: id of the session used in transaction

:type transaction_id: str
:param transaction_id: id of the transaction

:rtype: :class:`~google.cloud.spanner_v1.database.BatchSnapshot`
:returns: new wrapper
"""
return BatchSnapshot(
self, read_timestamp=read_timestamp, exact_staleness=exact_staleness
self,
read_timestamp=read_timestamp,
exact_staleness=exact_staleness,
session_id=session_id,
transaction_id=transaction_id,
)

def run_in_transaction(self, func, *args, **kw):
@@ -1138,10 +1156,19 @@ class BatchSnapshot(object):
``exact_staleness`` old.
"""

def __init__(self, database, read_timestamp=None, exact_staleness=None):
def __init__(
self,
database,
read_timestamp=None,
exact_staleness=None,
session_id=None,
transaction_id=None,
):
self._database = database
self._session_id = session_id
self._session = None
self._snapshot = None
self._transaction_id = transaction_id
self._read_timestamp = read_timestamp
self._exact_staleness = exact_staleness

@@ -1189,7 +1216,10 @@ def _get_session(self):
"""
if self._session is None:
session = self._session = self._database.session()
session.create()
if self._session_id is None:
session.create()
else:
session._session_id = self._session_id
return self._session

def _get_snapshot(self):
@@ -1199,10 +1229,22 @@ def _get_snapshot(self):
read_timestamp=self._read_timestamp,
exact_staleness=self._exact_staleness,
multi_use=True,
transaction_id=self._transaction_id,
)
self._snapshot.begin()
if self._transaction_id is None:
self._snapshot.begin()
return self._snapshot

def get_batch_transaction_id(self):
snapshot = self._snapshot
if snapshot is None:
raise ValueError("Read-only transaction not begun")
return BatchTransactionId(
snapshot._transaction_id,
snapshot._session.session_id,
snapshot._read_timestamp,
)

def read(self, *args, **kw):
"""Convenience method: perform read operation via snapshot.

2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
@@ -712,6 +712,7 @@ def __init__(
max_staleness=None,
exact_staleness=None,
multi_use=False,
transaction_id=None,
):
super(Snapshot, self).__init__(session)
opts = [read_timestamp, min_read_timestamp, max_staleness, exact_staleness]
@@ -734,6 +735,7 @@ def __init__(
self._max_staleness = max_staleness
self._exact_staleness = exact_staleness
self._multi_use = multi_use
self._transaction_id = transaction_id

def _make_txn_selector(self):
"""Helper for :meth:`read`."""
Loading