Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Batch Write API implementation and samples #1027

Merged
merged 20 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions google/cloud/spanner_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from .types.result_set import ResultSetStats
from .types.spanner import BatchCreateSessionsRequest
from .types.spanner import BatchCreateSessionsResponse
from .types.spanner import BatchWriteRequest
from .types.spanner import BatchWriteResponse
from .types.spanner import BeginTransactionRequest
from .types.spanner import CommitRequest
from .types.spanner import CreateSessionRequest
Expand Down Expand Up @@ -99,6 +101,8 @@
# google.cloud.spanner_v1.types
"BatchCreateSessionsRequest",
"BatchCreateSessionsResponse",
"BatchWriteRequest",
"BatchWriteResponse",
"BeginTransactionRequest",
"CommitRequest",
"CommitResponse",
Expand Down
72 changes: 70 additions & 2 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import Mutation
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud.spanner_v1 import BatchWriteRequest

from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._helpers import _make_list_value_pbs
Expand All @@ -42,9 +43,9 @@ class _BatchBase(_SessionWrapper):
transaction_tag = None
_read_only = False

def __init__(self, session):
def __init__(self, session, mutations=None):
super(_BatchBase, self).__init__(session)
self._mutations = []
self._mutations = [] if mutations is None else mutations

def _check_state(self):
"""Helper for :meth:`commit` et al.
Expand Down Expand Up @@ -215,6 +216,73 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.commit()


class MutationGroups(_SessionWrapper):
"""Accumulate mutations for transmission during :meth:`batch_write`.

:type session: :class:`~google.cloud.spanner_v1.session.Session`
:param session: the session used to perform the commit
"""

committed = None

def __init__(self, session):
super(MutationGroups, self).__init__(session)
self._mutation_groups = []

def group(self):
"""Returns a new mutation_group to which mutations can be added."""
mutation_group = BatchWriteRequest.MutationGroup()
self._mutation_groups.append(mutation_group)
This conversation was marked as resolved.
Show resolved Hide resolved
return _BatchBase(self._session, mutation_group.mutations)
This conversation was marked as resolved.
Show resolved Hide resolved

def batch_write(self, request_options=None):
"""Executes batch_write.

:type request_options:
:class:`google.cloud.spanner_v1.types.RequestOptions`
:param request_options:
(Optional) Common options for this request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

:rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]`
:returns: a sequence of responses for each batch.
"""
if self.committed is not None:
This conversation was marked as resolved.
Show resolved Hide resolved
raise ValueError("MutationGroups already committed")

database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
trace_attributes = {"num_mutation_groups": len(self._mutation_groups)}
if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)

request = BatchWriteRequest(
session=self._session.name,
mutation_groups=self._mutation_groups,
request_options=request_options,
)
with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes):
method = functools.partial(
api.batch_write,
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.committed = True
return response

This conversation was marked as resolved.
Show resolved Hide resolved

def _make_write_pb(table, columns, values):
"""Helper for :meth:`Batch.insert` et al.

Expand Down
45 changes: 45 additions & 0 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.batch import MutationGroups
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.pool import BurstyPool
from google.cloud.spanner_v1.pool import SessionCheckout
Expand Down Expand Up @@ -734,6 +735,17 @@ def batch(self, request_options=None):
"""
return BatchCheckout(self, request_options)

def mutation_groups(self):
This conversation was marked as resolved.
Show resolved Hide resolved
"""Return an object which wraps a mutation_group.

The wrapper *must* be used as a context manager, with the mutation group
as the value returned by the wrapper.

:rtype: :class:`~google.cloud.spanner_v1.database.MutationGroupsCheckout`
:returns: new wrapper
"""
return MutationGroupsCheckout(self)

def batch_snapshot(self, read_timestamp=None, exact_staleness=None):
"""Return an object which wraps a batch read / query.

Expand Down Expand Up @@ -1040,6 +1052,39 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._database._pool.put(self._session)


class MutationGroupsCheckout(object):
"""Context manager for using mutation groups from a database.

Inside the context manager, checks out a session from the database,
creates mutation groups from it, making the groups available.

Caller must *not* use the object to perform API requests outside the scope
of the context manager.

:type database: :class:`~google.cloud.spanner_v1.database.Database`
:param database: database to use
"""

def __init__(self, database):
self._database = database
self._session = self._mutation_groups = None
This conversation was marked as resolved.
Show resolved Hide resolved

def __enter__(self):
"""Begin ``with`` block."""
session = self._session = self._database._pool.get()
return MutationGroups(session)

def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
This conversation was marked as resolved.
Show resolved Hide resolved
if isinstance(exc_val, NotFound):
# If NotFound exception occurs inside the with block
# then we validate if the session still exists.
This conversation was marked as resolved.
Show resolved Hide resolved
if not self._session.exists():
self._session = self._database._pool._new_session()
self._session.create()
self._database._pool.put(self._session)


class SnapshotCheckout(object):
"""Context manager for using a snapshot from a database.

Expand Down
62 changes: 62 additions & 0 deletions samples/samples/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,65 @@ def insert_data(instance_id, database_id):
# [END spanner_insert_data]


# [START spanner_batch_write_at_least_once]
def batch_write(instance_id, database_id):
"""Inserts sample data into the given database via BatchWrite API.

The database and table must already exist and can be created using
`create_database`.
"""
from google.rpc.code_pb2 import OK

spanner_client = spanner.Client()
instance = spanner_client.instance(instance_id)
database = instance.database(database_id)

with database.mutation_groups() as groups:
group1 = groups.group()
group1.insert_or_update(
table="Singers",
columns=("SingerId", "FirstName", "LastName"),
values=[
(16, "Scarlet", "Terry"),
],
)

group2 = groups.group()
group2.insert_or_update(
table="Singers",
columns=("SingerId", "FirstName", "LastName"),
values=[
(17, "Marc", ""),
(18, "Catalina", "Smith"),
],
)
group2.insert_or_update(
table="Albums",
columns=("SingerId", "AlbumId", "AlbumTitle"),
values=[
(17, 1, "Total Junk"),
(18, 2, "Go, Go, Go"),
],
)

for response in groups.batch_write():
if response.status.code == OK:
print(
"Mutation group indexes {} have been applied with commit timestamp {}".format(
response.indexes, response.commit_timestamp
)
)
else:
print(
"Mutation group indexes {} could not be applied with error {}".format(
response.indexes, response.status
)
)


# [END spanner_batch_write_at_least_once]


# [START spanner_delete_data]
def delete_data(instance_id, database_id):
"""Deletes sample data from the given database.
Expand Down Expand Up @@ -2677,6 +2736,7 @@ def drop_sequence(instance_id, database_id):
subparsers.add_parser("create_instance", help=create_instance.__doc__)
subparsers.add_parser("create_database", help=create_database.__doc__)
subparsers.add_parser("insert_data", help=insert_data.__doc__)
subparsers.add_parser("batch_write", help=batch_write.__doc__)
subparsers.add_parser("delete_data", help=delete_data.__doc__)
subparsers.add_parser("query_data", help=query_data.__doc__)
subparsers.add_parser("read_data", help=read_data.__doc__)
Expand Down Expand Up @@ -2811,6 +2871,8 @@ def drop_sequence(instance_id, database_id):
create_database(args.instance_id, args.database_id)
elif args.command == "insert_data":
insert_data(args.instance_id, args.database_id)
elif args.command == "batch_write":
batch_write(args.instance_id, args.database_id)
elif args.command == "delete_data":
delete_data(args.instance_id, args.database_id)
elif args.command == "query_data":
Expand Down
7 changes: 7 additions & 0 deletions samples/samples/snippets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ def test_insert_data(capsys, instance_id, sample_database):
assert "Inserted data" in out


@pytest.mark.dependency(name="batch_write")
def test_batch_write(capsys, instance_id, sample_database):
snippets.batch_write(instance_id, sample_database.database_id)
out, _ = capsys.readouterr()
assert "could not be applied with error" not in out


@pytest.mark.dependency(depends=["insert_data"])
def test_delete_data(capsys, instance_id, sample_database):
snippets.delete_data(instance_id, sample_database.database_id)
Expand Down
38 changes: 38 additions & 0 deletions tests/system/test_session_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2521,6 +2521,44 @@ def test_partition_query(sessions_database, not_emulator):
batch_txn.close()


def test_mutation_groups_insert_or_update_then_query(sessions_database):
This conversation was marked as resolved.
Show resolved Hide resolved
sd = _sample_data
ROW_DATA = (
(1, "Phred", "Phlyntstone", "[email protected]"),
(2, "Bharney", "Rhubble", "[email protected]"),
(3, "Wylma", "Phlyntstone", "[email protected]"),
(4, "Pebbles", "Phlyntstone", "[email protected]"),
(5, "Betty", "Rhubble", "[email protected]"),
(6, "Slate", "Stephenson", "[email protected]"),
)
num_groups = 3
num_mutations_per_group = len(ROW_DATA) // num_groups

with sessions_database.mutation_groups() as groups:
for i in range(num_groups):
group = groups.group()
for j in range(num_mutations_per_group):
group.insert_or_update(
sd.TABLE, sd.COLUMNS, [ROW_DATA[i * num_mutations_per_group + j]]
)
# Response indexes received
seen = collections.Counter()
for response in groups.batch_write():
_check_batch_status(response.status.code)
assert response.commit_timestamp is not None
assert len(response.indexes) > 0
seen.update(response.indexes)
# All indexes must be in the range [0, num_groups-1] and seen exactly once
assert len(seen) == num_groups
assert all((0 <= idx < num_groups and ct == 1) for (idx, ct) in seen.items())

# Verify the writes by reading from the database
with sessions_database.snapshot() as snapshot:
rows = list(snapshot.execute_sql(sd.SQL))

sd._check_rows_data(rows, ROW_DATA)


class FauxCall:
def __init__(self, code, details="FauxCall"):
self._code = code
Expand Down
Loading
Loading