From a32dd63ace57c138958c8c5e7a701eeca7d9557d Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Tue, 10 Oct 2023 12:54:21 +0000 Subject: [PATCH 01/11] feat: Batch Write API implementation and samples --- google/cloud/spanner_v1/__init__.py | 4 + google/cloud/spanner_v1/batch.py | 72 +++++++++++- google/cloud/spanner_v1/database.py | 45 ++++++++ samples/samples/snippets.py | 51 +++++++++ samples/samples/snippets_test.py | 7 ++ tests/system/test_session_api.py | 38 +++++++ tests/unit/test_batch.py | 142 +++++++++++++++++++++++ tests/unit/test_database.py | 171 ++++++++++++++++++++++++++++ 8 files changed, 528 insertions(+), 2 deletions(-) diff --git a/google/cloud/spanner_v1/__init__.py b/google/cloud/spanner_v1/__init__.py index 039919563f..3b59bb3ef0 100644 --- a/google/cloud/spanner_v1/__init__.py +++ b/google/cloud/spanner_v1/__init__.py @@ -34,6 +34,8 @@ from .types.result_set import ResultSetStats from .types.spanner import BatchCreateSessionsRequest from .types.spanner import BatchCreateSessionsResponse +from .types.spanner import BatchWriteRequest +from .types.spanner import BatchWriteResponse from .types.spanner import BeginTransactionRequest from .types.spanner import CommitRequest from .types.spanner import CreateSessionRequest @@ -99,6 +101,8 @@ # google.cloud.spanner_v1.types "BatchCreateSessionsRequest", "BatchCreateSessionsResponse", + "BatchWriteRequest", + "BatchWriteResponse", "BeginTransactionRequest", "CommitRequest", "CommitResponse", diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 41e4460c30..252c378695 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -18,6 +18,7 @@ from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import Mutation from google.cloud.spanner_v1 import TransactionOptions +from google.cloud.spanner_v1 import BatchWriteRequest from google.cloud.spanner_v1._helpers import _SessionWrapper from google.cloud.spanner_v1._helpers import _make_list_value_pbs @@ -42,9 +43,9 @@ class _BatchBase(_SessionWrapper): transaction_tag = None _read_only = False - def __init__(self, session): + def __init__(self, session, mutations=None): super(_BatchBase, self).__init__(session) - self._mutations = [] + self._mutations = [] if mutations is None else mutations def _check_state(self): """Helper for :meth:`commit` et al. @@ -215,6 +216,73 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.commit() +class MutationGroups(_SessionWrapper): + """Accumulate mutations for transmission during :meth:`batch_write`. + + :type session: :class:`~google.cloud.spanner_v1.session.Session` + :param session: the session used to perform the commit + """ + + committed = None + + def __init__(self, session): + super(MutationGroups, self).__init__(session) + self._mutation_groups = [] + + def group(self): + """Returns a new mutation_group to which mutations can be added.""" + mutation_group = BatchWriteRequest.MutationGroup() + self._mutation_groups.append(mutation_group) + return _BatchBase(self._session, mutation_group.mutations) + + def batch_write(self, request_options=None): + """Executes batch_write. + + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for this request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + + :rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]` + :returns: a sequence of responses for each batch. + """ + if self.committed is not None: + raise ValueError("MutationGroups already committed") + + database = self._session._database + api = database.spanner_api + metadata = _metadata_with_prefix(database.name) + if database._route_to_leader_enabled: + metadata.append( + _metadata_with_leader_aware_routing(database._route_to_leader_enabled) + ) + trace_attributes = {"num_mutation_groups": len(self._mutation_groups)} + if request_options is None: + request_options = RequestOptions() + elif type(request_options) is dict: + request_options = RequestOptions(request_options) + + request = BatchWriteRequest( + session=self._session.name, + mutation_groups=self._mutation_groups, + request_options=request_options, + ) + with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes): + method = functools.partial( + api.batch_write, + request=request, + metadata=metadata, + ) + response = _retry( + method, + allowed_exceptions={InternalServerError: _check_rst_stream_error}, + ) + self.committed = True + return response + + def _make_write_pb(table, columns, values): """Helper for :meth:`Batch.insert` et al. diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index eee34361b3..37865c9faa 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -50,6 +50,7 @@ _metadata_with_leader_aware_routing, ) from google.cloud.spanner_v1.batch import Batch +from google.cloud.spanner_v1.batch import MutationGroups from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.pool import BurstyPool from google.cloud.spanner_v1.pool import SessionCheckout @@ -734,6 +735,17 @@ def batch(self, request_options=None): """ return BatchCheckout(self, request_options) + def mutation_groups(self): + """Return an object which wraps a mutation_group. + + The wrapper *must* be used as a context manager, with the mutation group + as the value returned by the wrapper. + + :rtype: :class:`~google.cloud.spanner_v1.database.MutationGroupsCheckout` + :returns: new wrapper + """ + return MutationGroupsCheckout(self) + def batch_snapshot(self, read_timestamp=None, exact_staleness=None): """Return an object which wraps a batch read / query. @@ -1040,6 +1052,39 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._database._pool.put(self._session) +class MutationGroupsCheckout(object): + """Context manager for using mutation groups from a database. + + Inside the context manager, checks out a session from the database, + creates mutation groups from it, making the groups available. + + Caller must *not* use the object to perform API requests outside the scope + of the context manager. + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: database to use + """ + + def __init__(self, database): + self._database = database + self._session = self._mutation_groups = None + + def __enter__(self): + """Begin ``with`` block.""" + session = self._session = self._database._pool.get() + return MutationGroups(session) + + def __exit__(self, exc_type, exc_val, exc_tb): + """End ``with`` block.""" + if isinstance(exc_val, NotFound): + # If NotFound exception occurs inside the with block + # then we validate if the session still exists. + if not self._session.exists(): + self._session = self._database._pool._new_session() + self._session.create() + self._database._pool.put(self._session) + + class SnapshotCheckout(object): """Context manager for using a snapshot from a database. diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 82fb95a0dd..8593f6393e 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -403,6 +403,54 @@ def insert_data(instance_id, database_id): # [END spanner_insert_data] +# [START spanner_batch_write] +def batch_write(instance_id, database_id): + """Inserts sample data into the given database via BatchWrite API. + + The database and table must already exist and can be created using + `create_database`. + """ + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + with database.mutation_groups() as groups: + group1 = groups.group() + group1.insert_or_update( + table="Singers", + columns=("SingerId", "FirstName", "LastName"), + values=[ + (16, "Scarlet", "Terry"), + ], + ) + + group2 = groups.group() + group2.insert_or_update( + table="Singers", + columns=("SingerId", "FirstName", "LastName"), + values=[ + (17, "Marc", "Richards"), + (18, "Catalina", "Smith"), + ], + ) + group2.insert_or_update( + table="Albums", + columns=("SingerId", "AlbumId", "AlbumTitle"), + values=[ + (17, 1, "Total Junk"), + (18, 2, "Go, Go, Go"), + ], + ) + + for response in groups.batch_write(): + print(response) + + print("Inserted data.") + + +# [END spanner_batch_write] + + # [START spanner_delete_data] def delete_data(instance_id, database_id): """Deletes sample data from the given database. @@ -2677,6 +2725,7 @@ def drop_sequence(instance_id, database_id): subparsers.add_parser("create_instance", help=create_instance.__doc__) subparsers.add_parser("create_database", help=create_database.__doc__) subparsers.add_parser("insert_data", help=insert_data.__doc__) + subparsers.add_parser("batch_write", help=batch_write.__doc__) subparsers.add_parser("delete_data", help=delete_data.__doc__) subparsers.add_parser("query_data", help=query_data.__doc__) subparsers.add_parser("read_data", help=read_data.__doc__) @@ -2811,6 +2860,8 @@ def drop_sequence(instance_id, database_id): create_database(args.instance_id, args.database_id) elif args.command == "insert_data": insert_data(args.instance_id, args.database_id) + elif args.command == "batch_write": + batch_write(args.instance_id, args.database_id) elif args.command == "delete_data": delete_data(args.instance_id, args.database_id) elif args.command == "query_data": diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index 22b5b6f944..5c6fa0eacd 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -290,6 +290,13 @@ def test_insert_data(capsys, instance_id, sample_database): assert "Inserted data" in out +@pytest.mark.dependency(name="batch_write") +def test_batch_write(capsys, instance_id, sample_database): + snippets.batch_write(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + assert "Inserted data" in out + + @pytest.mark.dependency(depends=["insert_data"]) def test_delete_data(capsys, instance_id, sample_database): snippets.delete_data(instance_id, sample_database.database_id) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index c4ea2ded40..4d7b142e41 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -2508,6 +2508,44 @@ def test_partition_query(sessions_database, not_emulator): batch_txn.close() +def test_mutation_groups_insert_or_update_then_query(sessions_database): + sd = _sample_data + ROW_DATA = ( + (1, "Phred", "Phlyntstone", "phred@example.com"), + (2, "Bharney", "Rhubble", "bharney@example.com"), + (3, "Wylma", "Phlyntstone", "wylma@example.com"), + (4, "Pebbles", "Phlyntstone", "pebbles@example.com"), + (5, "Betty", "Rhubble", "betty@example.com"), + (6, "Slate", "Stephenson", "slate@example.com"), + ) + num_groups = 3 + num_mutations_per_group = len(ROW_DATA) // num_groups + + with sessions_database.mutation_groups() as groups: + for i in range(num_groups): + group = groups.group() + for j in range(num_mutations_per_group): + group.insert_or_update( + sd.TABLE, sd.COLUMNS, [ROW_DATA[i * num_mutations_per_group + j]] + ) + # Response indexes received + seen = collections.Counter() + for response in groups.batch_write(): + _check_batch_status(response.status.code) + assert response.commit_timestamp is not None + assert len(response.indexes) > 0 + seen.update(response.indexes) + # All indexes must be in the range [0, num_groups-1] and seen exactly once + assert len(seen) == num_groups + assert all((0 <= idx < num_groups and ct == 1) for (idx, ct) in seen.items()) + + # Verify the writes by reading from the database + with sessions_database.snapshot() as snapshot: + rows = list(snapshot.execute_sql(sd.SQL)) + + sd._check_rows_data(rows, ROW_DATA) + + class FauxCall: def __init__(self, code, details="FauxCall"): self._code = code diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 856816628f..203c8a0cb5 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -413,6 +413,130 @@ class _BailOut(Exception): self.assertEqual(len(batch._mutations), 1) +class TestMutationGroups(_BaseTest, OpenTelemetryBase): + def _getTargetClass(self): + from google.cloud.spanner_v1.batch import MutationGroups + + return MutationGroups + + def test_ctor(self): + session = _Session() + groups = self._make_one(session) + self.assertIs(groups._session, session) + + def test_batch_write_already_committed(self): + from google.cloud.spanner_v1.keyset import KeySet + + keys = [[0], [1], [2]] + keyset = KeySet(keys=keys) + database = _Database() + database.spanner_api = _FauxSpannerAPI(_batch_write_response=[]) + session = _Session(database) + groups = self._make_one(session) + group = groups.group() + group.delete(TABLE_NAME, keyset=keyset) + groups.batch_write() + self.assertSpanAttributes( + "CloudSpanner.BatchWrite", + status=StatusCode.OK, + attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), + ) + assert groups.committed + # The second call to batch_write should raise an error. + with self.assertRaises(ValueError): + groups.batch_write() + + def test_batch_write_grpc_error(self): + from google.api_core.exceptions import Unknown + from google.cloud.spanner_v1.keyset import KeySet + + keys = [[0], [1], [2]] + keyset = KeySet(keys=keys) + database = _Database() + database.spanner_api = _FauxSpannerAPI(_rpc_error=True) + session = _Session(database) + groups = self._make_one(session) + group = groups.group() + group.delete(TABLE_NAME, keyset=keyset) + + with self.assertRaises(Unknown): + groups.batch_write() + + self.assertSpanAttributes( + "CloudSpanner.BatchWrite", + status=StatusCode.ERROR, + attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), + ) + + def _test_batch_write_with_request_options(self, request_options=None): + import datetime + from google.cloud.spanner_v1 import BatchWriteResponse + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.rpc.status_pb2 import Status + + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + status_pb = Status(code=200) + response = BatchWriteResponse( + commit_timestamp=now_pb, indexes=[0], status=status_pb + ) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI(_batch_write_response=[response]) + session = _Session(database) + groups = self._make_one(session) + group = groups.group() + group.insert(TABLE_NAME, COLUMNS, VALUES) + + response_iter = groups.batch_write(request_options) + self.assertEqual(len(response_iter), 1) + self.assertEqual(response_iter[0], response) + + ( + session, + mutation_groups, + actual_request_options, + metadata, + ) = api._batch_request + self.assertEqual(session, self.SESSION_NAME) + self.assertEqual(mutation_groups, groups._mutation_groups) + self.assertEqual( + metadata, + [ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ], + ) + if request_options is None: + expected_request_options = RequestOptions() + elif type(request_options) is dict: + expected_request_options = RequestOptions(request_options) + else: + expected_request_options = request_options + self.assertEqual(actual_request_options, expected_request_options) + + self.assertSpanAttributes( + "CloudSpanner.BatchWrite", + status=StatusCode.OK, + attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), + ) + + def test_batch_write_no_request_options(self): + self._test_batch_write_with_request_options() + + def test_batch_write_w_transaction_tag_success(self): + self._test_batch_write_with_request_options( + RequestOptions(transaction_tag="tag-1-1") + ) + + def test_batch_write_w_transaction_tag_dictionary_success(self): + self._test_batch_write_with_request_options({"transaction_tag": "tag-1-1"}) + + def test_batch_write_w_incorrect_tag_dictionary_error(self): + with self.assertRaises(ValueError): + self._test_batch_write_with_request_options({"incorrect_tag": "tag-1-1"}) + + class _Session(object): def __init__(self, database=None, name=TestBatch.SESSION_NAME): self._database = database @@ -428,6 +552,7 @@ class _FauxSpannerAPI: _create_instance_conflict = False _instance_not_found = False _committed = None + _batch_request = None _rpc_error = False def __init__(self, **kwargs): @@ -451,3 +576,20 @@ def commit( if self._rpc_error: raise Unknown("error") return self._commit_response + + def batch_write( + self, + request=None, + metadata=None, + ): + from google.api_core.exceptions import Unknown + + self._batch_request = ( + request.session, + request.mutation_groups, + request.request_options, + metadata, + ) + if self._rpc_error: + raise Unknown("error") + return self._batch_write_response diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index bd368eed11..3ab5cf6d9e 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1231,6 +1231,20 @@ def test_batch(self): self.assertIsInstance(checkout, BatchCheckout) self.assertIs(checkout._database, database) + def test_mutation_groups(self): + from google.cloud.spanner_v1.database import MutationGroupsCheckout + + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + checkout = database.mutation_groups() + self.assertIsInstance(checkout, MutationGroupsCheckout) + self.assertIs(checkout._database, database) + def test_batch_snapshot(self): from google.cloud.spanner_v1.database import BatchSnapshot @@ -2679,6 +2693,163 @@ def test_process_w_query_batch(self): ) +class TestMutationGroupsCheckout(_BaseTest): + def _get_target_class(self): + from google.cloud.spanner_v1.database import MutationGroupsCheckout + + return MutationGroupsCheckout + + @staticmethod + def _make_spanner_client(): + from google.cloud.spanner_v1 import SpannerClient + + return mock.create_autospec(SpannerClient) + + def test_ctor(self): + from google.cloud.spanner_v1.batch import MutationGroups + + database = _Database(self.DATABASE_NAME) + pool = database._pool = _Pool() + session = _Session(database) + pool.put(session) + checkout = self._make_one(database) + self.assertIs(checkout._database, database) + + with checkout as groups: + self.assertIsNone(pool._session) + self.assertIsInstance(groups, MutationGroups) + self.assertIs(groups._session, session) + + self.assertIs(pool._session, session) + + def test_context_mgr_success(self): + import datetime + from google.cloud.spanner_v1 import BatchWriteRequest + from google.cloud.spanner_v1 import BatchWriteResponse + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.spanner_v1.batch import MutationGroups + from google.rpc.status_pb2 import Status + + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + status_pb = Status(code=200) + response = BatchWriteResponse( + commit_timestamp=now_pb, indexes=[0], status=status_pb + ) + database = _Database(self.DATABASE_NAME) + api = database.spanner_api = self._make_spanner_client() + api.batch_write.return_value = [response] + pool = database._pool = _Pool() + session = _Session(database) + pool.put(session) + checkout = self._make_one(database) + + request_options = RequestOptions(transaction_tag=self.TRANSACTION_TAG) + request = BatchWriteRequest( + session=self.SESSION_NAME, + mutation_groups=[], + request_options=request_options, + ) + with checkout as groups: + self.assertIsNone(pool._session) + self.assertIsInstance(groups, MutationGroups) + self.assertIs(groups._session, session) + groups.batch_write(request_options) + self.assertEqual(groups.committed, True) + + self.assertIs(pool._session, session) + + api.batch_write.assert_called_once_with( + request=request, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ], + ) + + def test_context_mgr_failure(self): + from google.cloud.spanner_v1.batch import MutationGroups + + database = _Database(self.DATABASE_NAME) + pool = database._pool = _Pool() + session = _Session(database) + pool.put(session) + checkout = self._make_one(database) + + class Testing(Exception): + pass + + with self.assertRaises(Testing): + with checkout as groups: + self.assertIsNone(pool._session) + self.assertIsInstance(groups, MutationGroups) + self.assertIs(groups._session, session) + raise Testing() + + self.assertIs(pool._session, session) + + def test_context_mgr_session_not_found_error(self): + from google.cloud.exceptions import NotFound + + database = _Database(self.DATABASE_NAME) + session = _Session(database, name="session-1") + session.exists = mock.MagicMock(return_value=False) + pool = database._pool = _Pool() + new_session = _Session(database, name="session-2") + new_session.create = mock.MagicMock(return_value=[]) + pool._new_session = mock.MagicMock(return_value=new_session) + + pool.put(session) + checkout = self._make_one(database) + + self.assertEqual(pool._session, session) + with self.assertRaises(NotFound): + with checkout as _: + raise NotFound("Session not found") + # Assert that session-1 was removed from pool and new session was added. + self.assertEqual(pool._session, new_session) + + def test_context_mgr_table_not_found_error(self): + from google.cloud.exceptions import NotFound + + database = _Database(self.DATABASE_NAME) + session = _Session(database, name="session-1") + session.exists = mock.MagicMock(return_value=True) + pool = database._pool = _Pool() + pool._new_session = mock.MagicMock(return_value=[]) + + pool.put(session) + checkout = self._make_one(database) + + self.assertEqual(pool._session, session) + with self.assertRaises(NotFound): + with checkout as _: + raise NotFound("Table not found") + # Assert that session-1 was not removed from pool. + self.assertEqual(pool._session, session) + pool._new_session.assert_not_called() + + def test_context_mgr_unknown_error(self): + database = _Database(self.DATABASE_NAME) + session = _Session(database) + pool = database._pool = _Pool() + pool._new_session = mock.MagicMock(return_value=[]) + pool.put(session) + checkout = self._make_one(database) + + class Testing(Exception): + pass + + self.assertEqual(pool._session, session) + with self.assertRaises(Testing): + with checkout as _: + raise Testing("Unknown error.") + # Assert that session-1 was not removed from pool. + self.assertEqual(pool._session, session) + pool._new_session.assert_not_called() + + def _make_instance_api(): from google.cloud.spanner_admin_instance_v1 import InstanceAdminClient From 96e6e15effe958f52b21e52ffa1bb5f67ea8b128 Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Thu, 2 Nov 2023 11:11:36 +0000 Subject: [PATCH 02/11] Update sample --- samples/samples/snippets.py | 23 +++++++++++++++++------ samples/samples/snippets_test.py | 2 +- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 8593f6393e..f7c403cfc4 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -403,13 +403,15 @@ def insert_data(instance_id, database_id): # [END spanner_insert_data] -# [START spanner_batch_write] +# [START spanner_batch_write_at_least_once] def batch_write(instance_id, database_id): """Inserts sample data into the given database via BatchWrite API. The database and table must already exist and can be created using `create_database`. """ + from google.rpc.code_pb2 import OK + spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) database = instance.database(database_id) @@ -429,7 +431,7 @@ def batch_write(instance_id, database_id): table="Singers", columns=("SingerId", "FirstName", "LastName"), values=[ - (17, "Marc", "Richards"), + (17, "Marc", ""), (18, "Catalina", "Smith"), ], ) @@ -443,12 +445,21 @@ def batch_write(instance_id, database_id): ) for response in groups.batch_write(): - print(response) - - print("Inserted data.") + if response.status.code == OK: + print( + "Mutation group indexes {} have been applied with commit timestamp {}".format( + response.indexes, response.commit_timestamp + ) + ) + else: + print( + "Mutation group indexes {} could not be applied with error {}".format( + response.indexes, response.status + ) + ) -# [END spanner_batch_write] +# [END spanner_batch_write_at_least_once] # [START spanner_delete_data] diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index 5c6fa0eacd..85999363bb 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -294,7 +294,7 @@ def test_insert_data(capsys, instance_id, sample_database): def test_batch_write(capsys, instance_id, sample_database): snippets.batch_write(instance_id, sample_database.database_id) out, _ = capsys.readouterr() - assert "Inserted data" in out + assert "could not be applied with error" not in out @pytest.mark.dependency(depends=["insert_data"]) From 770830200256e44b70b0447f38fd86a24d83e3b3 Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Thu, 16 Nov 2023 13:03:22 +0000 Subject: [PATCH 03/11] review comments --- google/cloud/spanner_v1/batch.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 252c378695..14aeeea02c 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -43,9 +43,9 @@ class _BatchBase(_SessionWrapper): transaction_tag = None _read_only = False - def __init__(self, session, mutations=None): + def __init__(self, session, mutations=[]): super(_BatchBase, self).__init__(session) - self._mutations = [] if mutations is None else mutations + self._mutations = mutations def _check_state(self): """Helper for :meth:`commit` et al. @@ -229,6 +229,15 @@ def __init__(self, session): super(MutationGroups, self).__init__(session) self._mutation_groups = [] + def _check_state(self): + """Checks if the object's state is valid for making API requests. + + :raises: :exc:`ValueError` if the object's state is invalid for making + API requests. + """ + if self.committed is not None: + raise ValueError("MutationGroups already committed") + def group(self): """Returns a new mutation_group to which mutations can be added.""" mutation_group = BatchWriteRequest.MutationGroup() @@ -248,8 +257,7 @@ def batch_write(self, request_options=None): :rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]` :returns: a sequence of responses for each batch. """ - if self.committed is not None: - raise ValueError("MutationGroups already committed") + self._check_state() database = self._session._database api = database.spanner_api From 38d9b8344e67ddab60e1a3018004a10bf77104a5 Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Fri, 17 Nov 2023 08:40:24 +0000 Subject: [PATCH 04/11] return public class for mutation groups --- google/cloud/spanner_v1/batch.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 14aeeea02c..8d7d2bc91e 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -37,7 +37,10 @@ class _BatchBase(_SessionWrapper): """Accumulate mutations for transmission during :meth:`commit`. :type session: :class:`~google.cloud.spanner_v1.session.Session` - :param session: the session used to perform the commit + :param session: The session used to perform the commit. + + :type mutations: list + :param mutations: The list into which mutations are to be accumulated. """ transaction_tag = None @@ -216,8 +219,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.commit() +class MutationGroup(_BatchBase): + """A container for mutations. + + Clients should use :class:`~google.cloud.spanner_v1.MutationGroups` to + obtain instances instead of directly creating instances. + + :type session: :class:`~google.cloud.spanner_v1.session.Session` + :param session: The session used to perform the commit. + + :type mutations: list + :param mutations: The list into which mutations are to be accumulated. + """ + + def __init__(self, session, mutations): + super(MutationGroup, self).__init__(session, mutations) + + class MutationGroups(_SessionWrapper): - """Accumulate mutations for transmission during :meth:`batch_write`. + """Accumulate mutation groups for transmission during :meth:`batch_write`. :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: the session used to perform the commit @@ -239,10 +259,10 @@ def _check_state(self): raise ValueError("MutationGroups already committed") def group(self): - """Returns a new mutation_group to which mutations can be added.""" + """Returns a new `MutationGroup` to which mutations can be added.""" mutation_group = BatchWriteRequest.MutationGroup() self._mutation_groups.append(mutation_group) - return _BatchBase(self._session, mutation_group.mutations) + return MutationGroup(self._session, mutation_group.mutations) def batch_write(self, request_options=None): """Executes batch_write. From 25e69990058389a271d7d168b95a6fc03daba509 Mon Sep 17 00:00:00 2001 From: Sunny Singh <126051413+sunnsing-google@users.noreply.github.com> Date: Thu, 30 Nov 2023 00:10:50 +0530 Subject: [PATCH 05/11] Update google/cloud/spanner_v1/batch.py Co-authored-by: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com> --- google/cloud/spanner_v1/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 8d7d2bc91e..cf4c52d654 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -48,7 +48,7 @@ class _BatchBase(_SessionWrapper): def __init__(self, session, mutations=[]): super(_BatchBase, self).__init__(session) - self._mutations = mutations + self._mutations = [] def _check_state(self): """Helper for :meth:`commit` et al. From 7d57ac217c95dc3a05b8f27db2a8ffedf2c43176 Mon Sep 17 00:00:00 2001 From: Sunny Singh <126051413+sunnsing-google@users.noreply.github.com> Date: Thu, 30 Nov 2023 00:10:59 +0530 Subject: [PATCH 06/11] Update google/cloud/spanner_v1/batch.py Co-authored-by: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com> --- google/cloud/spanner_v1/batch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index cf4c52d654..0625a6f8bd 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -232,8 +232,9 @@ class MutationGroup(_BatchBase): :param mutations: The list into which mutations are to be accumulated. """ - def __init__(self, session, mutations): - super(MutationGroup, self).__init__(session, mutations) + def __init__(self, session, mutations=[]): + super(MutationGroup, self).__init__(session) + self._mutations = mutations class MutationGroups(_SessionWrapper): From 0a2aba454d237dae8d8d0e5d22a488fb288c4c99 Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Thu, 30 Nov 2023 11:48:37 +0000 Subject: [PATCH 07/11] review comments --- google/cloud/spanner_v1/batch.py | 4 ++-- google/cloud/spanner_v1/database.py | 2 +- tests/system/test_session_api.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 0625a6f8bd..98428360c9 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -37,7 +37,7 @@ class _BatchBase(_SessionWrapper): """Accumulate mutations for transmission during :meth:`commit`. :type session: :class:`~google.cloud.spanner_v1.session.Session` - :param session: The session used to perform the commit. + :param session: the session used to perform the commit :type mutations: list :param mutations: The list into which mutations are to be accumulated. @@ -46,7 +46,7 @@ class _BatchBase(_SessionWrapper): transaction_tag = None _read_only = False - def __init__(self, session, mutations=[]): + def __init__(self, session): super(_BatchBase, self).__init__(session) self._mutations = [] diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 37865c9faa..758547cf86 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -1067,7 +1067,7 @@ class MutationGroupsCheckout(object): def __init__(self, database): self._database = database - self._session = self._mutation_groups = None + self._session = None def __enter__(self): """Begin ``with`` block.""" diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 4703845f2d..43fd5e133a 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -2521,7 +2521,7 @@ def test_partition_query(sessions_database, not_emulator): batch_txn.close() -def test_mutation_groups_insert_or_update_then_query(sessions_database): +def test_mutation_groups_insert_or_update_then_query(not_emulator, sessions_database): sd = _sample_data ROW_DATA = ( (1, "Phred", "Phlyntstone", "phred@example.com"), From e401452a1642e69d28cdb4d8afd0afe8d2a5d61a Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Thu, 30 Nov 2023 11:50:53 +0000 Subject: [PATCH 08/11] remove doc --- google/cloud/spanner_v1/batch.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 98428360c9..da74bf35f0 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -38,9 +38,6 @@ class _BatchBase(_SessionWrapper): :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: the session used to perform the commit - - :type mutations: list - :param mutations: The list into which mutations are to be accumulated. """ transaction_tag = None From dc94850c64aa967bbb6eabd8e27bfdb6897becde Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Thu, 30 Nov 2023 18:27:06 +0000 Subject: [PATCH 09/11] feat(spanner): nit sample data refactoring --- tests/system/_sample_data.py | 8 ++++++++ tests/system/test_session_api.py | 16 +++++----------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/system/_sample_data.py b/tests/system/_sample_data.py index 2398442aff..9c83f42224 100644 --- a/tests/system/_sample_data.py +++ b/tests/system/_sample_data.py @@ -27,6 +27,14 @@ (2, "Bharney", "Rhubble", "bharney@example.com"), (3, "Wylma", "Phlyntstone", "wylma@example.com"), ) +BATCH_WRITE_ROW_DATA = ( + (1, "Phred", "Phlyntstone", "phred@example.com"), + (2, "Bharney", "Rhubble", "bharney@example.com"), + (3, "Wylma", "Phlyntstone", "wylma@example.com"), + (4, "Pebbles", "Phlyntstone", "pebbles@example.com"), + (5, "Betty", "Rhubble", "betty@example.com"), + (6, "Slate", "Stephenson", "slate@example.com"), +) ALL = spanner_v1.KeySet(all_=True) SQL = "SELECT * FROM contacts ORDER BY contact_id" diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 43fd5e133a..51d2937b47 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -2523,23 +2523,17 @@ def test_partition_query(sessions_database, not_emulator): def test_mutation_groups_insert_or_update_then_query(not_emulator, sessions_database): sd = _sample_data - ROW_DATA = ( - (1, "Phred", "Phlyntstone", "phred@example.com"), - (2, "Bharney", "Rhubble", "bharney@example.com"), - (3, "Wylma", "Phlyntstone", "wylma@example.com"), - (4, "Pebbles", "Phlyntstone", "pebbles@example.com"), - (5, "Betty", "Rhubble", "betty@example.com"), - (6, "Slate", "Stephenson", "slate@example.com"), - ) num_groups = 3 - num_mutations_per_group = len(ROW_DATA) // num_groups + num_mutations_per_group = len(sd.BATCH_WRITE_ROW_DATA) // num_groups with sessions_database.mutation_groups() as groups: for i in range(num_groups): group = groups.group() for j in range(num_mutations_per_group): group.insert_or_update( - sd.TABLE, sd.COLUMNS, [ROW_DATA[i * num_mutations_per_group + j]] + sd.TABLE, + sd.COLUMNS, + [sd.BATCH_WRITE_ROW_DATA[i * num_mutations_per_group + j]], ) # Response indexes received seen = collections.Counter() @@ -2556,7 +2550,7 @@ def test_mutation_groups_insert_or_update_then_query(not_emulator, sessions_data with sessions_database.snapshot() as snapshot: rows = list(snapshot.execute_sql(sd.SQL)) - sd._check_rows_data(rows, ROW_DATA) + sd._check_rows_data(rows, sd.BATCH_WRITE_ROW_DATA) class FauxCall: From a0937cdb7856627c8bb2097f78c11acf3d35c9fc Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Fri, 1 Dec 2023 13:50:14 +0000 Subject: [PATCH 10/11] review comments --- tests/unit/test_database.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 3ab5cf6d9e..cac45a26ac 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -2724,8 +2724,10 @@ def test_ctor(self): def test_context_mgr_success(self): import datetime + from google.cloud.spanner_v1._helpers import _make_list_value_pbs from google.cloud.spanner_v1 import BatchWriteRequest from google.cloud.spanner_v1 import BatchWriteResponse + from google.cloud.spanner_v1 import Mutation from google.cloud._helpers import UTC from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud.spanner_v1.batch import MutationGroups @@ -2748,13 +2750,27 @@ def test_context_mgr_success(self): request_options = RequestOptions(transaction_tag=self.TRANSACTION_TAG) request = BatchWriteRequest( session=self.SESSION_NAME, - mutation_groups=[], + mutation_groups=[ + BatchWriteRequest.MutationGroup( + mutations=[ + Mutation( + insert=Mutation.Write( + table="table", + columns=["col"], + values=_make_list_value_pbs([["val"]]), + ) + ) + ] + ) + ], request_options=request_options, ) with checkout as groups: self.assertIsNone(pool._session) self.assertIsInstance(groups, MutationGroups) self.assertIs(groups._session, session) + group = groups.group() + group.insert("table", ["col"], [["val"]]) groups.batch_write(request_options) self.assertEqual(groups.committed, True) From a4cf000e985d7f64fa33bd9a31e0c8777aba8d47 Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Fri, 1 Dec 2023 19:24:20 +0000 Subject: [PATCH 11/11] fix test --- tests/system/test_session_api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 51d2937b47..30981322cc 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -2526,6 +2526,9 @@ def test_mutation_groups_insert_or_update_then_query(not_emulator, sessions_data num_groups = 3 num_mutations_per_group = len(sd.BATCH_WRITE_ROW_DATA) // num_groups + with sessions_database.batch() as batch: + batch.delete(sd.TABLE, sd.ALL) + with sessions_database.mutation_groups() as groups: for i in range(num_groups): group = groups.group()