diff --git a/ssl/src/main/java/org/wildfly/security/ssl/JDKSpecific.java b/ssl/src/main/java/org/wildfly/security/ssl/JDKSpecific.java index ba8f628fdcb..b3b0ababdb0 100644 --- a/ssl/src/main/java/org/wildfly/security/ssl/JDKSpecific.java +++ b/ssl/src/main/java/org/wildfly/security/ssl/JDKSpecific.java @@ -18,6 +18,10 @@ package org.wildfly.security.ssl; +import static org.wildfly.security.ssl.ElytronMessages.tls; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.util.List; import java.util.function.BiFunction; @@ -27,23 +31,88 @@ final class JDKSpecific { + // SSLEngine Methods + + private static final Method SSLENGINE_GET_APPLICATION_PROTOCOL = getMethodOrNull(SSLEngine.class, "getApplicationProtocol"); + private static final Method SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL = getMethodOrNull(SSLEngine.class, "getHandshakeApplicationProtocol"); + private static final Method SSLENGINE_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = getMethodOrNull(SSLEngine.class, "setHandshakeApplicationProtocolSelector", BiFunction.class); + private static final Method SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = getMethodOrNull(SSLEngine.class, "getHandshakeApplicationProtocolSelector"); + + // SSLParameters Methods + + private static final Method SSLPARAMETERS_GET_APPLICATION_PROTOCOLS = getMethodOrNull(SSLParameters.class, "getApplicationProtocols"); + private static final Method SSLPARAMETERS_SET_APPLICATION_PROTOCOLS = getMethodOrNull(SSLParameters.class, "setApplicationProtocols", String[].class); + + // SSLSocket Methods + + private static final Method SSLSOCKET_GET_APPLICATION_PROTOCOL = getMethodOrNull(SSLSocket.class, "getApplicationProtocol"); + private static final Method SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL = getMethodOrNull(SSLSocket.class, "getHandshakeApplicationProtocol"); + private static final Method SSLSOCKET_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = getMethodOrNull(SSLSocket.class, "setHandshakeApplicationProtocolSelector", BiFunction.class); + private static final Method SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = getMethodOrNull(SSLSocket.class, "getHandshakeApplicationProtocolSelector"); + + private static Method getMethodOrNull(Class clazz, String methodName, Class... parameterTypes) { + try { + return clazz.getMethod(methodName, parameterTypes); + } catch (Exception e) { + if (tls.isTraceEnabled()) { + tls.tracef(e, "Unable to getMethod %s on class %s", methodName, clazz.getName()); + } else if (tls.isDebugEnabled()) { + tls.debugf("Unable to getMethod %s on class %s", methodName, clazz.getName()); + } + + return null; + } + } + /* * SSLEngine */ static String getApplicationProtocol(SSLEngine sslEngine) { + if (SSLENGINE_GET_APPLICATION_PROTOCOL != null) { + try { + return (String) SSLENGINE_GET_APPLICATION_PROTOCOL.invoke(sslEngine); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new UnsupportedOperationException(e); + } + } + throw new UnsupportedOperationException(); } static String getHandshakeApplicationProtocol(SSLEngine sslEngine) { + if (SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL != null) { + try { + return (String) SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL.invoke(sslEngine); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new UnsupportedOperationException(e); + } + } + throw new UnsupportedOperationException(); } static void setHandshakeApplicationProtocolSelector(SSLEngine sslEngine, BiFunction, String> selector) { + if (SSLENGINE_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR != null) { + try { + SSLENGINE_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(sslEngine, selector); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new UnsupportedOperationException(e); + } + } + throw new UnsupportedOperationException(); } static BiFunction, String> getHandshakeApplicationProtocolSelector(SSLEngine sslEngine) { + if (SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR != null) { + try { + return (BiFunction, String>) SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(sslEngine); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new UnsupportedOperationException(e); + } + } + throw new UnsupportedOperationException(); } @@ -52,10 +121,26 @@ static BiFunction, String> getHandshakeApplicationProtoc */ static String[] getApplicationProtocols(SSLParameters parameters) { + if (SSLPARAMETERS_GET_APPLICATION_PROTOCOLS != null) { + try { + return (String[]) SSLPARAMETERS_GET_APPLICATION_PROTOCOLS.invoke(parameters); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new UnsupportedOperationException(e); + } + } + throw new UnsupportedOperationException(); } static void setApplicationProtocols(SSLParameters parameters, String[] protocols) { + if (SSLPARAMETERS_SET_APPLICATION_PROTOCOLS != null) { + try { + SSLPARAMETERS_SET_APPLICATION_PROTOCOLS.invoke(parameters, (Object) protocols); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new UnsupportedOperationException(e); + } + } + throw new UnsupportedOperationException(); } @@ -79,6 +164,13 @@ static SSLParameters setSSLParameters(SSLParameters original) { } else if (original.getNeedClientAuth()) { params.setNeedClientAuth(original.getNeedClientAuth()); } + + try { + if (SSLPARAMETERS_GET_APPLICATION_PROTOCOLS != null && SSLPARAMETERS_SET_APPLICATION_PROTOCOLS != null) { + setApplicationProtocols(params, getApplicationProtocols(original)); + } + } catch (Exception ignored) {} + return params; } @@ -87,18 +179,50 @@ static SSLParameters setSSLParameters(SSLParameters original) { */ static String getApplicationProtocol(SSLSocket socket) { + if (SSLSOCKET_GET_APPLICATION_PROTOCOL != null) { + try { + return (String) SSLSOCKET_GET_APPLICATION_PROTOCOL.invoke(socket); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new UnsupportedOperationException(e); + } + } + throw new UnsupportedOperationException(); } static String getHandshakeApplicationProtocol(SSLSocket socket) { + if (SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL != null) { + try { + return (String) SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL.invoke(socket); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new UnsupportedOperationException(e); + } + } + throw new UnsupportedOperationException(); } static void setHandshakeApplicationProtocolSelector(SSLSocket socket, BiFunction, String> selector) { + if (SSLSOCKET_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR != null) { + try { + SSLSOCKET_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(socket, selector); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new UnsupportedOperationException(e); + } + } + throw new UnsupportedOperationException(); } static BiFunction, String> getHandshakeApplicationProtocolSelector(SSLSocket socket) { + if (SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR != null) { + try { + return (BiFunction, String>) SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(socket); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new UnsupportedOperationException(e); + } + } + throw new UnsupportedOperationException(); } diff --git a/ssl/src/main/java/org/wildfly/security/ssl/SSLConfiguratorImpl.java b/ssl/src/main/java/org/wildfly/security/ssl/SSLConfiguratorImpl.java index d23930ddad0..d3bff2dc58e 100644 --- a/ssl/src/main/java/org/wildfly/security/ssl/SSLConfiguratorImpl.java +++ b/ssl/src/main/java/org/wildfly/security/ssl/SSLConfiguratorImpl.java @@ -133,27 +133,27 @@ public void setNeedClientAuth(final SSLContext sslContext, final SSLServerSocket } public void setEnabledCipherSuites(final SSLContext sslContext, final SSLSocket sslSocket, final String[] cipherSuites) { - // ignored + sslSocket.setEnabledCipherSuites(cipherSuiteSelector.evaluate(cipherSuites)); } public void setEnabledCipherSuites(final SSLContext sslContext, final SSLEngine sslEngine, final String[] cipherSuites) { - // ignored + sslEngine.setEnabledCipherSuites(cipherSuiteSelector.evaluate(cipherSuites)); } - public void setEnabledCipherSuites(final SSLContext sslContext, final SSLServerSocket sslServerSocket, final String[] suites) { - // ignored + public void setEnabledCipherSuites(final SSLContext sslContext, final SSLServerSocket sslServerSocket, final String[] cipherSuites) { + sslServerSocket.setEnabledCipherSuites(cipherSuiteSelector.evaluate(cipherSuites)); } public void setEnabledProtocols(final SSLContext sslContext, final SSLSocket sslSocket, final String[] protocols) { - // ignored + sslSocket.setEnabledProtocols(protocolSelector.evaluate(protocols)); } public void setEnabledProtocols(final SSLContext sslContext, final SSLEngine sslEngine, final String[] protocols) { - // ignored + sslEngine.setEnabledProtocols(protocolSelector.evaluate(protocols)); } public void setEnabledProtocols(final SSLContext sslContext, final SSLServerSocket sslServerSocket, final String[] protocols) { - // ignored + sslServerSocket.setEnabledProtocols(protocolSelector.evaluate(protocols)); } private SSLParameters redefine(SSLParameters original) {