diff --git a/.github/workflows/gradle.yml b/.github/workflows/gradle.yml index 6d8cacf76..9579881d0 100644 --- a/.github/workflows/gradle.yml +++ b/.github/workflows/gradle.yml @@ -32,7 +32,7 @@ jobs: java-version: 1.12 - name: Build RCA with Gradle working-directory: ./tmp/rca - run: ./gradlew build + run: ./gradlew build --stacktrace - name: Generate Jacoco coverage report working-directory: ./tmp/rca run: ./gradlew jacocoTestReport diff --git a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/CertificateUtils.java b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/CertificateUtils.java index 57eee09d6..49d05a2f8 100644 --- a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/CertificateUtils.java +++ b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/CertificateUtils.java @@ -21,7 +21,12 @@ import java.security.KeyStore; import java.security.PrivateKey; import java.security.cert.Certificate; +import java.security.cert.X509Certificate; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; @@ -32,7 +37,7 @@ public class CertificateUtils { - public static final String ALIAS_PRIVATE = "private"; + public static final String ALIAS_IDENTITY = "identity"; public static final String ALIAS_CERT = "cert"; // The password is not used to encrypt keys on disk. public static final String IN_MEMORY_PWD = "opendistro"; @@ -65,14 +70,46 @@ public static PrivateKey getPrivateKey(final FileReader keyReader) throws Except public static KeyStore createKeyStore() throws Exception { String certFilePath = PluginSettings.instance().getSettingValue(CERTIFICATE_FILE_PATH); String keyFilePath = PluginSettings.instance().getSettingValue(PRIVATE_KEY_FILE_PATH); + KeyStore.ProtectionParameter protParam = new KeyStore.PasswordProtection( + CertificateUtils.IN_MEMORY_PWD.toCharArray()); PrivateKey pk = getPrivateKey(new FileReader(keyFilePath)); KeyStore ks = createEmptyStore(); Certificate certificate = getCertificate(new FileReader(certFilePath)); - ks.setCertificateEntry(ALIAS_CERT, certificate); - ks.setKeyEntry(ALIAS_PRIVATE, pk, IN_MEMORY_PWD.toCharArray(), new Certificate[] {certificate}); + ks.setEntry(ALIAS_IDENTITY, new KeyStore.PrivateKeyEntry(pk, new Certificate[]{certificate}), protParam); return ks; } + public static TrustManager[] getTrustManagers(boolean forServer) throws Exception { + // If a certificate authority is specified, create an authenticating trust manager + String certificateAuthority; + if (forServer) { + certificateAuthority = PluginSettings.instance().getSettingValue(TRUSTED_CAS_FILE_PATH); + } else { + certificateAuthority = PluginSettings.instance().getSettingValue(CLIENT_TRUSTED_CAS_FILE_PATH); + } + if (certificateAuthority != null && !certificateAuthority.isEmpty()) { + KeyStore ks = createEmptyStore(); + Certificate certificate = getCertificate(new FileReader(certificateAuthority)); + ks.setCertificateEntry(ALIAS_CERT, certificate); + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ks); + return tmf.getTrustManagers(); + } + // Otherwise, return an all-trusting TrustManager + return new TrustManager[] { + new X509TrustManager() { + + public X509Certificate[] getAcceptedIssuers() { + return null; + } + + public void checkClientTrusted(X509Certificate[] certs, String authType) {} + + public void checkServerTrusted(X509Certificate[] certs, String authType) {} + } + }; + } + public static KeyStore createEmptyStore() throws Exception { KeyStore ks = KeyStore.getInstance("JKS"); ks.load(null, IN_MEMORY_PWD.toCharArray()); diff --git a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServer.java b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServer.java index b8c8b5eea..19119d39e 100644 --- a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServer.java +++ b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServer.java @@ -16,22 +16,25 @@ package com.amazon.opendistro.elasticsearch.performanceanalyzer; import com.amazon.opendistro.elasticsearch.performanceanalyzer.config.PluginSettings; +import com.google.common.annotations.VisibleForTesting; import com.sun.net.httpserver.HttpServer; import com.sun.net.httpserver.HttpsConfigurator; +import com.sun.net.httpserver.HttpsParameters; import com.sun.net.httpserver.HttpsServer; + import java.net.InetAddress; import java.net.InetSocketAddress; import java.security.KeyStore; import java.security.Security; -import java.security.cert.X509Certificate; import java.util.concurrent.Executors; + import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLSession; -import javax.net.ssl.TrustManager; -import javax.net.ssl.X509TrustManager; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.bouncycastle.jce.provider.BouncyCastleProvider; @@ -40,8 +43,10 @@ public class PerformanceAnalyzerWebServer { private static final Logger LOG = LogManager.getLogger(PerformanceAnalyzerWebServer.class); public static final int WEBSERVICE_DEFAULT_PORT = 9600; - public static final String WEBSERVICE_PORT_CONF_NAME = "webservice-listener-port"; + @VisibleForTesting public static final String WEBSERVICE_BIND_HOST_NAME = "webservice-bind-host"; + @VisibleForTesting + public static final String WEBSERVICE_PORT_CONF_NAME = "webservice-listener-port"; // Use system default for max backlog. private static final int INCOMING_QUEUE_LENGTH = 1; @@ -66,8 +71,34 @@ public static HttpServer createInternalServer(String portFromSetting, String hos return null; } + /** + * ClientAuthConfigurator makes the server perform client authentication if the user has set up a + * certificate authority + */ + private static class ClientAuthConfigurator extends HttpsConfigurator { + public ClientAuthConfigurator(SSLContext sslContext) { + super(sslContext); + } + + @Override + public void configure(HttpsParameters params) { + final SSLParameters sslParams = getSSLContext().getDefaultSSLParameters(); + if (CertificateUtils.getTrustedCasFile() != null) { + LOG.debug("Enabling client auth"); + final SSLEngine sslEngine = getSSLContext().createSSLEngine(); + sslParams.setNeedClientAuth(true); + sslParams.setCipherSuites(sslEngine.getEnabledCipherSuites()); + sslParams.setProtocols(sslEngine.getEnabledProtocols()); + params.setSSLParameters(sslParams); + } else { + LOG.debug("Not enabling client auth"); + super.configure(params); + } + } + } + private static HttpServer createHttpsServer(int readerPort, String bindHost) throws Exception { - HttpsServer server = null; + HttpsServer server; if (bindHost != null && !bindHost.trim().isEmpty()) { LOG.info("Binding to Interface: {}", bindHost); server = @@ -81,38 +112,30 @@ private static HttpServer createHttpsServer(int readerPort, String bindHost) thr server = HttpsServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), readerPort), INCOMING_QUEUE_LENGTH); } - TrustManager[] trustAllCerts = - new TrustManager[] { - new X509TrustManager() { - - public X509Certificate[] getAcceptedIssuers() { - return null; - } - - public void checkClientTrusted(X509Certificate[] certs, String authType) {} - - public void checkServerTrusted(X509Certificate[] certs, String authType) {} - } - }; - - HostnameVerifier allHostsValid = - new HostnameVerifier() { - public boolean verify(String hostname, SSLSession session) { - return true; - } - }; - // Install the all-trusting trust manager SSLContext sslContext = SSLContext.getInstance("TLSv1.2"); KeyStore ks = CertificateUtils.createKeyStore(); KeyManagerFactory kmf = KeyManagerFactory.getInstance("NewSunX509"); kmf.init(ks, CertificateUtils.IN_MEMORY_PWD.toCharArray()); - sslContext.init(kmf.getKeyManagers(), trustAllCerts, null); + sslContext.init(kmf.getKeyManagers(), CertificateUtils.getTrustManagers(true), null); + server.setHttpsConfigurator(new ClientAuthConfigurator(sslContext)); + + + // TODO ask ktkrg why this is necessary + // Try to set HttpsURLConnection defaults, our webserver can still run even if this block fails + try { + LOG.debug("Setting default SSLSocketFactory..."); + HttpsURLConnection.setDefaultSSLSocketFactory(sslContext.getSocketFactory()); + LOG.debug("Default SSLSocketFactory set successfully"); + HostnameVerifier allHostsValid = (hostname, session) -> true; + LOG.debug("Setting default HostnameVerifier..."); + HttpsURLConnection.setDefaultHostnameVerifier(allHostsValid); + LOG.debug("Default HostnameVerifier set successfully"); + } catch (Exception e) { // Usually AccessControlException + LOG.warn("Exception while trying to set URLConnection defaults", e); + } - HttpsURLConnection.setDefaultSSLSocketFactory(sslContext.getSocketFactory()); - HttpsURLConnection.setDefaultHostnameVerifier(allHostsValid); - server.setHttpsConfigurator(new HttpsConfigurator(sslContext)); return server; } diff --git a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/config/PluginSettings.java b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/config/PluginSettings.java index 87bc842f6..0b5f9d036 100644 --- a/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/config/PluginSettings.java +++ b/src/main/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/config/PluginSettings.java @@ -103,6 +103,11 @@ public boolean getHttpsEnabled() { return this.httpsEnabled; } + @VisibleForTesting + public void setHttpsEnabled(boolean httpsEnabled) { + this.httpsEnabled = httpsEnabled; + } + @VisibleForTesting public void overrideProperty(String key, String value) { settings.setProperty(key, value); diff --git a/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServerTest.java b/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServerTest.java new file mode 100644 index 000000000..2d0f818b7 --- /dev/null +++ b/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/PerformanceAnalyzerWebServerTest.java @@ -0,0 +1,317 @@ +package com.amazon.opendistro.elasticsearch.performanceanalyzer; + +import com.amazon.opendistro.elasticsearch.performanceanalyzer.config.PluginSettings; +import com.sun.net.httpserver.HttpServer; + +import io.grpc.netty.shaded.io.netty.handler.codec.http.HttpResponseStatus; + +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.security.KeyStore; +import java.util.Objects; +import java.util.concurrent.Executors; + +import javax.annotation.Nullable; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSocketFactory; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class PerformanceAnalyzerWebServerTest { + private static final String BIND_HOST = "localhost"; + private static final String PORT = "11021"; + private static final String MESSAGE = "hello"; + + private String oldBindHost; + private String oldPort; + private String oldCertificateFilePath; + private String oldPrivateKeyFilePath; + private String oldTrustedCasFilePath; + private String oldClientCertificateFilePath; + private String oldClientPrivateKeyFilePath; + private String oldClientTrustedCasFilePath; + private boolean oldHttpsEnabled; + + private HttpServer server; + + @Before + public void setup() { + // Save old PluginSettings values + oldBindHost = PluginSettings.instance().getSettingValue(PerformanceAnalyzerWebServer.WEBSERVICE_BIND_HOST_NAME); + oldPort = PluginSettings.instance().getSettingValue(PerformanceAnalyzerWebServer.WEBSERVICE_PORT_CONF_NAME); + oldCertificateFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.CERTIFICATE_FILE_PATH); + oldPrivateKeyFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.PRIVATE_KEY_FILE_PATH); + oldTrustedCasFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.TRUSTED_CAS_FILE_PATH); + oldClientCertificateFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.CLIENT_CERTIFICATE_FILE_PATH); + oldClientPrivateKeyFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.CLIENT_PRIVATE_KEY_FILE_PATH); + oldClientTrustedCasFilePath = PluginSettings.instance().getSettingValue(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH); + oldHttpsEnabled = PluginSettings.instance().getHttpsEnabled(); + // Update bind host, port, and server certs for the test + PluginSettings.instance().overrideProperty(PerformanceAnalyzerWebServer.WEBSERVICE_BIND_HOST_NAME, BIND_HOST); + PluginSettings.instance().overrideProperty(PerformanceAnalyzerWebServer.WEBSERVICE_PORT_CONF_NAME, PORT); + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings.instance().overrideProperty(CertificateUtils.CERTIFICATE_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/server/localhost.crt")).getFile()); + PluginSettings.instance().overrideProperty(CertificateUtils.PRIVATE_KEY_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/server/localhost.key")).getFile()); + } + + @After + public void tearDown() { + // Unset all SSL settings + if (oldBindHost != null) { + PluginSettings.instance().overrideProperty(PerformanceAnalyzerWebServer.WEBSERVICE_BIND_HOST_NAME, oldBindHost); + } + if (oldPort != null) { + PluginSettings.instance().overrideProperty(PerformanceAnalyzerWebServer.WEBSERVICE_PORT_CONF_NAME, oldPort); + } + if (oldCertificateFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.CERTIFICATE_FILE_PATH, oldCertificateFilePath); + } + if (oldPrivateKeyFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.PRIVATE_KEY_FILE_PATH, oldPrivateKeyFilePath); + } + if (oldTrustedCasFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, oldTrustedCasFilePath); + } + if (oldClientCertificateFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_CERTIFICATE_FILE_PATH, oldClientCertificateFilePath); + } + if (oldClientPrivateKeyFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_PRIVATE_KEY_FILE_PATH, oldClientPrivateKeyFilePath); + } + if (oldClientTrustedCasFilePath != null) { + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, oldClientTrustedCasFilePath); + } + PluginSettings.instance().setHttpsEnabled(oldHttpsEnabled); + + // Stop the server + if (server != null) { + server.stop(0); + } + } + + public void initializeServer(boolean useHttps) { + PluginSettings.instance().setHttpsEnabled(useHttps); + server = PerformanceAnalyzerWebServer.createInternalServer(PORT, BIND_HOST, useHttps); + Assert.assertNotNull(server); + server.setExecutor(Executors.newFixedThreadPool(1)); + // Setup basic /test endpoint. When the server receives any request on /test, it responds with "hello" + server.createContext("/test", exchange -> { + exchange.getRequestBody().close(); + exchange.sendResponseHeaders(HttpResponseStatus.OK.code(), 0); + OutputStream response = exchange.getResponseBody(); + response.write(MESSAGE.getBytes()); + response.close(); + }); + server.start(); + } + + /** + * Issues a basic HTTP GET request to $BIND_HOST:$PORT/test and verifies that the response says "hello" + */ + public void verifyRequest(String urlString) throws Exception { + // Build the request + URL url = new URL(urlString); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setRequestProperty("Content-Type", "application/json"); + // Issue the request to the server + int status = connection.getResponseCode(); + BufferedReader streamReader; + if (status > 299) { + streamReader = new BufferedReader(new InputStreamReader(connection.getErrorStream())); + } else { + streamReader = new BufferedReader(new InputStreamReader(connection.getInputStream())); + } + // Read response & verify contents + String inputLine; + StringBuilder content = new StringBuilder(); + while ((inputLine = streamReader.readLine()) != null) { + content.append(inputLine); + } + streamReader.close(); + Assert.assertEquals(MESSAGE, content.toString()); + } + + /** + * Issues a basic HTTPS GET request to $BIND_HOST:$PORT/test and verifies that the response says "hello" + */ + public void verifyHttpsRequest(String urlString, String clientCert, String clientKey, String clientCA) + throws Exception { + // Build the request + URL url = new URL(urlString); + HttpsURLConnection connection = (HttpsURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setRequestProperty("Content-Type", "application/json"); + connection.setSSLSocketFactory(createSSLSocketFactory(clientCert, clientKey, clientCA)); + // Issue the request to the server + int status = connection.getResponseCode(); + BufferedReader streamReader; + if (status > 299) { + streamReader = new BufferedReader(new InputStreamReader(connection.getErrorStream())); + } else { + streamReader = new BufferedReader(new InputStreamReader(connection.getInputStream())); + } + // Read response & verify contents + String inputLine; + StringBuilder content = new StringBuilder(); + while ((inputLine = streamReader.readLine()) != null) { + content.append(inputLine); + } + streamReader.close(); + Assert.assertEquals(MESSAGE, content.toString()); + } + + /** + * testHttpServer verifies that any client can issue HTTP requests to the {@link PerformanceAnalyzerWebServer} + * when TLS is disabled + */ + @Test + public void testHttpServer() throws Exception { + // Start the HTTP server + initializeServer(false); + verifyRequest(String.format("http://%s:%s/test", BIND_HOST, PORT)); + } + + /** + * Verifies that the server accepts any client's requests when HTTPS is enabled but Auth is disabled. + */ + @Test + public void testNoAuthHttps() throws Exception { + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, ""); + initializeServer(true); + verifyRequest(String.format("https://%s:%s/test", BIND_HOST, PORT)); + } + + /** + * Utility method to create an {@link SSLSocketFactory} + * @param clientCert Client identity certificate + * @param clientKey Private key for the client certificate + * @param clientCA Client certificate authority (used to verify server identity) + * Set this to null if you don't want to authenticate the server + * @return An {@link SSLSocketFactory} configured based on the given params + * @throws Exception If something goes wrong with SSL setup + */ + public SSLSocketFactory createSSLSocketFactory(String clientCert, String clientKey, @Nullable String clientCA) throws Exception { + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings instance = PluginSettings.instance(); + // Save previous settings + String certFile = instance.getSettingValue(CertificateUtils.CERTIFICATE_FILE_PATH); + String pKey = instance.getSettingValue(CertificateUtils.PRIVATE_KEY_FILE_PATH); + String rootCA = instance.getSettingValue(CertificateUtils.TRUSTED_CAS_FILE_PATH); + String prevClientCA = instance.getSettingValue(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH); + // Override client identity settings + instance.overrideProperty(CertificateUtils.CERTIFICATE_FILE_PATH, + Objects.requireNonNull(classLoader.getResource(clientCert)).getFile()); + instance.overrideProperty(CertificateUtils.PRIVATE_KEY_FILE_PATH, + Objects.requireNonNull(classLoader.getResource(clientKey)).getFile()); + instance.overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, ""); + if (clientCA != null) { + instance.getSettingValue(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, clientCA); + } + // Setup SSLContext for the client + SSLContext sslContext = SSLContext.getInstance("TLSv1.2"); + KeyStore ks = CertificateUtils.createKeyStore(); + KeyManagerFactory kmf = KeyManagerFactory.getInstance("NewSunX509"); + kmf.init(ks, CertificateUtils.IN_MEMORY_PWD.toCharArray()); + sslContext.init(kmf.getKeyManagers(), CertificateUtils.getTrustManagers(false), null); + // Restore previous settings + instance.overrideProperty(CertificateUtils.CERTIFICATE_FILE_PATH, + certFile); + instance.overrideProperty(CertificateUtils.PRIVATE_KEY_FILE_PATH, + pKey); + instance.overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, rootCA); + if (prevClientCA == null) { + instance.overrideProperty(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, ""); + } + return sslContext.getSocketFactory(); + } + + /** + * Verifies that the HTTPS server responds to an authenticated client's requests. + */ + @Test + public void testAuthenticatedClientGetsResponse() throws Exception { + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + initializeServer(true); + // Build the request + verifyHttpsRequest(String.format("https://%s:%s/test", BIND_HOST, PORT), + "tls/client/localhost.crt", "tls/client/localhost.key", "tls/rootca/RootCA.pem"); + } + + /** + * Verifies that the HTTPS server doesn't respond to an unauthenticated client's requests. + */ + @Test + public void testUnauthenticatedClientGetsRejected() throws Exception { + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + initializeServer(true); + // Build the request + try { + verifyHttpsRequest(String.format("https://%s:%s/test", BIND_HOST, PORT), + "tls/attacker/attack_cert.pem", "tls/attacker/attack_key.pem", + "tls/rootca/RootCA.pem"); + throw new AssertionError("An unauthenticated client was able to talk to the server"); + } catch (SSLException e) { // Unauthenticated client is rejected! + assert true; + } catch (Exception e) { // Treat unexpected errors as a failure + throw new AssertionError("Received unexpected error when making unauthed REST call to server", e); + } + } + + /** + * Verifies that a client properly authenticates a trusted server + */ + @Test + public void testClientAuth() throws Exception { + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + initializeServer(true); + // Build the request + verifyHttpsRequest(String.format("https://%s:%s/test", BIND_HOST, PORT), + "tls/client/localhost.crt", "tls/client/localhost.key", "tls/rootca/RootCA.pem"); + } + + /** + * Verifies that a client doesn't authenticate a server with unknown credentials + */ + @Test + public void testThatClientRejectsUntrustedServer() { + ClassLoader classLoader = getClass().getClassLoader(); + PluginSettings.instance().overrideProperty(CertificateUtils.TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/RootCA.pem")).getFile()); + // Setup client CA that doesn't trust the server's identity + PluginSettings.instance().overrideProperty(CertificateUtils.CLIENT_TRUSTED_CAS_FILE_PATH, + Objects.requireNonNull(classLoader.getResource("tls/rootca/root2ca.pem")).getFile()); + initializeServer(true); + // Build the request + try { + verifyHttpsRequest(String.format("https://%s:%s/test", BIND_HOST, PORT), + "tls/client/localhost.crt", "tls/client/localhost.key", "tls/rootca/root2ca.pem"); + throw new AssertionError("The client accepted a response from an untrusted server"); + } catch (SSLException e) { // Unauthenticated server is rejected! + assert true; + } catch (Exception e) { // Treat unexpected errors as a failure + throw new AssertionError("Received unexpected error in testThatClientRejectsUntrustedServer", e); + } + } +} diff --git a/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/net/GRPCTest.java b/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/net/GRPCTest.java index 002857cf8..7d785065a 100644 --- a/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/net/GRPCTest.java +++ b/src/test/java/com/amazon/opendistro/elasticsearch/performanceanalyzer/net/GRPCTest.java @@ -2,19 +2,16 @@ import com.amazon.opendistro.elasticsearch.performanceanalyzer.CertificateUtils; import com.amazon.opendistro.elasticsearch.performanceanalyzer.config.PluginSettings; -import com.amazon.opendistro.elasticsearch.performanceanalyzer.core.Util; import com.amazon.opendistro.elasticsearch.performanceanalyzer.grpc.MetricsRequest; import com.amazon.opendistro.elasticsearch.performanceanalyzer.grpc.MetricsResponse; import com.amazon.opendistro.elasticsearch.performanceanalyzer.rca.GradleTaskForRca; import com.amazon.opendistro.elasticsearch.performanceanalyzer.util.WaitFor; import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; - import java.util.Objects; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; - import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.junit.AfterClass;