diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java index 3ca787fc34d..f8b275a0186 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java @@ -17,10 +17,9 @@ */ package org.apache.zookeeper.common; - -import java.io.ByteArrayInputStream; import java.io.Closeable; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; import java.net.Socket; import java.nio.file.Path; import java.nio.file.Paths; @@ -33,15 +32,14 @@ import java.security.Security; import java.security.cert.PKIXBuilderParameters; import java.security.cert.X509CertSelector; -import java.util.Arrays; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import javax.net.ssl.CertPathTrustManagerParameters; import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLServerSocket; import javax.net.ssl.SSLSocket; import javax.net.ssl.TrustManager; @@ -137,6 +135,7 @@ public static ClientAuth fromPropertyValue(String prop) { private String sslTruststoreLocationProperty = getConfigPrefix() + "trustStore.location"; private String sslTruststorePasswdProperty = getConfigPrefix() + "trustStore.password"; private String sslTruststoreTypeProperty = getConfigPrefix() + "trustStore.type"; + private String sslContextSupplierClassProperty = getConfigPrefix() + "context.supplier.class"; private String sslHostnameVerificationEnabledProperty = getConfigPrefix() + "hostnameVerification"; private String sslCrlEnabledProperty = getConfigPrefix() + "crl"; private String sslOcspEnabledProperty = getConfigPrefix() + "ocsp"; @@ -202,6 +201,10 @@ public String getSslTruststoreTypeProperty() { return sslTruststoreTypeProperty; } + public String getSslContextSupplierClassProperty() { + return sslContextSupplierClassProperty; + } + public String getSslHostnameVerificationEnabledProperty() { return sslHostnameVerificationEnabledProperty; } @@ -282,7 +285,28 @@ public int getSslHandshakeTimeoutMillis() { } } + @SuppressWarnings("unchecked") public SSLContextAndOptions createSSLContextAndOptions(ZKConfig config) throws SSLContextException { + final String supplierContextClassName = config.getProperty(sslContextSupplierClassProperty); + if (supplierContextClassName != null) { + if (LOG.isDebugEnabled()) { + LOG.debug("Loading SSLContext supplier from property '{}'", sslContextSupplierClassProperty); + } + try { + Class sslContextClass = Class.forName(supplierContextClassName); + Supplier sslContextSupplier = (Supplier) sslContextClass.getConstructor().newInstance(); + return new SSLContextAndOptions(this, config, sslContextSupplier.get()); + } catch (ClassNotFoundException | ClassCastException | NoSuchMethodException | InvocationTargetException | + InstantiationException | IllegalAccessException e) { + throw new SSLContextException("Could not retrieve the SSLContext from supplier source '" + supplierContextClassName + + "' provided in the property '" + sslContextSupplierClassProperty + "'", e); + } + } else { + return createSSLContextAndOptionsFromConfig(config); + } + } + + public SSLContextAndOptions createSSLContextAndOptionsFromConfig(ZKConfig config) throws SSLContextException { KeyManager[] keyManagers = null; TrustManager[] trustManagers = null; diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java index 43bc2d8e95c..76bdd2e2072 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java @@ -133,6 +133,8 @@ private void putSSLProperties(X509Util x509Util) { System.getProperty(x509Util.getSslTruststorePasswdProperty())); properties.put(x509Util.getSslTruststoreTypeProperty(), System.getProperty(x509Util.getSslTruststoreTypeProperty())); + properties.put(x509Util.getSslContextSupplierClassProperty(), + System.getProperty(x509Util.getSslContextSupplierClassProperty())); properties.put(x509Util.getSslHostnameVerificationEnabledProperty(), System.getProperty(x509Util.getSslHostnameVerificationEnabledProperty())); properties.put(x509Util.getSslCrlEnabledProperty(), diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java index 2a6bb3246f5..1fecd808de8 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java @@ -22,6 +22,7 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; +import java.security.NoSuchAlgorithmException; import java.security.Security; import java.util.Collection; import java.util.concurrent.Callable; @@ -30,6 +31,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import javax.net.ssl.HandshakeCompletedEvent; import javax.net.ssl.HandshakeCompletedListener; @@ -403,6 +405,23 @@ public void testGetSslHandshakeDetectionTimeoutMillisProperty() { } } + @Test(expected = X509Exception.SSLContextException.class) + public void testCreateSSLContext_invalidCustomSSLContextClass() throws Exception { + ZKConfig zkConfig = new ZKConfig(); + ClientX509Util clientX509Util = new ClientX509Util(); + zkConfig.setProperty(clientX509Util.getSslContextSupplierClassProperty(), String.class.getCanonicalName()); + clientX509Util.createSSLContext(zkConfig); + } + + @Test + public void testCreateSSLContext_validCustomSSLContextClass() throws Exception { + ZKConfig zkConfig = new ZKConfig(); + ClientX509Util clientX509Util = new ClientX509Util(); + zkConfig.setProperty(clientX509Util.getSslContextSupplierClassProperty(), SslContextSupplier.class.getName()); + final SSLContext sslContext = clientX509Util.createSSLContext(zkConfig); + Assert.assertEquals(SSLContext.getDefault(), sslContext); + } + private static void forceClose(Socket s) { if (s == null || s.isClosed()) { return; @@ -528,4 +547,18 @@ private void setCustomCipherSuites() { x509Util.close(); // remember to close old instance before replacing it x509Util = new ClientX509Util(); } + + public static class SslContextSupplier implements Supplier { + + @Override + public SSLContext get() { + try { + return SSLContext.getDefault(); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + + } + }