Skip to content

Commit

Permalink
[ELY-2026] Use reflection to pass through the new methods if they are…
Browse files Browse the repository at this point in the history
… now available on the SSLEngine, SSLParameters, and SSLSocket APIs.
  • Loading branch information
darranl committed Oct 7, 2020
1 parent 6e8b51c commit f6598bc
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions ssl/src/main/java/org/wildfly/security/ssl/JDKSpecific.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

package org.wildfly.security.ssl;

import static org.wildfly.security.ssl.ElytronMessages.tls;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.function.BiFunction;

Expand All @@ -27,23 +31,88 @@

final class JDKSpecific {

// SSLEngine Methods

private static final Method SSLENGINE_GET_APPLICATION_PROTOCOL = getMethodOrNull(SSLEngine.class, "getApplicationProtocol");
private static final Method SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL = getMethodOrNull(SSLEngine.class, "getHandshakeApplicationProtocol");
private static final Method SSLENGINE_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = getMethodOrNull(SSLEngine.class, "setHandshakeApplicationProtocolSelector", BiFunction.class);
private static final Method SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = getMethodOrNull(SSLEngine.class, "getHandshakeApplicationProtocolSelector");

// SSLParameters Methods

private static final Method SSLPARAMETERS_GET_APPLICATION_PROTOCOLS = getMethodOrNull(SSLParameters.class, "getApplicationProtocols");
private static final Method SSLPARAMETERS_SET_APPLICATION_PROTOCOLS = getMethodOrNull(SSLParameters.class, "setApplicationProtocols", String[].class);

// SSLSocket Methods

private static final Method SSLSOCKET_GET_APPLICATION_PROTOCOL = getMethodOrNull(SSLSocket.class, "getApplicationProtocol");
private static final Method SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL = getMethodOrNull(SSLSocket.class, "getHandshakeApplicationProtocol");
private static final Method SSLSOCKET_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = getMethodOrNull(SSLSocket.class, "setHandshakeApplicationProtocolSelector", BiFunction.class);
private static final Method SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = getMethodOrNull(SSLSocket.class, "getHandshakeApplicationProtocolSelector");

private static Method getMethodOrNull(Class clazz, String methodName, Class... parameterTypes) {
try {
return clazz.getMethod(methodName, parameterTypes);
} catch (Exception e) {
if (tls.isTraceEnabled()) {
tls.tracef(e, "Unable to getMethod %s on class %s", methodName, clazz.getName());
} else if (tls.isDebugEnabled()) {
tls.debugf("Unable to getMethod %s on class %s", methodName, clazz.getName());
}

return null;
}
}

/*
* SSLEngine
*/

static String getApplicationProtocol(SSLEngine sslEngine) {
if (SSLENGINE_GET_APPLICATION_PROTOCOL != null) {
try {
return (String) SSLENGINE_GET_APPLICATION_PROTOCOL.invoke(sslEngine);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new UnsupportedOperationException(e);
}
}

throw new UnsupportedOperationException();
}

static String getHandshakeApplicationProtocol(SSLEngine sslEngine) {
if (SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL != null) {
try {
return (String) SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL.invoke(sslEngine);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new UnsupportedOperationException(e);
}
}

throw new UnsupportedOperationException();
}

static void setHandshakeApplicationProtocolSelector(SSLEngine sslEngine, BiFunction<SSLEngine, List<String>, String> selector) {
if (SSLENGINE_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR != null) {
try {
SSLENGINE_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(sslEngine, selector);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new UnsupportedOperationException(e);
}
}

throw new UnsupportedOperationException();
}

static BiFunction<SSLEngine, List<String>, String> getHandshakeApplicationProtocolSelector(SSLEngine sslEngine) {
if (SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR != null) {
try {
return (BiFunction<SSLEngine, List<String>, String>) SSLENGINE_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(sslEngine);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new UnsupportedOperationException(e);
}
}

throw new UnsupportedOperationException();
}

Expand All @@ -52,10 +121,26 @@ static BiFunction<SSLEngine, List<String>, String> getHandshakeApplicationProtoc
*/

static String[] getApplicationProtocols(SSLParameters parameters) {
if (SSLPARAMETERS_GET_APPLICATION_PROTOCOLS != null) {
try {
return (String[]) SSLPARAMETERS_GET_APPLICATION_PROTOCOLS.invoke(parameters);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new UnsupportedOperationException(e);
}
}

throw new UnsupportedOperationException();
}

static void setApplicationProtocols(SSLParameters parameters, String[] protocols) {
if (SSLPARAMETERS_SET_APPLICATION_PROTOCOLS != null) {
try {
SSLPARAMETERS_SET_APPLICATION_PROTOCOLS.invoke(parameters, (Object[]) protocols); // TODO Check this cast is correct.
} catch (IllegalAccessException | InvocationTargetException e) {
throw new UnsupportedOperationException(e);
}
}

throw new UnsupportedOperationException();
}

Expand Down Expand Up @@ -87,18 +172,50 @@ static SSLParameters setSSLParameters(SSLParameters original) {
*/

static String getApplicationProtocol(SSLSocket socket) {
if (SSLSOCKET_GET_APPLICATION_PROTOCOL != null) {
try {
return (String) SSLSOCKET_GET_APPLICATION_PROTOCOL.invoke(socket);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new UnsupportedOperationException(e);
}
}

throw new UnsupportedOperationException();
}

static String getHandshakeApplicationProtocol(SSLSocket socket) {
if (SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL != null) {
try {
return (String) SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL.invoke(socket);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new UnsupportedOperationException(e);
}
}

throw new UnsupportedOperationException();
}

static void setHandshakeApplicationProtocolSelector(SSLSocket socket, BiFunction<SSLSocket, List<String>, String> selector) {
if (SSLSOCKET_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR != null) {
try {
SSLSOCKET_SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(socket, selector);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new UnsupportedOperationException(e);
}
}

throw new UnsupportedOperationException();
}

static BiFunction<SSLSocket, List<String>, String> getHandshakeApplicationProtocolSelector(SSLSocket socket) {
if (SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR != null) {
try {
return (BiFunction<SSLSocket, List<String>, String>) SSLSOCKET_GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(socket);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new UnsupportedOperationException(e);
}
}

throw new UnsupportedOperationException();
}

Expand Down

0 comments on commit f6598bc

Please sign in to comment.