Skip to content

Commit

Permalink
TLS first connection (#156)
Browse files Browse the repository at this point in the history
* TLS first connection

* Update sandbox/Example.TlsFirst/Example.TlsFirst.csproj

Co-authored-by: Caleb Lloyd <[email protected]>

* Use TLS modes

* Update src/NATS.Client.Core/NatsTlsOpts.cs

Co-authored-by: Caleb Lloyd <[email protected]>

* Additional TLS mode checks

* TLS first test fix

* Fixed on connecting callback return value

* Test debug

* Fixing TLS connection test

* Fixed warnings and format

* Fix TLS hosts for implicit connections

* Don't fix TLS host on initial connection

Fixing TLS host relies on the last connected seed host
which isn't available on initial connection.

* Use local for TLS host fix

* Reverted debug value in test

---------

Co-authored-by: Caleb Lloyd <[email protected]>
  • Loading branch information
mtmk and caleblloyd authored Oct 17, 2023
1 parent 0a98860 commit 4100569
Show file tree
Hide file tree
Showing 12 changed files with 384 additions and 84 deletions.
7 changes: 7 additions & 0 deletions NATS.Client.sln
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "NATS.Client.KeyValueStore.T
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.KeyValueStore.Watcher", "sandbox\Example.KeyValueStore.Watcher\Example.KeyValueStore.Watcher.csproj", "{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.TlsFirst", "sandbox\Example.TlsFirst\Example.TlsFirst.csproj", "{88625045-978F-417F-9F51-A4E3A9718945}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -183,6 +185,10 @@ Global
{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}.Debug|Any CPU.Build.0 = Debug|Any CPU
{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}.Release|Any CPU.ActiveCfg = Release|Any CPU
{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}.Release|Any CPU.Build.0 = Release|Any CPU
{88625045-978F-417F-9F51-A4E3A9718945}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{88625045-978F-417F-9F51-A4E3A9718945}.Debug|Any CPU.Build.0 = Debug|Any CPU
{88625045-978F-417F-9F51-A4E3A9718945}.Release|Any CPU.ActiveCfg = Release|Any CPU
{88625045-978F-417F-9F51-A4E3A9718945}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -215,6 +221,7 @@ Global
{A102AB7B-A90C-4717-B17C-045240838060} = {4827B3EC-73D8-436D-AE2A-5E29AC95FD0C}
{908F2CED-CAC0-4A4E-AD19-362A413B5DA4} = {C526E8AB-739A-48D7-8FC4-048978C9B650}
{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53} = {95A69671-16CA-4133-981C-CC381B7AAA30}
{88625045-978F-417F-9F51-A4E3A9718945} = {95A69671-16CA-4133-981C-CC381B7AAA30}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {8CBB7278-D093-448E-B3DE-B5991209A1AA}
Expand Down
15 changes: 15 additions & 0 deletions sandbox/Example.TlsFirst/Example.TlsFirst.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net6.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\NATS.Client.Core\NATS.Client.Core.csproj" />
</ItemGroup>

</Project>
7 changes: 7 additions & 0 deletions sandbox/Example.TlsFirst/Program.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using NATS.Client.Core;

// await using var nats = new NatsConnection();
await using var nats = new NatsConnection(NatsOpts.Default with { TlsOpts = new NatsTlsOpts { Mode = TlsMode.Implicit, InsecureSkipVerify = true, } });
await nats.ConnectAsync();
var timeSpan = await nats.PingAsync();
Console.WriteLine($"{timeSpan}");
11 changes: 11 additions & 0 deletions src/NATS.Client.Core/Internal/NatsUri.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ public NatsUri(string urlString, bool isSeed, string defaultScheme = DefaultSche

public int Port => Uri.Port;

public NatsUri CloneWith(string host, int? port = default)
{
var newUri = new UriBuilder(Uri)
{
Host = host,
Port = port ?? Port,
}.Uri.ToString();

return new NatsUri(newUri, IsSeed);
}

public override string ToString()
{
return IsWebSocket && Uri.AbsolutePath != "/" ? Uri.ToString() : Uri.ToString().Trim('/');
Expand Down
21 changes: 14 additions & 7 deletions src/NATS.Client.Core/Internal/SslStreamConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,17 @@ public void SignalDisconnected(Exception exception)
_waitForClosedSource.TrySetResult(exception);
}

public async Task AuthenticateAsClientAsync(string target)
public async Task AuthenticateAsClientAsync(NatsUri uri)
{
var options = SslClientAuthenticationOptions(target);
await _sslStream.AuthenticateAsClientAsync(options).ConfigureAwait(false);
var options = SslClientAuthenticationOptions(uri);
try
{
await _sslStream.AuthenticateAsClientAsync(options).ConfigureAwait(false);
}
catch (AuthenticationException ex)
{
throw new NatsException($"TLS authentication failed", ex);
}
}

private static X509Certificate LcsCbClientCerts(
Expand Down Expand Up @@ -123,11 +130,11 @@ private bool RcsCbCaCertChain(
return sslPolicyErrors == SslPolicyErrors.None;
}

private SslClientAuthenticationOptions SslClientAuthenticationOptions(string targetHost)
private SslClientAuthenticationOptions SslClientAuthenticationOptions(NatsUri uri)
{
if (_tlsOpts.Disabled)
if (_tlsOpts.EffectiveMode(uri) == TlsMode.Disable)
{
throw new InvalidOperationException("TLS is not permitted when TlsOptions.Disabled is set");
throw new InvalidOperationException("TLS is not permitted when TlsMode is set to Disable");
}

LocalCertificateSelectionCallback? lcsCb = default;
Expand All @@ -148,7 +155,7 @@ private SslClientAuthenticationOptions SslClientAuthenticationOptions(string tar

var options = new SslClientAuthenticationOptions
{
TargetHost = targetHost,
TargetHost = uri.Host,
EnabledSslProtocols = SslProtocols.Tls12,
ClientCertificates = _tlsCerts?.ClientCerts,
LocalCertificateSelectionCallback = lcsCb,
Expand Down
2 changes: 1 addition & 1 deletion src/NATS.Client.Core/Internal/TlsCerts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ internal class TlsCerts
{
public TlsCerts(NatsTlsOpts tlsOpts)
{
if (tlsOpts.Disabled)
if (tlsOpts.Mode == TlsMode.Disable)
{
return;
}
Expand Down
87 changes: 63 additions & 24 deletions src/NATS.Client.Core/NatsConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ public partial class NatsConnection : IAsyncDisposable, INatsConnection
// when reconnect, make new instance.
private ISocketConnection? _socket;
private CancellationTokenSource? _pingTimerCancellationTokenSource;
private NatsUri? _currentConnectUri;
private NatsUri? _lastSeedConnectUri;
private volatile NatsUri? _currentConnectUri;
private volatile NatsUri? _lastSeedConnectUri;
private NatsReadProtocolProcessor? _socketReader;
private NatsPipeliningWriteProtocolProcessor? _socketWriter;
private TaskCompletionSource _waitForOpenConnection;
Expand Down Expand Up @@ -222,9 +222,14 @@ private async ValueTask InitialConnectAsync()
Debug.Assert(ConnectionState == NatsConnectionState.Connecting, "Connection state");

var uris = Opts.GetSeedUris();
if (Opts.TlsOpts.Disabled && uris.Any(u => u.IsTls))
throw new NatsException($"URI {uris.First(u => u.IsTls)} requires TLS but NatsTlsOpts.Disabled is set to true");
if (Opts.TlsOpts.Required)

foreach (var uri in uris)
{
if (Opts.TlsOpts.EffectiveMode(uri) == TlsMode.Disable && uri.IsTls)
throw new NatsException($"URI {uri} requires TLS but TlsMode is set to Disable");
}

if (Opts.TlsOpts.HasTlsFile)
_tlsCerts = new TlsCerts(Opts.TlsOpts);

if (!Opts.AuthOpts.IsAnonymous)
Expand Down Expand Up @@ -255,6 +260,14 @@ private async ValueTask InitialConnectAsync()
var conn = new TcpConnection();
await conn.ConnectAsync(target.Host, target.Port, Opts.ConnectTimeout).ConfigureAwait(false);
_socket = conn;

if (Opts.TlsOpts.EffectiveMode(uri) == TlsMode.Implicit)
{
// upgrade TcpConnection to SslConnection
var sslConnection = conn.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts);
await sslConnection.AuthenticateAsClientAsync(uri).ConfigureAwait(false);
_socket = sslConnection;
}
}

_currentConnectUri = uri;
Expand Down Expand Up @@ -331,32 +344,24 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)
// check to see if we should upgrade to TLS
if (_socket is TcpConnection tcpConnection)
{
if (Opts.TlsOpts.Disabled && WritableServerInfo!.TlsRequired)
if (Opts.TlsOpts.EffectiveMode(_currentConnectUri) == TlsMode.Disable && WritableServerInfo!.TlsRequired)
{
throw new NatsException(
$"Server {_currentConnectUri} requires TLS but NatsTlsOpts.Disabled is set to true");
$"Server {_currentConnectUri} requires TLS but TlsMode is set to Disable");
}

if (Opts.TlsOpts.Required && !WritableServerInfo!.TlsRequired && !WritableServerInfo.TlsAvailable)
if (Opts.TlsOpts.EffectiveMode(_currentConnectUri) == TlsMode.Require && !WritableServerInfo!.TlsRequired && !WritableServerInfo.TlsAvailable)
{
throw new NatsException(
$"Server {_currentConnectUri} does not support TLS but NatsTlsOpts.Disabled is set to true");
$"Server {_currentConnectUri} does not support TLS but TlsMode is set to Require");
}

if (Opts.TlsOpts.Required || WritableServerInfo!.TlsRequired || WritableServerInfo.TlsAvailable)
if (Opts.TlsOpts.TryTls(_currentConnectUri) && (WritableServerInfo!.TlsRequired || WritableServerInfo.TlsAvailable))
{
// do TLS upgrade
// if the current URI is not a seed URI and is not a DNS hostname, check the server cert against the
// last seed hostname if it was a DNS hostname
var targetHost = _currentConnectUri.Host;
if (!_currentConnectUri.IsSeed
&& Uri.CheckHostName(targetHost) != UriHostNameType.Dns
&& Uri.CheckHostName(_lastSeedConnectUri!.Host) == UriHostNameType.Dns)
{
targetHost = _lastSeedConnectUri.Host;
}
var targetUri = FixTlsHost(_currentConnectUri);

_logger.LogDebug("Perform TLS Upgrade to " + targetHost);
_logger.LogDebug("Perform TLS Upgrade to " + targetUri);

// cancel INFO parsed signal and dispose current socket reader
infoParsedSignal.SetCanceled();
Expand All @@ -365,7 +370,7 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)

// upgrade TcpConnection to SslConnection
var sslConnection = tcpConnection.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts);
await sslConnection.AuthenticateAsClientAsync(targetHost).ConfigureAwait(false);
await sslConnection.AuthenticateAsClientAsync(targetUri).ConfigureAwait(false);
_socket = sslConnection;

// create new socket reader
Expand Down Expand Up @@ -452,11 +457,17 @@ private async void ReconnectLoop()
if (urlEnumerator.MoveNext())
{
url = urlEnumerator.Current;
var target = (url.Host, url.Port);

if (OnConnectingAsync != null)
{
var target = (url.Host, url.Port);
_logger.LogInformation("Try to invoke OnConnectingAsync before connect to NATS.");
target = await OnConnectingAsync(target).ConfigureAwait(false);
var newTarget = await OnConnectingAsync(target).ConfigureAwait(false);

if (newTarget.Host != target.Host || newTarget.Port != target.Port)
{
url = url.CloneWith(newTarget.Host, newTarget.Port);
}
}

_logger.LogInformation("Try to connect NATS {0}", url);
Expand All @@ -469,8 +480,16 @@ private async void ReconnectLoop()
else
{
var conn = new TcpConnection();
await conn.ConnectAsync(target.Host, target.Port, Opts.ConnectTimeout).ConfigureAwait(false);
await conn.ConnectAsync(url.Host, url.Port, Opts.ConnectTimeout).ConfigureAwait(false);
_socket = conn;

if (Opts.TlsOpts.EffectiveMode(url) == TlsMode.Implicit)
{
// upgrade TcpConnection to SslConnection
var sslConnection = conn.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts);
await sslConnection.AuthenticateAsClientAsync(FixTlsHost(url)).ConfigureAwait(false);
_socket = sslConnection;
}
}

_currentConnectUri = url;
Expand Down Expand Up @@ -515,6 +534,26 @@ private async void ReconnectLoop()
}
}

private NatsUri FixTlsHost(NatsUri uri)
{
var lastSeedConnectUri = _lastSeedConnectUri;
var lastSeedHost = lastSeedConnectUri?.Host;

if (string.IsNullOrEmpty(lastSeedHost))
return uri;

// if the current URI is not a seed URI and is not a DNS hostname, check the server cert against the
// last seed hostname if it was a DNS hostname
if (!uri.IsSeed
&& Uri.CheckHostName(uri.Host) != UriHostNameType.Dns
&& Uri.CheckHostName(lastSeedHost) == UriHostNameType.Dns)
{
return uri.CloneWith(lastSeedHost);
}

return uri;
}

private async Task WaitWithJitterAsync()
{
var jitter = Random.Shared.NextDouble() * Opts.ReconnectJitter.TotalMilliseconds;
Expand Down
54 changes: 50 additions & 4 deletions src/NATS.Client.Core/NatsTlsOpts.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,39 @@
using NATS.Client.Core.Internal;

namespace NATS.Client.Core;

/// <summary>
/// TLS mode to use during connection.
/// </summary>
public enum TlsMode
{
/// <summary>
/// For connections that use the "nats://" scheme and don't supply Client or CA Certificates - same as <c>Prefer</c>
/// For connections that use the "tls://" scheme or supply Client or CA Certificates - same as <c>Require</c>
/// </summary>
Auto,

/// <summary>
/// if the Server supports TLS, then use it, otherwise use plain-text.
/// </summary>
Prefer,

/// <summary>
/// Forces the connection to upgrade to TLS. if the Server does not support TLS, then fail the connection.
/// </summary>
Require,

/// <summary>
/// Upgrades the connection to TLS as soon as the connection is established.
/// </summary>
Implicit,

/// <summary>
/// Disabled mode will not attempt to upgrade the connection to TLS.
/// </summary>
Disable,
}

/// <summary>
/// Immutable options for TlsOptions, you can configure via `with` operator.
/// These options are ignored in WebSocket connections
Expand All @@ -17,11 +51,23 @@ public sealed record NatsTlsOpts
/// <summary>Path to PEM-encoded X509 CA Certificate</summary>
public string? CaFile { get; init; }

/// <summary>When true, disable TLS</summary>
public bool Disabled { get; init; }

/// <summary>When true, skip remote certificate verification and accept any server certificate</summary>
public bool InsecureSkipVerify { get; init; }

internal bool Required => CertFile != default || KeyFile != default || CaFile != default;
/// <summary>TLS mode to use during connection</summary>
public TlsMode Mode { get; init; }

internal bool HasTlsFile => CertFile != default || KeyFile != default || CaFile != default;

internal TlsMode EffectiveMode(NatsUri uri) => Mode switch
{
TlsMode.Auto => HasTlsFile || uri.Uri.Scheme.ToLower() == "tls" ? TlsMode.Require : TlsMode.Prefer,
_ => Mode,
};

internal bool TryTls(NatsUri uri)
{
var effectiveMode = EffectiveMode(uri);
return effectiveMode is TlsMode.Require or TlsMode.Prefer;
}
}
13 changes: 6 additions & 7 deletions tests/NATS.Client.Core.Tests/NatsConnectionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,11 @@ await Retry.Until(
[Fact]
public async Task ReconnectSingleTest()
{
using var options = new NatsServerOpts
{
TransportType = _transportType,
EnableWebSocket = _transportType == TransportType.WebSocket,
ServerDisposeReturnsPorts = false,
};
var options = new NatsServerOptsBuilder()
.UseTransport(_transportType)
.WithServerDisposeReturnsPorts()
.Build();

await using var server = NatsServer.Start(_output, options);
var subject = Guid.NewGuid().ToString();

Expand Down Expand Up @@ -249,7 +248,7 @@ public async Task ReconnectClusterTest()
await connection3.ConnectAsync();

_output.WriteLine("Server1 ClientConnectUrls:" +
string.Join(", ", connection1.ServerInfo?.ClientConnectUrls ?? Array.Empty<string>()));
string.Join(", ", connection1.ServerInfo?.ClientConnectUrls ?? Array.Empty<string>()));
_output.WriteLine("Server2 ClientConnectUrls:" +
string.Join(", ", connection2.ServerInfo?.ClientConnectUrls ?? Array.Empty<string>()));
_output.WriteLine("Server3 ClientConnectUrls:" +
Expand Down
Loading

0 comments on commit 4100569

Please sign in to comment.