Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLS first connection #156

Merged
merged 14 commits into from
Oct 17, 2023
7 changes: 7 additions & 0 deletions NATS.Client.sln
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "NATS.Client.KeyValueStore.T
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.KeyValueStore.Watcher", "sandbox\Example.KeyValueStore.Watcher\Example.KeyValueStore.Watcher.csproj", "{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.TlsFirst", "sandbox\Example.TlsFirst\Example.TlsFirst.csproj", "{88625045-978F-417F-9F51-A4E3A9718945}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -183,6 +185,10 @@ Global
{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}.Debug|Any CPU.Build.0 = Debug|Any CPU
{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}.Release|Any CPU.ActiveCfg = Release|Any CPU
{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}.Release|Any CPU.Build.0 = Release|Any CPU
{88625045-978F-417F-9F51-A4E3A9718945}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{88625045-978F-417F-9F51-A4E3A9718945}.Debug|Any CPU.Build.0 = Debug|Any CPU
{88625045-978F-417F-9F51-A4E3A9718945}.Release|Any CPU.ActiveCfg = Release|Any CPU
{88625045-978F-417F-9F51-A4E3A9718945}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -215,6 +221,7 @@ Global
{A102AB7B-A90C-4717-B17C-045240838060} = {4827B3EC-73D8-436D-AE2A-5E29AC95FD0C}
{908F2CED-CAC0-4A4E-AD19-362A413B5DA4} = {C526E8AB-739A-48D7-8FC4-048978C9B650}
{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53} = {95A69671-16CA-4133-981C-CC381B7AAA30}
{88625045-978F-417F-9F51-A4E3A9718945} = {95A69671-16CA-4133-981C-CC381B7AAA30}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {8CBB7278-D093-448E-B3DE-B5991209A1AA}
Expand Down
15 changes: 15 additions & 0 deletions sandbox/Example.TlsFirst/Example.TlsFirst.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net6.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
mtmk marked this conversation as resolved.
Show resolved Hide resolved
<IsPackable>false</IsPackable>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\NATS.Client.Core\NATS.Client.Core.csproj" />
</ItemGroup>

</Project>
7 changes: 7 additions & 0 deletions sandbox/Example.TlsFirst/Program.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using NATS.Client.Core;

// await using var nats = new NatsConnection();
await using var nats = new NatsConnection(NatsOpts.Default with { TlsOpts = new NatsTlsOpts { Mode = TlsMode.Implicit, InsecureSkipVerify = true, } });
await nats.ConnectAsync();
var timeSpan = await nats.PingAsync();
Console.WriteLine($"{timeSpan}");
11 changes: 11 additions & 0 deletions src/NATS.Client.Core/Internal/NatsUri.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ public NatsUri(string urlString, bool isSeed, string defaultScheme = DefaultSche

public int Port => Uri.Port;

public NatsUri CloneWith(string host, int? port = default)
{
var newUri = new UriBuilder(Uri)
{
Host = host,
Port = port ?? Port,
}.Uri.ToString();

return new NatsUri(newUri, IsSeed);
}

public override string ToString()
{
return IsWebSocket && Uri.AbsolutePath != "/" ? Uri.ToString() : Uri.ToString().Trim('/');
Expand Down
21 changes: 14 additions & 7 deletions src/NATS.Client.Core/Internal/SslStreamConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,17 @@ public void SignalDisconnected(Exception exception)
_waitForClosedSource.TrySetResult(exception);
}

public async Task AuthenticateAsClientAsync(string target)
public async Task AuthenticateAsClientAsync(NatsUri uri)
{
var options = SslClientAuthenticationOptions(target);
await _sslStream.AuthenticateAsClientAsync(options).ConfigureAwait(false);
var options = SslClientAuthenticationOptions(uri);
try
{
await _sslStream.AuthenticateAsClientAsync(options).ConfigureAwait(false);
}
catch (AuthenticationException ex)
{
throw new NatsException($"TLS authentication failed", ex);
}
}

private static X509Certificate LcsCbClientCerts(
Expand Down Expand Up @@ -123,11 +130,11 @@ private bool RcsCbCaCertChain(
return sslPolicyErrors == SslPolicyErrors.None;
}

private SslClientAuthenticationOptions SslClientAuthenticationOptions(string targetHost)
private SslClientAuthenticationOptions SslClientAuthenticationOptions(NatsUri uri)
{
if (_tlsOpts.Disabled)
if (_tlsOpts.EffectiveMode(uri) == TlsMode.Disable)
{
throw new InvalidOperationException("TLS is not permitted when TlsOptions.Disabled is set");
throw new InvalidOperationException("TLS is not permitted when TlsMode is set to Disable");
}

LocalCertificateSelectionCallback? lcsCb = default;
Expand All @@ -148,7 +155,7 @@ private SslClientAuthenticationOptions SslClientAuthenticationOptions(string tar

var options = new SslClientAuthenticationOptions
{
TargetHost = targetHost,
TargetHost = uri.Host,
EnabledSslProtocols = SslProtocols.Tls12,
ClientCertificates = _tlsCerts?.ClientCerts,
LocalCertificateSelectionCallback = lcsCb,
Expand Down
2 changes: 1 addition & 1 deletion src/NATS.Client.Core/Internal/TlsCerts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ internal class TlsCerts
{
public TlsCerts(NatsTlsOpts tlsOpts)
{
if (tlsOpts.Disabled)
if (tlsOpts.Mode == TlsMode.Disable)
{
return;
}
Expand Down
87 changes: 63 additions & 24 deletions src/NATS.Client.Core/NatsConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ public partial class NatsConnection : IAsyncDisposable, INatsConnection
// when reconnect, make new instance.
private ISocketConnection? _socket;
private CancellationTokenSource? _pingTimerCancellationTokenSource;
private NatsUri? _currentConnectUri;
private NatsUri? _lastSeedConnectUri;
private volatile NatsUri? _currentConnectUri;
private volatile NatsUri? _lastSeedConnectUri;
private NatsReadProtocolProcessor? _socketReader;
private NatsPipeliningWriteProtocolProcessor? _socketWriter;
private TaskCompletionSource _waitForOpenConnection;
Expand Down Expand Up @@ -222,9 +222,14 @@ private async ValueTask InitialConnectAsync()
Debug.Assert(ConnectionState == NatsConnectionState.Connecting, "Connection state");

var uris = Opts.GetSeedUris();
if (Opts.TlsOpts.Disabled && uris.Any(u => u.IsTls))
throw new NatsException($"URI {uris.First(u => u.IsTls)} requires TLS but NatsTlsOpts.Disabled is set to true");
if (Opts.TlsOpts.Required)

foreach (var uri in uris)
{
if (Opts.TlsOpts.EffectiveMode(uri) == TlsMode.Disable && uri.IsTls)
throw new NatsException($"URI {uri} requires TLS but TlsMode is set to Disable");
}

if (Opts.TlsOpts.HasTlsFile)
_tlsCerts = new TlsCerts(Opts.TlsOpts);

if (!Opts.AuthOpts.IsAnonymous)
Expand Down Expand Up @@ -255,6 +260,14 @@ private async ValueTask InitialConnectAsync()
var conn = new TcpConnection();
await conn.ConnectAsync(target.Host, target.Port, Opts.ConnectTimeout).ConfigureAwait(false);
_socket = conn;

if (Opts.TlsOpts.EffectiveMode(uri) == TlsMode.Implicit)
{
// upgrade TcpConnection to SslConnection
var sslConnection = conn.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts);
await sslConnection.AuthenticateAsClientAsync(uri).ConfigureAwait(false);
_socket = sslConnection;
}
}

_currentConnectUri = uri;
Expand Down Expand Up @@ -331,32 +344,24 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)
// check to see if we should upgrade to TLS
if (_socket is TcpConnection tcpConnection)
{
if (Opts.TlsOpts.Disabled && WritableServerInfo!.TlsRequired)
if (Opts.TlsOpts.EffectiveMode(_currentConnectUri) == TlsMode.Disable && WritableServerInfo!.TlsRequired)
{
throw new NatsException(
$"Server {_currentConnectUri} requires TLS but NatsTlsOpts.Disabled is set to true");
$"Server {_currentConnectUri} requires TLS but TlsMode is set to Disable");
}

if (Opts.TlsOpts.Required && !WritableServerInfo!.TlsRequired && !WritableServerInfo.TlsAvailable)
if (Opts.TlsOpts.EffectiveMode(_currentConnectUri) == TlsMode.Require && !WritableServerInfo!.TlsRequired && !WritableServerInfo.TlsAvailable)
{
throw new NatsException(
$"Server {_currentConnectUri} does not support TLS but NatsTlsOpts.Disabled is set to true");
$"Server {_currentConnectUri} does not support TLS but TlsMode is set to Require");
}

if (Opts.TlsOpts.Required || WritableServerInfo!.TlsRequired || WritableServerInfo.TlsAvailable)
if (Opts.TlsOpts.TryTls(_currentConnectUri) && (WritableServerInfo!.TlsRequired || WritableServerInfo.TlsAvailable))
{
// do TLS upgrade
// if the current URI is not a seed URI and is not a DNS hostname, check the server cert against the
// last seed hostname if it was a DNS hostname
var targetHost = _currentConnectUri.Host;
if (!_currentConnectUri.IsSeed
&& Uri.CheckHostName(targetHost) != UriHostNameType.Dns
&& Uri.CheckHostName(_lastSeedConnectUri!.Host) == UriHostNameType.Dns)
{
targetHost = _lastSeedConnectUri.Host;
}
var targetUri = FixTlsHost(_currentConnectUri);

_logger.LogDebug("Perform TLS Upgrade to " + targetHost);
_logger.LogDebug("Perform TLS Upgrade to " + targetUri);

// cancel INFO parsed signal and dispose current socket reader
infoParsedSignal.SetCanceled();
Expand All @@ -365,7 +370,7 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)

// upgrade TcpConnection to SslConnection
var sslConnection = tcpConnection.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts);
await sslConnection.AuthenticateAsClientAsync(targetHost).ConfigureAwait(false);
await sslConnection.AuthenticateAsClientAsync(targetUri).ConfigureAwait(false);
_socket = sslConnection;

// create new socket reader
Expand Down Expand Up @@ -452,11 +457,17 @@ private async void ReconnectLoop()
if (urlEnumerator.MoveNext())
{
url = urlEnumerator.Current;
var target = (url.Host, url.Port);

if (OnConnectingAsync != null)
{
var target = (url.Host, url.Port);
_logger.LogInformation("Try to invoke OnConnectingAsync before connect to NATS.");
target = await OnConnectingAsync(target).ConfigureAwait(false);
var newTarget = await OnConnectingAsync(target).ConfigureAwait(false);

if (newTarget.Host != target.Host || newTarget.Port != target.Port)
{
url = url.CloneWith(newTarget.Host, newTarget.Port);
}
}

_logger.LogInformation("Try to connect NATS {0}", url);
Expand All @@ -469,8 +480,16 @@ private async void ReconnectLoop()
else
{
var conn = new TcpConnection();
await conn.ConnectAsync(target.Host, target.Port, Opts.ConnectTimeout).ConfigureAwait(false);
await conn.ConnectAsync(url.Host, url.Port, Opts.ConnectTimeout).ConfigureAwait(false);
_socket = conn;

if (Opts.TlsOpts.EffectiveMode(url) == TlsMode.Implicit)
{
// upgrade TcpConnection to SslConnection
var sslConnection = conn.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts);
await sslConnection.AuthenticateAsClientAsync(FixTlsHost(url)).ConfigureAwait(false);
_socket = sslConnection;
}
}

_currentConnectUri = url;
Expand Down Expand Up @@ -515,6 +534,26 @@ private async void ReconnectLoop()
}
}

private NatsUri FixTlsHost(NatsUri uri)
{
var lastSeedConnectUri = _lastSeedConnectUri;
var lastSeedHost = lastSeedConnectUri?.Host;

if (string.IsNullOrEmpty(lastSeedHost))
return uri;

// if the current URI is not a seed URI and is not a DNS hostname, check the server cert against the
// last seed hostname if it was a DNS hostname
if (!uri.IsSeed
&& Uri.CheckHostName(uri.Host) != UriHostNameType.Dns
&& Uri.CheckHostName(lastSeedHost) == UriHostNameType.Dns)
{
return uri.CloneWith(lastSeedHost);
}

return uri;
}

private async Task WaitWithJitterAsync()
{
var jitter = Random.Shared.NextDouble() * Opts.ReconnectJitter.TotalMilliseconds;
Expand Down
54 changes: 50 additions & 4 deletions src/NATS.Client.Core/NatsTlsOpts.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,39 @@
using NATS.Client.Core.Internal;

namespace NATS.Client.Core;

/// <summary>
/// TLS mode to use during connection.
/// </summary>
public enum TlsMode
{
/// <summary>
/// For connections that use the "nats://" scheme and don't supply Client or CA Certificates - same as <c>Prefer</c>
/// For connections that use the "tls://" scheme or supply Client or CA Certificates - same as <c>Require</c>
/// </summary>
Auto,

/// <summary>
/// if the Server supports TLS, then use it, otherwise use plain-text.
/// </summary>
Prefer,

/// <summary>
/// Forces the connection to upgrade to TLS. if the Server does not support TLS, then fail the connection.
/// </summary>
Require,

/// <summary>
/// Upgrades the connection to TLS as soon as the connection is established.
/// </summary>
Implicit,

/// <summary>
/// Disabled mode will not attempt to upgrade the connection to TLS.
/// </summary>
Disable,
}

/// <summary>
/// Immutable options for TlsOptions, you can configure via `with` operator.
/// These options are ignored in WebSocket connections
Expand All @@ -17,11 +51,23 @@ public sealed record NatsTlsOpts
/// <summary>Path to PEM-encoded X509 CA Certificate</summary>
public string? CaFile { get; init; }

/// <summary>When true, disable TLS</summary>
public bool Disabled { get; init; }

/// <summary>When true, skip remote certificate verification and accept any server certificate</summary>
public bool InsecureSkipVerify { get; init; }

internal bool Required => CertFile != default || KeyFile != default || CaFile != default;
/// <summary>TLS mode to use during connection</summary>
public TlsMode Mode { get; init; }

internal bool HasTlsFile => CertFile != default || KeyFile != default || CaFile != default;

internal TlsMode EffectiveMode(NatsUri uri) => Mode switch
{
TlsMode.Auto => HasTlsFile || uri.Uri.Scheme.ToLower() == "tls" ? TlsMode.Require : TlsMode.Prefer,
_ => Mode,
};

internal bool TryTls(NatsUri uri)
{
var effectiveMode = EffectiveMode(uri);
return effectiveMode is TlsMode.Require or TlsMode.Prefer;
}
}
13 changes: 6 additions & 7 deletions tests/NATS.Client.Core.Tests/NatsConnectionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,11 @@ await Retry.Until(
[Fact]
public async Task ReconnectSingleTest()
{
using var options = new NatsServerOpts
{
TransportType = _transportType,
EnableWebSocket = _transportType == TransportType.WebSocket,
ServerDisposeReturnsPorts = false,
};
var options = new NatsServerOptsBuilder()
.UseTransport(_transportType)
.WithServerDisposeReturnsPorts()
.Build();

await using var server = NatsServer.Start(_output, options);
var subject = Guid.NewGuid().ToString();

Expand Down Expand Up @@ -249,7 +248,7 @@ public async Task ReconnectClusterTest()
await connection3.ConnectAsync();

_output.WriteLine("Server1 ClientConnectUrls:" +
string.Join(", ", connection1.ServerInfo?.ClientConnectUrls ?? Array.Empty<string>()));
string.Join(", ", connection1.ServerInfo?.ClientConnectUrls ?? Array.Empty<string>()));
_output.WriteLine("Server2 ClientConnectUrls:" +
string.Join(", ", connection2.ServerInfo?.ClientConnectUrls ?? Array.Empty<string>()));
_output.WriteLine("Server3 ClientConnectUrls:" +
Expand Down
Loading