Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Add full mutual auth to gRPC client/server and augment tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sid Narayan committed Jul 1, 2020
1 parent 14b22c7 commit 4b13b62
Show file tree
Hide file tree
Showing 26 changed files with 547 additions and 183 deletions.
3 changes: 1 addition & 2 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,4 @@
# permissions and limitations under the License.
#

localPaDir=../performance-analyzer
org.gradle.jvmargs=-Xmx4096m -XX:MaxPermSize=256m
localPaDir=../performance-analyzer
9 changes: 7 additions & 2 deletions pa_config/performance-analyzer.properties
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ metrics-db-file-prefix-path = /tmp/metricsdb_

https-enabled = false

#Setup the correct path for certificates
# Setup the correct path for server certificates
certificate-file-path = specify_path

private-key-file-path = specify_path
trusted-cas-file-path = specify_path

# Setup the correct path for client certificates (by default, the client will just use the server certificates)
#client-certificate-file-path = specify_path
#client-private-key-file-path = specify_path
#client-trusted-cas-file-path = specify_path

# WebService bind host; default only to local interface
#webservice-bind-host =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.cert.Certificate;
import javax.annotation.Nullable;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
Expand All @@ -37,6 +38,12 @@ public class CertificateUtils {
public static final String IN_MEMORY_PWD = "opendistro";
public static final String CERTIFICATE_FILE_PATH = "certificate-file-path";
public static final String PRIVATE_KEY_FILE_PATH = "private-key-file-path";
public static final String TRUSTED_CAS_FILE_PATH = "trusted-cas-file-path";
public static final String CLIENT_PREFIX = "client-";
public static final String CLIENT_CERTIFICATE_FILE_PATH = CLIENT_PREFIX + CERTIFICATE_FILE_PATH;
public static final String CLIENT_PRIVATE_KEY_FILE_PATH = CLIENT_PREFIX + PRIVATE_KEY_FILE_PATH;
public static final String CLIENT_TRUSTED_CAS_FILE_PATH = CLIENT_PREFIX + TRUSTED_CAS_FILE_PATH;

private static final Logger LOGGER = LogManager.getLogger(CertificateUtils.class);

public static Certificate getCertificate(final FileReader certReader) throws Exception {
Expand Down Expand Up @@ -81,4 +88,39 @@ public static File getPrivateKeyFile() {
String privateKeyPath = PluginSettings.instance().getSettingValue(PRIVATE_KEY_FILE_PATH);
return new File(privateKeyPath);
}

@Nullable
public static File getTrustedCasFile() {
String trustedCasPath = PluginSettings.instance().getSettingValue(TRUSTED_CAS_FILE_PATH);
if (trustedCasPath == null || trustedCasPath.isEmpty()) {
return null;
}
return new File(trustedCasPath);
}

public static File getClientCertificateFile() {
String certFilePath = PluginSettings.instance().getSettingValue(CLIENT_CERTIFICATE_FILE_PATH);
if (certFilePath == null || certFilePath.isEmpty()) {
return getCertificateFile();
}
return new File(certFilePath);
}

public static File getClientPrivateKeyFile() {
String privateKeyPath = PluginSettings.instance().getSettingValue(CLIENT_PRIVATE_KEY_FILE_PATH);
if (privateKeyPath == null || privateKeyPath.isEmpty()) {
return getPrivateKeyFile();
}
return new File(privateKeyPath);
}

@Nullable
public static File getClientTrustedCasFile() {
String trustedCasPath = PluginSettings.instance().getSettingValue(CLIENT_TRUSTED_CAS_FILE_PATH);
// By default, use the same CA as the server
if (trustedCasPath == null || trustedCasPath.isEmpty()) {
return getTrustedCasFile();
}
return new File(trustedCasPath);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;

import java.io.File;
import java.util.Map;
Expand All @@ -33,6 +33,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLException;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand All @@ -46,6 +47,11 @@
public class GRPCConnectionManager {

private static final Logger LOG = LogManager.getLogger(GRPCConnectionManager.class);
private final int port;
// TLS certificate, private key, and trusted root CA files
private File certFile;
private File pkeyFile;
private File trustedCasFile;

/**
* Map of remote host to a Netty channel to that host.
Expand All @@ -67,6 +73,27 @@ public class GRPCConnectionManager {

public GRPCConnectionManager(final boolean shouldUseHttps) {
this.shouldUseHttps = shouldUseHttps;
this.port = Util.RPC_PORT;
if (shouldUseHttps) {
this.certFile = CertificateUtils.getClientCertificateFile();
this.pkeyFile = CertificateUtils.getClientPrivateKeyFile();
this.trustedCasFile = CertificateUtils.getClientTrustedCasFile();
}
}

/**
* Constructor that allows you to specify which port a client should connect to
* @param shouldUseHttps Whether to enable TLS
* @param port The port number that client stubs should attempt to connect to
*/
public GRPCConnectionManager(final boolean shouldUseHttps, int port) {
this.shouldUseHttps = shouldUseHttps;
this.port = port;
if (shouldUseHttps) {
this.certFile = CertificateUtils.getClientCertificateFile();
this.pkeyFile = CertificateUtils.getClientPrivateKeyFile();
this.trustedCasFile = CertificateUtils.getClientTrustedCasFile();
}
}

@VisibleForTesting
Expand Down Expand Up @@ -144,28 +171,26 @@ private ManagedChannel buildChannelForHost(final String remoteHost) {
}

private ManagedChannel buildInsecureChannel(final String remoteHost) {
return ManagedChannelBuilder.forAddress(remoteHost, Util.RPC_PORT).usePlaintext().build();
return ManagedChannelBuilder.forAddress(remoteHost, this.port).usePlaintext().build();
}

private ManagedChannel buildSecureChannel(final String remoteHost) {
try {
File certFile = CertificateUtils.getCertificateFile();
File pkeyFile = CertificateUtils.getPrivateKeyFile();
return NettyChannelBuilder.forAddress(remoteHost, Util.RPC_PORT)
.sslContext(
GrpcSslContexts.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.keyManager(certFile, pkeyFile)
.build())
.build();
} catch (SSLException e) {
LOG.error("Unable to build an SSL gRPC client. Exception: {}", e.getMessage());
e.printStackTrace();

// Wrap the SSL Exception in a generic RTE and re-throw.
throw new RuntimeException(e);
}
}
private ManagedChannel buildSecureChannel(final String remoteHost) {
try {
SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient().keyManager(certFile, pkeyFile);
if (trustedCasFile != null) {
sslContextBuilder.trustManager(trustedCasFile);
}
return NettyChannelBuilder.forAddress(remoteHost, this.port)
.sslContext(sslContextBuilder.build())
.build();
} catch (SSLException e) {
LOG.error("Unable to build an SSL gRPC client. Exception: {}", e.getMessage());
e.printStackTrace();

// Wrap the SSL Exception in a generic RTE and re-throw.
throw new RuntimeException(e);
}
}

private InterNodeRpcServiceStub buildStubForHost(
final String remoteHost) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ public NetClient(final GRPCConnectionManager connectionManager) {
this.connectionManager = connectionManager;
}

public GRPCConnectionManager getConnectionManager() {
return connectionManager;
}

private ConcurrentMap<String, AtomicReference<StreamObserver<FlowUnitMessage>>> perHostOpenDataStreamMap =
new ConcurrentHashMap<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.grpc.netty.shaded.io.netty.channel.nio.NioEventLoopGroup;
import io.grpc.netty.shaded.io.netty.channel.socket.nio.NioServerSocketChannel;
import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.stub.StreamObserver;

import java.io.File;
Expand Down Expand Up @@ -86,7 +87,7 @@ public class NetServer extends InterNodeRpcServiceGrpc.InterNodeRpcServiceImplBa
/**
* The server instance.
*/
private Server server;
protected Server server;

public NetServer(final int port, final int numServerThreads, final boolean useHttps) {
this.port = port;
Expand Down Expand Up @@ -118,8 +119,12 @@ public void run() {
numServerThreads,
useHttps);
try {
server = useHttps ? buildHttpsServer(CertificateUtils.getCertificateFile(), CertificateUtils.getPrivateKeyFile())
: buildHttpServer();
if (useHttps) {
server = buildHttpsServer(CertificateUtils.getTrustedCasFile(), CertificateUtils.getCertificateFile(),
CertificateUtils.getPrivateKeyFile());
} else {
server = buildHttpServer();
}
server.start();
LOG.info("gRPC server started successfully!");
postStartHook();
Expand All @@ -144,13 +149,15 @@ private Server buildHttpServer() {
return buildBaseServer().executor(Executors.newSingleThreadExecutor()).build();
}

protected Server buildHttpsServer(File certFile, File pkeyFile) throws SSLException {
protected Server buildHttpsServer(File trustedCasFile, File certFile, File pkeyFile) throws SSLException {
SslContextBuilder sslContextBuilder = GrpcSslContexts.forServer(certFile, pkeyFile);
// If an authority is specified, authenticate clients
if (trustedCasFile != null) {
sslContextBuilder.trustManager(trustedCasFile).clientAuth(ClientAuth.REQUIRE);
}
return buildBaseServer()
.sslContext(GrpcSslContexts.forServer(certFile, pkeyFile)
.trustManager(certFile)
.clientAuth(ClientAuth.REQUIRE)
.build())
.useTransportSecurity(certFile, pkeyFile).build();
.sslContext(sslContextBuilder.build())
.build();
}

/**
Expand Down Expand Up @@ -246,7 +253,11 @@ public void stop() {
// Remove handlers.
sendDataHandler = null;
subscribeHandler = null;
}

public void shutdown() {
stop();
// Actually stop the server
if (server != null) {
server.shutdown();
try {
Expand Down
Loading

0 comments on commit 4b13b62

Please sign in to comment.