Skip to content

Commit

Permalink
Provide a way to dynamically update TLS certificates (#5228)
Browse files Browse the repository at this point in the history
Motivation:

#5033 

API design note:

- `TlsKeyPair` represents a pair of `PrivateKey` and `X509Certificate`
chain.
  - This API is used as an official key pair type in Armeria.
- All APIs for specifying key pairs in `TlsSetters` have been deprecated
in favor of `TlsSetters tls(TlsKeyPair)`.
    ```java
    TlsKeyPair.of(privateKey, certificate);
    TlsKeyPair.of(privateKey, keyPassword, certificate);
    ```
- `TlsProvider` dynamically resolves a `TlsKeyPair` for the given
hostname when a connection is established.
  - DNS wildcard format is supported as a hostname.
- `*` is used as a special hostname to get the `TlsKeyPair` for the
default virtual host.
    ```java
    TlsProvider
       .builder()
        // Set the default key pair.
       .keyPair(TlsKeyPair.of(...))
       // Set the key pair for "*.example.com".
       .keyPair("*.example.com", TlsKeyPair.of(...))
       .build();
    ```
- To dynamically update/reload `TlsKeyPair`, a custom `TlsProvider` can
be implemented.
    ```java
    class DynamicTlsProvider implements TlsProvider {
       @OverRide
       public TlsKeyPair keyPair(String hostname) {
           // relodableCache will be updated periodically by a scheduler
           return relodableCache.get(hostname);
       }
    }
    ```
- The newly returned key pair is used for the TLS handshake of new
connections.
- `ServerTlsConfig` and `ClientTlsConfig` are added to override the
default values and customize `SslContextBuilder`.
- Unlike `TlsProvider`, `*TlsConfig` are immutable so all `TlsKeyPair`s
returned by a `TlsProvider` build `SslContext` with the same
configuration.
- Both server and client allow `TlsProvider` and `TlsKeyPair` for TLS
configurations.
  ```java
  Server
    .builder()
    // For dynamic usage
    .tlsProvider(tlsProvider)
    // For customizing TLS
    .tlsProvider(tlsProvider, serverTlsConfig)
    // For sample usage
    .tls(tlsKeyPair)
    .build()

  ClientFactory
    .builder()
    // For dynamic usage
    .tlsProvider(tlsProvider)
    // For customizing TLS
    .tlsProvider(tlsProvider, clientTlsConfig)
    // For sample usage
    .tls(tlsKeyPair)
    .build()
  ```
- Some internal implementations for TLS handshake have been changed to
create `SslContext` dynamically.
  

Modifications:


- Server
- Add `TlsProviderMapping` that converts `TlsProvider` into SslContext
`Mapping` for `SniHandler`.
- A dynamic `TlsProvider` can be used to update the certificates without
`Server.reconfigure()`.
  - Add a setter method for `TlsProvider` to `ServerBuilder`.
- A builder method for `VirtualHost` isn't added because a `TlsProvider`
can contain multiple certificates.
- If necessary, I will consider `TlsProvider` at the virtual host level
later.
- Client
- Fix `Bootstraps` to create a `Bootstrap` with a `TlsKeyPair` returned
by `TlsProvider` when a new connection is created.
- If no `TlsProvider` is set, the original behavior that returns
predefined `BootStraap` is used.
  - Add options for `TlsProvider` to `ClientFactoryBuilder`.
- Common
  - `TlsProvider` provides separate builders for the client and server.
- `TlsKeyPair` provides various factory methods to easily create a key
pair from different resources.
  - Cache `SslContext`s and expire them after 1 hour of inactivity.
- If you think that the caching strategy will not be effective, I am
willing to revert it.
- Add `CloseableMeterBinder` to unregister when the associated resource
is unused.
- Deprecate) The following APIs have been deprecated:
  - `TlsSetters tls(File keyCertChainFile, File keyFile)`
- `TlsSetters tls(File keyCertChainFile, File keyFile, @nullable String
keyPassword)`
- `TlsSetters tls(InputStream keyCertChainInputStream, InputStream
keyInputStream)`
- `TlsSetters tls(InputStream keyCertChainInputStream, InputStream
keyInputStream,
                   @nullable String keyPassword)`
  - `TlsSetters tls(PrivateKey key, X509Certificate... keyCertChain)`
- `TlsSetters tls(PrivateKey key, Iterable<? extends X509Certificate>
keyCertChain)`
- `TlsSetters tls(PrivateKey key, @nullable String keyPassword,
X509Certificate... keyCertChain)`
- `TlsSetters tls(PrivateKey key, @nullable String keyPassword,
Iterable<? extends X509Certificate> keyCertChain)`


Result:

- Closes #5033 
- You can now set TLS configurations dynamically using `TlsProvider`

---------

Co-authored-by: Ikhun Um <[email protected]>
Co-authored-by: jrhee17 <[email protected]>
  • Loading branch information
3 people authored Nov 7, 2024
1 parent 1bb781a commit 6628513
Show file tree
Hide file tree
Showing 56 changed files with 3,703 additions and 371 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public class RoutersBenchmark {
FALLBACK_SERVICE = newServiceConfig(Route.ofCatchAll());
HOST = new VirtualHost(
"localhost", "localhost", 0, null,
null, SERVICES, FALLBACK_SERVICE, RejectedRouteHandler.DISABLED,
null, null, SERVICES, FALLBACK_SERVICE, RejectedRouteHandler.DISABLED,
unused -> NOPLogger.NOP_LOGGER, FALLBACK_SERVICE.defaultServiceNaming(),
FALLBACK_SERVICE.defaultLogName(), 0, 0, false,
AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), 0, SuccessFunction.ofDefault(),
Expand Down
4 changes: 4 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ allprojects {
doFirst {
addTestOutputListener({ descriptor, event ->
if (event.message.contains('LEAK: ')) {
if (isCi) {
logger.warn("Leak is detected in ${descriptor.className}.${descriptor.displayName}\n" +
"${event.message}")
}
hasLeak.set(true)
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import java.util.function.Function;
import java.util.function.Supplier;

import com.google.common.collect.ImmutableList;

import com.linecorp.armeria.client.endpoint.EndpointGroup;
import com.linecorp.armeria.client.redirect.RedirectConfig;
import com.linecorp.armeria.common.HttpHeaderNames;
Expand Down Expand Up @@ -532,20 +534,20 @@ protected final ClientOptions buildOptions() {
*/
protected final ClientOptions buildOptions(@Nullable ClientOptions baseOptions) {
final Collection<ClientOptionValue<?>> optVals = options.values();
final int numOpts = optVals.size();
final int extra = contextCustomizer == null ? 3 : 4;
final ClientOptionValue<?>[] optValArray = optVals.toArray(new ClientOptionValue[numOpts + extra]);
optValArray[numOpts] = ClientOptions.DECORATION.newValue(decoration.build());
optValArray[numOpts + 1] = ClientOptions.HEADERS.newValue(headers.build());
optValArray[numOpts + 2] = ClientOptions.CONTEXT_HOOK.newValue(contextHook);
final ImmutableList.Builder<ClientOptionValue<?>> additionalValues =
ImmutableList.builder();
additionalValues.addAll(optVals);
additionalValues.add(ClientOptions.DECORATION.newValue(decoration.build()));
additionalValues.add(ClientOptions.HEADERS.newValue(headers.build()));
additionalValues.add(ClientOptions.CONTEXT_HOOK.newValue(contextHook));
if (contextCustomizer != null) {
optValArray[numOpts + 3] = ClientOptions.CONTEXT_CUSTOMIZER.newValue(contextCustomizer);
additionalValues.add(ClientOptions.CONTEXT_CUSTOMIZER.newValue(contextCustomizer));
}

if (baseOptions != null) {
return ClientOptions.of(baseOptions, optValArray);
return ClientOptions.of(baseOptions, additionalValues.build());
} else {
return ClientOptions.of(optValArray);
return ClientOptions.of(additionalValues.build());
}
}
}
174 changes: 123 additions & 51 deletions core/src/main/java/com/linecorp/armeria/client/Bootstraps.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.internal.common.SslContextFactory;
import com.linecorp.armeria.internal.common.SslContextFactory.SslContextMode;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
Expand All @@ -36,65 +38,51 @@

final class Bootstraps {

private final Bootstrap[][] inetBootstraps;
private final Bootstrap @Nullable [][] unixBootstraps;
private final EventLoop eventLoop;
private final SslContext sslCtxHttp1Only;
private final SslContext sslCtxHttp1Or2;
@Nullable
private final SslContextFactory sslContextFactory;

private final HttpClientFactory clientFactory;
private final Bootstrap inetBaseBootstrap;
@Nullable
private final Bootstrap unixBaseBootstrap;
private final Bootstrap[][] inetBootstraps;
private final Bootstrap @Nullable [][] unixBootstraps;

Bootstraps(HttpClientFactory clientFactory, EventLoop eventLoop, SslContext sslCtxHttp1Or2,
SslContext sslCtxHttp1Only) {
Bootstraps(HttpClientFactory clientFactory, EventLoop eventLoop,
SslContext sslCtxHttp1Or2, SslContext sslCtxHttp1Only,
@Nullable SslContextFactory sslContextFactory) {
this.eventLoop = eventLoop;
this.sslCtxHttp1Or2 = sslCtxHttp1Or2;
this.sslCtxHttp1Only = sslCtxHttp1Only;
this.sslContextFactory = sslContextFactory;
this.clientFactory = clientFactory;

inetBaseBootstrap = clientFactory.newInetBootstrap();
inetBaseBootstrap.group(eventLoop);
inetBootstraps = staticBootstrapMap(inetBaseBootstrap);

final Bootstrap inetBaseBootstrap = clientFactory.newInetBootstrap();
final Bootstrap unixBaseBootstrap = clientFactory.newUnixBootstrap();
inetBootstraps = newBootstrapMap(inetBaseBootstrap, clientFactory, eventLoop);
unixBaseBootstrap = clientFactory.newUnixBootstrap();
if (unixBaseBootstrap != null) {
unixBootstraps = newBootstrapMap(unixBaseBootstrap, clientFactory, eventLoop);
unixBaseBootstrap.group(eventLoop);
unixBootstraps = staticBootstrapMap(unixBaseBootstrap);
} else {
unixBootstraps = null;
}
}

/**
* Returns a {@link Bootstrap} corresponding to the specified {@link SocketAddress}
* {@link SessionProtocol} and {@link SerializationFormat}.
*/
Bootstrap get(SocketAddress remoteAddress, SessionProtocol desiredProtocol,
SerializationFormat serializationFormat) {
if (!httpAndHttpsValues().contains(desiredProtocol)) {
throw new IllegalArgumentException("Unsupported session protocol: " + desiredProtocol);
}

if (remoteAddress instanceof InetSocketAddress) {
return select(inetBootstraps, desiredProtocol, serializationFormat);
}

assert remoteAddress instanceof DomainSocketAddress : remoteAddress;

if (unixBootstraps == null) {
throw new IllegalArgumentException("Domain sockets are not supported by " +
eventLoop.getClass().getName());
}

return select(unixBootstraps, desiredProtocol, serializationFormat);
}

private Bootstrap[][] newBootstrapMap(Bootstrap baseBootstrap,
HttpClientFactory clientFactory,
EventLoop eventLoop) {
baseBootstrap.group(eventLoop);
private Bootstrap[][] staticBootstrapMap(Bootstrap baseBootstrap) {
final Set<SessionProtocol> sessionProtocols = httpAndHttpsValues();
final Bootstrap[][] maps = (Bootstrap[][]) Array.newInstance(
Bootstrap.class, SessionProtocol.values().length, 2);
// Attempting to access the array with an unallowed protocol will trigger NPE,
// which will help us find a bug.
for (SessionProtocol p : sessionProtocols) {
final SslContext sslCtx = determineSslContext(p);
setBootstrap(baseBootstrap.clone(), clientFactory, maps, p, sslCtx, true);
setBootstrap(baseBootstrap.clone(), clientFactory, maps, p, sslCtx, false);
createAndSetBootstrap(baseBootstrap, maps, p, sslCtx, true);
createAndSetBootstrap(baseBootstrap, maps, p, sslCtx, false);
}
return maps;
}
Expand All @@ -106,22 +94,18 @@ SslContext determineSslContext(SessionProtocol desiredProtocol) {
return desiredProtocol.isExplicitHttp1() ? sslCtxHttp1Only : sslCtxHttp1Or2;
}

private static Bootstrap select(Bootstrap[][] bootstraps, SessionProtocol desiredProtocol,
SerializationFormat serializationFormat) {
private Bootstrap select(boolean isDomainSocket, SessionProtocol desiredProtocol,
SerializationFormat serializationFormat) {
final Bootstrap[][] bootstraps = isDomainSocket ? unixBootstraps : inetBootstraps;
assert bootstraps != null;
return bootstraps[desiredProtocol.ordinal()][toIndex(serializationFormat)];
}

private static void setBootstrap(Bootstrap bootstrap, HttpClientFactory clientFactory, Bootstrap[][] maps,
SessionProtocol p, SslContext sslCtx, boolean webSocket) {
bootstrap.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(new HttpClientPipelineConfigurator(
clientFactory, webSocket, p, sslCtx));
}
}
);
maps[p.ordinal()][toIndex(webSocket)] = bootstrap;
private void createAndSetBootstrap(Bootstrap baseBootstrap, Bootstrap[][] maps,
SessionProtocol desiredProtocol, SslContext sslContext,
boolean webSocket) {
maps[desiredProtocol.ordinal()][toIndex(webSocket)] = newBootstrap(baseBootstrap, desiredProtocol,
sslContext, webSocket, false);
}

private static int toIndex(boolean webSocket) {
Expand All @@ -131,4 +115,92 @@ private static int toIndex(boolean webSocket) {
private static int toIndex(SerializationFormat serializationFormat) {
return toIndex(serializationFormat == SerializationFormat.WS);
}

/**
* Returns a {@link Bootstrap} corresponding to the specified {@link SocketAddress}
* {@link SessionProtocol} and {@link SerializationFormat}.
*/
Bootstrap getOrCreate(SocketAddress remoteAddress, SessionProtocol desiredProtocol,
SerializationFormat serializationFormat) {
if (!httpAndHttpsValues().contains(desiredProtocol)) {
throw new IllegalArgumentException("Unsupported session protocol: " + desiredProtocol);
}

final boolean isDomainSocket = remoteAddress instanceof DomainSocketAddress;
if (isDomainSocket && unixBaseBootstrap == null) {
throw new IllegalArgumentException("Domain sockets are not supported by " +
eventLoop.getClass().getName());
}

if (sslContextFactory == null || !desiredProtocol.isTls()) {
return select(isDomainSocket, desiredProtocol, serializationFormat);
}

final Bootstrap baseBootstrap = isDomainSocket ? unixBaseBootstrap : inetBaseBootstrap;
assert baseBootstrap != null;
return newBootstrap(baseBootstrap, remoteAddress, desiredProtocol, serializationFormat);
}

private Bootstrap newBootstrap(Bootstrap baseBootstrap, SocketAddress remoteAddress,
SessionProtocol desiredProtocol,
SerializationFormat serializationFormat) {
final boolean webSocket = serializationFormat == SerializationFormat.WS;
final SslContext sslContext = newSslContext(remoteAddress, desiredProtocol);
return newBootstrap(baseBootstrap, desiredProtocol, sslContext, webSocket, true);
}

private Bootstrap newBootstrap(Bootstrap baseBootstrap, SessionProtocol desiredProtocol,
SslContext sslContext, boolean webSocket, boolean closeSslContext) {
final Bootstrap bootstrap = baseBootstrap.clone();
bootstrap.handler(clientChannelInitializer(desiredProtocol, sslContext, webSocket, closeSslContext));
return bootstrap;
}

SslContext getOrCreateSslContext(SocketAddress remoteAddress, SessionProtocol desiredProtocol) {
if (sslContextFactory == null) {
return determineSslContext(desiredProtocol);
} else {
return newSslContext(remoteAddress, desiredProtocol);
}
}

private SslContext newSslContext(SocketAddress remoteAddress, SessionProtocol desiredProtocol) {
final String hostname;
if (remoteAddress instanceof InetSocketAddress) {
hostname = ((InetSocketAddress) remoteAddress).getHostString();
} else {
assert remoteAddress instanceof DomainSocketAddress;
hostname = "unix:" + ((DomainSocketAddress) remoteAddress).path();
}

final SslContextMode sslContextMode =
desiredProtocol.isExplicitHttp1() ? SslContextFactory.SslContextMode.CLIENT_HTTP1_ONLY
: SslContextFactory.SslContextMode.CLIENT;
assert sslContextFactory != null;
return sslContextFactory.getOrCreate(sslContextMode, hostname);
}

boolean shouldReleaseSslContext(SslContext sslContext) {
return sslContext != sslCtxHttp1Only && sslContext != sslCtxHttp1Or2;
}

void releaseSslContext(SslContext sslContext) {
if (sslContextFactory != null) {
sslContextFactory.release(sslContext);
}
}

private ChannelInitializer<Channel> clientChannelInitializer(SessionProtocol p, SslContext sslCtx,
boolean webSocket, boolean closeSslContext) {
return new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
if (closeSslContext) {
ch.closeFuture().addListener(unused -> releaseSslContext(sslCtx));
}
ch.pipeline().addLast(new HttpClientPipelineConfigurator(
clientFactory, webSocket, p, sslCtx));
}
};
}
}
Loading

0 comments on commit 6628513

Please sign in to comment.