Skip to content

Commit

Permalink
Add support for specifying SecureRandom in SSLContext initialization (#…
Browse files Browse the repository at this point in the history
…14058)

Motivation:

Enhance security by supporting specific secure randomness source in
SSLContext initialization

Modification:

Support building SecureRandom in `SslContextBuilder`.
Allow passing SecureRandom as a parameter when creating an instance of
`JdkSslServerContext` through its constructor.

Result:

Enhance security

Fixes #14026

---------

Co-authored-by: Norman Maurer <[email protected]>
  • Loading branch information
thxwelchs and normanmaurer authored May 22, 2024
1 parent 243de91 commit ebf0e41
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import javax.net.ssl.TrustManagerFactory;
import java.io.File;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;

/**
Expand Down Expand Up @@ -175,7 +176,7 @@ public JdkSslClientContext(
long sessionCacheSize, long sessionTimeout) throws SSLException {
super(newSSLContext(provider, toX509CertificatesInternal(trustCertCollectionFile),
trustManagerFactory, null, null,
null, null, sessionCacheSize, sessionTimeout, KeyStore.getDefaultType()), true,
null, null, sessionCacheSize, sessionTimeout, null, KeyStore.getDefaultType()), true,
ciphers, cipherFilter, apn, ClientAuth.NONE, null, false);
}

Expand Down Expand Up @@ -258,7 +259,8 @@ public JdkSslClientContext(File trustCertCollectionFile, TrustManagerFactory tru
super(newSSLContext(null, toX509CertificatesInternal(
trustCertCollectionFile), trustManagerFactory,
toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout, KeyStore.getDefaultType()), true,
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout,
null, KeyStore.getDefaultType()), true,
ciphers, cipherFilter, apn, ClientAuth.NONE, null, false);
}

Expand All @@ -267,11 +269,11 @@ public JdkSslClientContext(File trustCertCollectionFile, TrustManagerFactory tru
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword,
KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
ApplicationProtocolConfig apn, String[] protocols, long sessionCacheSize, long sessionTimeout,
String keyStoreType)
SecureRandom secureRandom, String keyStoreType)
throws SSLException {
super(newSSLContext(sslContextProvider, trustCertCollection, trustManagerFactory,
keyCertChain, key, keyPassword, keyManagerFactory, sessionCacheSize,
sessionTimeout, keyStoreType),
sessionTimeout, secureRandom, keyStoreType),
true, ciphers, cipherFilter, toNegotiator(apn, false), ClientAuth.NONE, protocols, false);
}

Expand All @@ -280,7 +282,7 @@ private static SSLContext newSSLContext(Provider sslContextProvider,
TrustManagerFactory trustManagerFactory, X509Certificate[] keyCertChain,
PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
long sessionCacheSize, long sessionTimeout,
String keyStore) throws SSLException {
SecureRandom secureRandom, String keyStore) throws SSLException {
try {
if (trustCertCollection != null) {
trustManagerFactory = buildTrustManagerFactory(trustCertCollection, trustManagerFactory, keyStore);
Expand All @@ -293,7 +295,7 @@ private static SSLContext newSSLContext(Provider sslContextProvider,
: SSLContext.getInstance(PROTOCOL, sslContextProvider);
ctx.init(keyManagerFactory == null ? null : keyManagerFactory.getKeyManagers(),
trustManagerFactory == null ? null : trustManagerFactory.getTrustManagers(),
null);
secureRandom);

SSLSessionContext sessCtx = ctx.getClientSessionContext();
if (sessionCacheSize > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import javax.net.ssl.X509ExtendedTrustManager;
import java.io.File;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
Expand Down Expand Up @@ -207,7 +208,7 @@ public JdkSslServerContext(
long sessionCacheSize, long sessionTimeout, String keyStore) throws SSLException {
super(newSSLContext(provider, null, null,
toX509CertificatesInternal(certChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, null, sessionCacheSize, sessionTimeout, keyStore), false,
keyPassword, null, sessionCacheSize, sessionTimeout, null, keyStore), false,
ciphers, cipherFilter, apn, ClientAuth.NONE, null, false);
}

Expand Down Expand Up @@ -247,7 +248,7 @@ public JdkSslServerContext(File trustCertCollectionFile, TrustManagerFactory tru
long sessionCacheSize, long sessionTimeout) throws SSLException {
super(newSSLContext(null, toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory,
toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout, null), false,
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout, null, null), false,
ciphers, cipherFilter, apn, ClientAuth.NONE, null, false);
}

Expand Down Expand Up @@ -288,7 +289,8 @@ public JdkSslServerContext(File trustCertCollectionFile, TrustManagerFactory tru
long sessionCacheSize, long sessionTimeout) throws SSLException {
super(newSSLContext(null, toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory,
toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout, KeyStore.getDefaultType()), false,
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout,
null, KeyStore.getDefaultType()), false,
ciphers, cipherFilter, apn, ClientAuth.NONE, null, false);
}

Expand All @@ -298,16 +300,17 @@ public JdkSslServerContext(File trustCertCollectionFile, TrustManagerFactory tru
KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout,
ClientAuth clientAuth, String[] protocols, boolean startTls,
String keyStore) throws SSLException {
SecureRandom secureRandom, String keyStore) throws SSLException {
super(newSSLContext(provider, trustCertCollection, trustManagerFactory, keyCertChain, key,
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout, keyStore), false,
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout, secureRandom, keyStore), false,
ciphers, cipherFilter, toNegotiator(apn, true), clientAuth, protocols, startTls);
}

private static SSLContext newSSLContext(Provider sslContextProvider, X509Certificate[] trustCertCollection,
TrustManagerFactory trustManagerFactory, X509Certificate[] keyCertChain,
PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
long sessionCacheSize, long sessionTimeout, String keyStore)
TrustManagerFactory trustManagerFactory, X509Certificate[] keyCertChain,
PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
long sessionCacheSize, long sessionTimeout,
SecureRandom secureRandom, String keyStore)
throws SSLException {
if (key == null && keyManagerFactory == null) {
throw new NullPointerException("key, keyManagerFactory");
Expand All @@ -333,7 +336,7 @@ private static SSLContext newSSLContext(Provider sslContextProvider, X509Certifi
: SSLContext.getInstance(PROTOCOL, sslContextProvider);
ctx.init(keyManagerFactory.getKeyManagers(),
wrapTrustManagerIfNeeded(trustManagerFactory.getTrustManagers()),
null);
secureRandom);

SSLSessionContext sessCtx = ctx.getServerSessionContext();
if (sessionCacheSize > 0) {
Expand Down
15 changes: 9 additions & 6 deletions handler/src/main/java/io/netty/handler/ssl/SslContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
Expand Down Expand Up @@ -441,7 +442,7 @@ trustManagerFactory, toX509Certificates(keyCertChainFile),
toPrivateKey(keyFile, keyPassword),
keyPassword, keyManagerFactory, ciphers, cipherFilter, apn,
sessionCacheSize, sessionTimeout, ClientAuth.NONE, null,
false, false, keyStore);
false, false, null, keyStore);
} catch (Exception e) {
if (e instanceof SSLException) {
throw (SSLException) e;
Expand All @@ -457,7 +458,8 @@ static SslContext newServerContextInternal(
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls,
boolean enableOcsp, String keyStoreType, Map.Entry<SslContextOption<?>, Object>... ctxOptions)
boolean enableOcsp, SecureRandom secureRandom, String keyStoreType,
Map.Entry<SslContextOption<?>, Object>... ctxOptions)
throws SSLException {

if (provider == null) {
Expand All @@ -472,7 +474,7 @@ static SslContext newServerContextInternal(
return new JdkSslServerContext(sslContextProvider,
trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout,
clientAuth, protocols, startTls, keyStoreType);
clientAuth, protocols, startTls, secureRandom, keyStoreType);
case OPENSSL:
verifyNullSslContextProvider(provider, sslContextProvider);
return new OpenSslServerContext(
Expand Down Expand Up @@ -801,7 +803,7 @@ public static SslContext newClientContext(
toX509Certificates(keyCertChainFile), toPrivateKey(keyFile, keyPassword),
keyPassword, keyManagerFactory, ciphers, cipherFilter,
apn, null, sessionCacheSize, sessionTimeout, false,
KeyStore.getDefaultType());
null, KeyStore.getDefaultType());
} catch (Exception e) {
if (e instanceof SSLException) {
throw (SSLException) e;
Expand All @@ -816,7 +818,8 @@ static SslContext newClientContextInternal(
X509Certificate[] trustCert, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, String[] protocols,
long sessionCacheSize, long sessionTimeout, boolean enableOcsp, String keyStoreType,
long sessionCacheSize, long sessionTimeout, boolean enableOcsp,
SecureRandom secureRandom, String keyStoreType,
Map.Entry<SslContextOption<?>, Object>... options) throws SSLException {
if (provider == null) {
provider = defaultClientProvider();
Expand All @@ -829,7 +832,7 @@ static SslContext newClientContextInternal(
return new JdkSslClientContext(sslContextProvider,
trustCert, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, protocols, sessionCacheSize,
sessionTimeout, keyStoreType);
sessionTimeout, secureRandom, keyStoreType);
case OPENSSL:
verifyNullSslContextProvider(provider, sslContextProvider);
OpenSsl.ensureAvailability();
Expand Down
22 changes: 20 additions & 2 deletions handler/src/main/java/io/netty/handler/ssl/SslContextBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.Provider;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -206,6 +207,7 @@ public static SslContextBuilder forServer(KeyManager keyManager) {
private String[] protocols;
private boolean startTls;
private boolean enableOcsp;
private SecureRandom secureRandom;
private String keyStoreType = KeyStore.getDefaultType();
private final Map<SslContextOption<?>, Object> options = new HashMap<SslContextOption<?>, Object>();

Expand Down Expand Up @@ -600,6 +602,21 @@ public SslContextBuilder enableOcsp(boolean enableOcsp) {
return this;
}

/**
* Specify a non-default source of randomness for the {@link JdkSslContext}
* <p>
* In general, the best practice is to leave this unspecified, or to assign a new random source using the
* default {@code new SecureRandom()} constructor.
* Only assign this something when you have a good reason to.
*
* @param secureRandom the source of randomness for {@link JdkSslContext}
*
*/
public SslContextBuilder secureRandom(SecureRandom secureRandom) {
this.secureRandom = secureRandom;
return this;
}

/**
* Create new {@code SslContext} instance with configured settings.
* <p>If {@link #sslProvider(SslProvider)} is set to {@link SslProvider#OPENSSL_REFCNT} then the caller is
Expand All @@ -610,11 +627,12 @@ public SslContext build() throws SSLException {
return SslContext.newServerContextInternal(provider, sslContextProvider, trustCertCollection,
trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory,
ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, clientAuth, protocols, startTls,
enableOcsp, keyStoreType, toArray(options.entrySet(), EMPTY_ENTRIES));
enableOcsp, secureRandom, keyStoreType, toArray(options.entrySet(), EMPTY_ENTRIES));
} else {
return SslContext.newClientContextInternal(provider, sslContextProvider, trustCertCollection,
trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory,
ciphers, cipherFilter, apn, protocols, sessionCacheSize, sessionTimeout, enableOcsp, keyStoreType,
ciphers, cipherFilter, apn, protocols, sessionCacheSize,
sessionTimeout, enableOcsp, secureRandom, keyStoreType,
toArray(options.entrySet(), EMPTY_ENTRIES));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.net.Socket;
import java.security.Principal;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Collections;
Expand Down Expand Up @@ -244,6 +245,16 @@ public void testInvalidCipherOpenSSL() throws Exception {
}
}

@Test
public void testServerContextWithSecureRandom() throws Exception {
testServerContextWithSecureRandom(SslProvider.JDK, new SpySecureRandom());
}

@Test
public void testClientContextWithSecureRandom() throws Exception {
testClientContextWithSecureRandom(SslProvider.JDK, new SpySecureRandom());
}

private static void testKeyStoreType(SslProvider provider) throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate();
SslContextBuilder builder = SslContextBuilder.forServer(cert.certificate(), cert.privateKey())
Expand Down Expand Up @@ -326,6 +337,41 @@ private static void testServerContext(SslProvider provider) throws Exception {
engine.closeOutbound();
}

private static void testServerContextWithSecureRandom(SslProvider provider,
SpySecureRandom secureRandom) throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate();
SslContextBuilder builder = SslContextBuilder.forServer(cert.key(), cert.cert())
.sslProvider(provider)
.secureRandom(secureRandom)
.trustManager(cert.cert())
.clientAuth(ClientAuth.REQUIRE);
SslContext context = builder.build();
SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT);
assertFalse(engine.getWantClientAuth());
assertTrue(engine.getNeedClientAuth());
assertTrue(secureRandom.getCount() > 0);
engine.closeInbound();
engine.closeOutbound();
}

private static void testClientContextWithSecureRandom(SslProvider provider,
SpySecureRandom secureRandom) throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate();
SslContextBuilder builder = SslContextBuilder.forClient()
.sslProvider(provider)
.secureRandom(secureRandom)
.keyManager(cert.key(), cert.cert())
.trustManager(cert.cert())
.clientAuth(ClientAuth.OPTIONAL);
SslContext context = builder.build();
SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT);
assertFalse(engine.getWantClientAuth());
assertFalse(engine.getNeedClientAuth());
assertTrue(secureRandom.getCount() > 0);
engine.closeInbound();
engine.closeOutbound();
}

private static void testContextFromManagers(SslProvider provider) throws Exception {
final SelfSignedCertificate cert = new SelfSignedCertificate();
KeyManager customKeyManager = new X509ExtendedKeyManager() {
Expand Down Expand Up @@ -425,4 +471,54 @@ public X509Certificate[] getAcceptedIssuers() {
server_engine.closeInbound();
server_engine.closeOutbound();
}

private static final class SpySecureRandom extends SecureRandom {
private int count;

@Override
public int nextInt() {
count++;
return super.nextInt();
}

@Override
public int nextInt(int bound) {
count++;
return super.nextInt(bound);
}

@Override
public long nextLong() {
count++;
return super.nextLong();
}

@Override
public boolean nextBoolean() {
count++;
return super.nextBoolean();
}

@Override
public float nextFloat() {
count++;
return super.nextFloat();
}

@Override
public double nextDouble() {
count++;
return super.nextDouble();
}

@Override
public double nextGaussian() {
count++;
return super.nextGaussian();
}

public int getCount() {
return count;
}
}
}

0 comments on commit ebf0e41

Please sign in to comment.