From 9152a23af091670ddf35e767fbbd3d53b1dab9c7 Mon Sep 17 00:00:00 2001 From: Sid Narayan Date: Mon, 27 Jul 2020 10:37:17 -0700 Subject: [PATCH] Rest mutual auth fix (#279) * Add required mutual auth to gRPC Server/Client Previously we had 1 sided TLS on the server side. Data between the client and server was send over an encrypted channel, but any client could make requests to the server. This commit changes the behavior so that only clients with the matching certificates can make requests to the server when TLS is enabled. This commit does NOT add support for installing a trust manager. That must be added in the future. * Add full mutual auth to gRPC client/server and augment tests * Implement mutual auth for the REST endpoints This commit makes the PerformanceAnalyzerWebServer authenticate clients if the user specifies a certificate authority. It also properly sets up the server's identity, so that any clients can authenticate the server. * Fix merge issue with WireHopperTest * Fixing up PerformanceAnalyzerWebServerTest * Modify gradle.yml for testing * Remove info log flooding from testing gradle.yml --- .github/workflows/gradle.yml | 2 +- .../performanceanalyzer/CertificateUtils.java | 43 ++- .../PerformanceAnalyzerWebServer.java | 85 +++-- .../config/PluginSettings.java | 5 + .../PerformanceAnalyzerWebServerTest.java | 317 ++++++++++++++++++ .../performanceanalyzer/net/GRPCTest.java | 3 - 6 files changed, 417 insertions(+), 38 deletions(-) create mode 100644 src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServerTest.java diff --git a/.github/workflows/gradle.yml b/.github/workflows/gradle.yml index 6d8cacf76..9579881d0 100644 --- a/.github/workflows/gradle.yml +++ b/.github/workflows/gradle.yml @@ -32,7 +32,7 @@ jobs: java-version: 1.12 - name: Build RCA with Gradle working-directory: ./tmp/rca - run: ./gradlew build + run: ./gradlew build --stacktrace - name: Generate Jacoco coverage report working-directory: ./tmp/rca run: ./gradlew jacocoTestReport diff --git a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/CertificateUtils.java b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/CertificateUtils.java index 57eee09d6..49d05a2f8 100644 --- a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/CertificateUtils.java +++ b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/CertificateUtils.java @@ -21,7 +21,12 @@ import java.security.KeyStore; import java.security.PrivateKey; import java.security.cert.Certificate; +import java.security.cert.X509Certificate; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; @@ -32,7 +37,7 @@ public class CertificateUtils { - public static final String ALIAS_PRIVATE = "private"; + public static final String ALIAS_IDENTITY = "identity"; public static final String ALIAS_CERT = "cert"; // The password is not used to encrypt keys on disk. public static final String IN_MEMORY_PWD = "opendistro"; @@ -65,14 +70,46 @@ public static PrivateKey getPrivateKey(final FileReader keyReader) throws Except public static KeyStore createKeyStore() throws Exception { String certFilePath = PluginSettings.instance().getSettingValue(CERTIFICATE_FILE_PATH); String keyFilePath = PluginSettings.instance().getSettingValue(PRIVATE_KEY_FILE_PATH); + KeyStore.ProtectionParameter protParam = new KeyStore.PasswordProtection( + CertificateUtils.IN_MEMORY_PWD.toCharArray()); PrivateKey pk = getPrivateKey(new FileReader(keyFilePath)); KeyStore ks = createEmptyStore(); Certificate certificate = getCertificate(new FileReader(certFilePath)); - ks.setCertificateEntry(ALIAS_CERT, certificate); - ks.setKeyEntry(ALIAS_PRIVATE, pk, IN_MEMORY_PWD.toCharArray(), new Certificate[] {certificate}); + ks.setEntry(ALIAS_IDENTITY, new KeyStore.PrivateKeyEntry(pk, new Certificate[]{certificate}), protParam); return ks; } + public static TrustManager[] getTrustManagers(boolean forServer) throws Exception { + // If a certificate authority is specified, create an authenticating trust manager + String certificateAuthority; + if (forServer) { + certificateAuthority = PluginSettings.instance().getSettingValue(TRUSTED_CAS_FILE_PATH); + } else { + certificateAuthority = PluginSettings.instance().getSettingValue(CLIENT_TRUSTED_CAS_FILE_PATH); + } + if (certificateAuthority != null && !certificateAuthority.isEmpty()) { + KeyStore ks = createEmptyStore(); + Certificate certificate = getCertificate(new FileReader(certificateAuthority)); + ks.setCertificateEntry(ALIAS_CERT, certificate); + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ks); + return tmf.getTrustManagers(); + } + // Otherwise, return an all-trusting TrustManager + return new TrustManager[] { + new X509TrustManager() { + + public X509Certificate[] getAcceptedIssuers() { + return null; + } + + public void checkClientTrusted(X509Certificate[] certs, String authType) {} + + public void checkServerTrusted(X509Certificate[] certs, String authType) {} + } + }; + } + public static KeyStore createEmptyStore() throws Exception { KeyStore ks = KeyStore.getInstance("JKS"); ks.load(null, IN_MEMORY_PWD.toCharArray()); diff --git a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServer.java b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServer.java index b8c8b5eea..19119d39e 100644 --- a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServer.java +++ b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServer.java @@ -16,22 +16,25 @@ package com.amazon.opendistro.elasticsearch.performanceanalyzer; import com.amazon.opendistro.elasticsearch.performanceanalyzer.config.PluginSettings; +import com.google.common.annotations.VisibleForTesting; import com.sun.net.httpserver.HttpServer; import com.sun.net.httpserver.HttpsConfigurator; +import com.sun.net.httpserver.HttpsParameters; import com.sun.net.httpserver.HttpsServer; + import java.net.InetAddress; import java.net.InetSocketAddress; import java.security.KeyStore; import java.security.Security; -import java.security.cert.X509Certificate; import java.util.concurrent.Executors; + import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLSession; -import javax.net.ssl.TrustManager; -import javax.net.ssl.X509TrustManager; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.bouncycastle.jce.provider.BouncyCastleProvider; @@ -40,8 +43,10 @@ public class PerformanceAnalyzerWebServer { private static final Logger LOG = LogManager.getLogger(PerformanceAnalyzerWebServer.class); public static final int WEBSERVICE_DEFAULT_PORT = 9600; - public static final String WEBSERVICE_PORT_CONF_NAME = "webservice-listener-port"; + @VisibleForTesting public static final String WEBSERVICE_BIND_HOST_NAME = "webservice-bind-host"; + @VisibleForTesting + public static final String WEBSERVICE_PORT_CONF_NAME = "webservice-listener-port"; // Use system default for max backlog. private static final int INCOMING_QUEUE_LENGTH = 1; @@ -66,8 +71,34 @@ public static HttpServer createInternalServer(String portFromSetting, String hos return null; } + /** + * ClientAuthConfigurator makes the server perform client authentication if the user has set up a + * certificate authority + */ + private static class ClientAuthConfigurator extends HttpsConfigurator { + public ClientAuthConfigurator(SSLContext sslContext) { + super(sslContext); + } + + @Override + public void configure(HttpsParameters params) { + final SSLParameters sslParams = getSSLContext().getDefaultSSLParameters(); + if (CertificateUtils.getTrustedCasFile() != null) { + LOG.debug("Enabling client auth"); + final SSLEngine sslEngine = getSSLContext().createSSLEngine(); + sslParams.setNeedClientAuth(true); + sslParams.setCipherSuites(sslEngine.getEnabledCipherSuites()); + sslParams.setProtocols(sslEngine.getEnabledProtocols()); + params.setSSLParameters(sslParams); + } else { + LOG.debug("Not enabling client auth"); + super.configure(params); + } + } + } + private static HttpServer createHttpsServer(int readerPort, String bindHost) throws Exception { - HttpsServer server = null; + HttpsServer server; if (bindHost != null && !bindHost.trim().isEmpty()) { LOG.info("Binding to Interface: {}", bindHost); server = @@ -81,38 +112,30 @@ private static HttpServer createHttpsServer(int readerPort, String bindHost) thr server = HttpsServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), readerPort), INCOMING_QUEUE_LENGTH); } - TrustManager[] trustAllCerts = - new TrustManager[] { - new X509TrustManager() { - - public X509Certificate[] getAcceptedIssuers() { - return null; - } - - public void checkClientTrusted(X509Certificate[] certs, String authType) {} - - public void checkServerTrusted(X509Certificate[] certs, String authType) {} - } - }; - - HostnameVerifier allHostsValid = - new HostnameVerifier() { - public boolean verify(String hostname, SSLSession session) { - return true; - } - }; - // Install the all-trusting trust manager SSLContext sslContext = SSLContext.getInstance("TLSv1.2"); KeyStore ks = CertificateUtils.createKeyStore(); KeyManagerFactory kmf = KeyManagerFactory.getInstance("NewSunX509"); kmf.init(ks, CertificateUtils.IN_MEMORY_PWD.toCharArray()); - sslContext.init(kmf.getKeyManagers(), trustAllCerts, null); + sslContext.init(kmf.getKeyManagers(), CertificateUtils.getTrustManagers(true), null); + server.setHttpsConfigurator(new ClientAuthConfigurator(sslContext)); + + + // TODO ask ktkrg why this is necessary + // Try to set HttpsURLConnection defaults, our webserver can still run even if this block fails + try { + LOG.debug("Setting default SSLSocketFactory..."); + HttpsURLConnection.setDefaultSSLSocketFactory(sslContext.getSocketFactory()); + LOG.debug("Default SSLSocketFactory set successfully"); + HostnameVerifier allHostsValid = (hostname, session) -> true; + LOG.debug("Setting default HostnameVerifier..."); + HttpsURLConnection.setDefaultHostnameVerifier(allHostsValid); + LOG.debug("Default HostnameVerifier set successfully"); + } catch (Exception e) { // Usually AccessControlException + LOG.warn("Exception while trying to set URLConnection defaults", e); + } - HttpsURLConnection.setDefaultSSLSocketFactory(sslContext.getSocketFactory()); - HttpsURLConnection.setDefaultHostnameVerifier(allHostsValid); - server.setHttpsConfigurator(new HttpsConfigurator(sslContext)); return server; } diff --git a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/config/PluginSettings.java b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/config/PluginSettings.java index 87bc842f6..0b5f9d036 100644 --- a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/config/PluginSettings.java +++ b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/config/PluginSettings.java @@ -103,6 +103,11 @@ public boolean getHttpsEnabled() { return this.httpsEnabled; } + @VisibleForTesting + public void setHttpsEnabled(boolean httpsEnabled) { + this.httpsEnabled = httpsEnabled; + } + @VisibleForTesting public void overrideProperty(String key, String value) { settings.setProperty(key, value); diff --git a/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServerTest.java b/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServerTest.java new file mode 100644 index 000000000..2d0f818b7 --- /dev/null +++ b/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServerTest.java @@ -0,0 +1,317 @@ +package com.amazon.opendistro.elasticsearch.performanceanalyzer; + +import com.amazon.opendistro.elasticsearch.performanceanalyzer.config.PluginSettings; +import com.sun.net.httpserver.HttpServer; + +import io.grpc.netty.shaded.io.netty.handler.codec.http.HttpResponseStatus; + +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.security.KeyStore; +import java.util.Objects; +import java.util.concurrent.Executors; + +import javax.annotation.Nullable; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSocketFactory; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class PerformanceAnalyzerWebServerTest { + private static final String BIND_HOST = "localhost"; + private static final String PORT = "11021"; + private static final String MESSAGE = "hello"; + + private String oldBindHost; + private String oldPort; + private String oldCertificateFilePath; + private String oldPrivateKeyFilePath; + private String oldTrustedCasFilePath; + private String oldClientCertificateFilePath; + private String oldClientPrivateKeyFilePath; + private String oldClientTrustedCasFilePath; + private boolean oldHttpsEnabled; + + private HttpServer server; + + @Before + public void setup() { + // Save old PluginSettings values + oldBindHost = PluginSettings.instance().getSettingValue(PerformanceAnalyzerWebServer.WEBSERVICE_BIND_HOST_NAME); + oldPort = PluginSettings.instance().getSettingValue(PerformanceAnalyzerWebServer.WEBSERVICE_PORT_CONF_NAME); + oldCertificateFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.CERTIFICATE_FILE_PATH); + oldPrivateKeyFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.PRIVATE_KEY_FILE_PATH); + oldTrustedCasFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.TRUSTED_CAS_FILE_PATH); + oldClientCertificateFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.CLIENT_CERTIFICATE_FILE_PATH); + oldClientPrivateKeyFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.CLIENT_PRIVATE_KEY_FILE_PATH); + oldClientTrustedCasFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH); + oldHttpsEnabled = PluginSettings.instance().getHttpsEnabled(); + // Update bind host, port, and server certs for the test + PluginSettings.instance().overrideProperty(PerformanceAnalyzerWebServer.WEBSERVICE_BIND_HOST_NAME, BIND_HOST); + PluginSettings.instance().overrideProperty(PerformanceAnalyzerWebServer.WEBSERVICE_PORT_CONF_NAME, PORT); + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings.instance().overrideProperty(CertificateUtils.CERTIFICATE_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/server/localhost.crt")).getFile()); + PluginSettings.instance().overrideProperty(CertificateUtils.PRIVATE_KEY_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/server/localhost.key")).getFile()); + } + + @After + public void tearDown() { + // Unset all SSL settings + if (oldBindHost != null) { + PluginSettings.instance().overrideProperty(PerformanceAnalyzerWebServer.WEBSERVICE_BIND_HOST_NAME, oldBindHost); + } + if (oldPort != null) { + PluginSettings.instance().overrideProperty(PerformanceAnalyzerWebServer.WEBSERVICE_PORT_CONF_NAME, oldPort); + } + if (oldCertificateFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.CERTIFICATE_FILE_PATH, oldCertificateFilePath); + } + if (oldPrivateKeyFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.PRIVATE_KEY_FILE_PATH, oldPrivateKeyFilePath); + } + if (oldTrustedCasFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, oldTrustedCasFilePath); + } + if (oldClientCertificateFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_CERTIFICATE_FILE_PATH, oldClientCertificateFilePath); + } + if (oldClientPrivateKeyFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_PRIVATE_KEY_FILE_PATH, oldClientPrivateKeyFilePath); + } + if (oldClientTrustedCasFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, oldClientTrustedCasFilePath); + } + PluginSettings.instance().setHttpsEnabled(oldHttpsEnabled); + + // Stop the server + if (server != null) { + server.stop(0); + } + } + + public void initializeServer(boolean useHttps) { + PluginSettings.instance().setHttpsEnabled(useHttps); + server = PerformanceAnalyzerWebServer.createInternalServer(PORT, BIND_HOST, useHttps); + Assert.assertNotNull(server); + server.setExecutor(Executors.newFixedThreadPool(1)); + // Setup basic /test endpoint. When the server receives any request on /test, it responds with "hello" + server.createContext("/test", exchange -> { + exchange.getRequestBody().close(); + exchange.sendResponseHeaders(HttpResponseStatus.OK.code(), 0); + OutputStream response = exchange.getResponseBody(); + response.write(MESSAGE.getBytes()); + response.close(); + }); + server.start(); + } + + /** + * Issues a basic HTTP GET request to $BIND_HOST:$PORT/test and verifies that the response says "hello" + */ + public void verifyRequest(String urlString) throws Exception { + // Build the request + URL url = new URL(urlString); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setRequestProperty("Content-Type", "application/json"); + // Issue the request to the server + int status = connection.getResponseCode(); + BufferedReader streamReader; + if (status > 299) { + streamReader = new BufferedReader(new InputStreamReader(connection.getErrorStream())); + } else { + streamReader = new BufferedReader(new InputStreamReader(connection.getInputStream())); + } + // Read response & verify contents + String inputLine; + StringBuilder content = new StringBuilder(); + while ((inputLine = streamReader.readLine()) != null) { + content.append(inputLine); + } + streamReader.close(); + Assert.assertEquals(MESSAGE, content.toString()); + } + + /** + * Issues a basic HTTPS GET request to $BIND_HOST:$PORT/test and verifies that the response says "hello" + */ + public void verifyHttpsRequest(String urlString, String clientCert, String clientKey, String clientCA) + throws Exception { + // Build the request + URL url = new URL(urlString); + HttpsURLConnection connection = (HttpsURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setRequestProperty("Content-Type", "application/json"); + connection.setSSLSocketFactory(createSSLSocketFactory(clientCert, clientKey, clientCA)); + // Issue the request to the server + int status = connection.getResponseCode(); + BufferedReader streamReader; + if (status > 299) { + streamReader = new BufferedReader(new InputStreamReader(connection.getErrorStream())); + } else { + streamReader = new BufferedReader(new InputStreamReader(connection.getInputStream())); + } + // Read response & verify contents + String inputLine; + StringBuilder content = new StringBuilder(); + while ((inputLine = streamReader.readLine()) != null) { + content.append(inputLine); + } + streamReader.close(); + Assert.assertEquals(MESSAGE, content.toString()); + } + + /** + * testHttpServer verifies that any client can issue HTTP requests to the {@link PerformanceAnalyzerWebServer} + * when TLS is disabled + */ + @Test + public void testHttpServer() throws Exception { + // Start the HTTP server + initializeServer(false); + verifyRequest(String.format("http://%s:%s/test", BIND_HOST, PORT)); + } + + /** + * Verifies that the server accepts any client's requests when HTTPS is enabled but Auth is disabled. + */ + @Test + public void testNoAuthHttps() throws Exception { + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, ""); + initializeServer(true); + verifyRequest(String.format("https://%s:%s/test", BIND_HOST, PORT)); + } + + /** + * Utility method to create an {@link SSLSocketFactory} + * @param clientCert Client identity certificate + * @param clientKey Private key for the client certificate + * @param clientCA Client certificate authority (used to verify server identity) + * Set this to null if you don't want to authenticate the server + * @return An {@link SSLSocketFactory} configured based on the given params + * @throws Exception If something goes wrong with SSL setup + */ + public SSLSocketFactory createSSLSocketFactory(String clientCert, String clientKey, @Nullable String clientCA) throws Exception { + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings instance = PluginSettings.instance(); + // Save previous settings + String certFile = instance.getSettingValue(CertificateUtils.CERTIFICATE_FILE_PATH); + String pKey = instance.getSettingValue(CertificateUtils.PRIVATE_KEY_FILE_PATH); + String rootCA = instance.getSettingValue(CertificateUtils.TRUSTED_CAS_FILE_PATH); + String prevClientCA = instance.getSettingValue(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH); + // Override client identity settings + instance.overrideProperty(CertificateUtils.CERTIFICATE_FILE_PATH, + Objects.requireNonNull(classLoader.getResource(clientCert)).getFile()); + instance.overrideProperty(CertificateUtils.PRIVATE_KEY_FILE_PATH, + Objects.requireNonNull(classLoader.getResource(clientKey)).getFile()); + instance.overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, ""); + if (clientCA != null) { + instance.getSettingValue(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, clientCA); + } + // Setup SSLContext for the client + SSLContext sslContext = SSLContext.getInstance("TLSv1.2"); + KeyStore ks = CertificateUtils.createKeyStore(); + KeyManagerFactory kmf = KeyManagerFactory.getInstance("NewSunX509"); + kmf.init(ks, CertificateUtils.IN_MEMORY_PWD.toCharArray()); + sslContext.init(kmf.getKeyManagers(), CertificateUtils.getTrustManagers(false), null); + // Restore previous settings + instance.overrideProperty(CertificateUtils.CERTIFICATE_FILE_PATH, + certFile); + instance.overrideProperty(CertificateUtils.PRIVATE_KEY_FILE_PATH, + pKey); + instance.overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, rootCA); + if (prevClientCA == null) { + instance.overrideProperty(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, ""); + } + return sslContext.getSocketFactory(); + } + + /** + * Verifies that the HTTPS server responds to an authenticated client's requests. + */ + @Test + public void testAuthenticatedClientGetsResponse() throws Exception { + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + initializeServer(true); + // Build the request + verifyHttpsRequest(String.format("https://%s:%s/test", BIND_HOST, PORT), + "tls/client/localhost.crt", "tls/client/localhost.key", "tls/rootca/RootCA.pem"); + } + + /** + * Verifies that the HTTPS server doesn't respond to an unauthenticated client's requests. + */ + @Test + public void testUnauthenticatedClientGetsRejected() throws Exception { + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + initializeServer(true); + // Build the request + try { + verifyHttpsRequest(String.format("https://%s:%s/test", BIND_HOST, PORT), + "tls/attacker/attack_cert.pem", "tls/attacker/attack_key.pem", + "tls/rootca/RootCA.pem"); + throw new AssertionError("An unauthenticated client was able to talk to the server"); + } catch (SSLException e) { // Unauthenticated client is rejected! + assert true; + } catch (Exception e) { // Treat unexpected errors as a failure + throw new AssertionError("Received unexpected error when making unauthed REST call to server", e); + } + } + + /** + * Verifies that a client properly authenticates a trusted server + */ + @Test + public void testClientAuth() throws Exception { + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + initializeServer(true); + // Build the request + verifyHttpsRequest(String.format("https://%s:%s/test", BIND_HOST, PORT), + "tls/client/localhost.crt", "tls/client/localhost.key", "tls/rootca/RootCA.pem"); + } + + /** + * Verifies that a client doesn't authenticate a server with unknown credentials + */ + @Test + public void testThatClientRejectsUntrustedServer() { + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + // Setup client CA that doesn't trust the server's identity + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/root2ca.pem")).getFile()); + initializeServer(true); + // Build the request + try { + verifyHttpsRequest(String.format("https://%s:%s/test", BIND_HOST, PORT), + "tls/client/localhost.crt", "tls/client/localhost.key", "tls/rootca/root2ca.pem"); + throw new AssertionError("The client accepted a response from an untrusted server"); + } catch (SSLException e) { // Unauthenticated server is rejected! + assert true; + } catch (Exception e) { // Treat unexpected errors as a failure + throw new AssertionError("Received unexpected error in testThatClientRejectsUntrustedServer", e); + } + } +} diff --git a/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/net/GRPCTest.java b/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/net/GRPCTest.java index 002857cf8..7d785065a 100644 --- a/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/net/GRPCTest.java +++ b/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/net/GRPCTest.java @@ -2,19 +2,16 @@ import com.amazon.opendistro.elasticsearch.performanceanalyzer.CertificateUtils; import com.amazon.opendistro.elasticsearch.performanceanalyzer.config.PluginSettings; -import com.amazon.opendistro.elasticsearch.performanceanalyzer.core.Util; import com.amazon.opendistro.elasticsearch.performanceanalyzer.grpc.MetricsRequest; import com.amazon.opendistro.elasticsearch.performanceanalyzer.grpc.MetricsResponse; import com.amazon.opendistro.elasticsearch.performanceanalyzer.rca.GradleTaskForRca; import com.amazon.opendistro.elasticsearch.performanceanalyzer.util.WaitFor; import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; - import java.util.Objects; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; - import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.junit.AfterClass;