Skip to content

Commit

Permalink
Additional TLS mode checks
Browse files Browse the repository at this point in the history
  • Loading branch information
mtmk committed Oct 16, 2023
1 parent e681a9d commit 251a13c
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 79 deletions.
21 changes: 14 additions & 7 deletions src/NATS.Client.Core/Internal/SslStreamConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,17 @@ public void SignalDisconnected(Exception exception)
_waitForClosedSource.TrySetResult(exception);
}

public async Task AuthenticateAsClientAsync(string target)
public async Task AuthenticateAsClientAsync(NatsUri uri)
{
var options = SslClientAuthenticationOptions(target);
await _sslStream.AuthenticateAsClientAsync(options).ConfigureAwait(false);
var options = SslClientAuthenticationOptions(uri);
try
{
await _sslStream.AuthenticateAsClientAsync(options).ConfigureAwait(false);
}
catch (AuthenticationException ex)
{
throw new NatsException($"TLS authentication failed", ex);
}
}

private static X509Certificate LcsCbClientCerts(
Expand Down Expand Up @@ -123,11 +130,11 @@ private bool RcsCbCaCertChain(
return sslPolicyErrors == SslPolicyErrors.None;
}

private SslClientAuthenticationOptions SslClientAuthenticationOptions(string targetHost)
private SslClientAuthenticationOptions SslClientAuthenticationOptions(NatsUri uri)
{
if (_tlsOpts.EffectiveMode == TlsMode.Disabled)
if (_tlsOpts.EffectiveMode(uri) == TlsMode.Disable)
{
throw new InvalidOperationException("TLS is not permitted when TlsMode is set to Disabled");
throw new InvalidOperationException("TLS is not permitted when TlsMode is set to Disable");
}

LocalCertificateSelectionCallback? lcsCb = default;
Expand All @@ -148,7 +155,7 @@ private SslClientAuthenticationOptions SslClientAuthenticationOptions(string tar

var options = new SslClientAuthenticationOptions
{
TargetHost = targetHost,
TargetHost = uri.Host,
EnabledSslProtocols = SslProtocols.Tls12,
ClientCertificates = _tlsCerts?.ClientCerts,
LocalCertificateSelectionCallback = lcsCb,
Expand Down
2 changes: 1 addition & 1 deletion src/NATS.Client.Core/Internal/TlsCerts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ internal class TlsCerts
{
public TlsCerts(NatsTlsOpts tlsOpts)
{
if (tlsOpts.EffectiveMode == TlsMode.Disabled)
if (tlsOpts.Mode == TlsMode.Disable)
{
return;
}
Expand Down
33 changes: 18 additions & 15 deletions src/NATS.Client.Core/NatsConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ public partial class NatsConnection : IAsyncDisposable, INatsConnection
// when reconnect, make new instance.
private ISocketConnection? _socket;
private CancellationTokenSource? _pingTimerCancellationTokenSource;
private NatsUri? _currentConnectUri;
private NatsUri? _lastSeedConnectUri;
private volatile NatsUri? _currentConnectUri;
private volatile NatsUri? _lastSeedConnectUri;
private NatsReadProtocolProcessor? _socketReader;
private NatsPipeliningWriteProtocolProcessor? _socketWriter;
private TaskCompletionSource _waitForOpenConnection;
Expand Down Expand Up @@ -222,10 +222,13 @@ private async ValueTask InitialConnectAsync()
Debug.Assert(ConnectionState == NatsConnectionState.Connecting, "Connection state");

var uris = Opts.GetSeedUris();
if (Opts.TlsOpts.EffectiveMode == TlsMode.Require && uris.Any(u => !u.IsTls))
throw new NatsException($"URI {uris.First(u => !u.IsTls)} doesn't support TLS but TlsMode is set to Explicit");
if (Opts.TlsOpts.EffectiveMode == TlsMode.Disabled && uris.Any(u => u.IsTls))
throw new NatsException($"URI {uris.First(u => u.IsTls)} requires TLS but TlsMode is set to Disabled");

foreach (var uri in uris)
{
if (Opts.TlsOpts.EffectiveMode(uri) == TlsMode.Disable && uri.IsTls)
throw new NatsException($"URI {uri} requires TLS but TlsMode is set to Disable");
}

if (Opts.TlsOpts.HasTlsFile)
_tlsCerts = new TlsCerts(Opts.TlsOpts);

Expand Down Expand Up @@ -258,11 +261,11 @@ private async ValueTask InitialConnectAsync()
await conn.ConnectAsync(target.Host, target.Port, Opts.ConnectTimeout).ConfigureAwait(false);
_socket = conn;

if (Opts.TlsOpts.EffectiveMode == TlsMode.Implicit)
if (Opts.TlsOpts.EffectiveMode(uri) == TlsMode.Implicit)
{
// upgrade TcpConnection to SslConnection
var sslConnection = conn.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts);
await sslConnection.AuthenticateAsClientAsync(target.Host).ConfigureAwait(false);
await sslConnection.AuthenticateAsClientAsync(uri).ConfigureAwait(false);
_socket = sslConnection;
}
}
Expand Down Expand Up @@ -341,19 +344,19 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)
// check to see if we should upgrade to TLS
if (_socket is TcpConnection tcpConnection)
{
if (Opts.TlsOpts.EffectiveMode == TlsMode.Disabled && WritableServerInfo!.TlsRequired)
if (Opts.TlsOpts.EffectiveMode(_currentConnectUri) == TlsMode.Disable && WritableServerInfo!.TlsRequired)
{
throw new NatsException(
$"Server {_currentConnectUri} requires TLS but TlsMode is set to Disabled");
$"Server {_currentConnectUri} requires TLS but TlsMode is set to Disable");
}

if (Opts.TlsOpts.EffectiveMode == TlsMode.Require && !WritableServerInfo!.TlsRequired && !WritableServerInfo.TlsAvailable)
if (Opts.TlsOpts.EffectiveMode(_currentConnectUri) == TlsMode.Require && !WritableServerInfo!.TlsRequired && !WritableServerInfo.TlsAvailable)
{
throw new NatsException(
$"Server {_currentConnectUri} does not support TLS but TlsMode is set to Require");
}

if (Opts.TlsOpts.EffectiveMode == TlsMode.Prefer && (WritableServerInfo!.TlsRequired || WritableServerInfo.TlsAvailable))
if (Opts.TlsOpts.TryTls(_currentConnectUri) && (WritableServerInfo!.TlsRequired || WritableServerInfo.TlsAvailable))
{
// do TLS upgrade
// if the current URI is not a seed URI and is not a DNS hostname, check the server cert against the
Expand All @@ -375,7 +378,7 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)

// upgrade TcpConnection to SslConnection
var sslConnection = tcpConnection.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts);
await sslConnection.AuthenticateAsClientAsync(targetHost).ConfigureAwait(false);
await sslConnection.AuthenticateAsClientAsync(_currentConnectUri).ConfigureAwait(false);
_socket = sslConnection;

// create new socket reader
Expand Down Expand Up @@ -482,11 +485,11 @@ private async void ReconnectLoop()
await conn.ConnectAsync(target.Host, target.Port, Opts.ConnectTimeout).ConfigureAwait(false);
_socket = conn;

if (Opts.TlsOpts.EffectiveMode == TlsMode.Implicit)
if (Opts.TlsOpts.EffectiveMode(url) == TlsMode.Implicit)
{
// upgrade TcpConnection to SslConnection
var sslConnection = conn.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts);
await sslConnection.AuthenticateAsClientAsync(target.Host).ConfigureAwait(false);
await sslConnection.AuthenticateAsClientAsync(url).ConfigureAwait(false);
_socket = sslConnection;
}
}
Expand Down
14 changes: 11 additions & 3 deletions src/NATS.Client.Core/NatsTlsOpts.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using NATS.Client.Core.Internal;

namespace NATS.Client.Core;

/// <summary>
Expand Down Expand Up @@ -29,7 +31,7 @@ public enum TlsMode
/// <summary>
/// Disabled mode will not attempt to upgrade the connection to TLS.
/// </summary>
Disabled,
Disable,
}

/// <summary>
Expand Down Expand Up @@ -57,9 +59,15 @@ public sealed record NatsTlsOpts

internal bool HasTlsFile => CertFile != default || KeyFile != default || CaFile != default;

internal TlsMode EffectiveMode => Mode switch
internal TlsMode EffectiveMode(NatsUri uri) => Mode switch
{
TlsMode.Auto => TlsMode.Prefer,
TlsMode.Auto => HasTlsFile || uri.Uri.Scheme.ToLower() == "tls" ? TlsMode.Require : TlsMode.Prefer,
_ => Mode,
};

internal bool TryTls(NatsUri uri)
{
var effectiveMode = EffectiveMode(uri);
return effectiveMode is TlsMode.Require or TlsMode.Prefer;
}
}
15 changes: 7 additions & 8 deletions tests/NATS.Client.Core.Tests/NatsConnectionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,11 @@ await Retry.Until(
[Fact]
public async Task ReconnectSingleTest()
{
using var options = new NatsServerOpts
{
TransportType = _transportType,
EnableWebSocket = _transportType == TransportType.WebSocket,
ServerDisposeReturnsPorts = false,
};
var options = new NatsServerOptsBuilder()
.UseTransport(_transportType)
.WithServerDisposeReturnsPorts()
.Build();

await using var server = NatsServer.Start(_output, options);
var subject = Guid.NewGuid().ToString();

Expand Down Expand Up @@ -235,7 +234,7 @@ await Retry.Until(
[Fact(Timeout = 30000)]
public async Task ReconnectClusterTest()
{
await using var cluster = new NatsCluster(_output, _transportType);
await using var cluster = new NatsCluster(new NullOutputHelper(), _transportType);
await Task.Delay(TimeSpan.FromSeconds(5)); // wait for cluster completely connected.

var subject = Guid.NewGuid().ToString();
Expand All @@ -249,7 +248,7 @@ public async Task ReconnectClusterTest()
await connection3.ConnectAsync();

_output.WriteLine("Server1 ClientConnectUrls:" +
string.Join(", ", connection1.ServerInfo?.ClientConnectUrls ?? Array.Empty<string>()));
string.Join(", ", connection1.ServerInfo?.ClientConnectUrls ?? Array.Empty<string>()));
_output.WriteLine("Server2 ClientConnectUrls:" +
string.Join(", ", connection2.ServerInfo?.ClientConnectUrls ?? Array.Empty<string>()));
_output.WriteLine("Server3 ClientConnectUrls:" +
Expand Down
50 changes: 50 additions & 0 deletions tests/NATS.Client.Core.Tests/TlsFirstTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
namespace NATS.Client.Core.Tests;

public class TlsFirstTest
{
private readonly ITestOutputHelper _output;

public TlsFirstTest(ITestOutputHelper output) => _output = output;

[Fact]
public async Task Tls_first_connection()
{
if (!NatsServer.SupportsTlsFirst())
{
_output.WriteLine($"TLS first is NOT supported by the server");
return;
}

_output.WriteLine($"TLS first is supported by the server");

await using var server = NatsServer.Start(
new NullOutputHelper(),
new NatsServerOptsBuilder()
.UseTransport(TransportType.Tls, tlsFirst: true)
.Build());

var clientOpts = server.ClientOpts(NatsOpts.Default);

Assert.True(clientOpts.TlsOpts.Mode == TlsMode.Implicit);

// TLS first connection
{
await using var nats = new NatsConnection(clientOpts);
await nats.ConnectAsync();
var rtt = await nats.PingAsync();
Assert.True(rtt > TimeSpan.Zero);
_output.WriteLine($"Implicit TLS connection (RTT: {rtt})");
}

// Normal TLS connection should fail
{
await using var nats = new NatsConnection(clientOpts with { TlsOpts = clientOpts.TlsOpts with { Mode = TlsMode.Auto } });

var exception = await Assert.ThrowsAsync<NatsException>(async () => await nats.ConnectAsync());

Assert.Matches(@"can not start to connect nats server: tls://", exception.Message);

_output.WriteLine($"Auto TLS connection rejected");
}
}
}
Loading

0 comments on commit 251a13c

Please sign in to comment.