Skip to content

Commit

Permalink
Ensure key / values are shared between resumed sessions (netty#13819)
Browse files Browse the repository at this point in the history
Motivation:

When a session is resumed we need to also ensure we preserve the values
that were put into the internal storage of the session.

Modifications:

Ensure we preserve the values / keys of the internal storage on
resumption

Result:

Correctly implement session caching and reuse

---------

Co-authored-by: Chris Vest <[email protected]>
  • Loading branch information
2 people authored and franz1981 committed Feb 9, 2024
1 parent d5c72e2 commit 84e4d01
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,25 @@ public List<byte[]> getStatusResponses() {
return Collections.emptyList();
}

@Override
public void prepareHandshake() {
wrapped.prepareHandshake();
}

@Override
public Map<String, Object> keyValueStorage() {
return wrapped.keyValueStorage();
}

@Override
public OpenSslSessionId sessionId() {
return wrapped.sessionId();
}

@Override
public void setSessionDetails(long creationTime, long lastAccessedTime, OpenSslSessionId id) {
wrapped.setSessionDetails(creationTime, lastAccessedTime, id);
public void setSessionDetails(long creationTime, long lastAccessedTime, OpenSslSessionId id,
Map<String, Object> keyValueStorage) {
wrapped.setSessionDetails(creationTime, lastAccessedTime, id, keyValueStorage);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,22 @@ protected void sessionRemoved(NativeSslSession session) {
}

@Override
void setSession(long ssl, OpenSslSession session, String host, int port) {
boolean setSession(long ssl, OpenSslSession session, String host, int port) {
HostPort hostPort = keyFor(host, port);
if (hostPort == null) {
return;
return false;
}
final NativeSslSession nativeSslSession;
final boolean reused;
boolean singleUsed = false;
synchronized (this) {
nativeSslSession = sessions.get(hostPort);
if (nativeSslSession == null) {
return;
return false;
}
if (!nativeSslSession.isValid()) {
removeSessionWithId(nativeSslSession.sessionId());
return;
return false;
}
// Try to set the session, if true is returned OpenSSL incremented the reference count
// of the underlying SSL_SESSION*.
Expand All @@ -88,8 +88,9 @@ void setSession(long ssl, OpenSslSession session, String host, int port) {
}
nativeSslSession.setLastAccessedTime(System.currentTimeMillis());
session.setSessionDetails(nativeSslSession.getCreationTime(), nativeSslSession.getLastAccessedTime(),
nativeSslSession.sessionId());
nativeSslSession.sessionId(), nativeSslSession.keyValueStorage);
}
return reused;
}

private static HostPort keyFor(String host, int port) {
Expand Down
33 changes: 31 additions & 2 deletions handler/src/main/java/io/netty/handler/ssl/OpenSslSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import java.security.cert.Certificate;
import java.util.Map;

/**
* {@link SSLSession} that is specific to our native implementation.
*/
interface OpenSslSession extends SSLSession {

/**
* Called on a handshake session before being exposed to a {@link javax.net.ssl.TrustManager}.
* Session data must be cleared by this call.
*/
void prepareHandshake();

/**
* Return the {@link OpenSslSessionId} that can be used to identify this session.
*/
Expand All @@ -36,9 +43,31 @@ interface OpenSslSession extends SSLSession {
void setLocalCertificate(Certificate[] localCertificate);

/**
* Set the {@link OpenSslSessionId} for the {@link OpenSslSession}.
* Set the details for the session which might come from a cache.
*
* @param creationTime the time at which the session was created.
* @param lastAccessedTime the time at which the session was last accessed via the session infrastructure (cache).
* @param id the {@link OpenSslSessionId}
* @param keyValueStorage the key value store. See {@link #keyValueStorage()}.
*/
void setSessionDetails(long creationTime, long lastAccessedTime, OpenSslSessionId id,
Map<String, Object> keyValueStorage);

/**
* Return the underlying {@link Map} that is used by the following methods:
*
* <ul>
* <li>{@link #putValue(String, Object)}</li>
* <li>{@link #removeValue(String)}</li>
* <li>{@link #getValue(String)}</li>
* <li> {@link #getValueNames()}</li>
* </ul>
*
* The {@link Map} must be thread-safe!
*
* @return storage
*/
void setSessionDetails(long creationTime, long lastAccessedTime, OpenSslSessionId id);
Map<String, Object> keyValueStorage();

/**
* Set the last access time which will be returned by {@link #getLastAccessedTime()}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

/**
Expand Down Expand Up @@ -148,10 +147,14 @@ public boolean sessionCreated(long ssl, long sslSession) {
// We couldn't find the engine itself.
return false;
}
OpenSslSession openSslSession = (OpenSslSession) engine.getSession();
// Create the native session that we will put into our cache. We will share the key-value storage
// with the already existing session instance.
NativeSslSession session = new NativeSslSession(sslSession, engine.getPeerHost(), engine.getPeerPort(),
getSessionTimeout() * 1000L);
((OpenSslSession) engine.getSession()).setSessionDetails(
session.creationTime, session.lastAccessedTime, session.sessionId());
getSessionTimeout() * 1000L, openSslSession.keyValueStorage());

openSslSession.setSessionDetails(
session.creationTime, session.lastAccessedTime, session.sessionId(), session.keyValueStorage);
synchronized (this) {
// Mimic what OpenSSL is doing and expunge every 255 new sessions
// See https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_flush_sessions.html
Expand Down Expand Up @@ -209,14 +212,15 @@ public final long getSession(long ssl, byte[] sessionId) {
if (engine != null) {
OpenSslSession sslSession = (OpenSslSession) engine.getSession();
sslSession.setSessionDetails(session.getCreationTime(),
session.getLastAccessedTime(), session.sessionId());
session.getLastAccessedTime(), session.sessionId(), session.keyValueStorage);
}

return session.session();
}

void setSession(long ssl, OpenSslSession session, String host, int port) {
boolean setSession(long ssl, OpenSslSession session, String host, int port) {
// Do nothing by default as this needs special handling for the client side.
return false;
}

/**
Expand Down Expand Up @@ -293,6 +297,9 @@ static final class NativeSslSession implements OpenSslSession {
static final ResourceLeakDetector<NativeSslSession> LEAK_DETECTOR = ResourceLeakDetectorFactory.instance()
.newResourceLeakDetector(NativeSslSession.class);
private final ResourceLeakTracker<NativeSslSession> leakTracker;

final Map<String, Object> keyValueStorage;

private final long session;
private final String peerHost;
private final int peerPort;
Expand All @@ -303,17 +310,30 @@ static final class NativeSslSession implements OpenSslSession {
private volatile boolean valid = true;
private boolean freed;

NativeSslSession(long session, String peerHost, int peerPort, long timeout) {
NativeSslSession(long session, String peerHost, int peerPort, long timeout,
Map<String, Object> keyValueStorage) {
this.session = session;
this.peerHost = peerHost;
this.peerPort = peerPort;
this.timeout = timeout;
this.id = new OpenSslSessionId(io.netty.internal.tcnative.SSLSession.getSessionId(session));
this.keyValueStorage = keyValueStorage;
leakTracker = LEAK_DETECTOR.track(this);
}

@Override
public void setSessionDetails(long creationTime, long lastAccessedTime, OpenSslSessionId id) {
public Map<String, Object> keyValueStorage() {
return keyValueStorage;
}

@Override
public void prepareHandshake() {
throw new UnsupportedOperationException();
}

@Override
public void setSessionDetails(long creationTime, long lastAccessedTime,
OpenSslSessionId id, Map<String, Object> keyValueStorage) {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ final boolean isInCache(OpenSslSessionId id) {
return sessionCache.containsSessionWithId(id);
}

void setSessionFromCache(long ssl, OpenSslSession session, String host, int port) {
sessionCache.setSession(ssl, session, host, port);
boolean setSessionFromCache(long ssl, OpenSslSession session, String host, int port) {
return sessionCache.setSession(ssl, session, host, port);
}

final void destroy() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock;

import javax.crypto.spec.SecretKeySpec;
Expand Down Expand Up @@ -1980,7 +1981,11 @@ private SSLEngineResult.HandshakeStatus handshake() throws SSLException {
engineMap.add(this);

if (!sessionSet) {
parentContext.sessionContext().setSessionFromCache(ssl, session, getPeerHost(), getPeerPort());
if (!parentContext.sessionContext().setSessionFromCache(ssl, session, getPeerHost(), getPeerPort())) {
// The session was not reused via the cache. Call prepareHandshake() to ensure we remove all previous
// stored key-value pairs.
session.prepareHandshake();
}
sessionSet = true;
}

Expand Down Expand Up @@ -2365,7 +2370,7 @@ private final class DefaultOpenSslSession implements OpenSslSession {

private volatile int applicationBufferSize = MAX_PLAINTEXT_LENGTH;
private volatile Certificate[] localCertificateChain;
private Map<String, Object> values;
private volatile Map<String, Object> keyValueStorage = new ConcurrentHashMap<String, Object>();

DefaultOpenSslSession(OpenSslSessionContext sessionContext) {
this.sessionContext = sessionContext;
Expand All @@ -2375,18 +2380,34 @@ private SSLSessionBindingEvent newSSLSessionBindingEvent(String name) {
return new SSLSessionBindingEvent(session, name);
}

@Override
public void prepareHandshake() {
keyValueStorage.clear();
}

@Override
public void setSessionDetails(
long creationTime, long lastAccessedTime, OpenSslSessionId sessionId) {
long creationTime, long lastAccessedTime, OpenSslSessionId sessionId,
Map<String, Object> keyValueStorage) {
synchronized (ReferenceCountedOpenSslEngine.this) {
if (this.id == OpenSslSessionId.NULL_ID) {
this.id = sessionId;
this.creationTime = creationTime;
this.lastAccessed = lastAccessedTime;

// Update the key value storage. It's fine to just drop the previous stored values on the floor
// as the JDK does the same in the sense that it will use a new SSLSessionImpl instance once the
// handshake was done
this.keyValueStorage = keyValueStorage;
}
}
}

@Override
public Map<String, Object> keyValueStorage() {
return keyValueStorage;
}

@Override
public OpenSslSessionId sessionId() {
synchronized (ReferenceCountedOpenSslEngine.this) {
Expand Down Expand Up @@ -2458,16 +2479,7 @@ public void putValue(String name, Object value) {
checkNotNull(name, "name");
checkNotNull(value, "value");

final Object old;
synchronized (this) {
Map<String, Object> values = this.values;
if (values == null) {
// Use size of 2 to keep the memory overhead small
values = this.values = new HashMap<String, Object>(2);
}
old = values.put(name, value);
}

final Object old = keyValueStorage.put(name, value);
if (value instanceof SSLSessionBindingListener) {
// Use newSSLSessionBindingEvent so we always use the wrapper if needed.
((SSLSessionBindingListener) value).valueBound(newSSLSessionBindingEvent(name));
Expand All @@ -2478,39 +2490,19 @@ public void putValue(String name, Object value) {
@Override
public Object getValue(String name) {
checkNotNull(name, "name");
synchronized (this) {
if (values == null) {
return null;
}
return values.get(name);
}
return keyValueStorage.get(name);
}

@Override
public void removeValue(String name) {
checkNotNull(name, "name");

final Object old;
synchronized (this) {
Map<String, Object> values = this.values;
if (values == null) {
return;
}
old = values.remove(name);
}

final Object old = keyValueStorage.remove(name);
notifyUnbound(old, name);
}

@Override
public String[] getValueNames() {
synchronized (this) {
Map<String, Object> values = this.values;
if (values == null || values.isEmpty()) {
return EMPTY_STRINGS;
}
return values.keySet().toArray(EMPTY_STRINGS);
}
return keyValueStorage.keySet().toArray(EMPTY_STRINGS);
}

private void notifyUnbound(Object value, String name) {
Expand All @@ -2532,6 +2524,7 @@ public void handshakeFinished(byte[] id, String cipher, String protocol, byte[]
if (!isDestroyed()) {
if (this.id == OpenSslSessionId.NULL_ID) {
// if the handshake finished and it was not a resumption let ensure we try to set the id

this.id = id == null ? OpenSslSessionId.NULL_ID : new OpenSslSessionId(id);
// Once the handshake was done the lastAccessed and creationTime should be the same if we
// did not set it earlier via setSessionDetails(...)
Expand Down
Loading

0 comments on commit 84e4d01

Please sign in to comment.