Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Add required mutual auth to gRPC Server/Client
Browse files Browse the repository at this point in the history
Previously we had 1 sided TLS on the server side. Data between the
client and server was send over an encrypted channel, but any client
could make requests to the server.

This commit changes the behavior so that only clients with the matching
certificates can make requests to the server when TLS is enabled. This
commit does NOT add support for installing a trust manager. That must be
added in the future.
  • Loading branch information
Sid Narayan committed Jun 22, 2020
1 parent 2b4394f commit 14b22c7
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 50 deletions.
3 changes: 2 additions & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# permissions and limitations under the License.
#

localPaDir=../performance-analyzer
localPaDir=../performance-analyzer
org.gradle.jvmargs=-Xmx4096m -XX:MaxPermSize=256m
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ public class CertificateUtils {
public static final String ALIAS_CERT = "cert";
// The password is not used to encrypt keys on disk.
public static final String IN_MEMORY_PWD = "opendistro";
private static final String CERTIFICATE_FILE_PATH = "certificate-file-path";
private static final String PRIVATE_KEY_FILE_PATH = "private-key-file-path";
public static final String CERTIFICATE_FILE_PATH = "certificate-file-path";
public static final String PRIVATE_KEY_FILE_PATH = "private-key-file-path";
private static final Logger LOGGER = LogManager.getLogger(CertificateUtils.class);

public static Certificate getCertificate(final FileReader certReader) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

import com.amazon.opendistro.elasticsearch.performanceanalyzer.ConfigStatus;
import com.amazon.opendistro.elasticsearch.performanceanalyzer.core.Util;
import com.google.common.annotations.VisibleForTesting;

import java.io.File;
import java.util.Properties;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
Expand Down Expand Up @@ -100,6 +103,11 @@ public boolean getHttpsEnabled() {
return this.httpsEnabled;
}

@VisibleForTesting
public void overrideProperty(String key, String value) {
settings.setProperty(key, value);
}

public boolean shouldCleanupMetricsDBFiles() {
return shouldCleanupMetricsDBFiles;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package com.amazon.opendistro.elasticsearch.performanceanalyzer.net;

import com.amazon.opendistro.elasticsearch.performanceanalyzer.CertificateUtils;
import com.amazon.opendistro.elasticsearch.performanceanalyzer.core.Util;
import com.amazon.opendistro.elasticsearch.performanceanalyzer.grpc.InterNodeRpcServiceGrpc;
import com.amazon.opendistro.elasticsearch.performanceanalyzer.grpc.InterNodeRpcServiceGrpc.InterNodeRpcServiceStub;
Expand All @@ -24,9 +25,12 @@
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.util.InsecureTrustManagerFactory;

import java.io.File;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLException;
import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -145,12 +149,14 @@ private ManagedChannel buildInsecureChannel(final String remoteHost) {

private ManagedChannel buildSecureChannel(final String remoteHost) {
try {
File certFile = CertificateUtils.getCertificateFile();
File pkeyFile = CertificateUtils.getPrivateKeyFile();
return NettyChannelBuilder.forAddress(remoteHost, Util.RPC_PORT)
.sslContext(
GrpcSslContexts.forClient()
.trustManager(
InsecureTrustManagerFactory.INSTANCE)
.build())
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.keyManager(certFile, pkeyFile)
.build())
.build();
} catch (SSLException e) {
LOG.error("Unable to build an SSL gRPC client. Exception: {}", e.getMessage());
Expand All @@ -177,7 +183,15 @@ private void removeAllStubs() {
private void terminateAllConnections() {
for (Map.Entry<String, AtomicReference<ManagedChannel>> entry : perHostChannelMap.entrySet()) {
LOG.debug("shutting down connection to host: {}", entry.getKey());
entry.getValue().get().shutdownNow();
ManagedChannel channel = entry.getValue().get();
channel.shutdownNow();
try {
channel.awaitTermination(1, TimeUnit.MINUTES);
} catch (InterruptedException e) {
LOG.warn("Channel interrupted while shutting down", e);
channel.shutdownNow();
}

perHostChannelMap.remove(entry.getKey());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,22 @@
import com.amazon.opendistro.elasticsearch.performanceanalyzer.metrics.handler.MetricsServerHandler;
import com.amazon.opendistro.elasticsearch.performanceanalyzer.rca.net.handler.PublishRequestHandler;
import com.amazon.opendistro.elasticsearch.performanceanalyzer.rca.net.handler.SubscribeServerHandler;
import com.google.common.annotations.VisibleForTesting;

import io.grpc.Server;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.netty.shaded.io.netty.channel.nio.NioEventLoopGroup;
import io.grpc.netty.shaded.io.netty.channel.socket.nio.NioServerSocketChannel;
import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth;
import io.grpc.stub.StreamObserver;

import java.io.File;
import java.io.IOException;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down Expand Up @@ -108,40 +117,40 @@ public void run() {
port,
numServerThreads,
useHttps);
server = useHttps ? buildHttpsServer() : buildHttpServer();
try {
server = useHttps ? buildHttpsServer(CertificateUtils.getCertificateFile(), CertificateUtils.getPrivateKeyFile())
: buildHttpServer();
server.start();
LOG.info("gRPC server started successfully!");
postStartHook();
server.awaitTermination();
LOG.info(" gRPC server terminating..");
LOG.info("gRPC server terminating..");
} catch (InterruptedException | IOException e) {
e.printStackTrace();
LOG.error("gRPC server failed to start", e);
server.shutdownNow();
shutdownHook();
}
}

private Server buildHttpServer() {
private NettyServerBuilder buildBaseServer() {
return NettyServerBuilder.forPort(port)
.addService(this)
.bossEventLoopGroup(new NioEventLoopGroup(numServerThreads))
.workerEventLoopGroup(new NioEventLoopGroup(numServerThreads))
.channelType(NioServerSocketChannel.class)
.executor(Executors.newSingleThreadExecutor())
.build();
.addService(this)
.bossEventLoopGroup(new NioEventLoopGroup(numServerThreads))
.workerEventLoopGroup(new NioEventLoopGroup(numServerThreads))
.channelType(NioServerSocketChannel.class);
}

private Server buildHttpsServer() {
return NettyServerBuilder.forPort(port)
.addService(this)
.bossEventLoopGroup(new NioEventLoopGroup(numServerThreads))
.workerEventLoopGroup(new NioEventLoopGroup(numServerThreads))
.channelType(NioServerSocketChannel.class)
.useTransportSecurity(
CertificateUtils.getCertificateFile(),
CertificateUtils.getPrivateKeyFile())
.build();
private Server buildHttpServer() {
return buildBaseServer().executor(Executors.newSingleThreadExecutor()).build();
}

protected Server buildHttpsServer(File certFile, File pkeyFile) throws SSLException {
return buildBaseServer()
.sslContext(GrpcSslContexts.forServer(certFile, pkeyFile)
.trustManager(certFile)
.clientAuth(ClientAuth.REQUIRE)
.build())
.useTransportSecurity(certFile, pkeyFile).build();
}

/**
Expand Down Expand Up @@ -205,6 +214,7 @@ public void setMetricsHandler(MetricsServerHandler metricsServerHandler) {
* Unit test usage only.
* @return Current handler for /metrics rpc.
*/
@VisibleForTesting
public MetricsServerHandler getMetricsServerHandler() {
return metricsServerHandler;
}
Expand All @@ -213,6 +223,7 @@ public MetricsServerHandler getMetricsServerHandler() {
* Unit test usage only.
* @return Current handler for /publish rpc.
*/
@VisibleForTesting
public PublishRequestHandler getSendDataHandler() {
return sendDataHandler;
}
Expand All @@ -221,6 +232,7 @@ public PublishRequestHandler getSendDataHandler() {
* Unit test usage only.
* @return Current handler for /subscribe rpc.
*/
@VisibleForTesting
public SubscribeServerHandler getSubscribeHandler() {
return subscribeHandler;
}
Expand All @@ -234,5 +246,14 @@ public void stop() {
// Remove handlers.
sendDataHandler = null;
subscribeHandler = null;

if (server != null) {
server.shutdown();
try {
server.awaitTermination(1, TimeUnit.MINUTES);
} catch (InterruptedException e) {
server.shutdownNow();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package com.amazon.opendistro.elasticsearch.performanceanalyzer.net;

import com.amazon.opendistro.elasticsearch.performanceanalyzer.core.Util;
import com.amazon.opendistro.elasticsearch.performanceanalyzer.grpc.MetricsRequest;
import com.amazon.opendistro.elasticsearch.performanceanalyzer.grpc.MetricsResponse;
import com.amazon.opendistro.elasticsearch.performanceanalyzer.rca.GradleTaskForRca;
import com.amazon.opendistro.elasticsearch.performanceanalyzer.util.WaitFor;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;

@Category(GradleTaskForRca.class)
public class GRPCTest {
private static final Logger LOG = LogManager.getLogger(GRPCTest.class);

private static NetClient netClient;
private static NetClient insecureClient;
private static TestNetServer netServer;
private static ExecutorService executorService;
private static ExecutorService netServerExecutor;
private static AtomicReference<ExecutorService> clientExecutor;
private static AtomicReference<ExecutorService> serverExecutor;
private static GRPCConnectionManager connectionManager;
private static GRPCConnectionManager insecureConnectionManager;

@BeforeClass
public static void setup() throws Exception {
try {
connectionManager = new GRPCConnectionManager(true);
netClient = new NetClient(connectionManager);
insecureConnectionManager = new GRPCConnectionManager(false);
insecureClient = new NetClient(insecureConnectionManager);
executorService = Executors.newSingleThreadExecutor();
clientExecutor = new AtomicReference<>(null);
serverExecutor = new AtomicReference<>(Executors.newSingleThreadExecutor());
netServer = new TestNetServer(Util.RPC_PORT, 1, true);
netServerExecutor = Executors.newSingleThreadExecutor();
netServerExecutor.execute(netServer);
// Wait for the TestNetServer to start
WaitFor.waitFor(() -> netServer.isRunning.get(), 10, TimeUnit.SECONDS);
if (!netServer.isRunning.get()) {
throw new RuntimeException("Unable to start TestNetServer");
}
} catch (Exception e) {
LOG.error("Failed to initialize NetTest", e);
throw e;
}
}

@AfterClass
public static void tearDown() {
executorService.shutdown();
netServerExecutor.shutdown();
netServer.stop();
netClient.stop();
insecureClient.stop();
connectionManager.shutdown();
insecureConnectionManager.shutdown();
}

@Test
public void testSecureGetMetrics() throws Exception {
MetricsRequest request = MetricsRequest.newBuilder()
.addMetricList("CPU_UTILIZATION")
.addAggList("avg")
.addDimList("ShardId")
.build();
final MetricsResponse[] response = new MetricsResponse[1];
StreamObserver<MetricsResponse> observer = new StreamObserver<MetricsResponse>() {
@Override
public void onNext(MetricsResponse value) {
LOG.info("onNext called!");
response[0] = value;
}

@Override
public void onError(Throwable t) {
LOG.error("GetMetrics observer received error from server", t);
}

@Override
public void onCompleted() {
LOG.info("GetMetrics stream completed successfully");
}
};
netClient.getMetrics("127.0.0.1", request, observer);
WaitFor.waitFor(() -> {
return response[0] != null && response[0].getMetricsResult().equals("metrics");
}, 30, TimeUnit.SECONDS);
}

/**
* Verifies that an unauthorized client should not be able to communicate with a TLS secured server
*/
@Test
public void testUnauthorizedGetMetricsFails() throws Exception {
MetricsRequest request = MetricsRequest.newBuilder()
.addMetricList("CPU_UTILIZATION")
.addAggList("avg")
.addDimList("ShardId")
.build();
final Throwable[] errors = new Throwable[1];
StreamObserver<MetricsResponse> observer = new StreamObserver<MetricsResponse>() {
@Override
public void onNext(MetricsResponse value) {
LOG.error("onNext called successfully with insecure connection");
}

@Override
public void onError(Throwable t) {
errors[0] = t;
}

@Override
public void onCompleted() {
LOG.info("GetMetrics stream completed successfully");
}
};

try {
insecureClient.getMetrics("localhost", request, observer);
} catch (RuntimeException e) {
return;
}
WaitFor.waitFor(() -> {
if (errors[0] != null) {
if (errors[0] instanceof StatusRuntimeException) {
return true;
}
throw new Exception("Wanted StatusRuntimeException, but got unexpected error: {}", errors[0]);
}
return false;
}, 30, TimeUnit.SECONDS);
}
}
Loading

0 comments on commit 14b22c7

Please sign in to comment.