From ffb39158be5a551b698739c003ee6125a11c1c7a Mon Sep 17 00:00:00 2001 From: Astha Mohta <35952883+asthamohta@users.noreply.github.com> Date: Thu, 6 Apr 2023 17:28:56 +0530 Subject: [PATCH] feat: add databoost enabled property for batch transactions (#892) * proto changes * changes * changes * linting * changes * changes * changes * changes * changes * Changes * Update google/cloud/spanner_v1/snapshot.py Co-authored-by: Rajat Bhatta <93644539+rajatbhatta@users.noreply.github.com> * Update google/cloud/spanner_v1/database.py Co-authored-by: Rajat Bhatta <93644539+rajatbhatta@users.noreply.github.com> --------- Co-authored-by: Rajat Bhatta <93644539+rajatbhatta@users.noreply.github.com> --- google/cloud/spanner_v1/database.py | 18 ++++++- google/cloud/spanner_v1/snapshot.py | 20 +++++++ samples/samples/batch_sample.py | 7 ++- tests/system/test_session_api.py | 12 +++-- tests/unit/test_database.py | 83 ++++++++++++++++++++++++++++- 5 files changed, 133 insertions(+), 7 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index f919fa2c5e..8e72d6cf8f 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -1101,6 +1101,7 @@ def generate_read_batches( index="", partition_size_bytes=None, max_partitions=None, + data_boost_enabled=False, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, @@ -1135,6 +1136,11 @@ def generate_read_batches( service uses this as a hint, the actual number of partitions may differ. + :type data_boost_enabled: + :param data_boost_enabled: + (Optional) If this is for a partitioned read and this field is + set ``true``, the request will be executed via offline access. + :type retry: :class:`~google.api_core.retry.Retry` :param retry: (Optional) The retry settings for this request. @@ -1162,6 +1168,7 @@ def generate_read_batches( "columns": columns, "keyset": keyset._to_dict(), "index": index, + "data_boost_enabled": data_boost_enabled, } for partition in partitions: yield {"partition": partition, "read": read_info.copy()} @@ -1205,6 +1212,7 @@ def generate_query_batches( partition_size_bytes=None, max_partitions=None, query_options=None, + data_boost_enabled=False, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, @@ -1251,6 +1259,11 @@ def generate_query_batches( If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.spanner_v1.types.QueryOptions` + :type data_boost_enabled: + :param data_boost_enabled: + (Optional) If this is for a partitioned query and this field is + set ``true``, the request will be executed via offline access. + :type retry: :class:`~google.api_core.retry.Retry` :param retry: (Optional) The retry settings for this request. @@ -1272,7 +1285,10 @@ def generate_query_batches( timeout=timeout, ) - query_info = {"sql": sql} + query_info = { + "sql": sql, + "data_boost_enabled": data_boost_enabled, + } if params: query_info["params"] = params query_info["param_types"] = param_types diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index f1fff8b533..362e5dd1bc 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -167,6 +167,7 @@ def read( limit=0, partition=None, request_options=None, + data_boost_enabled=False, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, @@ -210,6 +211,14 @@ def read( :type timeout: float :param timeout: (Optional) The timeout for this request. + :type data_boost_enabled: + :param data_boost_enabled: + (Optional) If this is for a partitioned read and this field is + set ``true``, the request will be executed via offline access. + If the field is set to ``true`` but the request does not set + ``partition_token``, the API will return an + ``INVALID_ARGUMENT`` error. + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. @@ -247,6 +256,7 @@ def read( limit=limit, partition_token=partition, request_options=request_options, + data_boost_enabled=data_boost_enabled, ) restart = functools.partial( api.streaming_read, @@ -302,6 +312,7 @@ def execute_sql( partition=None, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + data_boost_enabled=False, ): """Perform an ``ExecuteStreamingSql`` API request. @@ -351,6 +362,14 @@ def execute_sql( :type timeout: float :param timeout: (Optional) The timeout for this request. + :type data_boost_enabled: + :param data_boost_enabled: + (Optional) If this is for a partitioned query and this field is + set ``true``, the request will be executed via offline access. + If the field is set to ``true`` but the request does not set + ``partition_token``, the API will return an + ``INVALID_ARGUMENT`` error. + :raises ValueError: for reuse of single-use snapshots, or if a transaction ID is already pending for multiple-use snapshots. @@ -400,6 +419,7 @@ def execute_sql( seqno=self._execute_sql_count, query_options=query_options, request_options=request_options, + data_boost_enabled=data_boost_enabled, ) restart = functools.partial( api.execute_streaming_sql, diff --git a/samples/samples/batch_sample.py b/samples/samples/batch_sample.py index 73d9f5667e..69913ac4b3 100644 --- a/samples/samples/batch_sample.py +++ b/samples/samples/batch_sample.py @@ -47,6 +47,10 @@ def run_batch_query(instance_id, database_id): table="Singers", columns=("SingerId", "FirstName", "LastName"), keyset=spanner.KeySet(all_=True), + # A Partition object is serializable and can be used from a different process. + # DataBoost option is an optional parameter which can also be used for partition read + # and query to execute the request via spanner independent compute resources. + data_boost_enabled=True, ) # Create a pool of workers for the tasks @@ -87,4 +91,5 @@ def process(snapshot, partition): args = parser.parse_args() - run_batch_query(args.instance_id, args.database_id) + if args.command == "run_batch_query": + run_batch_query(args.instance_id, args.database_id) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 6b7afbe525..7d58324b04 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -1875,7 +1875,7 @@ def test_read_with_range_keys_and_index_open_open(sessions_database): assert rows == expected -def test_partition_read_w_index(sessions_database): +def test_partition_read_w_index(sessions_database, not_emulator): sd = _sample_data row_count = 10 columns = sd.COLUMNS[1], sd.COLUMNS[2] @@ -1886,7 +1886,11 @@ def test_partition_read_w_index(sessions_database): batch_txn = sessions_database.batch_snapshot(read_timestamp=committed) batches = batch_txn.generate_read_batches( - sd.TABLE, columns, spanner_v1.KeySet(all_=True), index="name" + sd.TABLE, + columns, + spanner_v1.KeySet(all_=True), + index="name", + data_boost_enabled=True, ) for batch in batches: p_results_iter = batch_txn.process(batch) @@ -2494,7 +2498,7 @@ def test_execute_sql_returning_transfinite_floats(sessions_database, not_postgre assert math.isnan(float_array[2]) -def test_partition_query(sessions_database): +def test_partition_query(sessions_database, not_emulator): row_count = 40 sql = f"SELECT * FROM {_sample_data.TABLE}" committed = _set_up_table(sessions_database, row_count) @@ -2503,7 +2507,7 @@ def test_partition_query(sessions_database): all_data_rows = set(_row_data(row_count)) union = set() batch_txn = sessions_database.batch_snapshot(read_timestamp=committed) - for batch in batch_txn.generate_query_batches(sql): + for batch in batch_txn.generate_query_batches(sql, data_boost_enabled=True): p_results_iter = batch_txn.process(batch) # Lists aren't hashable so the results need to be converted rows = [tuple(result) for result in p_results_iter] diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index bff89320c7..030cf5512b 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -2114,6 +2114,7 @@ def test_generate_read_batches_w_max_partitions(self): "columns": self.COLUMNS, "keyset": {"all": True}, "index": "", + "data_boost_enabled": False, } self.assertEqual(len(batches), len(self.TOKENS)) for batch, token in zip(batches, self.TOKENS): @@ -2155,6 +2156,7 @@ def test_generate_read_batches_w_retry_and_timeout_params(self): "columns": self.COLUMNS, "keyset": {"all": True}, "index": "", + "data_boost_enabled": False, } self.assertEqual(len(batches), len(self.TOKENS)) for batch, token in zip(batches, self.TOKENS): @@ -2195,6 +2197,7 @@ def test_generate_read_batches_w_index_w_partition_size_bytes(self): "columns": self.COLUMNS, "keyset": {"all": True}, "index": self.INDEX, + "data_boost_enabled": False, } self.assertEqual(len(batches), len(self.TOKENS)) for batch, token in zip(batches, self.TOKENS): @@ -2212,6 +2215,47 @@ def test_generate_read_batches_w_index_w_partition_size_bytes(self): timeout=gapic_v1.method.DEFAULT, ) + def test_generate_read_batches_w_data_boost_enabled(self): + data_boost_enabled = True + keyset = self._make_keyset() + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + snapshot.partition_read.return_value = self.TOKENS + + batches = list( + batch_txn.generate_read_batches( + self.TABLE, + self.COLUMNS, + keyset, + index=self.INDEX, + data_boost_enabled=data_boost_enabled, + ) + ) + + expected_read = { + "table": self.TABLE, + "columns": self.COLUMNS, + "keyset": {"all": True}, + "index": self.INDEX, + "data_boost_enabled": True, + } + self.assertEqual(len(batches), len(self.TOKENS)) + for batch, token in zip(batches, self.TOKENS): + self.assertEqual(batch["partition"], token) + self.assertEqual(batch["read"], expected_read) + + snapshot.partition_read.assert_called_once_with( + table=self.TABLE, + columns=self.COLUMNS, + keyset=keyset, + index=self.INDEX, + partition_size_bytes=None, + max_partitions=None, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + ) + def test_process_read_batch(self): keyset = self._make_keyset() token = b"TOKEN" @@ -2288,7 +2332,11 @@ def test_generate_query_batches_w_max_partitions(self): batch_txn.generate_query_batches(sql, max_partitions=max_partitions) ) - expected_query = {"sql": sql, "query_options": client._query_options} + expected_query = { + "sql": sql, + "data_boost_enabled": False, + "query_options": client._query_options, + } self.assertEqual(len(batches), len(self.TOKENS)) for batch, token in zip(batches, self.TOKENS): self.assertEqual(batch["partition"], token) @@ -2326,6 +2374,7 @@ def test_generate_query_batches_w_params_w_partition_size_bytes(self): expected_query = { "sql": sql, + "data_boost_enabled": False, "params": params, "param_types": param_types, "query_options": client._query_options, @@ -2372,6 +2421,7 @@ def test_generate_query_batches_w_retry_and_timeout_params(self): expected_query = { "sql": sql, + "data_boost_enabled": False, "params": params, "param_types": param_types, "query_options": client._query_options, @@ -2391,6 +2441,37 @@ def test_generate_query_batches_w_retry_and_timeout_params(self): timeout=2.0, ) + def test_generate_query_batches_w_data_boost_enabled(self): + sql = "SELECT COUNT(*) FROM table_name" + client = _Client(self.PROJECT_ID) + instance = _Instance(self.INSTANCE_NAME, client=client) + database = _Database(self.DATABASE_NAME, instance=instance) + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + snapshot.partition_query.return_value = self.TOKENS + + batches = list(batch_txn.generate_query_batches(sql, data_boost_enabled=True)) + + expected_query = { + "sql": sql, + "data_boost_enabled": True, + "query_options": client._query_options, + } + self.assertEqual(len(batches), len(self.TOKENS)) + for batch, token in zip(batches, self.TOKENS): + self.assertEqual(batch["partition"], token) + self.assertEqual(batch["query"], expected_query) + + snapshot.partition_query.assert_called_once_with( + sql=sql, + params=None, + param_types=None, + partition_size_bytes=None, + max_partitions=None, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + ) + def test_process_query_batch(self): sql = ( "SELECT first_name, last_name, email FROM citizens " "WHERE age <= @max_age"