Skip to content

Commit

Permalink
feat(spanner): mTLS setup for spanner external host clients
Browse files Browse the repository at this point in the history
  • Loading branch information
sagnghos committed Jan 2, 2025
1 parent 1af8e46 commit e158e0f
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
import com.google.api.gax.core.GaxProperties;
import com.google.api.gax.grpc.GrpcCallContext;
import com.google.api.gax.grpc.GrpcInterceptorProvider;
import com.google.api.gax.grpc.GrpcTransportChannel;
import com.google.api.gax.longrunning.OperationTimedPollAlgorithm;
import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.FixedTransportChannelProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.tracing.ApiTracerFactory;
import com.google.api.gax.tracing.BaseApiTracerFactory;
Expand Down Expand Up @@ -69,13 +71,19 @@
import io.grpc.CompressorRegistry;
import io.grpc.Context;
import io.grpc.ExperimentalApi;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.MethodDescriptor;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.opentelemetry.api.GlobalOpenTelemetry;
import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.common.Attributes;
import java.io.File;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.time.Duration;
import java.util.ArrayList;
Expand All @@ -90,6 +98,8 @@
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
Expand Down Expand Up @@ -942,6 +952,7 @@ public static class Builder
private CloseableExecutorProvider asyncExecutorProvider;
private String compressorName;
private String emulatorHost = System.getenv("SPANNER_EMULATOR_HOST");
private ManagedChannel managedChannel;
private boolean leaderAwareRoutingEnabled = true;
private boolean attemptDirectPath = true;
private DirectedReadOptions directedReadOptions;
Expand Down Expand Up @@ -1485,6 +1496,28 @@ public Builder setEmulatorHost(String emulatorHost) {
return this;
}

public Builder useClientCert(String host, String clientCertificate, String clientKey) {
try {
URI uri = new URI(host);
managedChannel =
NettyChannelBuilder.forAddress(uri.getHost(), uri.getPort())
.sslContext(
GrpcSslContexts.forClient()
.keyManager(new File(clientCertificate), new File(clientKey))
.build())
.build();

setChannelProvider(
FixedTransportChannelProvider.create(GrpcTransportChannel.create(managedChannel)));
} catch (URISyntaxException e) {
throw new IllegalArgumentException(
"Invalid host format. Expected format: 'protocol://host[:port]'.", e);
} catch (Exception e) {
throw new RuntimeException("Unexpected error during mTLS setup.", e);
}
return this;
}

/**
* Sets OpenTelemetry object to be used for Spanner Metrics and Traces. GlobalOpenTelemetry will
* be used as fallback if this options is not set.
Expand Down Expand Up @@ -1593,6 +1626,23 @@ public SpannerOptions build() {
this.setChannelConfigurator(ManagedChannelBuilder::usePlaintext);
// As we are using plain text, we should never send any credentials.
this.setCredentials(NoCredentials.getInstance());
} else if (managedChannel != null) {
Runtime.getRuntime()
.addShutdownHook(
new Thread(
() -> {
final Logger logger = Logger.getLogger(SpannerOptions.class.getName());
try {
managedChannel.shutdown();
logger.log(
Level.INFO, "[SpannerOptions] ManagedChannel shut down successfully.");
} catch (Exception e) {
logger.log(
Level.WARNING,
"[SpannerOptions] Failed to shut down ManagedChannel.",
e);
}
}));
}
if (this.numChannels == null) {
this.numChannels =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import static com.google.cloud.spanner.connection.ConnectionProperties.AUTO_CONFIG_EMULATOR;
import static com.google.cloud.spanner.connection.ConnectionProperties.AUTO_PARTITION_MODE;
import static com.google.cloud.spanner.connection.ConnectionProperties.CHANNEL_PROVIDER;
import static com.google.cloud.spanner.connection.ConnectionProperties.CLIENT_CERTIFICATE;
import static com.google.cloud.spanner.connection.ConnectionProperties.CLIENT_KEY;
import static com.google.cloud.spanner.connection.ConnectionProperties.CREDENTIALS_PROVIDER;
import static com.google.cloud.spanner.connection.ConnectionProperties.CREDENTIALS_URL;
import static com.google.cloud.spanner.connection.ConnectionProperties.DATABASE_ROLE;
Expand Down Expand Up @@ -225,6 +227,8 @@ public String[] getValidValues() {
static final boolean DEFAULT_USE_VIRTUAL_THREADS = false;
static final boolean DEFAULT_USE_VIRTUAL_GRPC_TRANSPORT_THREADS = false;
static final String DEFAULT_CREDENTIALS = null;
static final String DEFAULT_CLIENT_CERTIFICATE = null;
static final String DEFAULT_CLIENT_KEY = null;
static final String DEFAULT_OAUTH_TOKEN = null;
static final Integer DEFAULT_MIN_SESSIONS = null;
static final Integer DEFAULT_MAX_SESSIONS = null;
Expand Down Expand Up @@ -263,6 +267,10 @@ public String[] getValidValues() {
private static final String DEFAULT_EMULATOR_HOST = "http://localhost:9010";
/** Use plain text is only for local testing purposes. */
static final String USE_PLAIN_TEXT_PROPERTY_NAME = "usePlainText";
/** Client certificate path to establish mTLS */
static final String CLIENT_CERTIFICATE_PROPERTY_NAME = "clientCertificate";
/** Client key path to establish mTLS */
static final String CLIENT_KEY_PROPERTY_NAME = "clientKey";
/** Name of the 'autocommit' connection property. */
public static final String AUTOCOMMIT_PROPERTY_NAME = "autocommit";
/** Name of the 'readonly' connection property. */
Expand Down Expand Up @@ -434,6 +442,12 @@ static boolean isEnableTransactionalConnectionStateForPostgreSQL() {
USE_PLAIN_TEXT_PROPERTY_NAME,
"Use a plain text communication channel (i.e. non-TLS) for communicating with the server (true/false). Set this value to true for communication with the Cloud Spanner emulator.",
DEFAULT_USE_PLAIN_TEXT),
ConnectionProperty.createStringProperty(
CLIENT_CERTIFICATE_PROPERTY_NAME,
"Specifies the file path to the client certificate required for establishing an mTLS connection."),
ConnectionProperty.createStringProperty(
CLIENT_KEY_PROPERTY_NAME,
"Specifies the file path to the client private key required for establishing an mTLS connection."),
ConnectionProperty.createStringProperty(
USER_AGENT_PROPERTY_NAME,
"The custom user-agent property name to use when communicating with Cloud Spanner. This property is intended for internal library usage, and should not be set by applications."),
Expand Down Expand Up @@ -828,6 +842,7 @@ public static Builder newBuilder() {
private final Credentials fixedCredentials;

private final String host;
private boolean isExternalHost;
private final String projectId;
private final String instanceId;
private final String databaseName;
Expand All @@ -841,10 +856,10 @@ public static Builder newBuilder() {

private ConnectionOptions(Builder builder) {
Matcher matcher;
boolean isExternalHost = false;
this.isExternalHost = false;
if (builder.isValidExternalHostUri(builder.uri)) {
matcher = Builder.EXTERNAL_HOST_PATTERN.matcher(builder.uri);
isExternalHost = true;
this.isExternalHost = true;
} else {
matcher = Builder.SPANNER_URI_PATTERN.matcher(builder.uri);
}
Expand Down Expand Up @@ -967,7 +982,7 @@ && getInitialConnectionPropertyValue(OAUTH_TOKEN) == null

String projectId = "default";
String instanceId = matcher.group(Builder.INSTANCE_GROUP);
if (!isExternalHost) {
if (!this.isExternalHost) {
projectId = matcher.group(Builder.PROJECT_GROUP);
} else if (instanceId == null) {
instanceId = "default";
Expand Down Expand Up @@ -1291,6 +1306,14 @@ boolean isUsePlainText() {
|| getInitialConnectionPropertyValue(USE_PLAIN_TEXT);
}

String getClientCertificate() {
return getInitialConnectionPropertyValue(CLIENT_CERTIFICATE);
}

String getClientCertificateKey() {
return getInitialConnectionPropertyValue(CLIENT_KEY);
}

/**
* The (custom) user agent string to use for this connection. If <code>null</code>, then the
* default JDBC user agent string will be used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import static com.google.cloud.spanner.connection.ConnectionOptions.AUTO_BATCH_DML_UPDATE_COUNT_VERIFICATION_PROPERTY_NAME;
import static com.google.cloud.spanner.connection.ConnectionOptions.AUTO_PARTITION_MODE_PROPERTY_NAME;
import static com.google.cloud.spanner.connection.ConnectionOptions.CHANNEL_PROVIDER_PROPERTY_NAME;
import static com.google.cloud.spanner.connection.ConnectionOptions.CLIENT_CERTIFICATE_PROPERTY_NAME;
import static com.google.cloud.spanner.connection.ConnectionOptions.CLIENT_KEY_PROPERTY_NAME;
import static com.google.cloud.spanner.connection.ConnectionOptions.CREDENTIALS_PROPERTY_NAME;
import static com.google.cloud.spanner.connection.ConnectionOptions.CREDENTIALS_PROVIDER_PROPERTY_NAME;
import static com.google.cloud.spanner.connection.ConnectionOptions.DATABASE_ROLE_PROPERTY_NAME;
Expand All @@ -33,6 +35,8 @@
import static com.google.cloud.spanner.connection.ConnectionOptions.DEFAULT_AUTO_BATCH_DML_UPDATE_COUNT_VERIFICATION;
import static com.google.cloud.spanner.connection.ConnectionOptions.DEFAULT_AUTO_PARTITION_MODE;
import static com.google.cloud.spanner.connection.ConnectionOptions.DEFAULT_CHANNEL_PROVIDER;
import static com.google.cloud.spanner.connection.ConnectionOptions.DEFAULT_CLIENT_CERTIFICATE;
import static com.google.cloud.spanner.connection.ConnectionOptions.DEFAULT_CLIENT_KEY;
import static com.google.cloud.spanner.connection.ConnectionOptions.DEFAULT_CREDENTIALS;
import static com.google.cloud.spanner.connection.ConnectionOptions.DEFAULT_DATABASE_ROLE;
import static com.google.cloud.spanner.connection.ConnectionOptions.DEFAULT_DATA_BOOST_ENABLED;
Expand Down Expand Up @@ -192,6 +196,20 @@ public class ConnectionProperties {
BooleanConverter.INSTANCE,
Context.STARTUP);

static final ConnectionProperty<String> CLIENT_CERTIFICATE =
create(
CLIENT_CERTIFICATE_PROPERTY_NAME,
"Specifies the file path to the client certificate required for establishing an mTLS connection.",
DEFAULT_CLIENT_CERTIFICATE,
StringValueConverter.INSTANCE,
Context.STARTUP);
static final ConnectionProperty<String> CLIENT_KEY =
create(
CLIENT_KEY_PROPERTY_NAME,
"Specifies the file path to the client private key required for establishing an mTLS connection.",
DEFAULT_CLIENT_KEY,
StringValueConverter.INSTANCE,
Context.STARTUP);
static final ConnectionProperty<String> CREDENTIALS_URL =
create(
CREDENTIALS_PROPERTY_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) {
// Set a custom channel configurator to allow http instead of https.
builder.setChannelConfigurator(ManagedChannelBuilder::usePlaintext);
}
if (options.getClientCertificate() != null && options.getClientCertificateKey() != null) {
builder.useClientCert(
options.getHost(), options.getClientCertificate(), options.getClientCertificateKey());
}
if (options.getConfigurator() != null) {
options.getConfigurator().configure(builder);
}
Expand Down

0 comments on commit e158e0f

Please sign in to comment.