Skip to content

Commit

Permalink
feat: add databoost enabled property for batch transactions (#892)
Browse files Browse the repository at this point in the history
* proto changes

* changes

* changes

* linting

* changes

* changes

* changes

* changes

* changes

* Changes

* Update google/cloud/spanner_v1/snapshot.py

Co-authored-by: Rajat Bhatta <[email protected]>

* Update google/cloud/spanner_v1/database.py

Co-authored-by: Rajat Bhatta <[email protected]>

---------

Co-authored-by: Rajat Bhatta <[email protected]>
  • Loading branch information
asthamohta and rajatbhatta authored Apr 6, 2023
1 parent 1f4a3ca commit ffb3915
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 7 deletions.
18 changes: 17 additions & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,7 @@ def generate_read_batches(
index="",
partition_size_bytes=None,
max_partitions=None,
data_boost_enabled=False,
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
Expand Down Expand Up @@ -1135,6 +1136,11 @@ def generate_read_batches(
service uses this as a hint, the actual number of partitions may
differ.
:type data_boost_enabled:
:param data_boost_enabled:
(Optional) If this is for a partitioned read and this field is
set ``true``, the request will be executed via offline access.
:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.
Expand Down Expand Up @@ -1162,6 +1168,7 @@ def generate_read_batches(
"columns": columns,
"keyset": keyset._to_dict(),
"index": index,
"data_boost_enabled": data_boost_enabled,
}
for partition in partitions:
yield {"partition": partition, "read": read_info.copy()}
Expand Down Expand Up @@ -1205,6 +1212,7 @@ def generate_query_batches(
partition_size_bytes=None,
max_partitions=None,
query_options=None,
data_boost_enabled=False,
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
Expand Down Expand Up @@ -1251,6 +1259,11 @@ def generate_query_batches(
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.QueryOptions`
:type data_boost_enabled:
:param data_boost_enabled:
(Optional) If this is for a partitioned query and this field is
set ``true``, the request will be executed via offline access.
:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.
Expand All @@ -1272,7 +1285,10 @@ def generate_query_batches(
timeout=timeout,
)

query_info = {"sql": sql}
query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types
Expand Down
20 changes: 20 additions & 0 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def read(
limit=0,
partition=None,
request_options=None,
data_boost_enabled=False,
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
Expand Down Expand Up @@ -210,6 +211,14 @@ def read(
:type timeout: float
:param timeout: (Optional) The timeout for this request.
:type data_boost_enabled:
:param data_boost_enabled:
(Optional) If this is for a partitioned read and this field is
set ``true``, the request will be executed via offline access.
If the field is set to ``true`` but the request does not set
``partition_token``, the API will return an
``INVALID_ARGUMENT`` error.
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
Expand Down Expand Up @@ -247,6 +256,7 @@ def read(
limit=limit,
partition_token=partition,
request_options=request_options,
data_boost_enabled=data_boost_enabled,
)
restart = functools.partial(
api.streaming_read,
Expand Down Expand Up @@ -302,6 +312,7 @@ def execute_sql(
partition=None,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
data_boost_enabled=False,
):
"""Perform an ``ExecuteStreamingSql`` API request.
Expand Down Expand Up @@ -351,6 +362,14 @@ def execute_sql(
:type timeout: float
:param timeout: (Optional) The timeout for this request.
:type data_boost_enabled:
:param data_boost_enabled:
(Optional) If this is for a partitioned query and this field is
set ``true``, the request will be executed via offline access.
If the field is set to ``true`` but the request does not set
``partition_token``, the API will return an
``INVALID_ARGUMENT`` error.
:raises ValueError:
for reuse of single-use snapshots, or if a transaction ID is
already pending for multiple-use snapshots.
Expand Down Expand Up @@ -400,6 +419,7 @@ def execute_sql(
seqno=self._execute_sql_count,
query_options=query_options,
request_options=request_options,
data_boost_enabled=data_boost_enabled,
)
restart = functools.partial(
api.execute_streaming_sql,
Expand Down
7 changes: 6 additions & 1 deletion samples/samples/batch_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def run_batch_query(instance_id, database_id):
table="Singers",
columns=("SingerId", "FirstName", "LastName"),
keyset=spanner.KeySet(all_=True),
# A Partition object is serializable and can be used from a different process.
# DataBoost option is an optional parameter which can also be used for partition read
# and query to execute the request via spanner independent compute resources.
data_boost_enabled=True,
)

# Create a pool of workers for the tasks
Expand Down Expand Up @@ -87,4 +91,5 @@ def process(snapshot, partition):

args = parser.parse_args()

run_batch_query(args.instance_id, args.database_id)
if args.command == "run_batch_query":
run_batch_query(args.instance_id, args.database_id)
12 changes: 8 additions & 4 deletions tests/system/test_session_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,7 +1875,7 @@ def test_read_with_range_keys_and_index_open_open(sessions_database):
assert rows == expected


def test_partition_read_w_index(sessions_database):
def test_partition_read_w_index(sessions_database, not_emulator):
sd = _sample_data
row_count = 10
columns = sd.COLUMNS[1], sd.COLUMNS[2]
Expand All @@ -1886,7 +1886,11 @@ def test_partition_read_w_index(sessions_database):

batch_txn = sessions_database.batch_snapshot(read_timestamp=committed)
batches = batch_txn.generate_read_batches(
sd.TABLE, columns, spanner_v1.KeySet(all_=True), index="name"
sd.TABLE,
columns,
spanner_v1.KeySet(all_=True),
index="name",
data_boost_enabled=True,
)
for batch in batches:
p_results_iter = batch_txn.process(batch)
Expand Down Expand Up @@ -2494,7 +2498,7 @@ def test_execute_sql_returning_transfinite_floats(sessions_database, not_postgre
assert math.isnan(float_array[2])


def test_partition_query(sessions_database):
def test_partition_query(sessions_database, not_emulator):
row_count = 40
sql = f"SELECT * FROM {_sample_data.TABLE}"
committed = _set_up_table(sessions_database, row_count)
Expand All @@ -2503,7 +2507,7 @@ def test_partition_query(sessions_database):
all_data_rows = set(_row_data(row_count))
union = set()
batch_txn = sessions_database.batch_snapshot(read_timestamp=committed)
for batch in batch_txn.generate_query_batches(sql):
for batch in batch_txn.generate_query_batches(sql, data_boost_enabled=True):
p_results_iter = batch_txn.process(batch)
# Lists aren't hashable so the results need to be converted
rows = [tuple(result) for result in p_results_iter]
Expand Down
83 changes: 82 additions & 1 deletion tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,7 @@ def test_generate_read_batches_w_max_partitions(self):
"columns": self.COLUMNS,
"keyset": {"all": True},
"index": "",
"data_boost_enabled": False,
}
self.assertEqual(len(batches), len(self.TOKENS))
for batch, token in zip(batches, self.TOKENS):
Expand Down Expand Up @@ -2155,6 +2156,7 @@ def test_generate_read_batches_w_retry_and_timeout_params(self):
"columns": self.COLUMNS,
"keyset": {"all": True},
"index": "",
"data_boost_enabled": False,
}
self.assertEqual(len(batches), len(self.TOKENS))
for batch, token in zip(batches, self.TOKENS):
Expand Down Expand Up @@ -2195,6 +2197,7 @@ def test_generate_read_batches_w_index_w_partition_size_bytes(self):
"columns": self.COLUMNS,
"keyset": {"all": True},
"index": self.INDEX,
"data_boost_enabled": False,
}
self.assertEqual(len(batches), len(self.TOKENS))
for batch, token in zip(batches, self.TOKENS):
Expand All @@ -2212,6 +2215,47 @@ def test_generate_read_batches_w_index_w_partition_size_bytes(self):
timeout=gapic_v1.method.DEFAULT,
)

def test_generate_read_batches_w_data_boost_enabled(self):
data_boost_enabled = True
keyset = self._make_keyset()
database = self._make_database()
batch_txn = self._make_one(database)
snapshot = batch_txn._snapshot = self._make_snapshot()
snapshot.partition_read.return_value = self.TOKENS

batches = list(
batch_txn.generate_read_batches(
self.TABLE,
self.COLUMNS,
keyset,
index=self.INDEX,
data_boost_enabled=data_boost_enabled,
)
)

expected_read = {
"table": self.TABLE,
"columns": self.COLUMNS,
"keyset": {"all": True},
"index": self.INDEX,
"data_boost_enabled": True,
}
self.assertEqual(len(batches), len(self.TOKENS))
for batch, token in zip(batches, self.TOKENS):
self.assertEqual(batch["partition"], token)
self.assertEqual(batch["read"], expected_read)

snapshot.partition_read.assert_called_once_with(
table=self.TABLE,
columns=self.COLUMNS,
keyset=keyset,
index=self.INDEX,
partition_size_bytes=None,
max_partitions=None,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
)

def test_process_read_batch(self):
keyset = self._make_keyset()
token = b"TOKEN"
Expand Down Expand Up @@ -2288,7 +2332,11 @@ def test_generate_query_batches_w_max_partitions(self):
batch_txn.generate_query_batches(sql, max_partitions=max_partitions)
)

expected_query = {"sql": sql, "query_options": client._query_options}
expected_query = {
"sql": sql,
"data_boost_enabled": False,
"query_options": client._query_options,
}
self.assertEqual(len(batches), len(self.TOKENS))
for batch, token in zip(batches, self.TOKENS):
self.assertEqual(batch["partition"], token)
Expand Down Expand Up @@ -2326,6 +2374,7 @@ def test_generate_query_batches_w_params_w_partition_size_bytes(self):

expected_query = {
"sql": sql,
"data_boost_enabled": False,
"params": params,
"param_types": param_types,
"query_options": client._query_options,
Expand Down Expand Up @@ -2372,6 +2421,7 @@ def test_generate_query_batches_w_retry_and_timeout_params(self):

expected_query = {
"sql": sql,
"data_boost_enabled": False,
"params": params,
"param_types": param_types,
"query_options": client._query_options,
Expand All @@ -2391,6 +2441,37 @@ def test_generate_query_batches_w_retry_and_timeout_params(self):
timeout=2.0,
)

def test_generate_query_batches_w_data_boost_enabled(self):
sql = "SELECT COUNT(*) FROM table_name"
client = _Client(self.PROJECT_ID)
instance = _Instance(self.INSTANCE_NAME, client=client)
database = _Database(self.DATABASE_NAME, instance=instance)
batch_txn = self._make_one(database)
snapshot = batch_txn._snapshot = self._make_snapshot()
snapshot.partition_query.return_value = self.TOKENS

batches = list(batch_txn.generate_query_batches(sql, data_boost_enabled=True))

expected_query = {
"sql": sql,
"data_boost_enabled": True,
"query_options": client._query_options,
}
self.assertEqual(len(batches), len(self.TOKENS))
for batch, token in zip(batches, self.TOKENS):
self.assertEqual(batch["partition"], token)
self.assertEqual(batch["query"], expected_query)

snapshot.partition_query.assert_called_once_with(
sql=sql,
params=None,
param_types=None,
partition_size_bytes=None,
max_partitions=None,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
)

def test_process_query_batch(self):
sql = (
"SELECT first_name, last_name, email FROM citizens " "WHERE age <= @max_age"
Expand Down

0 comments on commit ffb3915

Please sign in to comment.