diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java index 8dc5243b360..91f858cb4f2 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java @@ -83,7 +83,200 @@ public abstract class X509Util { public static final int DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS = 5000; + /** + * Enum specifying the client auth requirement of server-side TLS sockets created by this X509Util. + *
ClientAuth.NEED
.
+ * @param prop the property string.
+ * @return the ClientAuth.
+ * @throws IllegalArgumentException if the property value is not "NONE", "WANT", "NEED", or empty/null.
+ */
+ public static ClientAuth fromPropertyValue(String prop) {
+ if (prop == null || prop.length() == 0) {
+ return NEED;
+ }
+ return ClientAuth.valueOf(prop.toUpperCase());
+ }
+ }
+
+ /**
+ * Wrapper class for an SSLContext + some config options that can't be set on the context when it is created but
+ * must be set on a secure socket created by the context after the socket creation. By wrapping the options in this
+ * class we avoid reading from global system properties during socket configuration. This makes testing easier
+ * since we can create different X509Util instances with different configurations in a single test process, and
+ * unit test interactions between them.
+ */
+ public class SSLContextAndOptions {
+ private final String[] enabledProtocols;
+ private final String[] cipherSuites;
+ private final ClientAuth clientAuth;
+ private final SSLContext sslContext;
+ private final int handshakeDetectionTimeoutMillis;
+
+ /**
+ * Note: constructor is intentionally private, only the enclosing X509Util should be creating instances of this
+ * class.
+ * @param config
+ * @param sslContext
+ */
+ private SSLContextAndOptions(final ZKConfig config,
+ final SSLContext sslContext) {
+ this.sslContext = sslContext;
+ this.enabledProtocols = getEnabledProtocols(config, sslContext);
+ this.cipherSuites = getCipherSuites(config);
+ this.clientAuth = getClientAuth(config);
+ this.handshakeDetectionTimeoutMillis = getHandshakeDetectionTimeoutMillis(config);
+ }
+
+ public SSLContext getSSLContext() {
+ return sslContext;
+ }
+
+ public SSLSocket createSSLSocket() throws IOException {
+ return configureSSLSocket((SSLSocket) sslContext.getSocketFactory().createSocket(), true);
+ }
+
+ public SSLSocket createSSLSocket(Socket socket, byte[] pushbackBytes) throws IOException {
+ SSLSocket sslSocket;
+ if (pushbackBytes != null && pushbackBytes.length > 0) {
+ sslSocket = (SSLSocket) sslContext.getSocketFactory().createSocket(
+ socket, new ByteArrayInputStream(pushbackBytes), true);
+ } else {
+ sslSocket = (SSLSocket) sslContext.getSocketFactory().createSocket(
+ socket, null, socket.getPort(), true);
+ }
+ return configureSSLSocket(sslSocket, false);
+ }
+
+ public SSLServerSocket createSSLServerSocket() throws IOException {
+ SSLServerSocket sslServerSocket =
+ (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket();
+ return configureSSLServerSocket(sslServerSocket);
+ }
+
+ public SSLServerSocket createSSLServerSocket(int port) throws IOException {
+ SSLServerSocket sslServerSocket =
+ (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(port);
+ return configureSSLServerSocket(sslServerSocket);
+ }
+
+ private SSLSocket configureSSLSocket(SSLSocket socket, boolean isClientSocket) {
+ SSLParameters sslParameters = socket.getSSLParameters();
+ configureSslParameters(sslParameters, isClientSocket);
+ socket.setSSLParameters(sslParameters);
+ socket.setUseClientMode(isClientSocket);
+ return socket;
+ }
+
+ private SSLServerSocket configureSSLServerSocket(SSLServerSocket socket) {
+ SSLParameters sslParameters = socket.getSSLParameters();
+ configureSslParameters(sslParameters, false);
+ socket.setSSLParameters(sslParameters);
+ socket.setUseClientMode(false);
+ return socket;
+ }
+
+ private void configureSslParameters(SSLParameters sslParameters, boolean isClientSocket) {
+ if (cipherSuites != null) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Setup cipher suites for {} socket: {}",
+ isClientSocket ? "client" : "server",
+ Arrays.toString(cipherSuites));
+ }
+ sslParameters.setCipherSuites(cipherSuites);
+ }
+ if (enabledProtocols != null) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Setup enabled protocols for {} socket: {}",
+ isClientSocket ? "client" : "server",
+ Arrays.toString(enabledProtocols));
+ }
+ sslParameters.setProtocols(enabledProtocols);
+ }
+ if (!isClientSocket) {
+ switch (clientAuth) {
+ case NEED:
+ sslParameters.setNeedClientAuth(true);
+ break;
+ case WANT:
+ sslParameters.setWantClientAuth(true);
+ break;
+ default:
+ sslParameters.setNeedClientAuth(false); // also clears the wantClientAuth flag according to docs
+ break;
+ }
+ }
+ }
+
+ private String[] getEnabledProtocols(final ZKConfig config, final SSLContext sslContext) {
+ String enabledProtocolsInput = config.getProperty(getSslEnabledProtocolsProperty());
+ if (enabledProtocolsInput == null) {
+ return new String[] { sslContext.getProtocol() };
+ }
+ return enabledProtocolsInput.split(",");
+ }
+
+ private String[] getCipherSuites(final ZKConfig config) {
+ String cipherSuitesInput = config.getProperty(getSslCipherSuitesProperty());
+ if (cipherSuitesInput == null) {
+ return getDefaultCipherSuites();
+ } else {
+ return cipherSuitesInput.split(",");
+ }
+ }
+
+ private String[] getDefaultCipherSuites() {
+ String javaVersion = System.getProperty("java.specification.version");
+ if (javaVersion.startsWith("1.")) {
+ // Must be Java 1.8 or earlier
+ LOG.debug("Using Java8-optimized cipher suites for Java version {}", javaVersion);
+ return DEFAULT_CIPHERS_JAVA8;
+ } else {
+ // Must be Java 9 or later
+ LOG.debug("Using Java9-optimized cipher suites for Java version {}", javaVersion);
+ return DEFAULT_CIPHERS_JAVA9;
+ }
+ }
+
+ private ClientAuth getClientAuth(final ZKConfig config) {
+ return ClientAuth.fromPropertyValue(config.getProperty(getSslClientAuthProperty()));
+ }
+
+ private int getHandshakeDetectionTimeoutMillis(final ZKConfig config) {
+ String propertyString = config.getProperty(getSslHandshakeDetectionTimeoutMillisProperty());
+ int result;
+ if (propertyString == null) {
+ result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
+ } else {
+ result = Integer.parseInt(propertyString);
+ if (result < 1) {
+ // Timeout of 0 is not allowed, since an infinite timeout can permanently lock up an
+ // accept() thread.
+ LOG.warn("Invalid value for " + getSslHandshakeDetectionTimeoutMillisProperty() + ": " + result +
+ ", using the default value of " + DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS);
+ result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS;
+ }
+ }
+ return result;
+ }
+ }
+
private String sslProtocolProperty = getConfigPrefix() + "protocol";
+ private String sslEnabledProtocolsProperty = getConfigPrefix() + "enabledProtocols";
private String cipherSuitesProperty = getConfigPrefix() + "ciphersuites";
private String sslKeystoreLocationProperty = getConfigPrefix() + "keyStore.location";
private String sslKeystorePasswdProperty = getConfigPrefix() + "keyStore.password";
@@ -94,30 +287,36 @@ public abstract class X509Util {
private String sslHostnameVerificationEnabledProperty = getConfigPrefix() + "hostnameVerification";
private String sslCrlEnabledProperty = getConfigPrefix() + "crl";
private String sslOcspEnabledProperty = getConfigPrefix() + "ocsp";
+ private String sslClientAuthProperty = getConfigPrefix() + "clientAuth";
private String sslHandshakeDetectionTimeoutMillisProperty = getConfigPrefix() + "handshakeDetectionTimeoutMillis";
- private String[] cipherSuites;
+ private ZKConfig zkConfig;
+ private AtomicReference