diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java index 9fb6dd91323..229d20c3eda 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/X509Util.java @@ -34,6 +34,8 @@ import javax.net.ssl.X509ExtendedTrustManager; import javax.net.ssl.X509KeyManager; import javax.net.ssl.X509TrustManager; + +import java.io.ByteArrayInputStream; import java.io.File; import java.io.FileInputStream; import java.io.IOException; @@ -82,6 +84,8 @@ public abstract class X509Util { "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256" }; + public static final int DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS = 5000; + /** * This enum represents the file type of a KeyStore or TrustStore. Currently, JKS (java keystore) and PEM types * are supported. @@ -135,6 +139,7 @@ public static StoreFileType fromPropertyValue(String prop) { private String sslHostnameVerificationEnabledProperty = getConfigPrefix() + "hostnameVerification"; private String sslCrlEnabledProperty = getConfigPrefix() + "crl"; private String sslOcspEnabledProperty = getConfigPrefix() + "ocsp"; + private String sslHandshakeDetectionTimeoutMillisProperty = getConfigPrefix() + "handshakeDetectionTimeoutMillis"; private String[] cipherSuites; @@ -196,6 +201,16 @@ public String getSslOcspEnabledProperty() { return sslOcspEnabledProperty; } + /** + * Returns the config property key that controls the amount of time, in milliseconds, that the first + * UnifiedServerSocket read operation will block for when trying to detect the client mode (TLS or PLAINTEXT). + * + * @return the config property key. + */ + public String getSslHandshakeDetectionTimeoutMillisProperty() { + return sslHandshakeDetectionTimeoutMillisProperty; + } + public SSLContext getDefaultSSLContext() throws X509Exception.SSLContextException { SSLContext result = defaultSSLContext.get(); if (result == null) { @@ -218,6 +233,31 @@ private SSLContext createSSLContext() throws SSLContextException { return createSSLContext(config); } + /** + * Returns the max amount of time, in milliseconds, that the first UnifiedServerSocket read() operation should + * block for when trying to detect the client mode (TLS or PLAINTEXT). + * Defaults to {@link X509Util#DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS}. + * + * @return the handshake detection timeout, in milliseconds. + */ + public int getSslHandshakeTimeoutMillis() { + String propertyString = System.getProperty(getSslHandshakeDetectionTimeoutMillisProperty()); + int result; + if (propertyString == null) { + result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS; + } else { + result = Integer.parseInt(propertyString); + if (result < 1) { + // Timeout of 0 is not allowed, since an infinite timeout can permanently lock up an + // accept() thread. + LOG.warn("Invalid value for " + getSslHandshakeDetectionTimeoutMillisProperty() + ": " + result + + ", using the default value of " + DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS); + result = DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS; + } + } + return result; + } + public SSLContext createSSLContext(ZKConfig config) throws SSLContextException { KeyManager[] keyManagers = null; TrustManager[] trustManagers = null; @@ -427,14 +467,22 @@ public static X509TrustManager createTrustManager(String trustStoreLocation, Str public SSLSocket createSSLSocket() throws X509Exception, IOException { SSLSocket sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(); configureSSLSocket(sslSocket); - + sslSocket.setUseClientMode(true); return sslSocket; } - public SSLSocket createSSLSocket(Socket socket) throws X509Exception, IOException { - SSLSocket sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket(socket, null, socket.getPort(), true); + public SSLSocket createSSLSocket(Socket socket, byte[] pushbackBytes) throws X509Exception, IOException { + SSLSocket sslSocket; + if (pushbackBytes != null && pushbackBytes.length > 0) { + sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket( + socket, new ByteArrayInputStream(pushbackBytes), true); + } else { + sslSocket = (SSLSocket) getDefaultSSLContext().getSocketFactory().createSocket( + socket, null, socket.getPort(), true); + } configureSSLSocket(sslSocket); - + sslSocket.setUseClientMode(false); + sslSocket.setNeedClientAuth(true); return sslSocket; } diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java index 01bac691d94..effc0d52e52 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/common/ZKConfig.java @@ -130,6 +130,8 @@ private void putSSLProperties(X509Util x509Util) { System.getProperty(x509Util.getSslCrlEnabledProperty())); properties.put(x509Util.getSslOcspEnabledProperty(), System.getProperty(x509Util.getSslOcspEnabledProperty())); + properties.put(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), + System.getProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty())); } /** diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java index 9270548a7f0..abd4f1d537d 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Leader.java @@ -42,7 +42,6 @@ import javax.security.sasl.SaslException; import org.apache.zookeeper.ZooDefs.OpCode; -import org.apache.zookeeper.common.QuorumX509Util; import org.apache.zookeeper.common.Time; import org.apache.zookeeper.common.X509Exception; import org.apache.zookeeper.server.FinalRequestProcessor; @@ -234,21 +233,16 @@ public boolean isQuorumSynced(QuorumVerifier qv) { private final ServerSocket ss; - Leader(QuorumPeer self,LeaderZooKeeperServer zk) throws IOException, X509Exception { + Leader(QuorumPeer self,LeaderZooKeeperServer zk) throws IOException { this.self = self; this.proposalStats = new BufferStats(); try { - if (self.shouldUsePortUnification()) { + if (self.shouldUsePortUnification() || self.isSslQuorum()) { + boolean allowInsecureConnection = self.shouldUsePortUnification(); if (self.getQuorumListenOnAllIPs()) { - ss = new UnifiedServerSocket(new QuorumX509Util(), self.getQuorumAddress().getPort()); + ss = new UnifiedServerSocket(self.getX509Util(), allowInsecureConnection, self.getQuorumAddress().getPort()); } else { - ss = new UnifiedServerSocket(new QuorumX509Util()); - } - } else if (self.isSslQuorum()) { - if (self.getQuorumListenOnAllIPs()) { - ss = new QuorumX509Util().createSSLServerSocket(self.getQuorumAddress().getPort()); - } else { - ss = new QuorumX509Util().createSSLServerSocket(); + ss = new UnifiedServerSocket(self.getX509Util(), allowInsecureConnection); } } else { if (self.getQuorumListenOnAllIPs()) { @@ -261,9 +255,6 @@ public boolean isQuorumSynced(QuorumVerifier qv) { if (!self.getQuorumListenOnAllIPs()) { ss.bind(self.getQuorumAddress()); } - } catch (X509Exception e) { - LOG.error("Failed to setup ssl server socket", e); - throw e; } catch (BindException e) { if (self.getQuorumListenOnAllIPs()) { LOG.error("Couldn't bind to port " + self.getQuorumAddress().getPort(), e); @@ -399,8 +390,10 @@ public LearnerCnxAcceptor() { public void run() { try { while (!stop) { - try{ - Socket s = ss.accept(); + Socket s = null; + boolean error = false; + try { + s = ss.accept(); // start with the initLimit, once the ack is processed // in LearnerHandler switch to the syncLimit @@ -412,6 +405,7 @@ public void run() { LearnerHandler fh = new LearnerHandler(s, is, Leader.this); fh.start(); } catch (SocketException e) { + error = true; if (stop) { LOG.info("exception while shutting down acceptor: " + e); @@ -425,6 +419,19 @@ public void run() { } } catch (SaslException e){ LOG.error("Exception while connecting to quorum learner", e); + error = true; + } catch (Exception e) { + error = true; + throw e; + } finally { + // Don't leak sockets on errors + if (error && s != null && !s.isClosed()) { + try { + s.close(); + } catch (IOException e) { + LOG.warn("Error closing socket", e); + } + } } } } catch (Exception e) { diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java index c740d5348f4..faaa844ada2 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/Learner.java @@ -38,9 +38,7 @@ import org.apache.jute.InputArchive; import org.apache.jute.OutputArchive; import org.apache.jute.Record; -import org.apache.zookeeper.common.QuorumX509Util; import org.apache.zookeeper.common.X509Exception; -import org.apache.zookeeper.common.X509Util; import org.apache.zookeeper.server.ExitCode; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -74,8 +72,6 @@ static class PacketInFlight { protected Socket sock; - protected X509Util x509Util; - /** * Socket getter * @return @@ -304,10 +300,7 @@ protected void connectToLeader(InetSocketAddress addr, String hostname) private Socket createSocket() throws X509Exception, IOException { Socket sock; if (self.isSslQuorum()) { - if (x509Util == null) { - x509Util = new QuorumX509Util(); - } - sock = x509Util.createSSLSocket(); + sock = self.getX509Util().createSSLSocket(); } else { sock = new Socket(); } diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java index a86608ff2f2..94a526e94e1 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/PrependableSocket.java @@ -18,16 +18,15 @@ package org.apache.zookeeper.server.quorum; -import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; -import java.io.SequenceInputStream; +import java.io.PushbackInputStream; import java.net.Socket; import java.net.SocketImpl; public class PrependableSocket extends Socket { - private SequenceInputStream sequenceInputStream; + private PushbackInputStream pushbackInputStream; public PrependableSocket(SocketImpl base) throws IOException { super(base); @@ -35,15 +34,31 @@ public PrependableSocket(SocketImpl base) throws IOException { @Override public InputStream getInputStream() throws IOException { - if (sequenceInputStream == null) { + if (pushbackInputStream == null) { return super.getInputStream(); } - return sequenceInputStream; + return pushbackInputStream; } - public void prependToInputStream(byte[] bytes) throws IOException { - sequenceInputStream = new SequenceInputStream(new ByteArrayInputStream(bytes), getInputStream()); + /** + * Prepend some bytes that have already been read back to the socket's input stream. Note that this method can be + * called at most once with a non-0 length per socket instance. + * @param bytes the bytes to prepend. + * @param offset offset in the byte array to start at. + * @param length number of bytes to prepend. + * @throws IOException if this method was already called on the socket instance, or if super.getInputStream() throws. + */ + public void prependToInputStream(byte[] bytes, int offset, int length) throws IOException { + if (length == 0) { + return; // nothing to prepend + } + if (pushbackInputStream != null) { + throw new IOException("prependToInputStream() called more than once"); + } + PushbackInputStream pushbackInputStream = new PushbackInputStream(getInputStream(), length); + pushbackInputStream.unread(bytes, offset, length); + this.pushbackInputStream = pushbackInputStream; } } \ No newline at end of file diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java index 8b9102390bb..704580d9386 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumCnxManager.java @@ -47,9 +47,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -import org.apache.zookeeper.common.QuorumX509Util; import org.apache.zookeeper.common.X509Exception; -import org.apache.zookeeper.common.X509Util; import org.apache.zookeeper.server.ExitCode; import org.apache.zookeeper.server.quorum.QuorumPeerConfig.ConfigException; import org.apache.zookeeper.server.util.ConfigUtils; @@ -175,9 +173,6 @@ public class QuorumCnxManager { */ private final boolean tcpKeepAlive = Boolean.getBoolean("zookeeper.tcpKeepAlive"); - - private X509Util x509Util; - static public class Message { Message(ByteBuffer buffer, long sid) { this.buffer = buffer; @@ -291,8 +286,6 @@ public QuorumCnxManager(QuorumPeer self, // Starts listener thread that waits for connection requests listener = new Listener(); listener.setName("QuorumPeerListener"); - - x509Util = new QuorumX509Util(); } private void initializeAuth(final long mySid, @@ -655,17 +648,18 @@ synchronized private boolean connectOne(long sid, InetSocketAddress electionAddr try { LOG.debug("Opening channel to server " + sid); if (self.isSslQuorum()) { - SSLSocket sslSock = x509Util.createSSLSocket(); - setSockOpts(sslSock); - sslSock.connect(electionAddr, cnxTO); - sslSock.startHandshake(); - sock = sslSock; - } else { - sock = new Socket(); - setSockOpts(sock); - sock.connect(electionAddr, cnxTO); - } - LOG.debug("Connected to server " + sid); + SSLSocket sslSock = self.getX509Util().createSSLSocket(); + setSockOpts(sslSock); + sslSock.connect(electionAddr, cnxTO); + sslSock.startHandshake(); + sock = sslSock; + } else { + sock = new Socket(); + setSockOpts(sock); + sock.connect(electionAddr, cnxTO); + + } + LOG.debug("Connected to server " + sid); // Sends connection request asynchronously if the quorum // sasl authentication is enabled. This is required because // sasl server authentication process may take few seconds to @@ -876,9 +870,9 @@ public void run() { while((!shutdown) && (numRetries < 3)){ try { if (self.shouldUsePortUnification()) { - ss = new UnifiedServerSocket(x509Util); + ss = new UnifiedServerSocket(self.getX509Util(), true); } else if (self.isSslQuorum()) { - ss = x509Util.createSSLServerSocket(); + ss = new UnifiedServerSocket(self.getX509Util(), false); } else { ss = new ServerSocket(); } @@ -920,7 +914,7 @@ public void run() { + "see ZOOKEEPER-2836"); } } - } catch (IOException|X509Exception e) { + } catch (IOException e) { if (shutdown) { break; } diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java index 136a538d4b3..7abde4ba991 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeer.java @@ -47,6 +47,7 @@ import org.apache.zookeeper.KeeperException.BadArgumentsException; import org.apache.zookeeper.common.AtomicFileWritingIdiom; import org.apache.zookeeper.common.AtomicFileWritingIdiom.WriterStatement; +import org.apache.zookeeper.common.QuorumX509Util; import org.apache.zookeeper.common.Time; import org.apache.zookeeper.common.X509Exception; import org.apache.zookeeper.jmx.MBeanRegistry; @@ -479,6 +480,12 @@ public boolean shouldUsePortUnification() { return shouldUsePortUnification; } + private final QuorumX509Util x509Util; + + QuorumX509Util getX509Util() { + return x509Util; + } + /** * This is who I think the leader currently is. */ @@ -801,6 +808,7 @@ public QuorumPeer() throws SaslException { quorumStats = new QuorumStats(this); jmxRemotePeerBean = new HashMap(); adminServer = AdminServerFactory.createAdminServer(); + x509Util = new QuorumX509Util(); initialize(); } diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java index 45463b19622..aee5efcd68b 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/QuorumPeerConfig.java @@ -315,9 +315,8 @@ public void parseProperties(Properties zkProp) } } else if (key.equals("sslQuorum")){ sslQuorum = Boolean.parseBoolean(value); -// TODO: UnifiedServerSocket is currently buggy, will be fixed when @ivmaykov's PRs are merged. Disable port unification until then. -// } else if (key.equals("portUnification")){ -// shouldUsePortUnification = Boolean.parseBoolean(value); + } else if (key.equals("portUnification")){ + shouldUsePortUnification = Boolean.parseBoolean(value); } else if ((key.startsWith("server.") || key.startsWith("group") || key.startsWith("weight")) && zkProp.containsKey("dynamicConfigFile")) { throw new ConfigException("parameter: " + key + " must be in a separate dynamic config file"); } else if (key.equals(QuorumAuth.QUORUM_SASL_AUTH_ENABLED)) { diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java index 4802ecf308c..efce40cfa74 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/quorum/UnifiedServerSocket.java @@ -27,23 +27,111 @@ import javax.net.ssl.SSLSocket; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetAddress; import java.net.ServerSocket; import java.net.Socket; +import java.net.SocketAddress; import java.net.SocketException; +import java.net.SocketTimeoutException; +import java.nio.channels.SocketChannel; +/** + * A ServerSocket that can act either as a regular ServerSocket, as a SSLServerSocket, or as both, depending on + * the constructor parameters and on the type of client (TLS or plaintext) that connects to it. + * The constructors have the same signature as constructors of ServerSocket, with the addition of two parameters + * at the beginning: + * + * The !allowInsecureConnection mode is needed so we can update the SSLContext (in particular, the + * key store and/or trust store) without having to re-create the server socket. By starting with a plaintext socket + * and delaying the upgrade to TLS until after a client has connected and begins a handshake, we can keep the same + * UnifiedServerSocket instance around, and replace the default SSLContext in the provided X509Util when the key store + * and/or trust store file changes on disk. + */ public class UnifiedServerSocket extends ServerSocket { private static final Logger LOG = LoggerFactory.getLogger(UnifiedServerSocket.class); private X509Util x509Util; + private final boolean allowInsecureConnection; - public UnifiedServerSocket(X509Util x509Util) throws IOException { + /** + * Creates an unbound unified server socket by calling {@link ServerSocket#ServerSocket()}. + * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a + * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of + * the allowInsecureConnection parameter. + * @param x509Util the X509Util that provides the SSLContext to use for secure connections. + * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them. + * @throws IOException if {@link ServerSocket#ServerSocket()} throws. + */ + public UnifiedServerSocket(X509Util x509Util, boolean allowInsecureConnection) throws IOException { super(); this.x509Util = x509Util; + this.allowInsecureConnection = allowInsecureConnection; } - public UnifiedServerSocket(X509Util x509Util, int port) throws IOException { + /** + * Creates a unified server socket bound to the specified port by calling {@link ServerSocket#ServerSocket(int)}. + * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a + * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of + * the allowInsecureConnection parameter. + * @param x509Util the X509Util that provides the SSLContext to use for secure connections. + * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them. + * @param port the port number, or {@code 0} to use a port number that is automatically allocated. + * @throws IOException if {@link ServerSocket#ServerSocket(int)} throws. + */ + public UnifiedServerSocket(X509Util x509Util, boolean allowInsecureConnection, int port) throws IOException { super(port); this.x509Util = x509Util; + this.allowInsecureConnection = allowInsecureConnection; + } + + /** + * Creates a unified server socket bound to the specified port, with the specified backlog, by calling + * {@link ServerSocket#ServerSocket(int, int)}. + * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a + * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of + * the allowInsecureConnection parameter. + * @param x509Util the X509Util that provides the SSLContext to use for secure connections. + * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them. + * @param port the port number, or {@code 0} to use a port number that is automatically allocated. + * @param backlog requested maximum length of the queue of incoming connections. + * @throws IOException if {@link ServerSocket#ServerSocket(int, int)} throws. + */ + public UnifiedServerSocket(X509Util x509Util, + boolean allowInsecureConnection, + int port, + int backlog) throws IOException { + super(port, backlog); + this.x509Util = x509Util; + this.allowInsecureConnection = allowInsecureConnection; + } + + /** + * Creates a unified server socket bound to the specified port, with the specified backlog, and local IP address + * to bind to, by calling {@link ServerSocket#ServerSocket(int, int, InetAddress)}. + * Secure client connections will be upgraded to TLS once this socket detects the ClientHello message (start of a + * TLS handshake). Plaintext client connections will either be accepted or rejected depending on the value of + * the allowInsecureConnection parameter. + * @param x509Util the X509Util that provides the SSLContext to use for secure connections. + * @param allowInsecureConnection if true, accept plaintext connections, otherwise close them. + * @param port the port number, or {@code 0} to use a port number that is automatically allocated. + * @param backlog requested maximum length of the queue of incoming connections. + * @param bindAddr the local InetAddress the server will bind to. + * @throws IOException if {@link ServerSocket#ServerSocket(int, int, InetAddress)} throws. + */ + public UnifiedServerSocket(X509Util x509Util, + boolean allowInsecureConnection, + int port, + int backlog, + InetAddress bindAddr) throws IOException { + super(port, backlog, bindAddr); + this.x509Util = x509Util; + this.allowInsecureConnection = allowInsecureConnection; } @Override @@ -56,24 +144,642 @@ public Socket accept() throws IOException { } final PrependableSocket prependableSocket = new PrependableSocket(null); implAccept(prependableSocket); + return new UnifiedSocket(x509Util, allowInsecureConnection, prependableSocket); + } + + /** + * The result of calling accept() on a UnifiedServerSocket. This is a Socket that doesn't know if it's + * using plaintext or SSL/TLS at the time when it is created. Calling a method that indicates a desire to + * read or write from the socket will cause the socket to detect if the connected client is attempting + * to establish a TLS or plaintext connection. This is done by doing a blocking read of 5 bytes off the + * socket and checking if the bytes look like the start of a TLS ClientHello message. If it looks like + * the client is attempting to connect with TLS, the internal socket is upgraded to a SSLSocket. If not, + * any bytes read from the socket are pushed back to the input stream, and the socket continues + * to be treated as a plaintext socket. + * + * The methods that trigger this behavior are: + * + * + * Calling other socket methods (i.e option setters such as {@link Socket#setTcpNoDelay(boolean)}) does + * not trigger mode detection. + * + * Because detecting the mode is a potentially blocking operation, it should not be done in the + * accepting thread. Attempting to read from or write to the socket in the accepting thread opens the + * caller up to a denial-of-service attack, in which a client connects and then does nothing. This would + * prevent any other clients from connecting. Passing the socket returned by accept() to a separate + * thread which handles all read and write operations protects against this DoS attack. + * + * Callers can check if the socket has been upgraded to TLS by calling {@link UnifiedSocket#isSecureSocket()}, + * and can get the underlying SSLSocket by calling {@link UnifiedSocket#getSslSocket()}. + */ + public static class UnifiedSocket extends Socket { + private enum Mode { + UNKNOWN, + PLAINTEXT, + TLS + } - byte[] litmus = new byte[5]; - int bytesRead = prependableSocket.getInputStream().read(litmus, 0, 5); - prependableSocket.prependToInputStream(litmus); + private final X509Util x509Util; + private final boolean allowInsecureConnection; + private PrependableSocket prependableSocket; + private SSLSocket sslSocket; + private Mode mode; - if (bytesRead == 5 && SslHandler.isEncrypted(ChannelBuffers.wrappedBuffer(litmus))) { - LOG.info(getInetAddress() + " attempting to connect over ssl"); - SSLSocket sslSocket; + /** + * Note: this constructor is intentionally private. The only intended caller is + * {@link UnifiedServerSocket#accept()}. + * + * @param x509Util + * @param allowInsecureConnection + * @param prependableSocket + */ + private UnifiedSocket(X509Util x509Util, boolean allowInsecureConnection, PrependableSocket prependableSocket) { + this.x509Util = x509Util; + this.allowInsecureConnection = allowInsecureConnection; + this.prependableSocket = prependableSocket; + this.sslSocket = null; + this.mode = Mode.UNKNOWN; + } + + /** + * Returns true if the socket mode has been determined to be TLS. + * @return true if the mode is TLS, false if it is UNKNOWN or PLAINTEXT. + */ + public boolean isSecureSocket() { + return mode == Mode.TLS; + } + + /** + * Returns true if the socket mode has been determined to be PLAINTEXT. + * @return true if the mode is PLAINTEXT, false if it is UNKNOWN or TLS. + */ + public boolean isPlaintextSocket() { + return mode == Mode.PLAINTEXT; + } + + /** + * Returns true if the socket mode is not yet known. + * @return true if the mode is UNKNOWN, false if it is PLAINTEXT or TLS. + */ + public boolean isModeKnown() { + return mode != Mode.UNKNOWN; + } + + /** + * Detects the socket mode, see comments at the top of the class for more details. This operation will block + * for up to {@link X509Util#getSslHandshakeTimeoutMillis()} milliseconds and should not be called in the + * accept() thread if possible. + * @throws IOException + */ + private void detectMode() throws IOException { + byte[] litmus = new byte[5]; + int oldTimeout = -1; + int bytesRead = 0; + int newTimeout = x509Util.getSslHandshakeTimeoutMillis(); try { - sslSocket = x509Util.createSSLSocket(prependableSocket); - } catch (X509Exception e) { - throw new IOException("failed to create SSL context", e); + oldTimeout = prependableSocket.getSoTimeout(); + prependableSocket.setSoTimeout(newTimeout); + bytesRead = prependableSocket.getInputStream().read(litmus, 0, litmus.length); + } catch (SocketTimeoutException e) { + // Didn't read anything within the timeout, fallthrough and assume the connection is plaintext. + LOG.warn("Socket mode detection timed out after " + newTimeout + " ms, assuming PLAINTEXT"); + } finally { + // restore socket timeout to the old value + try { + if (oldTimeout != -1) { + prependableSocket.setSoTimeout(oldTimeout); + } + } catch (Exception e) { + LOG.warn("Failed to restore old socket timeout value of " + oldTimeout + " ms", e); + } + } + if (bytesRead < 0) { // Got a EOF right away, definitely not using TLS. Fallthrough. + bytesRead = 0; + } + + if (bytesRead == litmus.length && SslHandler.isEncrypted(ChannelBuffers.wrappedBuffer(litmus))) { + try { + sslSocket = x509Util.createSSLSocket(prependableSocket, litmus); + } catch (X509Exception e) { + throw new IOException("failed to create SSL context", e); + } + prependableSocket = null; + mode = Mode.TLS; + } else if (allowInsecureConnection) { + prependableSocket.prependToInputStream(litmus, 0, bytesRead); + mode = Mode.PLAINTEXT; + } else { + prependableSocket.close(); + mode = Mode.PLAINTEXT; + throw new IOException("Blocked insecure connection attempt"); + } + } + + private Socket getSocketAllowUnknownMode() { + if (isSecureSocket()) { + return sslSocket; + } else { // Note: mode is UNKNOWN or PLAINTEXT + return prependableSocket; + } + } + + /** + * Returns the underlying socket, detecting the socket mode if it is not yet known. This is a potentially + * blocking operation and should not be called in the accept() thread. + * @return the underlying socket, after the socket mode has been determined. + * @throws IOException + */ + private Socket getSocket() throws IOException { + if (!isModeKnown()) { + detectMode(); + } + if (mode == Mode.TLS) { + return sslSocket; + } else { + return prependableSocket; + } + } + + /** + * Returns the underlying SSLSocket if the mode is TLS. If the mode is UNKNOWN, causes mode detection which is a + * potentially blocking operation. If the mode ends up being PLAINTEXT, this will throw a SocketException, so + * callers are advised to only call this method after checking that {@link UnifiedSocket#isSecureSocket()} + * returned true. + * @return the underlying SSLSocket if the mode is known to be TLS. + * @throws IOException if detecting the socket mode fails + * @throws SocketException if the mode is PLAINTEXT. + */ + public SSLSocket getSslSocket() throws IOException { + if (!isModeKnown()) { + detectMode(); + } + if (!isSecureSocket()) { + throw new SocketException("Socket mode is not TLS"); } - sslSocket.setUseClientMode(false); return sslSocket; - } else { - LOG.info(getInetAddress() + " attempting to connect without ssl"); - return prependableSocket; } + + /** + * See {@link Socket#connect(SocketAddress)}. Calling this method does not trigger mode detection. + */ + @Override + public void connect(SocketAddress endpoint) throws IOException { + getSocketAllowUnknownMode().connect(endpoint); + } + + /** + * See {@link Socket#connect(SocketAddress, int)}. Calling this method does not trigger mode detection. + */ + @Override + public void connect(SocketAddress endpoint, int timeout) throws IOException { + getSocketAllowUnknownMode().connect(endpoint, timeout); + } + + /** + * See {@link Socket#bind(SocketAddress)}. Calling this method does not trigger mode detection. + */ + @Override + public void bind(SocketAddress bindpoint) throws IOException { + getSocketAllowUnknownMode().bind(bindpoint); + } + + /** + * See {@link Socket#getInetAddress()}. Calling this method does not trigger mode detection. + */ + @Override + public InetAddress getInetAddress() { + return getSocketAllowUnknownMode().getInetAddress(); + } + + /** + * See {@link Socket#getLocalAddress()}. Calling this method does not trigger mode detection. + */ + @Override + public InetAddress getLocalAddress() { + return getSocketAllowUnknownMode().getLocalAddress(); + } + + /** + * See {@link Socket#getPort()}. Calling this method does not trigger mode detection. + */ + @Override + public int getPort() { + return getSocketAllowUnknownMode().getPort(); + } + + /** + * See {@link Socket#getLocalPort()}. Calling this method does not trigger mode detection. + */ + @Override + public int getLocalPort() { + return getSocketAllowUnknownMode().getLocalPort(); + } + + /** + * See {@link Socket#getRemoteSocketAddress()}. Calling this method does not trigger mode detection. + */ + @Override + public SocketAddress getRemoteSocketAddress() { + return getSocketAllowUnknownMode().getRemoteSocketAddress(); + } + + /** + * See {@link Socket#getLocalSocketAddress()}. Calling this method does not trigger mode detection. + */ + @Override + public SocketAddress getLocalSocketAddress() { + return getSocketAllowUnknownMode().getLocalSocketAddress(); + } + + /** + * See {@link Socket#getChannel()}. Calling this method does not trigger mode detection. + */ + @Override + public SocketChannel getChannel() { + return getSocketAllowUnknownMode().getChannel(); + } + + /** + * See {@link Socket#getInputStream()}. If the socket mode has not yet been detected, the first read from the + * returned input stream will trigger mode detection, which is a potentially blocking operation. This means + * the accept() thread should avoid reading from this input stream if possible. + */ + @Override + public InputStream getInputStream() throws IOException { + return new UnifiedInputStream(this); + } + + /** + * See {@link Socket#getOutputStream()}. If the socket mode has not yet been detected, the first read from the + * returned input stream will trigger mode detection, which is a potentially blocking operation. This means + * the accept() thread should avoid reading from this input stream if possible. + */ + @Override + public OutputStream getOutputStream() throws IOException { + return new UnifiedOutputStream(this); + } + + /** + * See {@link Socket#setTcpNoDelay(boolean)}. Calling this method does not trigger mode detection. + */ + @Override + public void setTcpNoDelay(boolean on) throws SocketException { + getSocketAllowUnknownMode().setTcpNoDelay(on); + } + + /** + * See {@link Socket#getTcpNoDelay()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean getTcpNoDelay() throws SocketException { + return getSocketAllowUnknownMode().getTcpNoDelay(); + } + + /** + * See {@link Socket#setSoLinger(boolean, int)}. Calling this method does not trigger mode detection. + */ + @Override + public void setSoLinger(boolean on, int linger) throws SocketException { + getSocketAllowUnknownMode().setSoLinger(on, linger); + } + + /** + * See {@link Socket#getSoLinger()}. Calling this method does not trigger mode detection. + */ + @Override + public int getSoLinger() throws SocketException { + return getSocketAllowUnknownMode().getSoLinger(); + } + + /** + * See {@link Socket#sendUrgentData(int)}. Calling this method triggers mode detection, which is a potentially + * blocking operation, so it should not be done in the accept() thread. + */ + @Override + public void sendUrgentData(int data) throws IOException { + getSocket().sendUrgentData(data); + } + + /** + * See {@link Socket#setOOBInline(boolean)}. Calling this method does not trigger mode detection. + */ + @Override + public void setOOBInline(boolean on) throws SocketException { + getSocketAllowUnknownMode().setOOBInline(on); + } + + /** + * See {@link Socket#getOOBInline()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean getOOBInline() throws SocketException { + return getSocketAllowUnknownMode().getOOBInline(); + } + + /** + * See {@link Socket#setSoTimeout(int)}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized void setSoTimeout(int timeout) throws SocketException { + getSocketAllowUnknownMode().setSoTimeout(timeout); + } + + /** + * See {@link Socket#getSoTimeout()}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized int getSoTimeout() throws SocketException { + return getSocketAllowUnknownMode().getSoTimeout(); + } + + /** + * See {@link Socket#setSendBufferSize(int)}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized void setSendBufferSize(int size) throws SocketException { + getSocketAllowUnknownMode().setSendBufferSize(size); + } + + /** + * See {@link Socket#getSendBufferSize()}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized int getSendBufferSize() throws SocketException { + return getSocketAllowUnknownMode().getSendBufferSize(); + } + + /** + * See {@link Socket#setReceiveBufferSize(int)}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized void setReceiveBufferSize(int size) throws SocketException { + getSocketAllowUnknownMode().setReceiveBufferSize(size); + } + + /** + * See {@link Socket#getReceiveBufferSize()}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized int getReceiveBufferSize() throws SocketException { + return getSocketAllowUnknownMode().getReceiveBufferSize(); + } + + /** + * See {@link Socket#setKeepAlive(boolean)}. Calling this method does not trigger mode detection. + */ + @Override + public void setKeepAlive(boolean on) throws SocketException { + getSocketAllowUnknownMode().setKeepAlive(on); + } + + /** + * See {@link Socket#getKeepAlive()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean getKeepAlive() throws SocketException { + return getSocketAllowUnknownMode().getKeepAlive(); + } + + /** + * See {@link Socket#setTrafficClass(int)}. Calling this method does not trigger mode detection. + */ + @Override + public void setTrafficClass(int tc) throws SocketException { + getSocketAllowUnknownMode().setTrafficClass(tc); + } + + /** + * See {@link Socket#getTrafficClass()}. Calling this method does not trigger mode detection. + */ + @Override + public int getTrafficClass() throws SocketException { + return getSocketAllowUnknownMode().getTrafficClass(); + } + + /** + * See {@link Socket#setReuseAddress(boolean)}. Calling this method does not trigger mode detection. + */ + @Override + public void setReuseAddress(boolean on) throws SocketException { + getSocketAllowUnknownMode().setReuseAddress(on); + } + + /** + * See {@link Socket#getReuseAddress()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean getReuseAddress() throws SocketException { + return getSocketAllowUnknownMode().getReuseAddress(); + } + + /** + * See {@link Socket#close()}. Calling this method does not trigger mode detection. + */ + @Override + public synchronized void close() throws IOException { + getSocketAllowUnknownMode().close(); + } + + /** + * See {@link Socket#shutdownInput()}. Calling this method does not trigger mode detection. + */ + @Override + public void shutdownInput() throws IOException { + getSocketAllowUnknownMode().shutdownInput(); + } + + /** + * See {@link Socket#shutdownOutput()}. Calling this method does not trigger mode detection. + */ + @Override + public void shutdownOutput() throws IOException { + getSocketAllowUnknownMode().shutdownOutput(); + } + + /** + * See {@link Socket#toString()}. Calling this method does not trigger mode detection. + */ + @Override + public String toString() { + return "UnifiedSocket[mode=" + mode.toString() + "socket=" + getSocketAllowUnknownMode().toString() + "]"; + } + + /** + * See {@link Socket#isConnected()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean isConnected() { + return getSocketAllowUnknownMode().isConnected(); + } + + /** + * See {@link Socket#isBound()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean isBound() { + return getSocketAllowUnknownMode().isBound(); + } + + /** + * See {@link Socket#isClosed()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean isClosed() { + return getSocketAllowUnknownMode().isClosed(); + } + + /** + * See {@link Socket#isInputShutdown()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean isInputShutdown() { + return getSocketAllowUnknownMode().isInputShutdown(); + } + + /** + * See {@link Socket#isOutputShutdown()}. Calling this method does not trigger mode detection. + */ + @Override + public boolean isOutputShutdown() { + return getSocketAllowUnknownMode().isOutputShutdown(); + } + + /** + * See {@link Socket#setPerformancePreferences(int, int, int)}. Calling this method does not trigger + * mode detection. + */ + @Override + public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) { + getSocketAllowUnknownMode().setPerformancePreferences(connectionTime, latency, bandwidth); + } + } + + /** + * An input stream for a UnifiedSocket. The first read from this stream will trigger mode detection on the + * underlying UnifiedSocket. + */ + private static class UnifiedInputStream extends InputStream { + private final UnifiedSocket unifiedSocket; + private InputStream realInputStream; + + private UnifiedInputStream(UnifiedSocket unifiedSocket) { + this.unifiedSocket = unifiedSocket; + this.realInputStream = null; + } + + @Override + public int read() throws IOException { + return getRealInputStream().read(); + } + + /** + * Note: SocketInputStream has optimized implementations of bulk-read operations, so we need to call them + * directly instead of relying on the base-class implementation which just calls the single-byte read() over + * and over. Not implementing these results in awful performance. + */ + @Override + public int read(byte[] b) throws IOException { + return getRealInputStream().read(b); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return getRealInputStream().read(b, off, len); + } + + private InputStream getRealInputStream() throws IOException { + if (realInputStream == null) { + // Note: The first call to getSocket() triggers mode detection which can block + realInputStream = unifiedSocket.getSocket().getInputStream(); + } + return realInputStream; + } + + @Override + public long skip(long n) throws IOException { + return getRealInputStream().skip(n); + } + + @Override + public int available() throws IOException { + return getRealInputStream().available(); + } + + @Override + public void close() throws IOException { + getRealInputStream().close(); + } + + @Override + public synchronized void mark(int readlimit) { + try { + getRealInputStream().mark(readlimit); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public synchronized void reset() throws IOException { + getRealInputStream().reset(); + } + + @Override + public boolean markSupported() { + try { + return getRealInputStream().markSupported(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + } + + private static class UnifiedOutputStream extends OutputStream { + private final UnifiedSocket unifiedSocket; + private OutputStream realOutputStream; + + private UnifiedOutputStream(UnifiedSocket unifiedSocket) { + this.unifiedSocket = unifiedSocket; + this.realOutputStream = null; + } + + @Override + public void write(int b) throws IOException { + getRealOutputStream().write(b); + } + + @Override + public void write(byte[] b) throws IOException { + getRealOutputStream().write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + getRealOutputStream().write(b, off, len); + } + + @Override + public void flush() throws IOException { + getRealOutputStream().flush(); + } + + @Override + public void close() throws IOException { + getRealOutputStream().close(); + } + + private OutputStream getRealOutputStream() throws IOException { + if (realOutputStream == null) { + // Note: The first call to getSocket() triggers mode detection which can block + realOutputStream = unifiedSocket.getSocket().getOutputStream(); + } + return realOutputStream; + } + } -} \ No newline at end of file +} diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java index a9350434e44..ec0e6a955c6 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/common/X509UtilTest.java @@ -402,6 +402,34 @@ public void testLoadJKSTrustStoreWithWrongPassword() throws Exception { true); } + @Test + public void testGetSslHandshakeDetectionTimeoutMillisProperty() { + X509Util x509Util = new ClientX509Util(); + Assert.assertEquals( + X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS, + x509Util.getSslHandshakeTimeoutMillis()); + try { + String newPropertyString = Integer.toString(X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS + 1); + System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), newPropertyString); + // Note: need to create a new ClientX509Util to pick up modified property value + Assert.assertEquals( + X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS + 1, + new ClientX509Util().getSslHandshakeTimeoutMillis()); + // 0 value not allowed, will return the default + System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "0"); + Assert.assertEquals( + X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS, + new ClientX509Util().getSslHandshakeTimeoutMillis()); + // Negative value not allowed, will return the default + System.setProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty(), "-1"); + Assert.assertEquals( + X509Util.DEFAULT_HANDSHAKE_DETECTION_TIMEOUT_MILLIS, + new ClientX509Util().getSslHandshakeTimeoutMillis()); + } finally { + System.clearProperty(x509Util.getSslHandshakeDetectionTimeoutMillisProperty()); + } + } + // Warning: this will reset the x509Util private void setCustomCipherSuites() { System.setProperty(x509Util.getCipherSuitesProperty(), customCipherSuites[0] + "," + customCipherSuites[1]); diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java index b088f47b16c..67c15ade2c3 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/QuorumSSLTest.java @@ -80,7 +80,6 @@ import org.junit.After; import org.junit.Assert; import org.junit.Before; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; @@ -442,7 +441,6 @@ public void testQuorumSSL() throws Exception { Assert.assertFalse(ClientBase.waitForServerUp("127.0.0.1:" + clientPortQp3, CONNECTION_TIMEOUT)); } - @Ignore("portUnification is currently broken and disabled") @Test public void testRollingUpgrade() throws Exception { // Form a quorum without ssl diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java index 09a5d41260a..4f0244edb75 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java @@ -20,153 +20,604 @@ import org.apache.zookeeper.PortAssignment; import org.apache.zookeeper.client.ZKClientConfig; import org.apache.zookeeper.common.ClientX509Util; -import org.apache.zookeeper.common.Time; +import org.apache.zookeeper.common.X509Exception; +import org.apache.zookeeper.common.X509KeyType; +import org.apache.zookeeper.common.X509TestContext; import org.apache.zookeeper.common.X509Util; import org.apache.zookeeper.server.ServerCnxnFactory; +import org.apache.zookeeper.test.ClientBase; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import javax.net.ssl.HandshakeCompletedEvent; import javax.net.ssl.HandshakeCompletedListener; import javax.net.ssl.SSLSocket; +import java.io.BufferedInputStream; +import java.io.File; import java.io.IOException; import java.net.ConnectException; import java.net.InetSocketAddress; import java.net.Socket; +import java.net.SocketException; +import java.security.Security; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.junit.Assert.assertThat; - +@RunWith(Parameterized.class) public class UnifiedServerSocketTest { + @Parameterized.Parameters + public static Collection params() { + ArrayList result = new ArrayList<>(); + int paramIndex = 0; + for (X509KeyType caKeyType : X509KeyType.values()) { + for (X509KeyType certKeyType : X509KeyType.values()) { + for (Boolean hostnameVerification : new Boolean[] { true, false }) { + result.add(new Object[]{ + caKeyType, + certKeyType, + hostnameVerification, + paramIndex++ + }); + } + } + } + return result; + } + + /** + * Because key generation and writing / deleting files is kind of expensive, we cache the certs and on-disk files + * between test cases. None of the test cases modify any of this data so it's safe to reuse between tests. This + * caching makes all test cases after the first one for a given parameter combination complete almost instantly. + */ + private static Map cachedTestContexts; + private static File tempDir; + private static final int MAX_RETRIES = 5; private static final int TIMEOUT = 1000; + private static final byte[] DATA_TO_CLIENT = "hello client".getBytes(); + private static final byte[] DATA_FROM_CLIENT = "hello server".getBytes(); private X509Util x509Util; private int port; + private InetSocketAddress localServerAddress; private volatile boolean handshakeCompleted; + private X509TestContext x509TestContext; + + @BeforeClass + public static void setUpClass() throws Exception { + Security.addProvider(new BouncyCastleProvider()); + cachedTestContexts = new HashMap<>(); + tempDir = ClientBase.createEmptyTestDir(); + } + + @AfterClass + public static void cleanUpClass() { + Security.removeProvider(BouncyCastleProvider.PROVIDER_NAME); + cachedTestContexts.clear(); + cachedTestContexts = null; + } + + public UnifiedServerSocketTest( + X509KeyType caKeyType, + X509KeyType certKeyType, + Boolean hostnameVerification, + Integer paramIndex) throws Exception { + if (cachedTestContexts.containsKey(paramIndex)) { + x509TestContext = cachedTestContexts.get(paramIndex); + } else { + x509TestContext = X509TestContext.newBuilder() + .setTempDir(tempDir) + .setKeyStoreKeyType(certKeyType) + .setTrustStoreKeyType(caKeyType) + .setHostnameVerification(hostnameVerification) + .build(); + cachedTestContexts.put(paramIndex, x509TestContext); + } + } @Before public void setUp() throws Exception { handshakeCompleted = false; port = PortAssignment.unique(); + localServerAddress = new InetSocketAddress("localhost", port); - String testDataPath = System.getProperty("test.data.dir", "build/test/data"); System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY, "org.apache.zookeeper.server.NettyServerCnxnFactory"); System.setProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET, "org.apache.zookeeper.ClientCnxnSocketNetty"); System.setProperty(ZKClientConfig.SECURE_CLIENT, "true"); x509Util = new ClientX509Util(); - System.setProperty(x509Util.getSslKeystoreLocationProperty(), testDataPath + "/ssl/testKeyStore.jks"); - System.setProperty(x509Util.getSslKeystorePasswdProperty(), "testpass"); - System.setProperty(x509Util.getSslTruststoreLocationProperty(), testDataPath + "/ssl/testTrustStore.jks"); - System.setProperty(x509Util.getSslTruststorePasswdProperty(), "testpass"); - System.setProperty(x509Util.getSslHostnameVerificationEnabledProperty(), "false"); + x509TestContext.setSystemProperties(x509Util, X509Util.StoreFileType.JKS, X509Util.StoreFileType.JKS); } - @Test - public void testConnectWithSSL() throws Exception { - class ServerThread extends Thread { - public void run() { - try { - Socket unifiedSocket = new UnifiedServerSocket(x509Util, port).accept(); - ((SSLSocket)unifiedSocket).getSession(); // block until handshake completes - } catch (IOException e) { - e.printStackTrace(); + private static void forceClose(java.io.Closeable s) { + if (s == null) { + return; + } + try { + s.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + private static final class UnifiedServerThread extends Thread { + private final byte[] dataToClient; + private List dataFromClients; + private List workerThreads; + private UnifiedServerSocket serverSocket; + + UnifiedServerThread(X509Util x509Util, + InetSocketAddress bindAddress, + boolean allowInsecureConnection, + byte[] dataToClient) throws IOException { + this.dataToClient = dataToClient; + dataFromClients = new ArrayList<>(); + workerThreads = new ArrayList<>(); + serverSocket = new UnifiedServerSocket(x509Util, allowInsecureConnection); + serverSocket.bind(bindAddress); + } + + @Override + public void run() { + try { + Random rnd = new Random(); + while (true) { + final Socket unifiedSocket = serverSocket.accept(); + final boolean tcpNoDelay = rnd.nextBoolean(); + unifiedSocket.setTcpNoDelay(tcpNoDelay); + unifiedSocket.setSoTimeout(TIMEOUT); + final boolean keepAlive = rnd.nextBoolean(); + unifiedSocket.setKeepAlive(keepAlive); + // Note: getting the input stream should not block the thread or trigger mode detection. + BufferedInputStream bis = new BufferedInputStream(unifiedSocket.getInputStream()); + Thread t = new Thread(new Runnable() { + @Override + public void run() { + try { + byte[] buf = new byte[1024]; + int bytesRead = unifiedSocket.getInputStream().read(buf, 0, 1024); + // Make sure the settings applied above before the socket was potentially upgraded to + // TLS still apply. + Assert.assertEquals(tcpNoDelay, unifiedSocket.getTcpNoDelay()); + Assert.assertEquals(TIMEOUT, unifiedSocket.getSoTimeout()); + Assert.assertEquals(keepAlive, unifiedSocket.getKeepAlive()); + if (bytesRead > 0) { + byte[] dataFromClient = new byte[bytesRead]; + System.arraycopy(buf, 0, dataFromClient, 0, bytesRead); + synchronized (dataFromClients) { + dataFromClients.add(dataFromClient); + } + } + unifiedSocket.getOutputStream().write(dataToClient); + unifiedSocket.getOutputStream().flush(); + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException(e); + } finally { + forceClose(unifiedSocket); + } + } + }); + workerThreads.add(t); + t.start(); } + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException(e); + } finally { + forceClose(serverSocket); } } - ServerThread serverThread = new ServerThread(); - serverThread.start(); + public void shutdown(long millis) throws InterruptedException { + forceClose(serverSocket); // this should break the run() loop + for (Thread t : workerThreads) { + t.join(millis); + } + this.join(millis); + } + + synchronized byte[] getDataFromClient(int index) { + return dataFromClients.get(index); + } + } + + private SSLSocket connectWithSSL() throws IOException, X509Exception, InterruptedException { SSLSocket sslSocket = null; int retries = 0; while (retries < MAX_RETRIES) { try { sslSocket = x509Util.createSSLSocket(); + sslSocket.addHandshakeCompletedListener(new HandshakeCompletedListener() { + @Override + public void handshakeCompleted(HandshakeCompletedEvent handshakeCompletedEvent) { + handshakeCompleted = true; + } + }); sslSocket.setSoTimeout(TIMEOUT); - sslSocket.connect(new InetSocketAddress(port), TIMEOUT); + sslSocket.connect(localServerAddress, TIMEOUT); break; } catch (ConnectException connectException) { connectException.printStackTrace(); + forceClose(sslSocket); + sslSocket = null; Thread.sleep(TIMEOUT); } retries++; } - sslSocket.addHandshakeCompletedListener(new HandshakeCompletedListener() { - @Override - public void handshakeCompleted(HandshakeCompletedEvent handshakeCompletedEvent) { - completeHandshake(); + Assert.assertNotNull("Failed to connect to server with SSL", sslSocket); + return sslSocket; + } + + private Socket connectWithoutSSL() throws IOException, InterruptedException { + Socket socket = null; + int retries = 0; + while (retries < MAX_RETRIES) { + try { + socket = new Socket(); + socket.setSoTimeout(TIMEOUT); + socket.connect(localServerAddress, TIMEOUT); + break; + } catch (ConnectException connectException) { + connectException.printStackTrace(); + forceClose(socket); + socket = null; + Thread.sleep(TIMEOUT); } - }); - sslSocket.startHandshake(); + retries++; + } + Assert.assertNotNull("Failed to connect to server without SSL", socket); + return socket; + } + + @Test + public void testConnectWithSSLToNonStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + + Socket sslSocket = connectWithSSL(); + sslSocket.getOutputStream().write(DATA_FROM_CLIENT); + sslSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = sslSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + + serverThread.shutdown(TIMEOUT); + forceClose(sslSocket); - serverThread.join(TIMEOUT); + Assert.assertTrue(handshakeCompleted); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } + + @Test + public void testConnectWithSSLToStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, false, DATA_TO_CLIENT); + serverThread.start(); + + Socket sslSocket = connectWithSSL(); + sslSocket.getOutputStream().write(DATA_FROM_CLIENT); + sslSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = sslSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + + serverThread.shutdown(TIMEOUT); + forceClose(sslSocket); + + Assert.assertTrue(handshakeCompleted); + + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } - long start = Time.currentElapsedTime(); - while (Time.currentElapsedTime() < start + TIMEOUT) { - if (handshakeCompleted) { - return; + @Test + public void testConnectWithoutSSLToNonStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + + Socket socket = connectWithoutSSL(); + socket.getOutputStream().write(DATA_FROM_CLIENT); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = socket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + + serverThread.shutdown(TIMEOUT); + forceClose(socket); + + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } + + @Test + public void testConnectWithoutSSLToNonStrictServerPartialWrite() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + + Socket socket = connectWithoutSSL(); + // Write only 2 bytes of the message, wait a bit, then write the rest. + // This makes sure that writes smaller than 5 bytes don't break the plaintext mode on the server + // once it decides that the input doesn't look like a TLS handshake. + socket.getOutputStream().write(DATA_FROM_CLIENT, 0, 2); + socket.getOutputStream().flush(); + Thread.sleep(TIMEOUT / 2); + socket.getOutputStream().write(DATA_FROM_CLIENT, 2, DATA_FROM_CLIENT.length - 2); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = socket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + + serverThread.shutdown(TIMEOUT); + forceClose(socket); + + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } + + @Test + public void testConnectWithoutSSLToStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, false, DATA_TO_CLIENT); + serverThread.start(); + + Socket socket = connectWithoutSSL(); + socket.getOutputStream().write(DATA_FROM_CLIENT); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + try { + socket.getInputStream().read(buf, 0, buf.length); + } catch (SocketException e) { + // We expect the other end to hang up the connection + serverThread.shutdown(TIMEOUT); + forceClose(socket); + return; + } + Assert.fail("Expected server to hang up the connection. Read from server succeeded unexpectedly."); + } + + /** + * This test makes sure that UnifiedServerSocket used properly (a single thread accept()-ing connections and + * handing the resulting sockets to other threads for processing) is not vulnerable to a simple denial-of-service + * attack in which a client connects and never writes any bytes. This should not block the accepting thread, since + * the read to determine if the client is sending a TLS handshake or not happens in the processing thread. + * + * This version of the test uses a non-strict server socket (i.e. it accepts both TLS and plaintext connections). + */ + @Test + public void testDenialOfServiceResistanceNonStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + final boolean[] dosThreadConnected = new boolean[] { false }; + + Thread dosThread = new Thread(new Runnable() { + @Override + public void run() { + try { + Socket socket = connectWithoutSSL(); + synchronized (dosThreadConnected) { + dosThreadConnected[0] = true; + dosThreadConnected.notifyAll(); + } + Thread.sleep(100000L); + } catch (Exception e) { + // ... + } + } + }); + dosThread.start(); + // make sure the denial-of-service thread connects first + synchronized (dosThreadConnected) { + while (!dosThreadConnected[0]) { + dosThreadConnected.wait(); } } - Assert.fail("failed to complete handshake"); - } + Socket socket = connectWithoutSSL(); + socket.getOutputStream().write(DATA_FROM_CLIENT); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = socket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + Assert.assertFalse(handshakeCompleted); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + forceClose(socket); + + socket = connectWithSSL(); + socket.getOutputStream().write(DATA_FROM_CLIENT); + socket.getOutputStream().flush(); + buf = new byte[DATA_TO_CLIENT.length]; + bytesRead = socket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + Assert.assertTrue(handshakeCompleted); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(1)); + forceClose(socket); - private void completeHandshake() { - handshakeCompleted = true; + serverThread.shutdown(TIMEOUT); + dosThread.interrupt(); + dosThread.join(TIMEOUT); } + /** + * Like the above test, but with a strict server socket (closes non-TLS connections after seeing that there is + * no handshake). + */ @Test - public void testConnectWithoutSSL() throws Exception { - final byte[] testData = "hello there".getBytes(); - final String[] dataReadFromClient = {null}; + public void testDenialOfServiceResistanceStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, false, DATA_TO_CLIENT); + serverThread.start(); + final boolean[] dosThreadConnected = new boolean[] { false }; - class ServerThread extends Thread { + Thread dosThread = new Thread(new Runnable() { + @Override public void run() { try { - Socket unifiedSocket = new UnifiedServerSocket(x509Util, port).accept(); - unifiedSocket.getOutputStream().write(testData); - unifiedSocket.getOutputStream().flush(); - byte[] inputbuff = new byte[5]; - unifiedSocket.getInputStream().read(inputbuff, 0, 5); - dataReadFromClient[0] = new String(inputbuff); - } catch (IOException e) { - e.printStackTrace(); + Socket socket = connectWithoutSSL(); + synchronized (dosThreadConnected) { + dosThreadConnected[0] = true; + dosThreadConnected.notifyAll(); + } + Thread.sleep(100000L); + } catch (Exception e) { + // ... } } + }); + dosThread.start(); + // make sure the denial-of-service thread connects first + synchronized (dosThreadConnected) { + while (!dosThreadConnected[0]) { + dosThreadConnected.wait(); + } } - ServerThread serverThread = new ServerThread(); + + Socket sslSocket = connectWithSSL(); + sslSocket.getOutputStream().write(DATA_FROM_CLIENT); + sslSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = sslSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + + serverThread.shutdown(TIMEOUT); + forceClose(sslSocket); + + Assert.assertTrue(handshakeCompleted); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + dosThread.interrupt(); + dosThread.join(TIMEOUT); + } + + /** + * Similar to the DoS resistance tests above, but the bad client disconnects immediately without sending any data. + * @throws Exception + */ + @Test + public void testImmediateDisconnectResistanceNonStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); serverThread.start(); + final boolean[] disconnectThreadConnected = new boolean[] { false }; - Socket socket = null; - int retries = 0; - while (retries < MAX_RETRIES) { - try { - socket = new Socket(); - socket.setSoTimeout(TIMEOUT); - socket.connect(new InetSocketAddress(port), TIMEOUT); - break; - } catch (ConnectException connectException) { - connectException.printStackTrace(); - Thread.sleep(TIMEOUT); + Thread disconnectThread = new Thread(new Runnable() { + @Override + public void run() { + try { + Socket socket = connectWithoutSSL(); + socket.close(); + synchronized (disconnectThreadConnected) { + disconnectThreadConnected[0] = true; + disconnectThreadConnected.notifyAll(); + } + } catch (Exception e) { + // ... + } + } + }); + disconnectThread.start(); + // make sure the disconnect thread connects first + synchronized (disconnectThreadConnected) { + while (!disconnectThreadConnected[0]) { + disconnectThreadConnected.wait(); } - retries++; } - socket.getOutputStream().write("hellobello".getBytes()); + Socket socket = connectWithoutSSL(); + socket.getOutputStream().write(DATA_FROM_CLIENT); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = socket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + Assert.assertFalse(handshakeCompleted); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + forceClose(socket); + + socket = connectWithSSL(); + socket.getOutputStream().write(DATA_FROM_CLIENT); socket.getOutputStream().flush(); + buf = new byte[DATA_TO_CLIENT.length]; + bytesRead = socket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + Assert.assertTrue(handshakeCompleted); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(1)); + forceClose(socket); + + serverThread.shutdown(TIMEOUT); + disconnectThread.join(TIMEOUT); + } + + /** + * Like the above test, but with a strict server socket (closes non-TLS connections after seeing that there is + * no handshake). + */ + @Test + public void testImmediateDisconnectResistanceStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, false, DATA_TO_CLIENT); + serverThread.start(); + final boolean[] disconnectThreadConnected = new boolean[] { false }; + + Thread disconnectThread = new Thread(new Runnable() { + @Override + public void run() { + try { + Socket socket = connectWithoutSSL(); + socket.close(); + synchronized (disconnectThreadConnected) { + disconnectThreadConnected[0] = true; + disconnectThreadConnected.notifyAll(); + } + } catch (Exception e) { + // ... + } + } + }); + disconnectThread.start(); + // make sure the disconnect thread connects first + synchronized (disconnectThreadConnected) { + while (!disconnectThreadConnected[0]) { + disconnectThreadConnected.wait(); + } + } - byte[] readBytes = new byte[testData.length]; - socket.getInputStream().read(readBytes, 0, testData.length); + Socket sslSocket = connectWithSSL(); + sslSocket.getOutputStream().write(DATA_FROM_CLIENT); + sslSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = sslSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); - serverThread.join(TIMEOUT); + serverThread.shutdown(TIMEOUT); + forceClose(sslSocket); - Assert.assertArrayEquals(testData, readBytes); - assertThat("Data sent by the client is invalid", dataReadFromClient[0], equalTo("hello")); + Assert.assertTrue(handshakeCompleted); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + disconnectThread.interrupt(); + disconnectThread.join(TIMEOUT); } }