diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java index e07c935b0f78..8796350c00aa 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java @@ -69,7 +69,7 @@ private static class BatchReadOnlyTransactionImpl extends MultiUseReadOnlyTransa super( checkNotNull(session), checkNotNull(bound), - checkNotNull(spanner).getOptions().getGapicSpannerRpc(), + checkNotNull(spanner).getOptions().getSpannerRpcV1(), spanner.getOptions().getPrefetchChunks()); this.sessionName = session.getName(); this.options = session.getOptions(); @@ -82,7 +82,7 @@ private static class BatchReadOnlyTransactionImpl extends MultiUseReadOnlyTransa checkNotNull(session), checkNotNull(batchTransactionId).getTransactionId(), batchTransactionId.getTimestamp(), - checkNotNull(spanner).getOptions().getGapicSpannerRpc(), + checkNotNull(spanner).getOptions().getSpannerRpcV1(), spanner.getOptions().getPrefetchChunks()); this.sessionName = session.getName(); this.options = session.getOptions(); diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 173e9aa0cf0f..e9245f2db3f2 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -139,7 +139,6 @@ class SpannerImpl extends BaseService implements Spanner { } private final Random random = new Random(); - private final SpannerRpc rawGrpcRpc; private final SpannerRpc gapicRpc; private final int defaultPrefetchChunks; @@ -153,12 +152,10 @@ class SpannerImpl extends BaseService implements Spanner { private boolean spannerIsClosed = false; SpannerImpl( - SpannerRpc rawGrpcRpc, SpannerRpc gapicRpc, int defaultPrefetchChunks, SpannerOptions options) { super(options); - this.rawGrpcRpc = rawGrpcRpc; this.gapicRpc = gapicRpc; this.defaultPrefetchChunks = defaultPrefetchChunks; this.dbAdminClient = new DatabaseAdminClientImpl(options.getProjectId(), gapicRpc); @@ -169,7 +166,6 @@ class SpannerImpl extends BaseService implements Spanner { SpannerImpl(SpannerOptions options) { this( options.getSpannerRpcV1(), - options.getGapicSpannerRpc(), options.getPrefetchChunks(), options); } @@ -336,12 +332,10 @@ public void close() { } catch (InterruptedException | ExecutionException e) { throw SpannerExceptionFactory.newSpannerException(e); } - for (ManagedChannel channel : getOptions().getRpcChannels()) { - try { - channel.shutdown(); - } catch (RuntimeException e) { - logger.log(Level.WARNING, "Failed to close channel", e); - } + try { + gapicRpc.shutdown(); + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Failed to close channels", e); } } @@ -1067,18 +1061,17 @@ ResultSet executeQueryInternalWithOptions( new ResumableStreamIterator(MAX_BUFFERED_CHUNKS, QUERY) { @Override CloseableIterator startStream(@Nullable ByteString resumeToken) { - return new CloseableServerStreamIterator( + GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks); + SpannerRpc.StreamingCall call = rpc.executeQuery( resumeToken == null ? request : request.toBuilder().setResumeToken(resumeToken).build(), - null, - session.options)); - - // TODO(hzyi): make resume work - // Let resume fail for now. Gapic has its own resume, but in order not - // to introduce too much change at a time, we decide to plumb up - // ServerStream first and then figure out how to make resume work + stream.consumer(), + session.options); + call.request(prefetchChunks); + stream.setCall(call); + return stream; } }; return new GrpcResultSet(stream, this, queryMode); @@ -1178,18 +1171,17 @@ ResultSet readInternalWithOptions( new ResumableStreamIterator(MAX_BUFFERED_CHUNKS, READ) { @Override CloseableIterator startStream(@Nullable ByteString resumeToken) { - return new CloseableServerStreamIterator( + GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks); + SpannerRpc.StreamingCall call = rpc.read( resumeToken == null ? request : request.toBuilder().setResumeToken(resumeToken).build(), - null, - session.options)); - - // TODO(hzyi): make resume work - // Let resume fail for now. Gapic has its own resume, but in order not - // to introduce too much change at a time, we decide to plumb up - // ServerStream first and then figure out how to make resume work + stream.consumer(), + session.options); + call.request(prefetchChunks); + stream.setCall(call); + return stream; } }; GrpcResultSet resultSet = diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index fbab39630f5c..7338ea34ccd2 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -16,6 +16,9 @@ package com.google.cloud.spanner; +import com.google.api.gax.grpc.GrpcInterceptorProvider; +import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider; +import com.google.api.gax.rpc.TransportChannelProvider; import com.google.cloud.ServiceDefaults; import com.google.cloud.ServiceOptions; import com.google.cloud.ServiceRpc; @@ -23,7 +26,6 @@ import com.google.cloud.grpc.GrpcTransportOptions; import com.google.cloud.spanner.spi.SpannerRpcFactory; import com.google.cloud.spanner.spi.v1.GapicSpannerRpc; -import com.google.cloud.spanner.spi.v1.GrpcSpannerRpc; import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; @@ -53,7 +55,8 @@ public class SpannerOptions extends ServiceOptions { "https://www.googleapis.com/auth/spanner.admin", "https://www.googleapis.com/auth/spanner.data"); private static final int MAX_CHANNELS = 256; - private static final RpcChannelFactory DEFAULT_RPC_CHANNEL_FACTORY = new NettyRpcChannelFactory(); + private static final int MAX_MESSAGE_SIZE = 100 * 1024 * 1024; + private static final int MAX_HEADER_LIST_SIZE = 32 * 1024; //bytes /** Default implementation of {@code SpannerFactory}. */ private static class DefaultSpannerFactory implements SpannerFactory { @@ -71,11 +74,12 @@ private static class DefaultSpannerRpcFactory implements SpannerRpcFactory { @Override public ServiceRpc create(SpannerOptions options) { - return new GrpcSpannerRpc(options); + return new GapicSpannerRpc(options); } } - private final List rpcChannels; + private final TransportChannelProvider channelProvider; + private final GrpcInterceptorProvider interceptorProvider; private final SessionPoolOptions sessionPoolOptions; private final int prefetchChunks; private final int numChannels; @@ -83,17 +87,15 @@ public ServiceRpc create(SpannerOptions options) { private SpannerOptions(Builder builder) { super(SpannerFactory.class, SpannerRpcFactory.class, builder, new SpannerDefaults()); - numChannels = builder.numChannels; - String userAgent = getUserAgent(); - RpcChannelFactory defaultRpcChannelFactory = - userAgent == null - ? DEFAULT_RPC_CHANNEL_FACTORY - : new NettyRpcChannelFactory(userAgent); - rpcChannels = - createChannels( - getHost(), - MoreObjects.firstNonNull(builder.rpcChannelFactory, defaultRpcChannelFactory), - numChannels); + numChannels = builder.numChannels; + Preconditions.checkArgument( + numChannels >= 1 && numChannels <= MAX_CHANNELS, + "Number of channels must fall in the range [1, %s], found: %s", + MAX_CHANNELS, + numChannels); + + channelProvider = builder.channelProvider; + interceptorProvider = builder.interceptorProvider; sessionPoolOptions = builder.sessionPoolOptions != null ? builder.sessionPoolOptions @@ -107,10 +109,11 @@ public static class Builder extends ServiceOptions.Builder< Spanner, SpannerOptions, SpannerOptions.Builder> { private static final int DEFAULT_PREFETCH_CHUNKS = 4; - private RpcChannelFactory rpcChannelFactory; + private TransportChannelProvider channelProvider; + private GrpcInterceptorProvider interceptorProvider; + /** By default, we create 4 channels per {@link SpannerOptions} */ private int numChannels = 4; - private int prefetchChunks = DEFAULT_PREFETCH_CHUNKS; private SessionPoolOptions sessionPoolOptions; private ImmutableMap sessionLabels; @@ -123,6 +126,8 @@ private Builder() {} this.sessionPoolOptions = options.sessionPoolOptions; this.prefetchChunks = options.prefetchChunks; this.sessionLabels = options.sessionLabels; + this.channelProvider = options.channelProvider; + this.interceptorProvider = options.interceptorProvider; } @Override @@ -134,9 +139,21 @@ public Builder setTransportOptions(TransportOptions transportOptions) { return super.setTransportOptions(transportOptions); } - /** Sets the factory for creating gRPC channels. If not set, a default will be used. */ - public Builder setRpcChannelFactory(RpcChannelFactory factory) { - this.rpcChannelFactory = factory; + /** + * Sets the {@code ChannelProvider}. {@link GapicSpannerRpc} would create a default + * one if none is provided. + */ + public Builder setChannelProvider(TransportChannelProvider channelProvider) { + this.channelProvider = channelProvider; + return this; + } + + /** + * Sets the {@code GrpcInterceptorProvider}. {@link GapicSpannerRpc} would create + * a default one if none is provided. + */ + public Builder setInterceptorProvider(GrpcInterceptorProvider interceptorProvider) { + this.interceptorProvider = interceptorProvider; return this; } @@ -197,14 +214,6 @@ public SpannerOptions build() { } } - /** - * Interface for gRPC channel creation. Most users won't need to use this, as the default covers - * typical deployment scenarios. - */ - public interface RpcChannelFactory { - ManagedChannel newChannel(String host, int port); - } - /** Returns default instance of {@code SpannerOptions}. */ public static SpannerOptions getDefaultInstance() { return newBuilder().build(); @@ -214,8 +223,12 @@ public static Builder newBuilder() { return new Builder(); } - public List getRpcChannels() { - return rpcChannels; + public TransportChannelProvider getChannelProvider() { + return channelProvider; + } + + public GrpcInterceptorProvider getInterceptorProvider() { + return interceptorProvider; } public int getNumChannels() { @@ -238,88 +251,11 @@ public static GrpcTransportOptions getDefaultGrpcTransportOptions() { return GrpcTransportOptions.newBuilder().build(); } - /** - * Returns the default RPC channel factory used when none is specified. This may be useful for - * callers that wish to add interceptors to gRPC channels used by the Cloud Spanner client - * library. - */ - public static RpcChannelFactory getDefaultRpcChannelFactory() { - return DEFAULT_RPC_CHANNEL_FACTORY; - } - @Override protected String getDefaultHost() { return DEFAULT_HOST; } - private static List createChannels( - String rootUrl, RpcChannelFactory factory, int numChannels) { - Preconditions.checkArgument( - numChannels >= 1 && numChannels <= MAX_CHANNELS, - "Number of channels must fall in the range [1, %s], found: %s", - MAX_CHANNELS, - numChannels); - ImmutableList.Builder builder = ImmutableList.builder(); - for (int i = 0; i < numChannels; i++) { - builder.add(createChannel(rootUrl, factory)); - } - return builder.build(); - } - - private static ManagedChannel createChannel(String rootUrl, RpcChannelFactory factory) { - URL url; - try { - url = new URL(rootUrl); - } catch (MalformedURLException e) { - throw new IllegalArgumentException("Invalid host: " + rootUrl, e); - } - ManagedChannel channel = - factory.newChannel(url.getHost(), url.getPort() > 0 ? url.getPort() : url.getDefaultPort()); - return channel; - } - - static class NettyRpcChannelFactory implements RpcChannelFactory { - private static final int MAX_MESSAGE_SIZE = 100 * 1024 * 1024; - private static final int MAX_HEADER_LIST_SIZE = 32 * 1024; //bytes - private final String userAgent; - private final List interceptors; - - NettyRpcChannelFactory() { - this(null); - } - - NettyRpcChannelFactory(String userAgent) { - this(userAgent, ImmutableList.of()); - } - - NettyRpcChannelFactory(String userAgent, List interceptors) { - this.userAgent = userAgent; - this.interceptors = interceptors; - } - - @Override - public ManagedChannel newChannel(String host, int port) { - NettyChannelBuilder builder = - NettyChannelBuilder.forAddress(host, port) - .sslContext(newSslContext()) - .intercept(interceptors) - .maxHeaderListSize(MAX_HEADER_LIST_SIZE) - .maxMessageSize(MAX_MESSAGE_SIZE); - if (userAgent != null) { - builder.userAgent(userAgent); - } - return builder.build(); - } - - private static SslContext newSslContext() { - try { - return GrpcSslContexts.forClient().ciphers(null).build(); - } catch (SSLException e) { - throw new RuntimeException("SSL configuration failed: " + e.getMessage(), e); - } - } - } - private static class SpannerDefaults implements ServiceDefaults { @@ -348,10 +284,6 @@ protected SpannerRpc getSpannerRpcV1() { return (SpannerRpc) getRpc(); } - protected SpannerRpc getGapicSpannerRpc() { - return GapicSpannerRpc.create(this); - } - @SuppressWarnings("unchecked") @Override public Builder toBuilder() { diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index fba7beab6ab6..ad8f654b249c 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -18,19 +18,24 @@ import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; +import com.google.common.base.Preconditions; import com.google.api.core.ApiFunction; import com.google.api.gax.core.CredentialsProvider; import com.google.api.gax.core.GaxProperties; +import com.google.api.gax.core.InstantiatingExecutorProvider; import com.google.api.gax.grpc.GaxGrpcProperties; import com.google.api.gax.grpc.GrpcCallContext; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider; import com.google.api.gax.longrunning.OperationFuture; import com.google.api.gax.rpc.ApiClientHeaderProvider; +import com.google.api.gax.rpc.FixedTransportChannelProvider; import com.google.api.gax.rpc.HeaderProvider; import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.StatusCode; import com.google.api.gax.rpc.TransportChannelProvider; import com.google.api.gax.rpc.UnaryCallSettings; +import com.google.api.gax.rpc.ResponseObserver; +import com.google.api.gax.rpc.StreamController; import com.google.api.pathtemplate.PathTemplate; import com.google.cloud.ServiceOptions; import com.google.cloud.grpc.GrpcTransportOptions; @@ -91,7 +96,6 @@ import com.google.spanner.v1.Session; import com.google.spanner.v1.Transaction; import io.grpc.Context; -import java.io.IOException; import java.util.List; import java.util.Map; import java.util.concurrent.CancellationException; @@ -106,6 +110,7 @@ public class GapicSpannerRpc implements SpannerRpc { PathTemplate.create("projects/{project}"); private static final int MAX_MESSAGE_SIZE = 100 * 1024 * 1024; + // TODO(hzyi): change the stub names to be more intuitive private final SpannerStub stub; private final InstanceAdminStub instanceStub; private final DatabaseAdminStub databaseStub; @@ -114,17 +119,19 @@ public class GapicSpannerRpc implements SpannerRpc { private final SpannerMetadataProvider metadataProvider; public static GapicSpannerRpc create(SpannerOptions options) { - try { - return new GapicSpannerRpc(options); - } catch (IOException e) { - throw new IllegalStateException(e); - } + return new GapicSpannerRpc(options); } - public GapicSpannerRpc(SpannerOptions options) throws IOException { + public GapicSpannerRpc(SpannerOptions options) { this.projectId = options.getProjectId(); this.projectName = PROJECT_NAME_TEMPLATE.instantiate("project", this.projectId); + // TODO(hzyi): inject userAgent to headerProvider so that it + // can be picked up by ChannelProvider + + // create a metadataProvider which combines both internal headers and + // per-method-call extra headers for channelProvider to inject the headers + // for rpc calls ApiClientHeaderProvider.Builder internalHeaderProviderBuilder = ApiClientHeaderProvider.newBuilder(); ApiClientHeaderProvider internalHeaderProvider = @@ -142,17 +149,24 @@ public GapicSpannerRpc(SpannerOptions options) throws IOException { mergedHeaderProvider.getHeaders(), internalHeaderProviderBuilder.getResourceHeaderKey()); - // TODO(pongad): add watchdog - - // TODO(hzyi): make this channelProvider configurable through SpannerOptions + // First check if SpannerOptions provides a TransportChannerProvider. Create one + // with information gathered from SpannerOptions if none is provided TransportChannelProvider channelProvider = - InstantiatingGrpcChannelProvider - .newBuilder() - .setEndpoint(options.getEndpoint()) - .setMaxInboundMessageSize(MAX_MESSAGE_SIZE) - .setPoolSize(options.getNumChannels()) - .setInterceptorProvider(new SpannerInterceptorProvider()) - .build(); + MoreObjects.firstNonNull( + options.getChannelProvider(), + InstantiatingGrpcChannelProvider.newBuilder() + .setEndpoint(options.getEndpoint()) + .setMaxInboundMessageSize(MAX_MESSAGE_SIZE) + .setPoolSize(options.getNumChannels()) + + // Then check if SpannerOptions provides an InterceptorProvider. Create a default + // SpannerInterceptorProvider if none is provided + .setInterceptorProvider( + MoreObjects.firstNonNull( + options.getInterceptorProvider(), SpannerInterceptorProvider.createDefault())) + .setHeaderProvider(mergedHeaderProvider) + .setExecutorProvider(InstantiatingExecutorProvider.newBuilder().build()) + .build()); CredentialsProvider credentialsProvider = GrpcTransportOptions.setUpCredentialsProvider(options); @@ -399,17 +413,47 @@ public void deleteSession(String sessionName, @Nullable Map options) } @Override - public ServerStream read( + public StreamingCall read( ReadRequest request, ResultStreamConsumer consumer, @Nullable Map options) { GrpcCallContext context = newCallContext(options, request.getSession()); - return stub.streamingReadCallable().call(request, context); + SpannerResponseObserver responseObserver = new SpannerResponseObserver(consumer); + stub.streamingReadCallable().call(request, responseObserver, context); + final StreamController controller = responseObserver.getController(); + return new StreamingCall() { + @Override + public void request(int numMessage) { + controller.request(numMessage); + } + + // TODO(hzyi): streamController currently does not support cancel with message. Add + // this in gax and update this method later + @Override + public void cancel(String message) { + controller.cancel(); + } + }; } @Override - public ServerStream executeQuery( + public StreamingCall executeQuery( ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map options) { GrpcCallContext context = newCallContext(options, request.getSession()); - return stub.executeStreamingSqlCallable().call(request, context); + SpannerResponseObserver responseObserver = new SpannerResponseObserver(consumer); + stub.executeStreamingSqlCallable().call(request, responseObserver, context); + final StreamController controller = responseObserver.getController(); + return new StreamingCall() { + @Override + public void request(int numMessage) { + controller.request(numMessage); + } + + // TODO(hzyi): streamController currently does not support cancel with message. Add + // this in gax and update this method later + @Override + public void cancel(String message) { + controller.cancel(); + } + }; } @Override @@ -470,4 +514,52 @@ private GrpcCallContext newCallContext(@Nullable Map options, String metadataProvider.newExtraHeaders(resource, projectName)); return context; } + + public void shutdown() { + this.stub.close(); + this.instanceStub.close(); + this.databaseStub.close(); + } + + /** + * A {@code ResponseObserver} that exposes the {@code StreamController} and delegates callbacks + * to the {@link ResultStreamConsumer}. + */ + private static class SpannerResponseObserver implements ResponseObserver { + private StreamController controller; + private ResultStreamConsumer consumer; + + public SpannerResponseObserver(ResultStreamConsumer consumer) { + this.consumer = consumer; + } + + @Override + public void onStart(StreamController controller) { + + // Disable the auto flow control to allow client library + // set the number of messages it prefers to request + controller.disableAutoInboundFlowControl(); + this.controller = controller; + } + + @Override + public void onResponse(PartialResultSet response) { + consumer.onPartialResultSet(response); + } + + @Override + public void onError(Throwable t) { + consumer.onError(SpannerExceptionFactory.newSpannerException(t)); + } + + @Override + public void onComplete() { + consumer.onCompleted(); + } + + StreamController getController() { + return Preconditions.checkNotNull(this.controller); + } + } + } diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcSpannerRpc.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcSpannerRpc.java deleted file mode 100644 index 3858a09ceedf..000000000000 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcSpannerRpc.java +++ /dev/null @@ -1,647 +0,0 @@ -/* - * Copyright 2017 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.cloud.spanner.spi.v1; - -import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; - -import com.google.api.gax.core.GaxProperties; -import com.google.api.gax.grpc.GaxGrpcProperties; -import com.google.api.gax.longrunning.OperationFuture; -import com.google.api.gax.rpc.ApiClientHeaderProvider; -import com.google.api.gax.rpc.HeaderProvider; -import com.google.api.gax.rpc.ServerStream; -import com.google.api.pathtemplate.PathTemplate; -import com.google.cloud.NoCredentials; -import com.google.cloud.ServiceOptions; -import com.google.cloud.spanner.SpannerException; -import com.google.cloud.spanner.SpannerExceptionFactory; -import com.google.cloud.spanner.SpannerOptions; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; -import com.google.common.collect.ImmutableList; -import com.google.longrunning.GetOperationRequest; -import com.google.longrunning.OperationsGrpc; -import com.google.protobuf.Empty; -import com.google.protobuf.FieldMask; -import com.google.spanner.admin.database.v1.CreateDatabaseMetadata; -import com.google.spanner.admin.database.v1.CreateDatabaseRequest; -import com.google.spanner.admin.database.v1.Database; -import com.google.spanner.admin.database.v1.DatabaseAdminGrpc; -import com.google.spanner.admin.database.v1.DropDatabaseRequest; -import com.google.spanner.admin.database.v1.GetDatabaseDdlRequest; -import com.google.spanner.admin.database.v1.GetDatabaseRequest; -import com.google.spanner.admin.database.v1.ListDatabasesRequest; -import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; -import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest; -import com.google.spanner.admin.instance.v1.CreateInstanceMetadata; -import com.google.spanner.admin.instance.v1.CreateInstanceRequest; -import com.google.spanner.admin.instance.v1.DeleteInstanceRequest; -import com.google.spanner.admin.instance.v1.GetInstanceConfigRequest; -import com.google.spanner.admin.instance.v1.GetInstanceRequest; -import com.google.spanner.admin.instance.v1.Instance; -import com.google.spanner.admin.instance.v1.InstanceAdminGrpc; -import com.google.spanner.admin.instance.v1.InstanceConfig; -import com.google.spanner.admin.instance.v1.ListInstanceConfigsRequest; -import com.google.spanner.admin.instance.v1.ListInstanceConfigsResponse; -import com.google.spanner.admin.instance.v1.ListInstancesRequest; -import com.google.spanner.admin.instance.v1.ListInstancesResponse; -import com.google.spanner.admin.instance.v1.UpdateInstanceMetadata; -import com.google.spanner.admin.instance.v1.UpdateInstanceRequest; -import com.google.spanner.v1.BeginTransactionRequest; -import com.google.spanner.v1.CommitRequest; -import com.google.spanner.v1.CommitResponse; -import com.google.spanner.v1.CreateSessionRequest; -import com.google.spanner.v1.DeleteSessionRequest; -import com.google.spanner.v1.ExecuteSqlRequest; -import com.google.spanner.v1.PartialResultSet; -import com.google.spanner.v1.PartitionQueryRequest; -import com.google.spanner.v1.PartitionReadRequest; -import com.google.spanner.v1.PartitionResponse; -import com.google.spanner.v1.ReadRequest; -import com.google.spanner.v1.RollbackRequest; -import com.google.spanner.v1.Session; -import com.google.spanner.v1.SpannerGrpc; -import com.google.spanner.v1.Transaction; -import io.grpc.CallCredentials; -import io.grpc.CallOptions; -import io.grpc.Channel; -import io.grpc.ClientCall; -import io.grpc.ClientInterceptor; -import io.grpc.ClientInterceptors; -import io.grpc.Context; -import io.grpc.ForwardingClientCall; -import io.grpc.ForwardingClientCallListener; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.ServiceDescriptor; -import io.grpc.Status; -import io.grpc.auth.MoreCallCredentials; -import io.grpc.stub.AbstractStub; -import io.grpc.stub.ClientCallStreamObserver; -import io.grpc.stub.ClientCalls; -import io.grpc.stub.ClientResponseObserver; -import io.opencensus.trace.Tracing; -import io.opencensus.trace.export.SampledSpanStore; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.concurrent.CancellationException; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; -import java.util.logging.Level; -import java.util.logging.Logger; -import javax.annotation.Nullable; - -import com.google.longrunning.Operation; - -/** Implementation of Cloud Spanner remote calls using gRPC. */ -public class GrpcSpannerRpc implements SpannerRpc { - - static { - setupTracingConfig(); - } - - private static final Logger logger = Logger.getLogger(GrpcSpannerRpc.class.getName()); - - private static final PathTemplate PROJECT_NAME_TEMPLATE = - PathTemplate.create("projects/{project}"); - - private final Random random = new Random(); - private final List channels; - private final String projectId; - private final String projectName; - private final CallCredentials credentials; - private final SpannerMetadataProvider metadataProvider; - - public GrpcSpannerRpc(SpannerOptions options) { - this.projectId = options.getProjectId(); - this.projectName = PROJECT_NAME_TEMPLATE.instantiate("project", this.projectId); - this.credentials = callCredentials(options); - ImmutableList.Builder channelsBuilder = ImmutableList.builder(); - ImmutableList.Builder stubsBuilder = ImmutableList.builder(); - for (Channel channel : options.getRpcChannels()) { - channel = - ClientInterceptors.intercept( - channel, - new LoggingInterceptor(Level.FINER), - WatchdogInterceptor.newDefaultWatchdogInterceptor(), - new SpannerErrorInterceptor()); - channelsBuilder.add(channel); - stubsBuilder.add(withCredentials(SpannerGrpc.newFutureStub(channel), credentials)); - } - this.channels = channelsBuilder.build(); - - ApiClientHeaderProvider.Builder internalHeaderProviderBuilder = - ApiClientHeaderProvider.newBuilder(); - ApiClientHeaderProvider internalHeaderProvider = - internalHeaderProviderBuilder - .setClientLibToken( - ServiceOptions.getGoogApiClientLibName(), - GaxProperties.getLibraryVersion(options.getClass())) - .setTransportToken( - GaxGrpcProperties.getGrpcTokenName(), GaxGrpcProperties.getGrpcVersion()) - .build(); - - HeaderProvider mergedHeaderProvider = options.getMergedHeaderProvider(internalHeaderProvider); - this.metadataProvider = - SpannerMetadataProvider.create( - mergedHeaderProvider.getHeaders(), - internalHeaderProviderBuilder.getResourceHeaderKey()); - } - - private static CallCredentials callCredentials(SpannerOptions options) { - if (options.getCredentials() == null) { - return null; - } - if (options.getCredentials().equals(NoCredentials.getInstance())) { - return null; - } - return MoreCallCredentials.from(options.getScopedCredentials()); - } - - private > S withCredentials(S stub, CallCredentials credentials) { - if (credentials == null) { - return stub; - } - return stub.withCallCredentials(credentials); - } - - private String projectName() { - return projectName; - } - - @Override - public Paginated listInstanceConfigs(int pageSize, @Nullable String pageToken) - throws SpannerException { - ListInstanceConfigsRequest.Builder request = - ListInstanceConfigsRequest.newBuilder().setParent(projectName()).setPageSize(0); - if (pageToken != null) { - request.setPageToken(pageToken); - } - ListInstanceConfigsResponse response = - get( - doUnaryCall( - InstanceAdminGrpc.getListInstanceConfigsMethod(), - request.build(), - projectName(), - null)); - return new Paginated<>(response.getInstanceConfigsList(), response.getNextPageToken()); - } - - @Override - public InstanceConfig getInstanceConfig(String instanceConfigName) throws SpannerException { - GetInstanceConfigRequest request = - GetInstanceConfigRequest.newBuilder().setName(instanceConfigName).build(); - return get( - doUnaryCall(InstanceAdminGrpc.getGetInstanceConfigMethod(), request, projectName(), null)); - } - - @Override - public Paginated listInstances( - int pageSize, @Nullable String pageToken, @Nullable String filter) throws SpannerException { - ListInstancesRequest.Builder request = - ListInstancesRequest.newBuilder().setParent(projectName()).setPageSize(pageSize); - if (pageToken != null) { - request.setPageToken(pageToken); - } - if (filter != null) { - request.setFilter(filter); - } - ListInstancesResponse response = - get( - doUnaryCall( - InstanceAdminGrpc.getListInstancesMethod(), request.build(), projectName(), null)); - return new Paginated<>(response.getInstancesList(), response.getNextPageToken()); - } - - @Override - public OperationFuture createInstance( - String parent, String instanceId, Instance instance) throws SpannerException { - throw new UnsupportedOperationException("Not implemented: createInstance"); - } - - @Override - public OperationFuture updateInstance( - Instance instance, FieldMask fieldMask) throws SpannerException { - throw new UnsupportedOperationException("Not implemented: createInstance"); - } - - @Override - public Instance getInstance(String instanceName) throws SpannerException { - return get( - doUnaryCall( - InstanceAdminGrpc.getGetInstanceMethod(), - GetInstanceRequest.newBuilder().setName(instanceName).build(), - instanceName, - null)); - } - - @Override - public void deleteInstance(String instanceName) throws SpannerException { - get( - doUnaryCall( - InstanceAdminGrpc.getDeleteInstanceMethod(), - DeleteInstanceRequest.newBuilder().setName(instanceName).build(), - instanceName, - null)); - } - - @Override - public Paginated listDatabases( - String instanceName, int pageSize, @Nullable String pageToken) throws SpannerException { - ListDatabasesRequest.Builder builder = - ListDatabasesRequest.newBuilder().setParent(instanceName).setPageSize(pageSize); - if (pageToken != null) { - builder.setPageToken(pageToken); - } - com.google.spanner.admin.database.v1.ListDatabasesResponse response = - get( - doUnaryCall( - DatabaseAdminGrpc.getListDatabasesMethod(), builder.build(), instanceName, null)); - return new Paginated<>(response.getDatabasesList(), response.getNextPageToken()); - } - - @Override - public OperationFuture createDatabase( - String instanceName, String createDatabaseStatement, Iterable additionalStatements) { - throw new UnsupportedOperationException("Not Implemented: createDatabase"); - } - - @Override - public OperationFuture updateDatabaseDdl( - String databaseName, Iterable updateStatements, @Nullable String operationId) - throws SpannerException { - throw new UnsupportedOperationException("Not Implemented: updateDatabaseDdl"); - } - - @Override - public void dropDatabase(String databaseName) throws SpannerException { - get( - doUnaryCall( - DatabaseAdminGrpc.getDropDatabaseMethod(), - DropDatabaseRequest.newBuilder().setDatabase(databaseName).build(), - databaseName, - null)); - } - - @Override - public List getDatabaseDdl(String databaseName) throws SpannerException { - GetDatabaseDdlRequest request = - GetDatabaseDdlRequest.newBuilder().setDatabase(databaseName).build(); - return get(doUnaryCall(DatabaseAdminGrpc.getGetDatabaseDdlMethod(), request, databaseName, null)) - .getStatementsList(); - } - - @Override - public Database getDatabase(String databaseName) throws SpannerException { - return get( - doUnaryCall( - DatabaseAdminGrpc.getGetDatabaseMethod(), - GetDatabaseRequest.newBuilder().setName(databaseName).build(), - databaseName, - null)); - } - - @Override - public Operation getOperation(String name) throws SpannerException { - GetOperationRequest request = GetOperationRequest.newBuilder().setName(name).build(); - return get(doUnaryCall(OperationsGrpc.getGetOperationMethod(), request, name, null)); - } - - @Override - public Session createSession( - String databaseName, @Nullable Map labels, @Nullable Map options) { - CreateSessionRequest.Builder request = - CreateSessionRequest.newBuilder().setDatabase(databaseName); - if (labels != null && !labels.isEmpty()) { - Session.Builder session = Session.newBuilder().putAllLabels(labels); - request.setSession(session); - } - return get( - doUnaryCall( - SpannerGrpc.getCreateSessionMethod(), - request.build(), - databaseName, - Option.CHANNEL_HINT.getLong(options))); - } - - @Override - public void deleteSession(String sessionName, @Nullable Map options) { - DeleteSessionRequest request = DeleteSessionRequest.newBuilder().setName(sessionName).build(); - get( - doUnaryCall( - SpannerGrpc.getDeleteSessionMethod(), - request, - sessionName, - Option.CHANNEL_HINT.getLong(options))); - } - - @Override - public ServerStream read( - ReadRequest request, ResultStreamConsumer consumer, @Nullable Map options) { - throw new UnsupportedOperationException("Not implemented: read"); - } - - @Override - public ServerStream executeQuery( - ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map options) { - throw new UnsupportedOperationException("Not implemented: executeQuery"); - } - - @Override - public Transaction beginTransaction( - BeginTransactionRequest request, @Nullable Map options) { - return get( - doUnaryCall( - SpannerGrpc.getBeginTransactionMethod(), - request, - request.getSession(), - Option.CHANNEL_HINT.getLong(options))); - } - - @Override - public CommitResponse commit(CommitRequest commitRequest, @Nullable Map options) { - return get( - doUnaryCall( - SpannerGrpc.getCommitMethod(), - commitRequest, - commitRequest.getSession(), - Option.CHANNEL_HINT.getLong(options))); - } - - @Override - public void rollback(RollbackRequest request, @Nullable Map options) { - get( - doUnaryCall( - SpannerGrpc.getRollbackMethod(), - request, - request.getSession(), - Option.CHANNEL_HINT.getLong(options))); - } - - @Override - public PartitionResponse partitionQuery( - PartitionQueryRequest request, @Nullable Map options) - throws SpannerException { - return get( - doUnaryCall( - SpannerGrpc.getPartitionQueryMethod(), - request, - request.getSession(), - Option.CHANNEL_HINT.getLong(options))); - } - - @Override - public PartitionResponse partitionRead( - PartitionReadRequest request, @Nullable Map options) - throws SpannerException { - return get( - doUnaryCall( - SpannerGrpc.getPartitionReadMethod(), - request, - request.getSession(), - Option.CHANNEL_HINT.getLong(options))); - } - - /** Gets the result of an async RPC call, handling any exceptions encountered. */ - private static T get(final Future future) throws SpannerException { - final Context context = Context.current(); - try { - return future.get(); - } catch (InterruptedException e) { - // We are the sole consumer of the future, so cancel it. - future.cancel(true); - throw SpannerExceptionFactory.propagateInterrupt(e); - } catch (ExecutionException | CancellationException e) { - throw newSpannerException(context, e); - } - } - - private Future doUnaryCall( - MethodDescriptor method, - ReqT request, - @Nullable String resource, - @Nullable Long channelHint) { - CallOptions callOptions = - credentials == null - ? CallOptions.DEFAULT - : CallOptions.DEFAULT.withCallCredentials(credentials); - final ClientCall call = - new MetadataClientCall<>( - pick(channelHint, channels).newCall(method, callOptions), - metadataProvider.newMetadata(resource, projectName())); - return ClientCalls.futureUnaryCall(call, request); - } - - private StreamingCall doStreamingCall( - MethodDescriptor method, - T request, - ResultStreamConsumer consumer, - @Nullable String resource, - @Nullable Long channelHint) { - final Context context = Context.current(); - // TODO: Add deadline based on context. - CallOptions callOptions = - credentials == null - ? CallOptions.DEFAULT - : CallOptions.DEFAULT.withCallCredentials(credentials); - final ClientCall call = - new MetadataClientCall<>( - pick(channelHint, channels).newCall(method, callOptions), - metadataProvider.newMetadata(resource, projectName())); - ResultSetStreamObserver observer = new ResultSetStreamObserver(consumer, context, call); - ClientCalls.asyncServerStreamingCall(call, request, observer); - return observer; - } - - @VisibleForTesting - static class MetadataClientCall - extends ForwardingClientCall.SimpleForwardingClientCall { - private final Metadata extraMetadata; - - MetadataClientCall(ClientCall call, Metadata extraMetadata) { - super(call); - this.extraMetadata = extraMetadata; - } - - @Override - public void start(Listener responseListener, Metadata metadata) { - metadata.merge(extraMetadata); - super.start(responseListener, metadata); - } - } - - private T pick(@Nullable Long hint, List elements) { - long hintVal = Math.abs(hint != null ? hint : random.nextLong()); - long index = hintVal % elements.size(); - return elements.get((int) index); - } - - /** - * This is a one time setup for grpcz pages. This adds all of the methods to the Tracing - * environment required to show a consistent set of methods relating to Cloud Bigtable on the - * grpcz page. If HBase artifacts are present, this will add tracing metadata for HBase methods. - * - * TODO: Remove this when we depend on gRPC 1.8 - */ - private static void setupTracingConfig() { - SampledSpanStore store = Tracing.getExportComponent().getSampledSpanStore(); - if (store == null) { - // Tracing implementation is not linked. - return; - } - List descriptors = new ArrayList<>(); - addDescriptor(descriptors, SpannerGrpc.getServiceDescriptor()); - addDescriptor(descriptors, DatabaseAdminGrpc.getServiceDescriptor()); - addDescriptor(descriptors, InstanceAdminGrpc.getServiceDescriptor()); - store.registerSpanNamesForCollection(descriptors); - } - - /** - * Reads a list of {@link MethodDescriptor}s from a {@link ServiceDescriptor} and creates a list - * of Open Census tags. - */ - private static void addDescriptor(List descriptors, ServiceDescriptor serviceDescriptor) { - for (MethodDescriptor method : serviceDescriptor.getMethods()) { - // This is added by a grpc ClientInterceptor - descriptors.add("Sent." + method.getFullMethodName().replace('/', '.')); - } - } - - private static class ResultSetStreamObserver - implements ClientResponseObserver, StreamingCall { - private final ResultStreamConsumer consumer; - private final Context context; - private final ClientCall call; - private volatile ClientCallStreamObserver requestStream; - - public ResultSetStreamObserver( - ResultStreamConsumer consumer, Context context, ClientCall call) { - this.consumer = consumer; - this.context = context; - this.call = call; - } - - @Override - public void beforeStart(final ClientCallStreamObserver requestStream) { - this.requestStream = requestStream; - requestStream.disableAutoInboundFlowControl(); - } - - @Override - public void onNext(PartialResultSet value) { - consumer.onPartialResultSet(value); - } - - @Override - public void onError(Throwable t) { - consumer.onError(newSpannerException(context, t)); - } - - @Override - public void onCompleted() { - consumer.onCompleted(); - } - - @Override - public void request(int numMessages) { - requestStream.request(numMessages); - } - - @Override - public void cancel(@Nullable String message) { - call.cancel(message, null); - } - } - - static class LoggingInterceptor implements ClientInterceptor { - private final Level level; - - LoggingInterceptor(Level level) { - this.level = level; - } - - private class CallLogger { - private final MethodDescriptor method; - - CallLogger(MethodDescriptor method) { - this.method = method; - } - - void log(String message) { - logger.log( - level, - "{0}[{1}]: {2}", - new Object[] { - method.getFullMethodName(), - Integer.toHexString(System.identityHashCode(this)), - message - }); - } - - void logfmt(String message, Object... params) { - log(String.format(message, params)); - } - } - - @Override - public ClientCall interceptCall( - MethodDescriptor method, CallOptions callOptions, Channel next) { - if (!logger.isLoggable(level)) { - return next.newCall(method, callOptions); - } - - final CallLogger callLogger = new CallLogger(method); - callLogger.log("Start"); - return new ForwardingClientCall.SimpleForwardingClientCall( - next.newCall(method, callOptions)) { - @Override - public void start(Listener responseListener, Metadata headers) { - super.start( - new ForwardingClientCallListener.SimpleForwardingClientCallListener( - responseListener) { - @Override - public void onMessage(RespT message) { - callLogger.logfmt("Received:\n%s", message); - super.onMessage(message); - } - - @Override - public void onClose(Status status, Metadata trailers) { - callLogger.logfmt("Closed with status %s and trailers %s", status, trailers); - super.onClose(status, trailers); - } - }, - headers); - } - - @Override - public void sendMessage(ReqT message) { - callLogger.logfmt("Send:\n%s", message); - super.sendMessage(message); - } - - @Override - public void cancel(@Nullable String message, @Nullable Throwable cause) { - callLogger.logfmt("Cancelled with message %s", message); - super.cancel(message, cause); - } - }; - } - } -} diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerInterceptorProvider.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerInterceptorProvider.java index 09e625e2107a..c51966837d00 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerInterceptorProvider.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerInterceptorProvider.java @@ -15,6 +15,7 @@ */ package com.google.cloud.spanner.spi.v1; +import com.google.api.core.InternalApi; import com.google.api.gax.grpc.GrpcInterceptorProvider; import com.google.common.collect.ImmutableList; import io.grpc.ClientInterceptor; @@ -23,21 +24,39 @@ import java.util.logging.Logger; /** - * For internal use only. - * An interceptor provider that provides a list of grpc interceptors for {@code GapicSpannerRpc} - * to handle logging and error augmentation by intercepting grpc calls. + * For internal use only. An interceptor provider that provides a list of grpc interceptors for + * {@code GapicSpannerRpc} to handle logging and error augmentation by intercepting grpc calls. */ -class SpannerInterceptorProvider implements GrpcInterceptorProvider { +@InternalApi("Exposed for testing") +public class SpannerInterceptorProvider implements GrpcInterceptorProvider { - private static final List clientInterceptors = + private static final List defaultInterceptors = ImmutableList.of( new SpannerErrorInterceptor(), - new LoggingInterceptor(Logger.getLogger(GrpcSpannerRpc.class.getName()), Level.FINER), + new LoggingInterceptor(Logger.getLogger(GapicSpannerRpc.class.getName()), Level.FINER), WatchdogInterceptor.newDefaultWatchdogInterceptor()); + private final List clientInterceptors; + + private SpannerInterceptorProvider(List clientInterceptors) { + this.clientInterceptors = clientInterceptors; + } + + public static SpannerInterceptorProvider createDefault() { + return new SpannerInterceptorProvider(defaultInterceptors); + } + + public SpannerInterceptorProvider with(ClientInterceptor clientInterceptor) { + List interceptors = + ImmutableList.builder() + .addAll(this.clientInterceptors) + .add(clientInterceptor) + .build(); + return new SpannerInterceptorProvider(interceptors); + } + @Override public List getInterceptors() { return clientInterceptors; } - } diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java index 1e8c26a06119..5a50dbfbbbad 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java @@ -204,10 +204,10 @@ Session createSession(String databaseName, @Nullable Map labels, void deleteSession(String sessionName, @Nullable Map options) throws SpannerException; - ServerStream read( + StreamingCall read( ReadRequest request, ResultStreamConsumer consumer, @Nullable Map options); - ServerStream executeQuery( + StreamingCall executeQuery( ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map options); Transaction beginTransaction(BeginTransactionRequest request, @Nullable Map options) @@ -225,4 +225,6 @@ PartitionResponse partitionQuery( PartitionResponse partitionRead( PartitionReadRequest request, @Nullable Map options) throws SpannerException; + + public void shutdown(); } diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BatchClientImplTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BatchClientImplTest.java index a17537bfead0..15698b658781 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BatchClientImplTest.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BatchClientImplTest.java @@ -48,7 +48,6 @@ public final class BatchClientImplTest { private static final ByteString TXN_ID = ByteString.copyFromUtf8("my-txn"); private static final String TIMESTAMP = "2017-11-15T10:54:20Z"; - @Mock private SpannerRpc rawGrpcRpc; @Mock private SpannerRpc gapicRpc; @Mock private SpannerOptions spannerOptions; @Captor private ArgumentCaptor> optionsCaptor; @@ -60,7 +59,7 @@ public final class BatchClientImplTest { public void setUp() { initMocks(this); DatabaseId db = DatabaseId.of(DB_NAME); - SpannerImpl spanner = new SpannerImpl(rawGrpcRpc, gapicRpc, 1, spannerOptions); + SpannerImpl spanner = new SpannerImpl(gapicRpc, 1, spannerOptions); client = new BatchClientImpl(db, spanner); } @@ -72,7 +71,7 @@ public void testBatchReadOnlyTxnWithBound() throws Exception { com.google.protobuf.Timestamp timestamp = Timestamps.parse(TIMESTAMP); Transaction txnMetadata = Transaction.newBuilder().setId(TXN_ID).setReadTimestamp(timestamp).build(); - when(spannerOptions.getGapicSpannerRpc()).thenReturn(gapicRpc); + when(spannerOptions.getSpannerRpcV1()).thenReturn(gapicRpc); when(gapicRpc.beginTransaction(Mockito.any(), optionsCaptor.capture())) .thenReturn(txnMetadata); diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GceTestEnvConfig.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GceTestEnvConfig.java index 422d9b43f525..b49fb1ef5d70 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GceTestEnvConfig.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GceTestEnvConfig.java @@ -18,7 +18,7 @@ import static com.google.common.base.Preconditions.checkState; -import com.google.common.collect.ImmutableList; +import com.google.cloud.spanner.spi.v1.SpannerInterceptorProvider; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -55,10 +55,9 @@ public GceTestEnvConfig() { } options = builder - .setRpcChannelFactory( - new SpannerOptions.NettyRpcChannelFactory( - null, - ImmutableList.of(new GrpcErrorInjector(errorProbability)))) + .setInterceptorProvider( + SpannerInterceptorProvider.createDefault() + .with(new GrpcErrorInjector(errorProbability))) .build(); } diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java index e0cee1bcb506..09bcfe22a0fa 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java @@ -71,7 +71,7 @@ public class SessionImplTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - SpannerImpl spanner = new SpannerImpl(rpc, rpc, 1, spannerOptions); + SpannerImpl spanner = new SpannerImpl(rpc, 1, spannerOptions); String dbName = "projects/p1/instances/i1/databases/d1"; String sessionName = dbName + "/sessions/s1"; DatabaseId db = DatabaseId.of(dbName); @@ -282,15 +282,18 @@ public void request(int numMessages) {} } private void mockRead(final PartialResultSet myResultSet) { - ServerStreamingCallable serverStreamingCallable = - new ServerStreamingStashCallable(Arrays.asList(myResultSet)); - final ServerStream mockServerStream = serverStreamingCallable.call(null); - Mockito.when( - rpc.read( - Mockito.any(), - Mockito.any(), - Mockito.eq(options))) - .thenReturn(mockServerStream); + final ArgumentCaptor consumer = + ArgumentCaptor.forClass(SpannerRpc.ResultStreamConsumer.class); + Mockito.when(rpc.read(Mockito.any(), consumer.capture(), Mockito.eq(options))) + .then( + new Answer() { + @Override + public SpannerRpc.StreamingCall answer(InvocationOnMock invocation) throws Throwable { + consumer.getValue().onPartialResultSet(myResultSet); + consumer.getValue().onCompleted(); + return new NoOpStreamingCall(); + } + }); } @Test diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerImplTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerImplTest.java index a075a4615384..f95009f66a2a 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerImplTest.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerImplTest.java @@ -43,7 +43,7 @@ public class SpannerImplTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - impl = new SpannerImpl(rpc, rpc, 1, spannerOptions); + impl = new SpannerImpl(rpc, 1, spannerOptions); } @Test diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerOptionsTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerOptionsTest.java index 731cd1270d8d..f2d4552acdc2 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerOptionsTest.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerOptionsTest.java @@ -39,14 +39,6 @@ public class SpannerOptionsTest { @Rule public ExpectedException thrown = ExpectedException.none(); - private static class TestChannelFactory implements SpannerOptions.RpcChannelFactory { - @Override - public ManagedChannel newChannel(String host, int port) { - // Disable SSL to avoid a dependency on ALPN/NPN. - return NettyChannelBuilder.forAddress(host, port).usePlaintext(true).build(); - } - } - @Test public void defaultBuilder() { // We need to set the project id since in test environment we cannot obtain a default project @@ -54,7 +46,6 @@ public void defaultBuilder() { SpannerOptions options = SpannerOptions.newBuilder() .setProjectId("test-project") - .setRpcChannelFactory(new TestChannelFactory()) .build(); assertThat(options.getHost()).isEqualTo("https://spanner.googleapis.com"); assertThat(options.getPrefetchChunks()).isEqualTo(4); @@ -69,7 +60,6 @@ public void builder() { labels.put("env", "dev"); SpannerOptions options = SpannerOptions.newBuilder() - .setRpcChannelFactory(new TestChannelFactory()) .setHost(host) .setProjectId(projectId) .setPrefetchChunks(2) diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/RequestMetadataTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/RequestMetadataTest.java deleted file mode 100644 index 3e27aa68e042..000000000000 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/RequestMetadataTest.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2017 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.cloud.spanner.spi.v1; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.doNothing; - -import com.google.cloud.spanner.spi.v1.GrpcSpannerRpc.MetadataClientCall; -import io.grpc.ClientCall; -import io.grpc.Metadata; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; - -/** Unit tests for {@link GrpcSpannerRpc.MetadataClientCall}. */ -@RunWith(JUnit4.class) -public class RequestMetadataTest { - private static final Metadata.Key HEADER_KEY = - Metadata.Key.of("google-cloud-resource-prefix", Metadata.ASCII_STRING_MARSHALLER); - - private Metadata metadata; - - @Mock - private ClientCall innerCall; - @Mock - private ClientCall.Listener listener; - @Captor - private ArgumentCaptor innerMetadata; - - @Before - public void setUp() { - MockitoAnnotations.initMocks(this); - metadata = new Metadata(); - } - - @Test - public void metadataForwardingTest() { - doNothing() - .when(innerCall) - .start(Mockito.>any(), innerMetadata.capture()); - - Metadata in = new Metadata(); - in.put(HEADER_KEY, "TEST_HEADER"); - MetadataClientCall metadataCall = new MetadataClientCall<>(innerCall, in); - metadataCall.start(listener, metadata); - assertTrue(innerMetadata.getValue().containsKey(HEADER_KEY)); - assertEquals(innerMetadata.getValue().get(HEADER_KEY), "TEST_HEADER"); - } -} diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/SpannerMetadataProviderTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/SpannerMetadataProviderTest.java index 76f1f56c4d5a..12a9d2850ce4 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/SpannerMetadataProviderTest.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/SpannerMetadataProviderTest.java @@ -15,11 +15,14 @@ */ package com.google.cloud.spanner.spi.v1; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.grpc.Metadata; import io.grpc.Metadata.Key; +import java.util.List; import java.util.Map; import org.junit.Test; @@ -67,6 +70,16 @@ public void testGetResourceHeaderValue() { getResourceHeaderValue(metadataProvider, "projects/p/instances/i/databases/d/operations")); } + @Test + public void testNewExtraHeaders() { + SpannerMetadataProvider metadataProvider = + SpannerMetadataProvider.create(ImmutableMap.of(), "header1"); + Map> extraHeaders = metadataProvider.newExtraHeaders(null, "value1"); + assertThat(extraHeaders) + .containsExactlyEntriesIn( + ImmutableMap.>of("header1", ImmutableList.of("value1"))); + } + private String getResourceHeaderValue( SpannerMetadataProvider headerProvider, String resourceTokenTemplate) { Metadata metadata = headerProvider.newMetadata(resourceTokenTemplate, "projects/p");