diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java index 8dc5243b360..91f858cb4f2 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java @@ -83,7 +83,200 @@ public abstract class X509Util { public static final int DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS = 5000; + /** + * Enum specifying the client auth requirement of server-side TLS sockets created by this X509Util. + * + * + * If the config property is not set, the default value is NEED. + */ + public enum ClientAuth { + NONE, + WANT, + NEED; + + /** + * Converts a property value to a ClientAuth enum. If the input string is empty or null, returns + * ClientAuth.NEED. + * @param prop the property string. + * @return the ClientAuth. + * @throws IllegalArgumentException if the property value is not "NONE", "WANT", "NEED", or empty/null. + */ + public static ClientAuth fromPropertyValue(String prop) { + if (prop == null || prop.length() == 0) { + return NEED; + } + return ClientAuth.valueOf(prop.toUpperCase()); + } + } + + /** + * Wrapper class for an SSLContext + some config options that can't be set on the context when it is created but + * must be set on a secure socket created by the context after the socket creation. By wrapping the options in this + * class we avoid reading from global system properties during socket configuration. This makes testing easier + * since we can create different X509Util instances with different configurations in a single test process, and + * unit test interactions between them. + */ + public class SSLContextAndOptions { + private final String[] enabledProtocols; + private final String[] cipherSuites; + private final ClientAuth clientAuth; + private final SSLContext sslContext; + private final int handshakeDetectionTimeoutMillis; + + /** + * Note: constructor is intentionally private, only the enclosing X509Util should be creating instances of this + * class. + * @param config + * @param sslContext + */ + private SSLContextAndOptions(final ZKConfig config, + final SSLContext sslContext) { + this.sslContext = sslContext; + this.enabledProtocols = getEnabledProtocols(config, sslContext); + this.cipherSuites = getCipherSuites(config); + this.clientAuth = getClientAuth(config); + this.handshakeDetectionTimeoutMillis = getHandshakeDetectionTimeoutMillis(config); + } + + public SSLContext getSSLContext() { + return sslContext; + } + + public SSLSocket createSSLSocket() throws IOException { + return configureSSLSocket((SSLSocket) sslContext.getSocketFactory().createSocket(), true); + } + + public SSLSocket createSSLSocket(Socket socket, byte[] pushbackBytes) throws IOException { + SSLSocket sslSocket; + if (pushbackBytes != null && pushbackBytes.length > 0) { + sslSocket = (SSLSocket) sslContext.getSocketFactory().createSocket( + socket, new ByteArrayInputStream(pushbackBytes), true); + } else { + sslSocket = (SSLSocket) sslContext.getSocketFactory().createSocket( + socket, null, socket.getPort(), true); + } + return configureSSLSocket(sslSocket, false); + } + + public SSLServerSocket createSSLServerSocket() throws IOException { + SSLServerSocket sslServerSocket = + (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(); + return configureSSLServerSocket(sslServerSocket); + } + + public SSLServerSocket createSSLServerSocket(int port) throws IOException { + SSLServerSocket sslServerSocket = + (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(port); + return configureSSLServerSocket(sslServerSocket); + } + + private SSLSocket configureSSLSocket(SSLSocket socket, boolean isClientSocket) { + SSLParameters sslParameters = socket.getSSLParameters(); + configureSslParameters(sslParameters, isClientSocket); + socket.setSSLParameters(sslParameters); + socket.setUseClientMode(isClientSocket); + return socket; + } + + private SSLServerSocket configureSSLServerSocket(SSLServerSocket socket) { + SSLParameters sslParameters = socket.getSSLParameters(); + configureSslParameters(sslParameters, false); + socket.setSSLParameters(sslParameters); + socket.setUseClientMode(false); + return socket; + } + + private void configureSslParameters(SSLParameters sslParameters, boolean isClientSocket) { + if (cipherSuites != null) { + if (LOG.isDebugEnabled()) { + LOG.debug("Setup cipher suites for {} socket: {}", + isClientSocket ? "client" : "server", + Arrays.toString(cipherSuites)); + } + sslParameters.setCipherSuites(cipherSuites); + } + if (enabledProtocols != null) { + if (LOG.isDebugEnabled()) { + LOG.debug("Setup enabled protocols for {} socket: {}", + isClientSocket ? "client" : "server", + Arrays.toString(enabledProtocols)); + } + sslParameters.setProtocols(enabledProtocols); + } + if (!isClientSocket) { + switch (clientAuth) { + case NEED: + sslParameters.setNeedClientAuth(true); + break; + case WANT: + sslParameters.setWantClientAuth(true); + break; + default: + sslParameters.setNeedClientAuth(false); // also clears the wantClientAuth flag according to docs + break; + } + } + } + + private String[] getEnabledProtocols(final ZKConfig config, final SSLContext sslContext) { + String enabledProtocolsInput = config.getProperty(getSslEnabledProtocolsProperty()); + if (enabledProtocolsInput == null) { + return new String[] { sslContext.getProtocol() }; + } + return enabledProtocolsInput.split(","); + } + + private String[] getCipherSuites(final ZKConfig config) { + String cipherSuitesInput = config.getProperty(getSslCipherSuitesProperty()); + if (cipherSuitesInput == null) { + return getDefaultCipherSuites(); + } else { + return cipherSuitesInput.split(","); + } + } + + private String[] getDefaultCipherSuites() { + String javaVersion = System.getProperty("java.specification.version"); + if (javaVersion.startsWith("1.")) { + // Must be Java 1.8 or earlier + LOG.debug("Using Java8-optimized cipher suites for Java version {}", javaVersion); + return DEFAULT_CIPHERS_JAVA8; + } else { + // Must be Java 9 or later + LOG.debug("Using Java9-optimized cipher suites for Java version {}", javaVersion); + return DEFAULT_CIPHERS_JAVA9; + } + } + + private ClientAuth getClientAuth(final ZKConfig config) { + return ClientAuth.fromPropertyValue(config.getProperty(getSslClientAuthProperty())); + } + + private int getHandshakeDetectionTimeoutMillis(final ZKConfig config) { + String propertyString = config.getProperty(getSslHandshakeDetectionTimeoutMillisProperty()); + int result; + if (propertyString == null) { + result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS; + } else { + result = Integer.parseInt(propertyString); + if (result < 1) { + // Timeout of 0 is not allowed, since an infinite timeout can permanently lock up an + // accept() thread. + LOG.warn("Invalid value for " + getSslHandshakeDetectionTimeoutMillisProperty() + ": " + result + + ", using the default value of " + DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS); + result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS; + } + } + return result; + } + } + private String sslProtocolProperty = getConfigPrefix() + "protocol"; + private String sslEnabledProtocolsProperty = getConfigPrefix() + "enabledProtocols"; private String cipherSuitesProperty = getConfigPrefix() + "ciphersuites"; private String sslKeystoreLocationProperty = getConfigPrefix() + "keyStore.location"; private String sslKeystorePasswdProperty = getConfigPrefix() + "keyStore.password"; @@ -94,30 +287,36 @@ public abstract class X509Util { private String sslHostnameVerificationEnabledProperty = getConfigPrefix() + "hostnameVerification"; private String sslCrlEnabledProperty = getConfigPrefix() + "crl"; private String sslOcspEnabledProperty = getConfigPrefix() + "ocsp"; + private String sslClientAuthProperty = getConfigPrefix() + "clientAuth"; private String sslHandshakeDetectionTimeoutMillisProperty = getConfigPrefix() + "handshakeDetectionTimeoutMillis"; - private String[] cipherSuites; + private ZKConfig zkConfig; + private AtomicReference defaultSSLContextAndOptions = new AtomicReference<>(null); - private AtomicReference defaultSSLContext = new AtomicReference<>(null); private FileChangeWatcher keyStoreFileWatcher; private FileChangeWatcher trustStoreFileWatcher; public X509Util() { - String cipherSuitesInput = System.getProperty(cipherSuitesProperty); - if (cipherSuitesInput == null) { - cipherSuites = getDefaultCipherSuites(); - } else { - cipherSuites = cipherSuitesInput.split(","); - } + this(null); + } + + public X509Util(ZKConfig zkConfig) { + this.zkConfig = zkConfig; + keyStoreFileWatcher = trustStoreFileWatcher = null; } protected abstract String getConfigPrefix(); + protected abstract boolean shouldVerifyClientHostname(); public String getSslProtocolProperty() { return sslProtocolProperty; } + public String getSslEnabledProtocolsProperty() { + return sslEnabledProtocolsProperty; + } + public String getCipherSuitesProperty() { return cipherSuitesProperty; } @@ -126,6 +325,10 @@ public String getSslKeystoreLocationProperty() { return sslKeystoreLocationProperty; } + public String getSslCipherSuitesProperty() { + return cipherSuitesProperty; + } + public String getSslKeystorePasswdProperty() { return sslKeystorePasswdProperty; } @@ -158,6 +361,10 @@ public String getSslOcspEnabledProperty() { return sslOcspEnabledProperty; } + public String getSslClientAuthProperty() { + return sslClientAuthProperty; + } + /** * Returns the config property key that controls the amount of time, in milliseconds, that the first * UnifiedServerSocket read operation will block for when trying to detect the client mode (TLS or PLAINTEXT). @@ -169,30 +376,37 @@ public String getSslHandshakeDetectionTimeoutMillisProperty() { } public SSLContext getDefaultSSLContext() throws X509Exception.SSLContextException { - SSLContext result = defaultSSLContext.get(); + return getDefaultSSLContextAndOptions().getSSLContext(); + } + + public SSLContext createSSLContext(ZKConfig config) throws SSLContextException { + return createSSLContextAndOptions(config).getSSLContext(); + } + + public SSLContextAndOptions getDefaultSSLContextAndOptions() throws X509Exception.SSLContextException { + SSLContextAndOptions result = defaultSSLContextAndOptions.get(); if (result == null) { - result = createSSLContext(); - if (!defaultSSLContext.compareAndSet(null, result)) { + result = createSSLContextAndOptions(); + if (!defaultSSLContextAndOptions.compareAndSet(null, result)) { // lost the race, another thread already set the value - result = defaultSSLContext.get(); + result = defaultSSLContextAndOptions.get(); } } return result; } - private void resetDefaultSSLContext() throws X509Exception.SSLContextException { - SSLContext newContext = createSSLContext(); - defaultSSLContext.set(newContext); + private void resetDefaultSSLContextAndOptions() throws X509Exception.SSLContextException { + SSLContextAndOptions newContext = createSSLContextAndOptions(); + defaultSSLContextAndOptions.set(newContext); } - private SSLContext createSSLContext() throws SSLContextException { + private SSLContextAndOptions createSSLContextAndOptions() throws SSLContextException { /* * Since Configuration initializes the key store and trust store related * configuration from system property. Reading property from * configuration will be same reading from system property */ - ZKConfig config=new ZKConfig(); - return createSSLContext(config); + return createSSLContextAndOptions(zkConfig == null ? new ZKConfig() : zkConfig); } /** @@ -203,24 +417,19 @@ private SSLContext createSSLContext() throws SSLContextException { * @return the handshake detection timeout, in milliseconds. */ public int getSslHandshakeTimeoutMillis() { - String propertyString = System.getProperty(getSslHandshakeDetectionTimeoutMillisProperty()); - int result; - if (propertyString == null) { - result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS; - } else { - result = Integer.parseInt(propertyString); - if (result < 1) { - // Timeout of 0 is not allowed, since an infinite timeout can permanently lock up an - // accept() thread. - LOG.warn("Invalid value for " + getSslHandshakeDetectionTimeoutMillisProperty() + ": " + result + - ", using the default value of " + DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS); - result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS; - } + try { + SSLContextAndOptions ctx = getDefaultSSLContextAndOptions(); + return ctx.handshakeDetectionTimeoutMillis; + } catch (SSLContextException e) { + LOG.error("Error creating SSL context and options", e); + return DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS; + } catch (Exception e) { + LOG.error("Error parsing config property " + getSslHandshakeDetectionTimeoutMillisProperty(), e); + return DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS; } - return result; } - public SSLContext createSSLContext(ZKConfig config) throws SSLContextException { + public SSLContextAndOptions createSSLContextAndOptions(ZKConfig config) throws SSLContextException { KeyManager[] keyManagers = null; TrustManager[] trustManagers = null; @@ -269,12 +478,12 @@ public SSLContext createSSLContext(ZKConfig config) throws SSLContextException { } } - String protocol = System.getProperty(sslProtocolProperty, DEFAULT_PROTOCOL); + String protocol = config.getProperty(sslProtocolProperty, DEFAULT_PROTOCOL); try { SSLContext sslContext = SSLContext.getInstance(protocol); sslContext.init(keyManagers, trustManagers, null); - return sslContext; - } catch (NoSuchAlgorithmException|KeyManagementException sslContextInitException) { + return new SSLContextAndOptions(config, sslContext); + } catch (NoSuchAlgorithmException | KeyManagementException sslContextInitException) { throw new SSLContextException(sslContextInitException); } } @@ -416,64 +625,19 @@ public static X509TrustManager createTrustManager( } public SSLSocket createSSLSocket() throws X509Exception, IOException { - SSLSocket sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(); - configureSSLSocket(sslSocket); - sslSocket.setUseClientMode(true); - return sslSocket; + return getDefaultSSLContextAndOptions().createSSLSocket(); } public SSLSocket createSSLSocket(Socket socket, byte[] pushbackBytes) throws X509Exception, IOException { - SSLSocket sslSocket; - if (pushbackBytes != null && pushbackBytes.length > 0) { - sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket( - socket, new ByteArrayInputStream(pushbackBytes), true); - } else { - sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket( - socket, null, socket.getPort(), true); - } - configureSSLSocket(sslSocket); - sslSocket.setUseClientMode(false); - sslSocket.setNeedClientAuth(true); - return sslSocket; - } - - private void configureSSLSocket(SSLSocket sslSocket) { - SSLParameters sslParameters = sslSocket.getSSLParameters(); - LOG.debug("Setup cipher suites for client socket: {}", Arrays.toString(cipherSuites)); - sslParameters.setCipherSuites(cipherSuites); - sslSocket.setSSLParameters(sslParameters); + return getDefaultSSLContextAndOptions().createSSLSocket(socket, pushbackBytes); } public SSLServerSocket createSSLServerSocket() throws X509Exception, IOException { - SSLServerSocket sslServerSocket = (SSLServerSocket) getDefaultSSLContext().getServerSocketFactory().createServerSocket(); - configureSSLServerSocket(sslServerSocket); - - return sslServerSocket; + return getDefaultSSLContextAndOptions().createSSLServerSocket(); } public SSLServerSocket createSSLServerSocket(int port) throws X509Exception, IOException { - SSLServerSocket sslServerSocket = (SSLServerSocket) getDefaultSSLContext().getServerSocketFactory().createServerSocket(port); - configureSSLServerSocket(sslServerSocket); - - return sslServerSocket; - } - - private void configureSSLServerSocket(SSLServerSocket sslServerSocket) { - SSLParameters sslParameters = sslServerSocket.getSSLParameters(); - sslParameters.setNeedClientAuth(true); - LOG.debug("Setup cipher suites for server socket: {}", Arrays.toString(cipherSuites)); - sslParameters.setCipherSuites(cipherSuites); - sslServerSocket.setSSLParameters(sslParameters); - } - - private String[] getDefaultCipherSuites() { - String javaVersion = System.getProperty("java.specification.version"); - if ("9".equals(javaVersion)) { - LOG.debug("Using Java9-optimized cipher suites for Java version {}", javaVersion); - return DEFAULT_CIPHERS_JAVA9; - } - LOG.debug("Using Java8-optimized cipher suites for Java version {}", javaVersion); - return DEFAULT_CIPHERS_JAVA8; + return getDefaultSSLContextAndOptions().createSSLServerSocket(port); } /** @@ -482,7 +646,7 @@ private String[] getDefaultCipherSuites() { * @throws IOException if creating the FileChangeWatcher objects fails. */ public void enableCertFileReloading() throws IOException { - ZKConfig config = new ZKConfig(); + ZKConfig config = zkConfig == null ? new ZKConfig() : zkConfig; String keyStoreLocation = config.getProperty(sslKeystoreLocationProperty); if (keyStoreLocation != null && !keyStoreLocation.isEmpty()) { final Path filePath = Paths.get(keyStoreLocation).toAbsolutePath(); @@ -569,7 +733,7 @@ private void handleWatchEvent(Path filePath, WatchEvent event) { event.kind() + " with context: " + event.context()); } try { - this.resetDefaultSSLContext(); + this.resetDefaultSSLContextAndOptions(); } catch (SSLContextException e) { throw new RuntimeException(e); } diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java index effc0d52e52..93bdffbe27c 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java @@ -112,6 +112,12 @@ protected void handleBackwardCompatibility() { } private void putSSLProperties(X509Util x509Util) { + properties.put(x509Util.getSslProtocolProperty(), + System.getProperty(x509Util.getSslProtocolProperty())); + properties.put(x509Util.getSslEnabledProtocolsProperty(), + System.getProperty(x509Util.getSslEnabledProtocolsProperty())); + properties.put(x509Util.getSslCipherSuitesProperty(), + System.getProperty(x509Util.getSslCipherSuitesProperty())); properties.put(x509Util.getSslKeystoreLocationProperty(), System.getProperty(x509Util.getSslKeystoreLocationProperty())); properties.put(x509Util.getSslKeystorePasswdProperty(), @@ -130,6 +136,8 @@ private void putSSLProperties(X509Util x509Util) { System.getProperty(x509Util.getSslCrlEnabledProperty())); properties.put(x509Util.getSslOcspEnabledProperty(), System.getProperty(x509Util.getSslOcspEnabledProperty())); + properties.put(x509Util.getSslClientAuthProperty(), + System.getProperty(x509Util.getSslClientAuthProperty())); properties.put(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), System.getProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty())); }