diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
index 4ea105b3e13..0aea4c344ac 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java
@@ -34,6 +34,7 @@
import java.security.cert.PKIXBuilderParameters;
import java.security.cert.X509CertSelector;
import java.util.Arrays;
+import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.CertPathTrustManagerParameters;
@@ -82,7 +83,187 @@ public abstract class X509Util implements Closeable, AutoCloseable {
public static final int DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS = 5000;
+ /**
+ * Enum specifying the client auth requirement of server-side TLS sockets created by this X509Util.
+ *
+ * - NONE - do not request a client certificate.
+ * - WANT - request a client certificate, but allow anonymous clients to connect.
+ * - NEED - require a client certificate, disconnect anonymous clients.
+ *
+ *
+ * If the config property is not set, the default value is NEED.
+ */
+ public enum ClientAuth {
+ NONE,
+ WANT,
+ NEED;
+
+ /**
+ * Converts a property value to a ClientAuth enum. If the input string is empty or null, returns
+ * ClientAuth.NEED
.
+ * @param prop the property string.
+ * @return the ClientAuth.
+ * @throws IllegalArgumentException if the property value is not "NONE", "WANT", "NEED", or empty/null.
+ */
+ public static ClientAuth fromPropertyValue(String prop) {
+ if (prop == null || prop.length() == 0) {
+ return NEED;
+ }
+ return ClientAuth.valueOf(prop.toUpperCase());
+ }
+ }
+
+ /**
+ * Wrapper class for an SSLContext + some config options that can't be set on the context when it is created but
+ * must be set on a secure socket created by the context after the socket creation. By wrapping the options in this
+ * class we avoid reading from global system properties during socket configuration. This makes testing easier
+ * since we can create different X509Util instances with different configurations in a single test process, and
+ * unit test interactions between them.
+ */
+ public class SSLContextAndOptions {
+ private final String[] enabledProtocols;
+ private final String[] cipherSuites;
+ private final ClientAuth clientAuth;
+ private final SSLContext sslContext;
+ private final int handshakeDetectionTimeoutMillis;
+
+ /**
+ * Note: constructor is intentionally private, only the enclosing X509Util should be creating instances of this
+ * class.
+ * @param config
+ * @param sslContext
+ */
+ private SSLContextAndOptions(final ZKConfig config,
+ final SSLContext sslContext) {
+ this.sslContext = sslContext;
+ this.enabledProtocols = getEnabledProtocols(config, sslContext);
+ this.cipherSuites = getCipherSuites(config);
+ this.clientAuth = getClientAuth(config);
+ this.handshakeDetectionTimeoutMillis = getHandshakeDetectionTimeoutMillis(config);
+ }
+
+ public SSLContext getSSLContext() {
+ return sslContext;
+ }
+
+ public SSLSocket createSSLSocket() throws IOException {
+ return configureSSLSocket((SSLSocket) sslContext.getSocketFactory().createSocket(), true);
+ }
+
+ public SSLSocket createSSLSocket(Socket socket, byte[] pushbackBytes) throws IOException {
+ SSLSocket sslSocket;
+ if (pushbackBytes != null && pushbackBytes.length > 0) {
+ sslSocket = (SSLSocket) sslContext.getSocketFactory().createSocket(
+ socket, new ByteArrayInputStream(pushbackBytes), true);
+ } else {
+ sslSocket = (SSLSocket) sslContext.getSocketFactory().createSocket(
+ socket, null, socket.getPort(), true);
+ }
+ return configureSSLSocket(sslSocket, false);
+ }
+
+ public SSLServerSocket createSSLServerSocket() throws IOException {
+ SSLServerSocket sslServerSocket =
+ (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket();
+ return configureSSLServerSocket(sslServerSocket);
+ }
+
+ public SSLServerSocket createSSLServerSocket(int port) throws IOException {
+ SSLServerSocket sslServerSocket =
+ (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(port);
+ return configureSSLServerSocket(sslServerSocket);
+ }
+
+ private SSLSocket configureSSLSocket(SSLSocket socket, boolean isClientSocket) {
+ SSLParameters sslParameters = socket.getSSLParameters();
+ configureSslParameters(sslParameters, isClientSocket);
+ socket.setSSLParameters(sslParameters);
+ socket.setUseClientMode(isClientSocket);
+ return socket;
+ }
+
+ private SSLServerSocket configureSSLServerSocket(SSLServerSocket socket) {
+ SSLParameters sslParameters = socket.getSSLParameters();
+ configureSslParameters(sslParameters, false);
+ socket.setSSLParameters(sslParameters);
+ socket.setUseClientMode(false);
+ return socket;
+ }
+
+ private void configureSslParameters(SSLParameters sslParameters, boolean isClientSocket) {
+ if (cipherSuites != null) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Setup cipher suites for {} socket: {}",
+ isClientSocket ? "client" : "server",
+ Arrays.toString(cipherSuites));
+ }
+ sslParameters.setCipherSuites(cipherSuites);
+ }
+ if (enabledProtocols != null) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Setup enabled protocols for {} socket: {}",
+ isClientSocket ? "client" : "server",
+ Arrays.toString(enabledProtocols));
+ }
+ sslParameters.setProtocols(enabledProtocols);
+ }
+ if (!isClientSocket) {
+ switch (clientAuth) {
+ case NEED:
+ sslParameters.setNeedClientAuth(true);
+ break;
+ case WANT:
+ sslParameters.setWantClientAuth(true);
+ break;
+ default:
+ sslParameters.setNeedClientAuth(false); // also clears the wantClientAuth flag according to docs
+ break;
+ }
+ }
+ }
+
+ private String[] getEnabledProtocols(final ZKConfig config, final SSLContext sslContext) {
+ String enabledProtocolsInput = config.getProperty(getSslEnabledProtocolsProperty());
+ if (enabledProtocolsInput == null) {
+ return new String[] { sslContext.getProtocol() };
+ }
+ return enabledProtocolsInput.split(",");
+ }
+
+ private String[] getCipherSuites(final ZKConfig config) {
+ String cipherSuitesInput = config.getProperty(getSslCipherSuitesProperty());
+ if (cipherSuitesInput == null) {
+ return X509Util.getDefaultCipherSuites();
+ } else {
+ return cipherSuitesInput.split(",");
+ }
+ }
+
+ private ClientAuth getClientAuth(final ZKConfig config) {
+ return ClientAuth.fromPropertyValue(config.getProperty(getSslClientAuthProperty()));
+ }
+
+ private int getHandshakeDetectionTimeoutMillis(final ZKConfig config) {
+ String propertyString = config.getProperty(getSslHandshakeDetectionTimeoutMillisProperty());
+ int result;
+ if (propertyString == null) {
+ result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
+ } else {
+ result = Integer.parseInt(propertyString);
+ if (result < 1) {
+ // Timeout of 0 is not allowed, since an infinite timeout can permanently lock up an
+ // accept() thread.
+ LOG.warn("Invalid value for " + getSslHandshakeDetectionTimeoutMillisProperty() + ": " + result +
+ ", using the default value of " + DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS);
+ result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
+ }
+ }
+ return result;
+ }
+ }
+
private String sslProtocolProperty = getConfigPrefix() + "protocol";
+ private String sslEnabledProtocolsProperty = getConfigPrefix() + "enabledProtocols";
private String cipherSuitesProperty = getConfigPrefix() + "ciphersuites";
private String sslKeystoreLocationProperty = getConfigPrefix() + "keyStore.location";
private String sslKeystorePasswdProperty = getConfigPrefix() + "keyStore.password";
@@ -93,30 +274,36 @@ public abstract class X509Util implements Closeable, AutoCloseable {
private String sslHostnameVerificationEnabledProperty = getConfigPrefix() + "hostnameVerification";
private String sslCrlEnabledProperty = getConfigPrefix() + "crl";
private String sslOcspEnabledProperty = getConfigPrefix() + "ocsp";
+ private String sslClientAuthProperty = getConfigPrefix() + "clientAuth";
private String sslHandshakeDetectionTimeoutMillisProperty = getConfigPrefix() + "handshakeDetectionTimeoutMillis";
- private String[] cipherSuites;
+ private ZKConfig zkConfig;
+ private AtomicReference defaultSSLContextAndOptions = new AtomicReference<>(null);
- private AtomicReference defaultSSLContext = new AtomicReference<>(null);
private FileChangeWatcher keyStoreFileWatcher;
private FileChangeWatcher trustStoreFileWatcher;
public X509Util() {
- String cipherSuitesInput = System.getProperty(cipherSuitesProperty);
- if (cipherSuitesInput == null) {
- cipherSuites = getDefaultCipherSuites();
- } else {
- cipherSuites = cipherSuitesInput.split(",");
- }
+ this(null);
+ }
+
+ public X509Util(ZKConfig zkConfig) {
+ this.zkConfig = zkConfig;
+ keyStoreFileWatcher = trustStoreFileWatcher = null;
}
protected abstract String getConfigPrefix();
+
protected abstract boolean shouldVerifyClientHostname();
public String getSslProtocolProperty() {
return sslProtocolProperty;
}
+ public String getSslEnabledProtocolsProperty() {
+ return sslEnabledProtocolsProperty;
+ }
+
public String getCipherSuitesProperty() {
return cipherSuitesProperty;
}
@@ -125,6 +312,10 @@ public String getSslKeystoreLocationProperty() {
return sslKeystoreLocationProperty;
}
+ public String getSslCipherSuitesProperty() {
+ return cipherSuitesProperty;
+ }
+
public String getSslKeystorePasswdProperty() {
return sslKeystorePasswdProperty;
}
@@ -157,6 +348,10 @@ public String getSslOcspEnabledProperty() {
return sslOcspEnabledProperty;
}
+ public String getSslClientAuthProperty() {
+ return sslClientAuthProperty;
+ }
+
/**
* Returns the config property key that controls the amount of time, in milliseconds, that the first
* UnifiedServerSocket read operation will block for when trying to detect the client mode (TLS or PLAINTEXT).
@@ -168,30 +363,37 @@ public String getSslHandshakeDetectionTimeoutMillisProperty() {
}
public SSLContext getDefaultSSLContext() throws X509Exception.SSLContextException {
- SSLContext result = defaultSSLContext.get();
+ return getDefaultSSLContextAndOptions().getSSLContext();
+ }
+
+ public SSLContext createSSLContext(ZKConfig config) throws SSLContextException {
+ return createSSLContextAndOptions(config).getSSLContext();
+ }
+
+ public SSLContextAndOptions getDefaultSSLContextAndOptions() throws X509Exception.SSLContextException {
+ SSLContextAndOptions result = defaultSSLContextAndOptions.get();
if (result == null) {
- result = createSSLContext();
- if (!defaultSSLContext.compareAndSet(null, result)) {
+ result = createSSLContextAndOptions();
+ if (!defaultSSLContextAndOptions.compareAndSet(null, result)) {
// lost the race, another thread already set the value
- result = defaultSSLContext.get();
+ result = defaultSSLContextAndOptions.get();
}
}
return result;
}
- private void resetDefaultSSLContext() throws X509Exception.SSLContextException {
- SSLContext newContext = createSSLContext();
- defaultSSLContext.set(newContext);
+ private void resetDefaultSSLContextAndOptions() throws X509Exception.SSLContextException {
+ SSLContextAndOptions newContext = createSSLContextAndOptions();
+ defaultSSLContextAndOptions.set(newContext);
}
- private SSLContext createSSLContext() throws SSLContextException {
+ private SSLContextAndOptions createSSLContextAndOptions() throws SSLContextException {
/*
* Since Configuration initializes the key store and trust store related
* configuration from system property. Reading property from
* configuration will be same reading from system property
*/
- ZKConfig config=new ZKConfig();
- return createSSLContext(config);
+ return createSSLContextAndOptions(zkConfig == null ? new ZKConfig() : zkConfig);
}
/**
@@ -202,24 +404,19 @@ private SSLContext createSSLContext() throws SSLContextException {
* @return the handshake detection timeout, in milliseconds.
*/
public int getSslHandshakeTimeoutMillis() {
- String propertyString = System.getProperty(getSslHandshakeDetectionTimeoutMillisProperty());
- int result;
- if (propertyString == null) {
- result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
- } else {
- result = Integer.parseInt(propertyString);
- if (result < 1) {
- // Timeout of 0 is not allowed, since an infinite timeout can permanently lock up an
- // accept() thread.
- LOG.warn("Invalid value for " + getSslHandshakeDetectionTimeoutMillisProperty() + ": " + result +
- ", using the default value of " + DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS);
- result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
- }
+ try {
+ SSLContextAndOptions ctx = getDefaultSSLContextAndOptions();
+ return ctx.handshakeDetectionTimeoutMillis;
+ } catch (SSLContextException e) {
+ LOG.error("Error creating SSL context and options", e);
+ return DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
+ } catch (Exception e) {
+ LOG.error("Error parsing config property " + getSslHandshakeDetectionTimeoutMillisProperty(), e);
+ return DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
}
- return result;
}
- public SSLContext createSSLContext(ZKConfig config) throws SSLContextException {
+ public SSLContextAndOptions createSSLContextAndOptions(ZKConfig config) throws SSLContextException {
KeyManager[] keyManagers = null;
TrustManager[] trustManagers = null;
@@ -269,12 +466,12 @@ public SSLContext createSSLContext(ZKConfig config) throws SSLContextException {
}
}
- String protocol = System.getProperty(sslProtocolProperty, DEFAULT_PROTOCOL);
+ String protocol = config.getProperty(sslProtocolProperty, DEFAULT_PROTOCOL);
try {
SSLContext sslContext = SSLContext.getInstance(protocol);
sslContext.init(keyManagers, trustManagers, null);
- return sslContext;
- } catch (NoSuchAlgorithmException|KeyManagementException sslContextInitException) {
+ return new SSLContextAndOptions(config, sslContext);
+ } catch (NoSuchAlgorithmException | KeyManagementException sslContextInitException) {
throw new SSLContextException(sslContextInitException);
}
}
@@ -399,64 +596,40 @@ public static X509TrustManager createTrustManager(
}
public SSLSocket createSSLSocket() throws X509Exception, IOException {
- SSLSocket sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket();
- configureSSLSocket(sslSocket);
- sslSocket.setUseClientMode(true);
- return sslSocket;
+ return getDefaultSSLContextAndOptions().createSSLSocket();
}
public SSLSocket createSSLSocket(Socket socket, byte[] pushbackBytes) throws X509Exception, IOException {
- SSLSocket sslSocket;
- if (pushbackBytes != null && pushbackBytes.length > 0) {
- sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(
- socket, new ByteArrayInputStream(pushbackBytes), true);
- } else {
- sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(
- socket, null, socket.getPort(), true);
- }
- configureSSLSocket(sslSocket);
- sslSocket.setUseClientMode(false);
- sslSocket.setNeedClientAuth(true);
- return sslSocket;
- }
-
- private void configureSSLSocket(SSLSocket sslSocket) {
- SSLParameters sslParameters = sslSocket.getSSLParameters();
- LOG.debug("Setup cipher suites for client socket: {}", Arrays.toString(cipherSuites));
- sslParameters.setCipherSuites(cipherSuites);
- sslSocket.setSSLParameters(sslParameters);
+ return getDefaultSSLContextAndOptions().createSSLSocket(socket, pushbackBytes);
}
public SSLServerSocket createSSLServerSocket() throws X509Exception, IOException {
- SSLServerSocket sslServerSocket = (SSLServerSocket) getDefaultSSLContext().getServerSocketFactory().createServerSocket();
- configureSSLServerSocket(sslServerSocket);
-
- return sslServerSocket;
+ return getDefaultSSLContextAndOptions().createSSLServerSocket();
}
public SSLServerSocket createSSLServerSocket(int port) throws X509Exception, IOException {
- SSLServerSocket sslServerSocket = (SSLServerSocket) getDefaultSSLContext().getServerSocketFactory().createServerSocket(port);
- configureSSLServerSocket(sslServerSocket);
-
- return sslServerSocket;
+ return getDefaultSSLContextAndOptions().createSSLServerSocket(port);
}
- private void configureSSLServerSocket(SSLServerSocket sslServerSocket) {
- SSLParameters sslParameters = sslServerSocket.getSSLParameters();
- sslParameters.setNeedClientAuth(true);
- LOG.debug("Setup cipher suites for server socket: {}", Arrays.toString(cipherSuites));
- sslParameters.setCipherSuites(cipherSuites);
- sslServerSocket.setSSLParameters(sslParameters);
+ static String[] getDefaultCipherSuites() {
+ return getDefaultCipherSuitesForJavaVersion(System.getProperty("java.specification.version"));
}
- private String[] getDefaultCipherSuites() {
- String javaVersion = System.getProperty("java.specification.version");
- if ("9".equals(javaVersion)) {
- LOG.debug("Using Java9-optimized cipher suites for Java version {}", javaVersion);
+ static String[] getDefaultCipherSuitesForJavaVersion(String javaVersion) {
+ Objects.requireNonNull(javaVersion);
+ if (javaVersion.matches("\\d+")) {
+ // Must be Java 9 or later
+ LOG.debug("Using Java9+ optimized cipher suites for Java version {}", javaVersion);
return DEFAULT_CIPHERS_JAVA9;
+ } else if (javaVersion.startsWith("1.")) {
+ // Must be Java 1.8 or earlier
+ LOG.debug("Using Java8 optimized cipher suites for Java version {}", javaVersion);
+ return DEFAULT_CIPHERS_JAVA8;
+ } else {
+ LOG.debug("Could not parse java version {}, using Java8 optimized cipher suites",
+ javaVersion);
+ return DEFAULT_CIPHERS_JAVA8;
}
- LOG.debug("Using Java8-optimized cipher suites for Java version {}", javaVersion);
- return DEFAULT_CIPHERS_JAVA8;
}
private FileChangeWatcher newFileChangeWatcher(String fileLocation) throws IOException {
@@ -483,7 +656,7 @@ private FileChangeWatcher newFileChangeWatcher(String fileLocation) throws IOExc
*/
public void enableCertFileReloading() throws IOException {
LOG.info("enabling cert file reloading");
- ZKConfig config = new ZKConfig();
+ ZKConfig config = zkConfig == null ? new ZKConfig() : zkConfig;
FileChangeWatcher newKeyStoreFileWatcher =
newFileChangeWatcher(config.getProperty(sslKeystoreLocationProperty));
if (newKeyStoreFileWatcher != null) {
@@ -548,7 +721,7 @@ private void handleWatchEvent(Path filePath, WatchEvent> event) {
event.kind() + " with context: " + event.context());
}
try {
- this.resetDefaultSSLContext();
+ this.resetDefaultSSLContextAndOptions();
} catch (SSLContextException e) {
throw new RuntimeException(e);
}
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
index 086c07ee812..43bc2d8e95c 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java
@@ -115,6 +115,12 @@ protected void handleBackwardCompatibility() {
}
private void putSSLProperties(X509Util x509Util) {
+ properties.put(x509Util.getSslProtocolProperty(),
+ System.getProperty(x509Util.getSslProtocolProperty()));
+ properties.put(x509Util.getSslEnabledProtocolsProperty(),
+ System.getProperty(x509Util.getSslEnabledProtocolsProperty()));
+ properties.put(x509Util.getSslCipherSuitesProperty(),
+ System.getProperty(x509Util.getSslCipherSuitesProperty()));
properties.put(x509Util.getSslKeystoreLocationProperty(),
System.getProperty(x509Util.getSslKeystoreLocationProperty()));
properties.put(x509Util.getSslKeystorePasswdProperty(),
@@ -133,6 +139,8 @@ private void putSSLProperties(X509Util x509Util) {
System.getProperty(x509Util.getSslCrlEnabledProperty()));
properties.put(x509Util.getSslOcspEnabledProperty(),
System.getProperty(x509Util.getSslOcspEnabledProperty()));
+ properties.put(x509Util.getSslClientAuthProperty(),
+ System.getProperty(x509Util.getSslClientAuthProperty()));
properties.put(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(),
System.getProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty()));
}
diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
index 1058010febb..8e33c108781 100644
--- a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
+++ b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java
@@ -389,6 +389,46 @@ public void testGetSslHandshakeDetectionTimeoutMillisProperty() {
}
}
+ @Test
+ public void testGetDefaultCipherSuitesJava8() {
+ String[] cipherSuites = X509Util.getDefaultCipherSuitesForJavaVersion("1.8");
+ // Java 8 default should have the CBC suites first
+ Assert.assertTrue(cipherSuites[0].contains("CBC"));
+ }
+
+ @Test
+ public void testGetDefaultCipherSuitesJava9() {
+ String[] cipherSuites = X509Util.getDefaultCipherSuitesForJavaVersion("9");
+ // Java 9+ default should have the GCM suites first
+ Assert.assertTrue(cipherSuites[0].contains("GCM"));
+ }
+
+ @Test
+ public void testGetDefaultCipherSuitesJava10() {
+ String[] cipherSuites = X509Util.getDefaultCipherSuitesForJavaVersion("10");
+ // Java 9+ default should have the GCM suites first
+ Assert.assertTrue(cipherSuites[0].contains("GCM"));
+ }
+
+ @Test
+ public void testGetDefaultCipherSuitesJava11() {
+ String[] cipherSuites = X509Util.getDefaultCipherSuitesForJavaVersion("11");
+ // Java 9+ default should have the GCM suites first
+ Assert.assertTrue(cipherSuites[0].contains("GCM"));
+ }
+
+ @Test
+ public void testGetDefaultCipherSuitesUnknownVersion() {
+ String[] cipherSuites = X509Util.getDefaultCipherSuitesForJavaVersion("notaversion");
+ // If version can't be parsed, use the more conservative Java 8 default
+ Assert.assertTrue(cipherSuites[0].contains("CBC"));
+ }
+
+ @Test(expected = NullPointerException.class)
+ public void testGetDefaultCipherSuitesNullVersion() {
+ X509Util.getDefaultCipherSuitesForJavaVersion(null);
+ }
+
// Warning: this will reset the x509Util
private void setCustomCipherSuites() {
System.setProperty(x509Util.getCipherSuitesProperty(), customCipherSuites[0] + "," + customCipherSuites[1]);