diff --git a/google/cloud/spanner_v1/gapic_metadata.json b/google/cloud/spanner_v1/gapic_metadata.json index ea51736a55..f5957c633a 100644 --- a/google/cloud/spanner_v1/gapic_metadata.json +++ b/google/cloud/spanner_v1/gapic_metadata.json @@ -15,6 +15,11 @@ "batch_create_sessions" ] }, + "BatchWrite": { + "methods": [ + "batch_write" + ] + }, "BeginTransaction": { "methods": [ "begin_transaction" @@ -95,6 +100,11 @@ "batch_create_sessions" ] }, + "BatchWrite": { + "methods": [ + "batch_write" + ] + }, "BeginTransaction": { "methods": [ "begin_transaction" @@ -175,6 +185,11 @@ "batch_create_sessions" ] }, + "BatchWrite": { + "methods": [ + "batch_write" + ] + }, "BeginTransaction": { "methods": [ "begin_transaction" diff --git a/google/cloud/spanner_v1/services/spanner/async_client.py b/google/cloud/spanner_v1/services/spanner/async_client.py index 977970ce7e..7c2e950793 100644 --- a/google/cloud/spanner_v1/services/spanner/async_client.py +++ b/google/cloud/spanner_v1/services/spanner/async_client.py @@ -1973,6 +1973,143 @@ async def sample_partition_read(): # Done; return the response. return response + def batch_write( + self, + request: Optional[Union[spanner.BatchWriteRequest, dict]] = None, + *, + session: Optional[str] = None, + mutation_groups: Optional[ + MutableSequence[spanner.BatchWriteRequest.MutationGroup] + ] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Awaitable[AsyncIterable[spanner.BatchWriteResponse]]: + r"""Batches the supplied mutation groups in a collection + of efficient transactions. All mutations in a group are + committed atomically. However, mutations across groups + can be committed non-atomically in an unspecified order + and thus, they must be independent of each other. + Partial failure is possible, i.e., some groups may have + been committed successfully, while some may have failed. + The results of individual batches are streamed into the + response as the batches are applied. + + BatchWrite requests are not replay protected, meaning + that each mutation group may be applied more than once. + Replays of non-idempotent mutations may have undesirable + effects. For example, replays of an insert mutation may + produce an already exists error or if you use generated + or commit timestamp-based keys, it may result in + additional rows being added to the mutation's table. We + recommend structuring your mutation groups to be + idempotent to avoid this issue. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import spanner_v1 + + async def sample_batch_write(): + # Create a client + client = spanner_v1.SpannerAsyncClient() + + # Initialize request argument(s) + mutation_groups = spanner_v1.MutationGroup() + mutation_groups.mutations.insert.table = "table_value" + + request = spanner_v1.BatchWriteRequest( + session="session_value", + mutation_groups=mutation_groups, + ) + + # Make the request + stream = await client.batch_write(request=request) + + # Handle the response + async for response in stream: + print(response) + + Args: + request (Optional[Union[google.cloud.spanner_v1.types.BatchWriteRequest, dict]]): + The request object. The request for + [BatchWrite][google.spanner.v1.Spanner.BatchWrite]. + session (:class:`str`): + Required. The session in which the + batch request is to be run. + + This corresponds to the ``session`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + mutation_groups (:class:`MutableSequence[google.cloud.spanner_v1.types.BatchWriteRequest.MutationGroup]`): + Required. The groups of mutations to + be applied. + + This corresponds to the ``mutation_groups`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + AsyncIterable[google.cloud.spanner_v1.types.BatchWriteResponse]: + The result of applying a batch of + mutations. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([session, mutation_groups]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = spanner.BatchWriteRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if session is not None: + request.session = session + if mutation_groups: + request.mutation_groups.extend(mutation_groups) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.batch_write, + default_timeout=3600.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("session", request.session),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + async def __aenter__(self) -> "SpannerAsyncClient": return self diff --git a/google/cloud/spanner_v1/services/spanner/client.py b/google/cloud/spanner_v1/services/spanner/client.py index 59dc4f222c..03907a1b0b 100644 --- a/google/cloud/spanner_v1/services/spanner/client.py +++ b/google/cloud/spanner_v1/services/spanner/client.py @@ -2119,6 +2119,143 @@ def sample_partition_read(): # Done; return the response. return response + def batch_write( + self, + request: Optional[Union[spanner.BatchWriteRequest, dict]] = None, + *, + session: Optional[str] = None, + mutation_groups: Optional[ + MutableSequence[spanner.BatchWriteRequest.MutationGroup] + ] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Iterable[spanner.BatchWriteResponse]: + r"""Batches the supplied mutation groups in a collection + of efficient transactions. All mutations in a group are + committed atomically. However, mutations across groups + can be committed non-atomically in an unspecified order + and thus, they must be independent of each other. + Partial failure is possible, i.e., some groups may have + been committed successfully, while some may have failed. + The results of individual batches are streamed into the + response as the batches are applied. + + BatchWrite requests are not replay protected, meaning + that each mutation group may be applied more than once. + Replays of non-idempotent mutations may have undesirable + effects. For example, replays of an insert mutation may + produce an already exists error or if you use generated + or commit timestamp-based keys, it may result in + additional rows being added to the mutation's table. We + recommend structuring your mutation groups to be + idempotent to avoid this issue. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import spanner_v1 + + def sample_batch_write(): + # Create a client + client = spanner_v1.SpannerClient() + + # Initialize request argument(s) + mutation_groups = spanner_v1.MutationGroup() + mutation_groups.mutations.insert.table = "table_value" + + request = spanner_v1.BatchWriteRequest( + session="session_value", + mutation_groups=mutation_groups, + ) + + # Make the request + stream = client.batch_write(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.cloud.spanner_v1.types.BatchWriteRequest, dict]): + The request object. The request for + [BatchWrite][google.spanner.v1.Spanner.BatchWrite]. + session (str): + Required. The session in which the + batch request is to be run. + + This corresponds to the ``session`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + mutation_groups (MutableSequence[google.cloud.spanner_v1.types.BatchWriteRequest.MutationGroup]): + Required. The groups of mutations to + be applied. + + This corresponds to the ``mutation_groups`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]: + The result of applying a batch of + mutations. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([session, mutation_groups]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a spanner.BatchWriteRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, spanner.BatchWriteRequest): + request = spanner.BatchWriteRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if session is not None: + request.session = session + if mutation_groups is not None: + request.mutation_groups = mutation_groups + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_write] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("session", request.session),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def __enter__(self) -> "SpannerClient": return self diff --git a/google/cloud/spanner_v1/services/spanner/transports/base.py b/google/cloud/spanner_v1/services/spanner/transports/base.py index 668191c5f2..27006d8fbc 100644 --- a/google/cloud/spanner_v1/services/spanner/transports/base.py +++ b/google/cloud/spanner_v1/services/spanner/transports/base.py @@ -322,6 +322,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=30.0, client_info=client_info, ), + self.batch_write: gapic_v1.method.wrap_method( + self.batch_write, + default_timeout=3600.0, + client_info=client_info, + ), } def close(self): @@ -473,6 +478,15 @@ def partition_read( ]: raise NotImplementedError() + @property + def batch_write( + self, + ) -> Callable[ + [spanner.BatchWriteRequest], + Union[spanner.BatchWriteResponse, Awaitable[spanner.BatchWriteResponse]], + ]: + raise NotImplementedError() + @property def kind(self) -> str: raise NotImplementedError() diff --git a/google/cloud/spanner_v1/services/spanner/transports/grpc.py b/google/cloud/spanner_v1/services/spanner/transports/grpc.py index 7236f0ed27..86d9ba4133 100644 --- a/google/cloud/spanner_v1/services/spanner/transports/grpc.py +++ b/google/cloud/spanner_v1/services/spanner/transports/grpc.py @@ -755,6 +755,50 @@ def partition_read( ) return self._stubs["partition_read"] + @property + def batch_write( + self, + ) -> Callable[[spanner.BatchWriteRequest], spanner.BatchWriteResponse]: + r"""Return a callable for the batch write method over gRPC. + + Batches the supplied mutation groups in a collection + of efficient transactions. All mutations in a group are + committed atomically. However, mutations across groups + can be committed non-atomically in an unspecified order + and thus, they must be independent of each other. + Partial failure is possible, i.e., some groups may have + been committed successfully, while some may have failed. + The results of individual batches are streamed into the + response as the batches are applied. + + BatchWrite requests are not replay protected, meaning + that each mutation group may be applied more than once. + Replays of non-idempotent mutations may have undesirable + effects. For example, replays of an insert mutation may + produce an already exists error or if you use generated + or commit timestamp-based keys, it may result in + additional rows being added to the mutation's table. We + recommend structuring your mutation groups to be + idempotent to avoid this issue. + + Returns: + Callable[[~.BatchWriteRequest], + ~.BatchWriteResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_write" not in self._stubs: + self._stubs["batch_write"] = self.grpc_channel.unary_stream( + "/google.spanner.v1.Spanner/BatchWrite", + request_serializer=spanner.BatchWriteRequest.serialize, + response_deserializer=spanner.BatchWriteResponse.deserialize, + ) + return self._stubs["batch_write"] + def close(self): self.grpc_channel.close() diff --git a/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py b/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py index 62a975c319..d0755e3a67 100644 --- a/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py +++ b/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py @@ -771,6 +771,50 @@ def partition_read( ) return self._stubs["partition_read"] + @property + def batch_write( + self, + ) -> Callable[[spanner.BatchWriteRequest], Awaitable[spanner.BatchWriteResponse]]: + r"""Return a callable for the batch write method over gRPC. + + Batches the supplied mutation groups in a collection + of efficient transactions. All mutations in a group are + committed atomically. However, mutations across groups + can be committed non-atomically in an unspecified order + and thus, they must be independent of each other. + Partial failure is possible, i.e., some groups may have + been committed successfully, while some may have failed. + The results of individual batches are streamed into the + response as the batches are applied. + + BatchWrite requests are not replay protected, meaning + that each mutation group may be applied more than once. + Replays of non-idempotent mutations may have undesirable + effects. For example, replays of an insert mutation may + produce an already exists error or if you use generated + or commit timestamp-based keys, it may result in + additional rows being added to the mutation's table. We + recommend structuring your mutation groups to be + idempotent to avoid this issue. + + Returns: + Callable[[~.BatchWriteRequest], + Awaitable[~.BatchWriteResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_write" not in self._stubs: + self._stubs["batch_write"] = self.grpc_channel.unary_stream( + "/google.spanner.v1.Spanner/BatchWrite", + request_serializer=spanner.BatchWriteRequest.serialize, + response_deserializer=spanner.BatchWriteResponse.deserialize, + ) + return self._stubs["batch_write"] + def close(self): return self.grpc_channel.close() diff --git a/google/cloud/spanner_v1/services/spanner/transports/rest.py b/google/cloud/spanner_v1/services/spanner/transports/rest.py index d7157886a5..5e32bfaf2a 100644 --- a/google/cloud/spanner_v1/services/spanner/transports/rest.py +++ b/google/cloud/spanner_v1/services/spanner/transports/rest.py @@ -78,6 +78,14 @@ def post_batch_create_sessions(self, response): logging.log(f"Received response: {response}") return response + def pre_batch_write(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_batch_write(self, response): + logging.log(f"Received response: {response}") + return response + def pre_begin_transaction(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -211,6 +219,27 @@ def post_batch_create_sessions( """ return response + def pre_batch_write( + self, request: spanner.BatchWriteRequest, metadata: Sequence[Tuple[str, str]] + ) -> Tuple[spanner.BatchWriteRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for batch_write + + Override in a subclass to manipulate the request or metadata + before they are sent to the Spanner server. + """ + return request, metadata + + def post_batch_write( + self, response: rest_streaming.ResponseIterator + ) -> rest_streaming.ResponseIterator: + """Post-rpc interceptor for batch_write + + Override in a subclass to manipulate the response + after it is returned by the Spanner server but before + it is returned to user code. + """ + return response + def pre_begin_transaction( self, request: spanner.BeginTransactionRequest, @@ -681,6 +710,101 @@ def __call__( resp = self._interceptor.post_batch_create_sessions(resp) return resp + class _BatchWrite(SpannerRestStub): + def __hash__(self): + return hash("BatchWrite") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: spanner.BatchWriteRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> rest_streaming.ResponseIterator: + r"""Call the batch write method over HTTP. + + Args: + request (~.spanner.BatchWriteRequest): + The request object. The request for + [BatchWrite][google.spanner.v1.Spanner.BatchWrite]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.spanner.BatchWriteResponse: + The result of applying a batch of + mutations. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{session=projects/*/instances/*/databases/*/sessions/*}:batchWrite", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_batch_write(request, metadata) + pb_request = spanner.BatchWriteRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = rest_streaming.ResponseIterator(response, spanner.BatchWriteResponse) + resp = self._interceptor.post_batch_write(resp) + return resp + class _BeginTransaction(SpannerRestStub): def __hash__(self): return hash("BeginTransaction") @@ -2056,6 +2180,14 @@ def batch_create_sessions( # In C++ this would require a dynamic_cast return self._BatchCreateSessions(self._session, self._host, self._interceptor) # type: ignore + @property + def batch_write( + self, + ) -> Callable[[spanner.BatchWriteRequest], spanner.BatchWriteResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._BatchWrite(self._session, self._host, self._interceptor) # type: ignore + @property def begin_transaction( self, diff --git a/google/cloud/spanner_v1/types/__init__.py b/google/cloud/spanner_v1/types/__init__.py index df0960d9d9..f4f619f6c4 100644 --- a/google/cloud/spanner_v1/types/__init__.py +++ b/google/cloud/spanner_v1/types/__init__.py @@ -36,6 +36,8 @@ from .spanner import ( BatchCreateSessionsRequest, BatchCreateSessionsResponse, + BatchWriteRequest, + BatchWriteResponse, BeginTransactionRequest, CommitRequest, CreateSessionRequest, @@ -81,6 +83,8 @@ "ResultSetStats", "BatchCreateSessionsRequest", "BatchCreateSessionsResponse", + "BatchWriteRequest", + "BatchWriteResponse", "BeginTransactionRequest", "CommitRequest", "CreateSessionRequest", diff --git a/google/cloud/spanner_v1/types/spanner.py b/google/cloud/spanner_v1/types/spanner.py index 310cf8e31f..dfd83ac165 100644 --- a/google/cloud/spanner_v1/types/spanner.py +++ b/google/cloud/spanner_v1/types/spanner.py @@ -53,6 +53,8 @@ "BeginTransactionRequest", "CommitRequest", "RollbackRequest", + "BatchWriteRequest", + "BatchWriteResponse", }, ) @@ -1329,4 +1331,83 @@ class RollbackRequest(proto.Message): ) +class BatchWriteRequest(proto.Message): + r"""The request for [BatchWrite][google.spanner.v1.Spanner.BatchWrite]. + + Attributes: + session (str): + Required. The session in which the batch + request is to be run. + request_options (google.cloud.spanner_v1.types.RequestOptions): + Common options for this request. + mutation_groups (MutableSequence[google.cloud.spanner_v1.types.BatchWriteRequest.MutationGroup]): + Required. The groups of mutations to be + applied. + """ + + class MutationGroup(proto.Message): + r"""A group of mutations to be committed together. Related + mutations should be placed in a group. For example, two + mutations inserting rows with the same primary key prefix in + both parent and child tables are related. + + Attributes: + mutations (MutableSequence[google.cloud.spanner_v1.types.Mutation]): + Required. The mutations in this group. + """ + + mutations: MutableSequence[mutation.Mutation] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=mutation.Mutation, + ) + + session: str = proto.Field( + proto.STRING, + number=1, + ) + request_options: "RequestOptions" = proto.Field( + proto.MESSAGE, + number=3, + message="RequestOptions", + ) + mutation_groups: MutableSequence[MutationGroup] = proto.RepeatedField( + proto.MESSAGE, + number=4, + message=MutationGroup, + ) + + +class BatchWriteResponse(proto.Message): + r"""The result of applying a batch of mutations. + + Attributes: + indexes (MutableSequence[int]): + The mutation groups applied in this batch. The values index + into the ``mutation_groups`` field in the corresponding + ``BatchWriteRequest``. + status (google.rpc.status_pb2.Status): + An ``OK`` status indicates success. Any other status + indicates a failure. + commit_timestamp (google.protobuf.timestamp_pb2.Timestamp): + The commit timestamp of the transaction that applied this + batch. Present if ``status`` is ``OK``, absent otherwise. + """ + + indexes: MutableSequence[int] = proto.RepeatedField( + proto.INT32, + number=1, + ) + status: status_pb2.Status = proto.Field( + proto.MESSAGE, + number=2, + message=status_pb2.Status, + ) + commit_timestamp: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/samples/generated_samples/snippet_metadata_google.spanner.v1.json b/samples/generated_samples/snippet_metadata_google.spanner.v1.json index a8e8be3ae3..4384d19e2a 100644 --- a/samples/generated_samples/snippet_metadata_google.spanner.v1.json +++ b/samples/generated_samples/snippet_metadata_google.spanner.v1.json @@ -180,6 +180,175 @@ ], "title": "spanner_v1_generated_spanner_batch_create_sessions_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.spanner_v1.SpannerAsyncClient", + "shortName": "SpannerAsyncClient" + }, + "fullName": "google.cloud.spanner_v1.SpannerAsyncClient.batch_write", + "method": { + "fullName": "google.spanner.v1.Spanner.BatchWrite", + "service": { + "fullName": "google.spanner.v1.Spanner", + "shortName": "Spanner" + }, + "shortName": "BatchWrite" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.spanner_v1.types.BatchWriteRequest" + }, + { + "name": "session", + "type": "str" + }, + { + "name": "mutation_groups", + "type": "MutableSequence[google.cloud.spanner_v1.types.BatchWriteRequest.MutationGroup]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]", + "shortName": "batch_write" + }, + "description": "Sample for BatchWrite", + "file": "spanner_v1_generated_spanner_batch_write_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "spanner_v1_generated_Spanner_BatchWrite_async", + "segments": [ + { + "end": 56, + "start": 27, + "type": "FULL" + }, + { + "end": 56, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 57, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "spanner_v1_generated_spanner_batch_write_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.spanner_v1.SpannerClient", + "shortName": "SpannerClient" + }, + "fullName": "google.cloud.spanner_v1.SpannerClient.batch_write", + "method": { + "fullName": "google.spanner.v1.Spanner.BatchWrite", + "service": { + "fullName": "google.spanner.v1.Spanner", + "shortName": "Spanner" + }, + "shortName": "BatchWrite" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.spanner_v1.types.BatchWriteRequest" + }, + { + "name": "session", + "type": "str" + }, + { + "name": "mutation_groups", + "type": "MutableSequence[google.cloud.spanner_v1.types.BatchWriteRequest.MutationGroup]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]", + "shortName": "batch_write" + }, + "description": "Sample for BatchWrite", + "file": "spanner_v1_generated_spanner_batch_write_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "spanner_v1_generated_Spanner_BatchWrite_sync", + "segments": [ + { + "end": 56, + "start": 27, + "type": "FULL" + }, + { + "end": 56, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 57, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "spanner_v1_generated_spanner_batch_write_sync.py" + }, { "canonical": true, "clientMethod": { diff --git a/samples/generated_samples/spanner_v1_generated_spanner_batch_write_async.py b/samples/generated_samples/spanner_v1_generated_spanner_batch_write_async.py new file mode 100644 index 0000000000..39352562b1 --- /dev/null +++ b/samples/generated_samples/spanner_v1_generated_spanner_batch_write_async.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchWrite +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-spanner + + +# [START spanner_v1_generated_Spanner_BatchWrite_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import spanner_v1 + + +async def sample_batch_write(): + # Create a client + client = spanner_v1.SpannerAsyncClient() + + # Initialize request argument(s) + mutation_groups = spanner_v1.MutationGroup() + mutation_groups.mutations.insert.table = "table_value" + + request = spanner_v1.BatchWriteRequest( + session="session_value", + mutation_groups=mutation_groups, + ) + + # Make the request + stream = await client.batch_write(request=request) + + # Handle the response + async for response in stream: + print(response) + +# [END spanner_v1_generated_Spanner_BatchWrite_async] diff --git a/samples/generated_samples/spanner_v1_generated_spanner_batch_write_sync.py b/samples/generated_samples/spanner_v1_generated_spanner_batch_write_sync.py new file mode 100644 index 0000000000..4ee88b0cd6 --- /dev/null +++ b/samples/generated_samples/spanner_v1_generated_spanner_batch_write_sync.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchWrite +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-spanner + + +# [START spanner_v1_generated_Spanner_BatchWrite_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import spanner_v1 + + +def sample_batch_write(): + # Create a client + client = spanner_v1.SpannerClient() + + # Initialize request argument(s) + mutation_groups = spanner_v1.MutationGroup() + mutation_groups.mutations.insert.table = "table_value" + + request = spanner_v1.BatchWriteRequest( + session="session_value", + mutation_groups=mutation_groups, + ) + + # Make the request + stream = client.batch_write(request=request) + + # Handle the response + for response in stream: + print(response) + +# [END spanner_v1_generated_Spanner_BatchWrite_sync] diff --git a/scripts/fixup_spanner_v1_keywords.py b/scripts/fixup_spanner_v1_keywords.py index df4d3501f2..b1ba4084df 100644 --- a/scripts/fixup_spanner_v1_keywords.py +++ b/scripts/fixup_spanner_v1_keywords.py @@ -40,6 +40,7 @@ class spannerCallTransformer(cst.CSTTransformer): CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { 'batch_create_sessions': ('database', 'session_count', 'session_template', ), + 'batch_write': ('session', 'mutation_groups', 'request_options', ), 'begin_transaction': ('session', 'options', 'request_options', ), 'commit': ('session', 'transaction_id', 'single_use_transaction', 'mutations', 'return_commit_stats', 'request_options', ), 'create_session': ('database', 'session', ), diff --git a/tests/unit/gapic/spanner_v1/test_spanner.py b/tests/unit/gapic/spanner_v1/test_spanner.py index 8bf8407724..7f593f1953 100644 --- a/tests/unit/gapic/spanner_v1/test_spanner.py +++ b/tests/unit/gapic/spanner_v1/test_spanner.py @@ -3857,6 +3857,292 @@ async def test_partition_read_field_headers_async(): ) in kw["metadata"] +@pytest.mark.parametrize( + "request_type", + [ + spanner.BatchWriteRequest, + dict, + ], +) +def test_batch_write(request_type, transport: str = "grpc"): + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_write), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = iter([spanner.BatchWriteResponse()]) + response = client.batch_write(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == spanner.BatchWriteRequest() + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, spanner.BatchWriteResponse) + + +def test_batch_write_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_write), "__call__") as call: + client.batch_write() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == spanner.BatchWriteRequest() + + +@pytest.mark.asyncio +async def test_batch_write_async( + transport: str = "grpc_asyncio", request_type=spanner.BatchWriteRequest +): + client = SpannerAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_write), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[spanner.BatchWriteResponse()] + ) + response = await client.batch_write(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == spanner.BatchWriteRequest() + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, spanner.BatchWriteResponse) + + +@pytest.mark.asyncio +async def test_batch_write_async_from_dict(): + await test_batch_write_async(request_type=dict) + + +def test_batch_write_field_headers(): + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = spanner.BatchWriteRequest() + + request.session = "session_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_write), "__call__") as call: + call.return_value = iter([spanner.BatchWriteResponse()]) + client.batch_write(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "session=session_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_batch_write_field_headers_async(): + client = SpannerAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = spanner.BatchWriteRequest() + + request.session = "session_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_write), "__call__") as call: + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[spanner.BatchWriteResponse()] + ) + await client.batch_write(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "session=session_value", + ) in kw["metadata"] + + +def test_batch_write_flattened(): + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_write), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = iter([spanner.BatchWriteResponse()]) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.batch_write( + session="session_value", + mutation_groups=[ + spanner.BatchWriteRequest.MutationGroup( + mutations=[ + mutation.Mutation( + insert=mutation.Mutation.Write(table="table_value") + ) + ] + ) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].session + mock_val = "session_value" + assert arg == mock_val + arg = args[0].mutation_groups + mock_val = [ + spanner.BatchWriteRequest.MutationGroup( + mutations=[ + mutation.Mutation( + insert=mutation.Mutation.Write(table="table_value") + ) + ] + ) + ] + assert arg == mock_val + + +def test_batch_write_flattened_error(): + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_write( + spanner.BatchWriteRequest(), + session="session_value", + mutation_groups=[ + spanner.BatchWriteRequest.MutationGroup( + mutations=[ + mutation.Mutation( + insert=mutation.Mutation.Write(table="table_value") + ) + ] + ) + ], + ) + + +@pytest.mark.asyncio +async def test_batch_write_flattened_async(): + client = SpannerAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_write), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = iter([spanner.BatchWriteResponse()]) + + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.batch_write( + session="session_value", + mutation_groups=[ + spanner.BatchWriteRequest.MutationGroup( + mutations=[ + mutation.Mutation( + insert=mutation.Mutation.Write(table="table_value") + ) + ] + ) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].session + mock_val = "session_value" + assert arg == mock_val + arg = args[0].mutation_groups + mock_val = [ + spanner.BatchWriteRequest.MutationGroup( + mutations=[ + mutation.Mutation( + insert=mutation.Mutation.Write(table="table_value") + ) + ] + ) + ] + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_batch_write_flattened_error_async(): + client = SpannerAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.batch_write( + spanner.BatchWriteRequest(), + session="session_value", + mutation_groups=[ + spanner.BatchWriteRequest.MutationGroup( + mutations=[ + mutation.Mutation( + insert=mutation.Mutation.Write(table="table_value") + ) + ] + ) + ], + ) + + @pytest.mark.parametrize( "request_type", [ @@ -7695,6 +7981,315 @@ def test_partition_read_rest_error(): ) +@pytest.mark.parametrize( + "request_type", + [ + spanner.BatchWriteRequest, + dict, + ], +) +def test_batch_write_rest(request_type): + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "session": "projects/sample1/instances/sample2/databases/sample3/sessions/sample4" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner.BatchWriteResponse( + indexes=[752], + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = spanner.BatchWriteResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + json_return_value = "[{}]".format(json_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + with mock.patch.object(response_value, "iter_content") as iter_content: + iter_content.return_value = iter(json_return_value) + response = client.batch_write(request) + + assert isinstance(response, Iterable) + response = next(response) + + # Establish that the response is the type that we expect. + assert isinstance(response, spanner.BatchWriteResponse) + assert response.indexes == [752] + + +def test_batch_write_rest_required_fields(request_type=spanner.BatchWriteRequest): + transport_class = transports.SpannerRestTransport + + request_init = {} + request_init["session"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_write._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["session"] = "session_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_write._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "session" in jsonified_request + assert jsonified_request["session"] == "session_value" + + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = spanner.BatchWriteResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = spanner.BatchWriteResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + json_return_value = "[{}]".format(json_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + with mock.patch.object(response_value, "iter_content") as iter_content: + iter_content.return_value = iter(json_return_value) + response = client.batch_write(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_batch_write_rest_unset_required_fields(): + transport = transports.SpannerRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.batch_write._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "session", + "mutationGroups", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_batch_write_rest_interceptors(null_interceptor): + transport = transports.SpannerRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.SpannerRestInterceptor(), + ) + client = SpannerClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.SpannerRestInterceptor, "post_batch_write" + ) as post, mock.patch.object( + transports.SpannerRestInterceptor, "pre_batch_write" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = spanner.BatchWriteRequest.pb(spanner.BatchWriteRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = spanner.BatchWriteResponse.to_json( + spanner.BatchWriteResponse() + ) + req.return_value._content = "[{}]".format(req.return_value._content) + + request = spanner.BatchWriteRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = spanner.BatchWriteResponse() + + client.batch_write( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_batch_write_rest_bad_request( + transport: str = "rest", request_type=spanner.BatchWriteRequest +): + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "session": "projects/sample1/instances/sample2/databases/sample3/sessions/sample4" + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.batch_write(request) + + +def test_batch_write_rest_flattened(): + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner.BatchWriteResponse() + + # get arguments that satisfy an http rule for this method + sample_request = { + "session": "projects/sample1/instances/sample2/databases/sample3/sessions/sample4" + } + + # get truthy value for each flattened field + mock_args = dict( + session="session_value", + mutation_groups=[ + spanner.BatchWriteRequest.MutationGroup( + mutations=[ + mutation.Mutation( + insert=mutation.Mutation.Write(table="table_value") + ) + ] + ) + ], + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = spanner.BatchWriteResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + json_return_value = "[{}]".format(json_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + with mock.patch.object(response_value, "iter_content") as iter_content: + iter_content.return_value = iter(json_return_value) + client.batch_write(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{session=projects/*/instances/*/databases/*/sessions/*}:batchWrite" + % client.transport._host, + args[1], + ) + + +def test_batch_write_rest_flattened_error(transport: str = "rest"): + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_write( + spanner.BatchWriteRequest(), + session="session_value", + mutation_groups=[ + spanner.BatchWriteRequest.MutationGroup( + mutations=[ + mutation.Mutation( + insert=mutation.Mutation.Write(table="table_value") + ) + ] + ) + ], + ) + + +def test_batch_write_rest_error(): + client = SpannerClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.SpannerGrpcTransport( @@ -7849,6 +8444,7 @@ def test_spanner_base_transport(): "rollback", "partition_query", "partition_read", + "batch_write", ) for method in methods: with pytest.raises(NotImplementedError): @@ -8161,6 +8757,9 @@ def test_spanner_client_transport_session_collision(transport_name): session1 = client1.transport.partition_read._session session2 = client2.transport.partition_read._session assert session1 != session2 + session1 = client1.transport.batch_write._session + session2 = client2.transport.batch_write._session + assert session1 != session2 def test_spanner_grpc_transport_channel():