From a32dd63ace57c138958c8c5e7a701eeca7d9557d Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Tue, 10 Oct 2023 12:54:21 +0000 Subject: [PATCH] feat: Batch Write API implementation and samples --- google/cloud/spanner_v1/__init__.py | 4 + google/cloud/spanner_v1/batch.py | 72 +++++++++++- google/cloud/spanner_v1/database.py | 45 ++++++++ samples/samples/snippets.py | 51 +++++++++ samples/samples/snippets_test.py | 7 ++ tests/system/test_session_api.py | 38 +++++++ tests/unit/test_batch.py | 142 +++++++++++++++++++++++ tests/unit/test_database.py | 171 ++++++++++++++++++++++++++++ 8 files changed, 528 insertions(+), 2 deletions(-) diff --git a/google/cloud/spanner_v1/__init__.py b/google/cloud/spanner_v1/__init__.py index 039919563f..3b59bb3ef0 100644 --- a/google/cloud/spanner_v1/__init__.py +++ b/google/cloud/spanner_v1/__init__.py @@ -34,6 +34,8 @@ from .types.result_set import ResultSetStats from .types.spanner import BatchCreateSessionsRequest from .types.spanner import BatchCreateSessionsResponse +from .types.spanner import BatchWriteRequest +from .types.spanner import BatchWriteResponse from .types.spanner import BeginTransactionRequest from .types.spanner import CommitRequest from .types.spanner import CreateSessionRequest @@ -99,6 +101,8 @@ # google.cloud.spanner_v1.types "BatchCreateSessionsRequest", "BatchCreateSessionsResponse", + "BatchWriteRequest", + "BatchWriteResponse", "BeginTransactionRequest", "CommitRequest", "CommitResponse", diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 41e4460c30..252c378695 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -18,6 +18,7 @@ from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import Mutation from google.cloud.spanner_v1 import TransactionOptions +from google.cloud.spanner_v1 import BatchWriteRequest from google.cloud.spanner_v1._helpers import _SessionWrapper from google.cloud.spanner_v1._helpers import _make_list_value_pbs @@ -42,9 +43,9 @@ class _BatchBase(_SessionWrapper): transaction_tag = None _read_only = False - def __init__(self, session): + def __init__(self, session, mutations=None): super(_BatchBase, self).__init__(session) - self._mutations = [] + self._mutations = [] if mutations is None else mutations def _check_state(self): """Helper for :meth:`commit` et al. @@ -215,6 +216,73 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.commit() +class MutationGroups(_SessionWrapper): + """Accumulate mutations for transmission during :meth:`batch_write`. + + :type session: :class:`~google.cloud.spanner_v1.session.Session` + :param session: the session used to perform the commit + """ + + committed = None + + def __init__(self, session): + super(MutationGroups, self).__init__(session) + self._mutation_groups = [] + + def group(self): + """Returns a new mutation_group to which mutations can be added.""" + mutation_group = BatchWriteRequest.MutationGroup() + self._mutation_groups.append(mutation_group) + return _BatchBase(self._session, mutation_group.mutations) + + def batch_write(self, request_options=None): + """Executes batch_write. + + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for this request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + + :rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]` + :returns: a sequence of responses for each batch. + """ + if self.committed is not None: + raise ValueError("MutationGroups already committed") + + database = self._session._database + api = database.spanner_api + metadata = _metadata_with_prefix(database.name) + if database._route_to_leader_enabled: + metadata.append( + _metadata_with_leader_aware_routing(database._route_to_leader_enabled) + ) + trace_attributes = {"num_mutation_groups": len(self._mutation_groups)} + if request_options is None: + request_options = RequestOptions() + elif type(request_options) is dict: + request_options = RequestOptions(request_options) + + request = BatchWriteRequest( + session=self._session.name, + mutation_groups=self._mutation_groups, + request_options=request_options, + ) + with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes): + method = functools.partial( + api.batch_write, + request=request, + metadata=metadata, + ) + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) + self.committed = True + return response + + def _make_write_pb(table, columns, values): """Helper for :meth:`Batch.insert` et al. diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index eee34361b3..37865c9faa 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -50,6 +50,7 @@ _metadata_with_leader_aware_routing, ) from google.cloud.spanner_v1.batch import Batch +from google.cloud.spanner_v1.batch import MutationGroups from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.pool import BurstyPool from google.cloud.spanner_v1.pool import SessionCheckout @@ -734,6 +735,17 @@ def batch(self, request_options=None): """ return BatchCheckout(self, request_options) + def mutation_groups(self): + """Return an object which wraps a mutation_group. + + The wrapper *must* be used as a context manager, with the mutation group + as the value returned by the wrapper. + + :rtype: :class:`~google.cloud.spanner_v1.database.MutationGroupsCheckout` + :returns: new wrapper + """ + return MutationGroupsCheckout(self) + def batch_snapshot(self, read_timestamp=None, exact_staleness=None): """Return an object which wraps a batch read / query. @@ -1040,6 +1052,39 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._database._pool.put(self._session) +class MutationGroupsCheckout(object): + """Context manager for using mutation groups from a database. + + Inside the context manager, checks out a session from the database, + creates mutation groups from it, making the groups available. + + Caller must *not* use the object to perform API requests outside the scope + of the context manager. + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: database to use + """ + + def __init__(self, database): + self._database = database + self._session = self._mutation_groups = None + + def __enter__(self): + """Begin ``with`` block.""" + session = self._session = self._database._pool.get() + return MutationGroups(session) + + def __exit__(self, exc_type, exc_val, exc_tb): + """End ``with`` block.""" + if isinstance(exc_val, NotFound): + # If NotFound exception occurs inside the with block + # then we validate if the session still exists. + if not self._session.exists(): + self._session = self._database._pool._new_session() + self._session.create() + self._database._pool.put(self._session) + + class SnapshotCheckout(object): """Context manager for using a snapshot from a database. diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 82fb95a0dd..8593f6393e 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -403,6 +403,54 @@ def insert_data(instance_id, database_id): # [END spanner_insert_data] +# [START spanner_batch_write] +def batch_write(instance_id, database_id): + """Inserts sample data into the given database via BatchWrite API. + + The database and table must already exist and can be created using + `create_database`. + """ + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + with database.mutation_groups() as groups: + group1 = groups.group() + group1.insert_or_update( + table="Singers", + columns=("SingerId", "FirstName", "LastName"), + values=[ + (16, "Scarlet", "Terry"), + ], + ) + + group2 = groups.group() + group2.insert_or_update( + table="Singers", + columns=("SingerId", "FirstName", "LastName"), + values=[ + (17, "Marc", "Richards"), + (18, "Catalina", "Smith"), + ], + ) + group2.insert_or_update( + table="Albums", + columns=("SingerId", "AlbumId", "AlbumTitle"), + values=[ + (17, 1, "Total Junk"), + (18, 2, "Go, Go, Go"), + ], + ) + + for response in groups.batch_write(): + print(response) + + print("Inserted data.") + + +# [END spanner_batch_write] + + # [START spanner_delete_data] def delete_data(instance_id, database_id): """Deletes sample data from the given database. @@ -2677,6 +2725,7 @@ def drop_sequence(instance_id, database_id): subparsers.add_parser("create_instance", help=create_instance.__doc__) subparsers.add_parser("create_database", help=create_database.__doc__) subparsers.add_parser("insert_data", help=insert_data.__doc__) + subparsers.add_parser("batch_write", help=batch_write.__doc__) subparsers.add_parser("delete_data", help=delete_data.__doc__) subparsers.add_parser("query_data", help=query_data.__doc__) subparsers.add_parser("read_data", help=read_data.__doc__) @@ -2811,6 +2860,8 @@ def drop_sequence(instance_id, database_id): create_database(args.instance_id, args.database_id) elif args.command == "insert_data": insert_data(args.instance_id, args.database_id) + elif args.command == "batch_write": + batch_write(args.instance_id, args.database_id) elif args.command == "delete_data": delete_data(args.instance_id, args.database_id) elif args.command == "query_data": diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index 22b5b6f944..5c6fa0eacd 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -290,6 +290,13 @@ def test_insert_data(capsys, instance_id, sample_database): assert "Inserted data" in out +@pytest.mark.dependency(name="batch_write") +def test_batch_write(capsys, instance_id, sample_database): + snippets.batch_write(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + assert "Inserted data" in out + + @pytest.mark.dependency(depends=["insert_data"]) def test_delete_data(capsys, instance_id, sample_database): snippets.delete_data(instance_id, sample_database.database_id) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index c4ea2ded40..4d7b142e41 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -2508,6 +2508,44 @@ def test_partition_query(sessions_database, not_emulator): batch_txn.close() +def test_mutation_groups_insert_or_update_then_query(sessions_database): + sd = _sample_data + ROW_DATA = ( + (1, "Phred", "Phlyntstone", "phred@example.com"), + (2, "Bharney", "Rhubble", "bharney@example.com"), + (3, "Wylma", "Phlyntstone", "wylma@example.com"), + (4, "Pebbles", "Phlyntstone", "pebbles@example.com"), + (5, "Betty", "Rhubble", "betty@example.com"), + (6, "Slate", "Stephenson", "slate@example.com"), + ) + num_groups = 3 + num_mutations_per_group = len(ROW_DATA) // num_groups + + with sessions_database.mutation_groups() as groups: + for i in range(num_groups): + group = groups.group() + for j in range(num_mutations_per_group): + group.insert_or_update( + sd.TABLE, sd.COLUMNS, [ROW_DATA[i * num_mutations_per_group + j]] + ) + # Response indexes received + seen = collections.Counter() + for response in groups.batch_write(): + _check_batch_status(response.status.code) + assert response.commit_timestamp is not None + assert len(response.indexes) > 0 + seen.update(response.indexes) + # All indexes must be in the range [0, num_groups-1] and seen exactly once + assert len(seen) == num_groups + assert all((0 <= idx < num_groups and ct == 1) for (idx, ct) in seen.items()) + + # Verify the writes by reading from the database + with sessions_database.snapshot() as snapshot: + rows = list(snapshot.execute_sql(sd.SQL)) + + sd._check_rows_data(rows, ROW_DATA) + + class FauxCall: def __init__(self, code, details="FauxCall"): self._code = code diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 856816628f..203c8a0cb5 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -413,6 +413,130 @@ class _BailOut(Exception): self.assertEqual(len(batch._mutations), 1) +class TestMutationGroups(_BaseTest, OpenTelemetryBase): + def _getTargetClass(self): + from google.cloud.spanner_v1.batch import MutationGroups + + return MutationGroups + + def test_ctor(self): + session = _Session() + groups = self._make_one(session) + self.assertIs(groups._session, session) + + def test_batch_write_already_committed(self): + from google.cloud.spanner_v1.keyset import KeySet + + keys = [[0], [1], [2]] + keyset = KeySet(keys=keys) + database = _Database() + database.spanner_api = _FauxSpannerAPI(_batch_write_response=[]) + session = _Session(database) + groups = self._make_one(session) + group = groups.group() + group.delete(TABLE_NAME, keyset=keyset) + groups.batch_write() + self.assertSpanAttributes( + "CloudSpanner.BatchWrite", + status=StatusCode.OK, + attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), + ) + assert groups.committed + # The second call to batch_write should raise an error. + with self.assertRaises(ValueError): + groups.batch_write() + + def test_batch_write_grpc_error(self): + from google.api_core.exceptions import Unknown + from google.cloud.spanner_v1.keyset import KeySet + + keys = [[0], [1], [2]] + keyset = KeySet(keys=keys) + database = _Database() + database.spanner_api = _FauxSpannerAPI(_rpc_error=True) + session = _Session(database) + groups = self._make_one(session) + group = groups.group() + group.delete(TABLE_NAME, keyset=keyset) + + with self.assertRaises(Unknown): + groups.batch_write() + + self.assertSpanAttributes( + "CloudSpanner.BatchWrite", + status=StatusCode.ERROR, + attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), + ) + + def _test_batch_write_with_request_options(self, request_options=None): + import datetime + from google.cloud.spanner_v1 import BatchWriteResponse + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.rpc.status_pb2 import Status + + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + status_pb = Status(code=200) + response = BatchWriteResponse( + commit_timestamp=now_pb, indexes=[0], status=status_pb + ) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI(_batch_write_response=[response]) + session = _Session(database) + groups = self._make_one(session) + group = groups.group() + group.insert(TABLE_NAME, COLUMNS, VALUES) + + response_iter = groups.batch_write(request_options) + self.assertEqual(len(response_iter), 1) + self.assertEqual(response_iter[0], response) + + ( + session, + mutation_groups, + actual_request_options, + metadata, + ) = api._batch_request + self.assertEqual(session, self.SESSION_NAME) + self.assertEqual(mutation_groups, groups._mutation_groups) + self.assertEqual( + metadata, + [ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ], + ) + if request_options is None: + expected_request_options = RequestOptions() + elif type(request_options) is dict: + expected_request_options = RequestOptions(request_options) + else: + expected_request_options = request_options + self.assertEqual(actual_request_options, expected_request_options) + + self.assertSpanAttributes( + "CloudSpanner.BatchWrite", + status=StatusCode.OK, + attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), + ) + + def test_batch_write_no_request_options(self): + self._test_batch_write_with_request_options() + + def test_batch_write_w_transaction_tag_success(self): + self._test_batch_write_with_request_options( + RequestOptions(transaction_tag="tag-1-1") + ) + + def test_batch_write_w_transaction_tag_dictionary_success(self): + self._test_batch_write_with_request_options({"transaction_tag": "tag-1-1"}) + + def test_batch_write_w_incorrect_tag_dictionary_error(self): + with self.assertRaises(ValueError): + self._test_batch_write_with_request_options({"incorrect_tag": "tag-1-1"}) + + class _Session(object): def __init__(self, database=None, name=TestBatch.SESSION_NAME): self._database = database @@ -428,6 +552,7 @@ class _FauxSpannerAPI: _create_instance_conflict = False _instance_not_found = False _committed = None + _batch_request = None _rpc_error = False def __init__(self, **kwargs): @@ -451,3 +576,20 @@ def commit( if self._rpc_error: raise Unknown("error") return self._commit_response + + def batch_write( + self, + request=None, + metadata=None, + ): + from google.api_core.exceptions import Unknown + + self._batch_request = ( + request.session, + request.mutation_groups, + request.request_options, + metadata, + ) + if self._rpc_error: + raise Unknown("error") + return self._batch_write_response diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index bd368eed11..3ab5cf6d9e 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1231,6 +1231,20 @@ def test_batch(self): self.assertIsInstance(checkout, BatchCheckout) self.assertIs(checkout._database, database) + def test_mutation_groups(self): + from google.cloud.spanner_v1.database import MutationGroupsCheckout + + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + checkout = database.mutation_groups() + self.assertIsInstance(checkout, MutationGroupsCheckout) + self.assertIs(checkout._database, database) + def test_batch_snapshot(self): from google.cloud.spanner_v1.database import BatchSnapshot @@ -2679,6 +2693,163 @@ def test_process_w_query_batch(self): ) +class TestMutationGroupsCheckout(_BaseTest): + def _get_target_class(self): + from google.cloud.spanner_v1.database import MutationGroupsCheckout + + return MutationGroupsCheckout + + @staticmethod + def _make_spanner_client(): + from google.cloud.spanner_v1 import SpannerClient + + return mock.create_autospec(SpannerClient) + + def test_ctor(self): + from google.cloud.spanner_v1.batch import MutationGroups + + database = _Database(self.DATABASE_NAME) + pool = database._pool = _Pool() + session = _Session(database) + pool.put(session) + checkout = self._make_one(database) + self.assertIs(checkout._database, database) + + with checkout as groups: + self.assertIsNone(pool._session) + self.assertIsInstance(groups, MutationGroups) + self.assertIs(groups._session, session) + + self.assertIs(pool._session, session) + + def test_context_mgr_success(self): + import datetime + from google.cloud.spanner_v1 import BatchWriteRequest + from google.cloud.spanner_v1 import BatchWriteResponse + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.spanner_v1.batch import MutationGroups + from google.rpc.status_pb2 import Status + + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + status_pb = Status(code=200) + response = BatchWriteResponse( + commit_timestamp=now_pb, indexes=[0], status=status_pb + ) + database = _Database(self.DATABASE_NAME) + api = database.spanner_api = self._make_spanner_client() + api.batch_write.return_value = [response] + pool = database._pool = _Pool() + session = _Session(database) + pool.put(session) + checkout = self._make_one(database) + + request_options = RequestOptions(transaction_tag=self.TRANSACTION_TAG) + request = BatchWriteRequest( + session=self.SESSION_NAME, + mutation_groups=[], + request_options=request_options, + ) + with checkout as groups: + self.assertIsNone(pool._session) + self.assertIsInstance(groups, MutationGroups) + self.assertIs(groups._session, session) + groups.batch_write(request_options) + self.assertEqual(groups.committed, True) + + self.assertIs(pool._session, session) + + api.batch_write.assert_called_once_with( + request=request, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ], + ) + + def test_context_mgr_failure(self): + from google.cloud.spanner_v1.batch import MutationGroups + + database = _Database(self.DATABASE_NAME) + pool = database._pool = _Pool() + session = _Session(database) + pool.put(session) + checkout = self._make_one(database) + + class Testing(Exception): + pass + + with self.assertRaises(Testing): + with checkout as groups: + self.assertIsNone(pool._session) + self.assertIsInstance(groups, MutationGroups) + self.assertIs(groups._session, session) + raise Testing() + + self.assertIs(pool._session, session) + + def test_context_mgr_session_not_found_error(self): + from google.cloud.exceptions import NotFound + + database = _Database(self.DATABASE_NAME) + session = _Session(database, name="session-1") + session.exists = mock.MagicMock(return_value=False) + pool = database._pool = _Pool() + new_session = _Session(database, name="session-2") + new_session.create = mock.MagicMock(return_value=[]) + pool._new_session = mock.MagicMock(return_value=new_session) + + pool.put(session) + checkout = self._make_one(database) + + self.assertEqual(pool._session, session) + with self.assertRaises(NotFound): + with checkout as _: + raise NotFound("Session not found") + # Assert that session-1 was removed from pool and new session was added. + self.assertEqual(pool._session, new_session) + + def test_context_mgr_table_not_found_error(self): + from google.cloud.exceptions import NotFound + + database = _Database(self.DATABASE_NAME) + session = _Session(database, name="session-1") + session.exists = mock.MagicMock(return_value=True) + pool = database._pool = _Pool() + pool._new_session = mock.MagicMock(return_value=[]) + + pool.put(session) + checkout = self._make_one(database) + + self.assertEqual(pool._session, session) + with self.assertRaises(NotFound): + with checkout as _: + raise NotFound("Table not found") + # Assert that session-1 was not removed from pool. + self.assertEqual(pool._session, session) + pool._new_session.assert_not_called() + + def test_context_mgr_unknown_error(self): + database = _Database(self.DATABASE_NAME) + session = _Session(database) + pool = database._pool = _Pool() + pool._new_session = mock.MagicMock(return_value=[]) + pool.put(session) + checkout = self._make_one(database) + + class Testing(Exception): + pass + + self.assertEqual(pool._session, session) + with self.assertRaises(Testing): + with checkout as _: + raise Testing("Unknown error.") + # Assert that session-1 was not removed from pool. + self.assertEqual(pool._session, session) + pool._new_session.assert_not_called() + + def _make_instance_api(): from google.cloud.spanner_admin_instance_v1 import InstanceAdminClient