diff --git a/NATS.Client.sln b/NATS.Client.sln index 2f0aff886..cf2ced273 100644 --- a/NATS.Client.sln +++ b/NATS.Client.sln @@ -73,6 +73,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "NATS.Client.KeyValueStore.T EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.KeyValueStore.Watcher", "sandbox\Example.KeyValueStore.Watcher\Example.KeyValueStore.Watcher.csproj", "{912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Example.TlsFirst", "sandbox\Example.TlsFirst\Example.TlsFirst.csproj", "{88625045-978F-417F-9F51-A4E3A9718945}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -183,6 +185,10 @@ Global {912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}.Debug|Any CPU.Build.0 = Debug|Any CPU {912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}.Release|Any CPU.ActiveCfg = Release|Any CPU {912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53}.Release|Any CPU.Build.0 = Release|Any CPU + {88625045-978F-417F-9F51-A4E3A9718945}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {88625045-978F-417F-9F51-A4E3A9718945}.Debug|Any CPU.Build.0 = Debug|Any CPU + {88625045-978F-417F-9F51-A4E3A9718945}.Release|Any CPU.ActiveCfg = Release|Any CPU + {88625045-978F-417F-9F51-A4E3A9718945}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -215,6 +221,7 @@ Global {A102AB7B-A90C-4717-B17C-045240838060} = {4827B3EC-73D8-436D-AE2A-5E29AC95FD0C} {908F2CED-CAC0-4A4E-AD19-362A413B5DA4} = {C526E8AB-739A-48D7-8FC4-048978C9B650} {912A4F2F-1BD1-4AE2-BAB8-5A49C221DB53} = {95A69671-16CA-4133-981C-CC381B7AAA30} + {88625045-978F-417F-9F51-A4E3A9718945} = {95A69671-16CA-4133-981C-CC381B7AAA30} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {8CBB7278-D093-448E-B3DE-B5991209A1AA} diff --git a/sandbox/Example.TlsFirst/Example.TlsFirst.csproj b/sandbox/Example.TlsFirst/Example.TlsFirst.csproj new file mode 100644 index 000000000..e4a0b5983 --- /dev/null +++ b/sandbox/Example.TlsFirst/Example.TlsFirst.csproj @@ -0,0 +1,15 @@ + + + + Exe + net6.0 + enable + enable + false + + + + + + + diff --git a/sandbox/Example.TlsFirst/Program.cs b/sandbox/Example.TlsFirst/Program.cs new file mode 100644 index 000000000..8a4305799 --- /dev/null +++ b/sandbox/Example.TlsFirst/Program.cs @@ -0,0 +1,7 @@ +using NATS.Client.Core; + +// await using var nats = new NatsConnection(); +await using var nats = new NatsConnection(NatsOpts.Default with { TlsOpts = new NatsTlsOpts { Mode = TlsMode.Implicit, InsecureSkipVerify = true, } }); +await nats.ConnectAsync(); +var timeSpan = await nats.PingAsync(); +Console.WriteLine($"{timeSpan}"); diff --git a/src/NATS.Client.Core/Internal/NatsUri.cs b/src/NATS.Client.Core/Internal/NatsUri.cs index f792b8089..64d3ea4c2 100644 --- a/src/NATS.Client.Core/Internal/NatsUri.cs +++ b/src/NATS.Client.Core/Internal/NatsUri.cs @@ -52,6 +52,17 @@ public NatsUri(string urlString, bool isSeed, string defaultScheme = DefaultSche public int Port => Uri.Port; + public NatsUri CloneWith(string host, int? port = default) + { + var newUri = new UriBuilder(Uri) + { + Host = host, + Port = port ?? Port, + }.Uri.ToString(); + + return new NatsUri(newUri, IsSeed); + } + public override string ToString() { return IsWebSocket && Uri.AbsolutePath != "/" ? Uri.ToString() : Uri.ToString().Trim('/'); diff --git a/src/NATS.Client.Core/Internal/SslStreamConnection.cs b/src/NATS.Client.Core/Internal/SslStreamConnection.cs index b4ab8e0bf..71016ed3e 100644 --- a/src/NATS.Client.Core/Internal/SslStreamConnection.cs +++ b/src/NATS.Client.Core/Internal/SslStreamConnection.cs @@ -69,10 +69,17 @@ public void SignalDisconnected(Exception exception) _waitForClosedSource.TrySetResult(exception); } - public async Task AuthenticateAsClientAsync(string target) + public async Task AuthenticateAsClientAsync(NatsUri uri) { - var options = SslClientAuthenticationOptions(target); - await _sslStream.AuthenticateAsClientAsync(options).ConfigureAwait(false); + var options = SslClientAuthenticationOptions(uri); + try + { + await _sslStream.AuthenticateAsClientAsync(options).ConfigureAwait(false); + } + catch (AuthenticationException ex) + { + throw new NatsException($"TLS authentication failed", ex); + } } private static X509Certificate LcsCbClientCerts( @@ -123,11 +130,11 @@ private bool RcsCbCaCertChain( return sslPolicyErrors == SslPolicyErrors.None; } - private SslClientAuthenticationOptions SslClientAuthenticationOptions(string targetHost) + private SslClientAuthenticationOptions SslClientAuthenticationOptions(NatsUri uri) { - if (_tlsOpts.Disabled) + if (_tlsOpts.EffectiveMode(uri) == TlsMode.Disable) { - throw new InvalidOperationException("TLS is not permitted when TlsOptions.Disabled is set"); + throw new InvalidOperationException("TLS is not permitted when TlsMode is set to Disable"); } LocalCertificateSelectionCallback? lcsCb = default; @@ -148,7 +155,7 @@ private SslClientAuthenticationOptions SslClientAuthenticationOptions(string tar var options = new SslClientAuthenticationOptions { - TargetHost = targetHost, + TargetHost = uri.Host, EnabledSslProtocols = SslProtocols.Tls12, ClientCertificates = _tlsCerts?.ClientCerts, LocalCertificateSelectionCallback = lcsCb, diff --git a/src/NATS.Client.Core/Internal/TlsCerts.cs b/src/NATS.Client.Core/Internal/TlsCerts.cs index b9024c249..65e41b00a 100644 --- a/src/NATS.Client.Core/Internal/TlsCerts.cs +++ b/src/NATS.Client.Core/Internal/TlsCerts.cs @@ -6,7 +6,7 @@ internal class TlsCerts { public TlsCerts(NatsTlsOpts tlsOpts) { - if (tlsOpts.Disabled) + if (tlsOpts.Mode == TlsMode.Disable) { return; } diff --git a/src/NATS.Client.Core/NatsConnection.cs b/src/NATS.Client.Core/NatsConnection.cs index 7e036dbae..dce8b9d88 100644 --- a/src/NATS.Client.Core/NatsConnection.cs +++ b/src/NATS.Client.Core/NatsConnection.cs @@ -42,8 +42,8 @@ public partial class NatsConnection : IAsyncDisposable, INatsConnection // when reconnect, make new instance. private ISocketConnection? _socket; private CancellationTokenSource? _pingTimerCancellationTokenSource; - private NatsUri? _currentConnectUri; - private NatsUri? _lastSeedConnectUri; + private volatile NatsUri? _currentConnectUri; + private volatile NatsUri? _lastSeedConnectUri; private NatsReadProtocolProcessor? _socketReader; private NatsPipeliningWriteProtocolProcessor? _socketWriter; private TaskCompletionSource _waitForOpenConnection; @@ -222,9 +222,14 @@ private async ValueTask InitialConnectAsync() Debug.Assert(ConnectionState == NatsConnectionState.Connecting, "Connection state"); var uris = Opts.GetSeedUris(); - if (Opts.TlsOpts.Disabled && uris.Any(u => u.IsTls)) - throw new NatsException($"URI {uris.First(u => u.IsTls)} requires TLS but NatsTlsOpts.Disabled is set to true"); - if (Opts.TlsOpts.Required) + + foreach (var uri in uris) + { + if (Opts.TlsOpts.EffectiveMode(uri) == TlsMode.Disable && uri.IsTls) + throw new NatsException($"URI {uri} requires TLS but TlsMode is set to Disable"); + } + + if (Opts.TlsOpts.HasTlsFile) _tlsCerts = new TlsCerts(Opts.TlsOpts); if (!Opts.AuthOpts.IsAnonymous) @@ -255,6 +260,14 @@ private async ValueTask InitialConnectAsync() var conn = new TcpConnection(); await conn.ConnectAsync(target.Host, target.Port, Opts.ConnectTimeout).ConfigureAwait(false); _socket = conn; + + if (Opts.TlsOpts.EffectiveMode(uri) == TlsMode.Implicit) + { + // upgrade TcpConnection to SslConnection + var sslConnection = conn.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts); + await sslConnection.AuthenticateAsClientAsync(uri).ConfigureAwait(false); + _socket = sslConnection; + } } _currentConnectUri = uri; @@ -331,32 +344,24 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect) // check to see if we should upgrade to TLS if (_socket is TcpConnection tcpConnection) { - if (Opts.TlsOpts.Disabled && WritableServerInfo!.TlsRequired) + if (Opts.TlsOpts.EffectiveMode(_currentConnectUri) == TlsMode.Disable && WritableServerInfo!.TlsRequired) { throw new NatsException( - $"Server {_currentConnectUri} requires TLS but NatsTlsOpts.Disabled is set to true"); + $"Server {_currentConnectUri} requires TLS but TlsMode is set to Disable"); } - if (Opts.TlsOpts.Required && !WritableServerInfo!.TlsRequired && !WritableServerInfo.TlsAvailable) + if (Opts.TlsOpts.EffectiveMode(_currentConnectUri) == TlsMode.Require && !WritableServerInfo!.TlsRequired && !WritableServerInfo.TlsAvailable) { throw new NatsException( - $"Server {_currentConnectUri} does not support TLS but NatsTlsOpts.Disabled is set to true"); + $"Server {_currentConnectUri} does not support TLS but TlsMode is set to Require"); } - if (Opts.TlsOpts.Required || WritableServerInfo!.TlsRequired || WritableServerInfo.TlsAvailable) + if (Opts.TlsOpts.TryTls(_currentConnectUri) && (WritableServerInfo!.TlsRequired || WritableServerInfo.TlsAvailable)) { // do TLS upgrade - // if the current URI is not a seed URI and is not a DNS hostname, check the server cert against the - // last seed hostname if it was a DNS hostname - var targetHost = _currentConnectUri.Host; - if (!_currentConnectUri.IsSeed - && Uri.CheckHostName(targetHost) != UriHostNameType.Dns - && Uri.CheckHostName(_lastSeedConnectUri!.Host) == UriHostNameType.Dns) - { - targetHost = _lastSeedConnectUri.Host; - } + var targetUri = FixTlsHost(_currentConnectUri); - _logger.LogDebug("Perform TLS Upgrade to " + targetHost); + _logger.LogDebug("Perform TLS Upgrade to " + targetUri); // cancel INFO parsed signal and dispose current socket reader infoParsedSignal.SetCanceled(); @@ -365,7 +370,7 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect) // upgrade TcpConnection to SslConnection var sslConnection = tcpConnection.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts); - await sslConnection.AuthenticateAsClientAsync(targetHost).ConfigureAwait(false); + await sslConnection.AuthenticateAsClientAsync(targetUri).ConfigureAwait(false); _socket = sslConnection; // create new socket reader @@ -452,11 +457,17 @@ private async void ReconnectLoop() if (urlEnumerator.MoveNext()) { url = urlEnumerator.Current; - var target = (url.Host, url.Port); + if (OnConnectingAsync != null) { + var target = (url.Host, url.Port); _logger.LogInformation("Try to invoke OnConnectingAsync before connect to NATS."); - target = await OnConnectingAsync(target).ConfigureAwait(false); + var newTarget = await OnConnectingAsync(target).ConfigureAwait(false); + + if (newTarget.Host != target.Host || newTarget.Port != target.Port) + { + url = url.CloneWith(newTarget.Host, newTarget.Port); + } } _logger.LogInformation("Try to connect NATS {0}", url); @@ -469,8 +480,16 @@ private async void ReconnectLoop() else { var conn = new TcpConnection(); - await conn.ConnectAsync(target.Host, target.Port, Opts.ConnectTimeout).ConfigureAwait(false); + await conn.ConnectAsync(url.Host, url.Port, Opts.ConnectTimeout).ConfigureAwait(false); _socket = conn; + + if (Opts.TlsOpts.EffectiveMode(url) == TlsMode.Implicit) + { + // upgrade TcpConnection to SslConnection + var sslConnection = conn.UpgradeToSslStreamConnection(Opts.TlsOpts, _tlsCerts); + await sslConnection.AuthenticateAsClientAsync(FixTlsHost(url)).ConfigureAwait(false); + _socket = sslConnection; + } } _currentConnectUri = url; @@ -515,6 +534,26 @@ private async void ReconnectLoop() } } + private NatsUri FixTlsHost(NatsUri uri) + { + var lastSeedConnectUri = _lastSeedConnectUri; + var lastSeedHost = lastSeedConnectUri?.Host; + + if (string.IsNullOrEmpty(lastSeedHost)) + return uri; + + // if the current URI is not a seed URI and is not a DNS hostname, check the server cert against the + // last seed hostname if it was a DNS hostname + if (!uri.IsSeed + && Uri.CheckHostName(uri.Host) != UriHostNameType.Dns + && Uri.CheckHostName(lastSeedHost) == UriHostNameType.Dns) + { + return uri.CloneWith(lastSeedHost); + } + + return uri; + } + private async Task WaitWithJitterAsync() { var jitter = Random.Shared.NextDouble() * Opts.ReconnectJitter.TotalMilliseconds; diff --git a/src/NATS.Client.Core/NatsTlsOpts.cs b/src/NATS.Client.Core/NatsTlsOpts.cs index f23a9fc18..6a386cb9f 100644 --- a/src/NATS.Client.Core/NatsTlsOpts.cs +++ b/src/NATS.Client.Core/NatsTlsOpts.cs @@ -1,5 +1,39 @@ +using NATS.Client.Core.Internal; + namespace NATS.Client.Core; +/// +/// TLS mode to use during connection. +/// +public enum TlsMode +{ + /// + /// For connections that use the "nats://" scheme and don't supply Client or CA Certificates - same as Prefer + /// For connections that use the "tls://" scheme or supply Client or CA Certificates - same as Require + /// + Auto, + + /// + /// if the Server supports TLS, then use it, otherwise use plain-text. + /// + Prefer, + + /// + /// Forces the connection to upgrade to TLS. if the Server does not support TLS, then fail the connection. + /// + Require, + + /// + /// Upgrades the connection to TLS as soon as the connection is established. + /// + Implicit, + + /// + /// Disabled mode will not attempt to upgrade the connection to TLS. + /// + Disable, +} + /// /// Immutable options for TlsOptions, you can configure via `with` operator. /// These options are ignored in WebSocket connections @@ -17,11 +51,23 @@ public sealed record NatsTlsOpts /// Path to PEM-encoded X509 CA Certificate public string? CaFile { get; init; } - /// When true, disable TLS - public bool Disabled { get; init; } - /// When true, skip remote certificate verification and accept any server certificate public bool InsecureSkipVerify { get; init; } - internal bool Required => CertFile != default || KeyFile != default || CaFile != default; + /// TLS mode to use during connection + public TlsMode Mode { get; init; } + + internal bool HasTlsFile => CertFile != default || KeyFile != default || CaFile != default; + + internal TlsMode EffectiveMode(NatsUri uri) => Mode switch + { + TlsMode.Auto => HasTlsFile || uri.Uri.Scheme.ToLower() == "tls" ? TlsMode.Require : TlsMode.Prefer, + _ => Mode, + }; + + internal bool TryTls(NatsUri uri) + { + var effectiveMode = EffectiveMode(uri); + return effectiveMode is TlsMode.Require or TlsMode.Prefer; + } } diff --git a/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs b/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs index 98c1407d7..b31aeee8b 100644 --- a/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs +++ b/tests/NATS.Client.Core.Tests/NatsConnectionTest.cs @@ -150,12 +150,11 @@ await Retry.Until( [Fact] public async Task ReconnectSingleTest() { - using var options = new NatsServerOpts - { - TransportType = _transportType, - EnableWebSocket = _transportType == TransportType.WebSocket, - ServerDisposeReturnsPorts = false, - }; + var options = new NatsServerOptsBuilder() + .UseTransport(_transportType) + .WithServerDisposeReturnsPorts() + .Build(); + await using var server = NatsServer.Start(_output, options); var subject = Guid.NewGuid().ToString(); @@ -249,7 +248,7 @@ public async Task ReconnectClusterTest() await connection3.ConnectAsync(); _output.WriteLine("Server1 ClientConnectUrls:" + - string.Join(", ", connection1.ServerInfo?.ClientConnectUrls ?? Array.Empty())); + string.Join(", ", connection1.ServerInfo?.ClientConnectUrls ?? Array.Empty())); _output.WriteLine("Server2 ClientConnectUrls:" + string.Join(", ", connection2.ServerInfo?.ClientConnectUrls ?? Array.Empty())); _output.WriteLine("Server3 ClientConnectUrls:" + diff --git a/tests/NATS.Client.Core.Tests/TlsFirstTest.cs b/tests/NATS.Client.Core.Tests/TlsFirstTest.cs new file mode 100644 index 000000000..605c30d84 --- /dev/null +++ b/tests/NATS.Client.Core.Tests/TlsFirstTest.cs @@ -0,0 +1,50 @@ +namespace NATS.Client.Core.Tests; + +public class TlsFirstTest +{ + private readonly ITestOutputHelper _output; + + public TlsFirstTest(ITestOutputHelper output) => _output = output; + + [Fact] + public async Task Tls_first_connection() + { + if (!NatsServer.SupportsTlsFirst()) + { + _output.WriteLine($"TLS first is NOT supported by the server"); + return; + } + + _output.WriteLine($"TLS first is supported by the server"); + + await using var server = NatsServer.Start( + new NullOutputHelper(), + new NatsServerOptsBuilder() + .UseTransport(TransportType.Tls, tlsFirst: true) + .Build()); + + var clientOpts = server.ClientOpts(NatsOpts.Default); + + Assert.True(clientOpts.TlsOpts.Mode == TlsMode.Implicit); + + // TLS first connection + { + await using var nats = new NatsConnection(clientOpts); + await nats.ConnectAsync(); + var rtt = await nats.PingAsync(); + Assert.True(rtt > TimeSpan.Zero); + _output.WriteLine($"Implicit TLS connection (RTT: {rtt})"); + } + + // Normal TLS connection should fail + { + await using var nats = new NatsConnection(clientOpts with { TlsOpts = clientOpts.TlsOpts with { Mode = TlsMode.Auto } }); + + var exception = await Assert.ThrowsAsync(async () => await nats.ConnectAsync()); + + Assert.Matches(@"can not start to connect nats server: tls://", exception.Message); + + _output.WriteLine($"Auto TLS connection rejected"); + } + } +} diff --git a/tests/NATS.Client.TestUtilities/NatsServer.cs b/tests/NATS.Client.TestUtilities/NatsServer.cs index 1805bd67b..dff9f1115 100644 --- a/tests/NATS.Client.TestUtilities/NatsServer.cs +++ b/tests/NATS.Client.TestUtilities/NatsServer.cs @@ -67,10 +67,7 @@ private NatsServer(ITestOutputHelper outputHelper, NatsServerOpts opts) opts.JetStreamStoreDir = _jetStreamStoreDir; } - _configFileName = Path.GetTempFileName(); - var config = opts.ConfigFileContents; - File.WriteAllText(_configFileName, config); - var cmd = $"{NatsServerPath} -c {_configFileName}"; + (_configFileName, var config, var cmd) = GetCmd(opts); outputHelper.WriteLine("ProcessStart: " + cmd + Environment.NewLine + config); var (p, stdout, stderr) = ProcessX.GetDualAsyncEnumerable(cmd); @@ -159,6 +156,62 @@ public static NatsServer StartJS(ITestOutputHelper outputHelper, TransportType t public static NatsServer Start() => Start(new NullOutputHelper(), TransportType.Tcp); + public static bool SupportsTlsFirst() + { + var (configFileName, _, _) = GetCmd(new NatsServerOptsBuilder().UseTransport(TransportType.Tls, tlsFirst: true).Build()); + + Process? process = null; + try + { + process = new Process + { + StartInfo = new ProcessStartInfo + { + FileName = NatsServerPath, + Arguments = $"-c \"{configFileName}\"", + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + }, + }; + + var mre = new ManualResetEventSlim(); + var matched = 0; + DataReceivedEventHandler? handler = (_, e) => + { + if (e.Data != null) + { + if (Regex.IsMatch(e.Data, @"(?:\[INF\] Server is ready|error parsing)")) + { + mre.Set(); + } + + if (Regex.IsMatch(e.Data, @"Clients that are not using ""TLS Handshake First"" option will fail to connect")) + { + Interlocked.Increment(ref matched); + } + } + }; + process.OutputDataReceived += handler; + process.ErrorDataReceived += handler; + + process.Start(); + process.BeginOutputReadLine(); + process.BeginErrorReadLine(); + + if (!mre.Wait(10_000)) + { + throw new Exception("Can't start nats-server"); + } + + return Volatile.Read(ref matched) > 0; + } + finally + { + process?.Kill(); + } + } + public static NatsServer StartWithTrace(ITestOutputHelper outputHelper) => Start( outputHelper, @@ -309,6 +362,18 @@ public NatsConnectionPool CreatePooledClientConnection(NatsOpts opts) public NatsOpts ClientOpts(NatsOpts opts) { + var tls = opts.TlsOpts ?? NatsTlsOpts.Default; + + var natsTlsOpts = Opts.EnableTls + ? tls with + { + CertFile = Opts.TlsClientCertFile, + KeyFile = Opts.TlsClientKeyFile, + CaFile = Opts.TlsCaFile, + Mode = Opts.TlsFirst ? TlsMode.Implicit : TlsMode.Auto, + } + : NatsTlsOpts.Default; + return opts with { LoggerFactory = _loggerFactory, @@ -316,14 +381,7 @@ public NatsOpts ClientOpts(NatsOpts opts) // ConnectTimeout = TimeSpan.FromSeconds(1), // ReconnectWait = TimeSpan.Zero, // ReconnectJitter = TimeSpan.Zero, - TlsOpts = Opts.EnableTls - ? NatsTlsOpts.Default with - { - CertFile = Opts.TlsClientCertFile, - KeyFile = Opts.TlsClientKeyFile, - CaFile = Opts.TlsCaFile, - } - : NatsTlsOpts.Default, + TlsOpts = natsTlsOpts, Url = ClientUrl, }; } @@ -354,6 +412,18 @@ public void LogMessage(string categoryName, LogLevel logLevel, EventId e } } + private static (string configFileName, string config, string cmd) GetCmd(NatsServerOpts opts) + { + var configFileName = Path.GetTempFileName(); + + var config = opts.ConfigFileContents; + File.WriteAllText(configFileName, config); + + var cmd = $"{NatsServerPath} -c {configFileName}"; + + return (configFileName, config, cmd); + } + private async Task EnumerateWithLogsAsync(ProcessAsyncEnumerable enumerable, CancellationToken cancellationToken) { var l = new List(); @@ -384,25 +454,45 @@ public class NatsCluster : IAsyncDisposable { public NatsCluster(ITestOutputHelper outputHelper, TransportType transportType) { - var opts1 = new NatsServerOpts + var opts1 = new NatsServerOptsBuilder() + .UseTransport(transportType) + .EnableClustering() + .Build(); + + var opts2 = new NatsServerOptsBuilder() + .UseTransport(transportType) + .EnableClustering() + .Build(); + + var opts3 = new NatsServerOptsBuilder() + .UseTransport(transportType) + .EnableClustering() + .Build(); + + // By querying the ports we set the values lazily on all the opts. + outputHelper.WriteLine($"opts1.ServerPort={opts1.ServerPort}"); + outputHelper.WriteLine($"opts1.ClusteringPort={opts1.ClusteringPort}"); + if (opts1.EnableWebSocket) { - TransportType = transportType, - EnableWebSocket = transportType == TransportType.WebSocket, - EnableClustering = true, - }; - var opts2 = new NatsServerOpts + outputHelper.WriteLine($"opts1.WebSocketPort={opts1.WebSocketPort}"); + } + + outputHelper.WriteLine($"opts2.ServerPort={opts2.ServerPort}"); + outputHelper.WriteLine($"opts2.ClusteringPort={opts2.ClusteringPort}"); + if (opts2.EnableWebSocket) { - TransportType = transportType, - EnableWebSocket = transportType == TransportType.WebSocket, - EnableClustering = true, - }; - var opts3 = new NatsServerOpts + outputHelper.WriteLine($"opts2.WebSocketPort={opts2.WebSocketPort}"); + } + + outputHelper.WriteLine($"opts3.ServerPort={opts3.ServerPort}"); + outputHelper.WriteLine($"opts3.ClusteringPort={opts3.ClusteringPort}"); + if (opts3.EnableWebSocket) { - TransportType = transportType, - EnableWebSocket = transportType == TransportType.WebSocket, - EnableClustering = true, - }; + outputHelper.WriteLine($"opts3.WebSocketPort={opts3.WebSocketPort}"); + } + var routes = new[] { opts1, opts2, opts3 }; + foreach (var opt in routes) { opt.SetRoutes(routes); diff --git a/tests/NATS.Client.TestUtilities/NatsServerOpts.cs b/tests/NATS.Client.TestUtilities/NatsServerOpts.cs index 7d16e1b72..913186cae 100644 --- a/tests/NATS.Client.TestUtilities/NatsServerOpts.cs +++ b/tests/NATS.Client.TestUtilities/NatsServerOpts.cs @@ -16,27 +16,42 @@ public sealed class NatsServerOptsBuilder private readonly List _extraConfigs = new(); private bool _enableWebSocket; private bool _enableTls; + private bool _tlsFirst; private bool _enableJetStream; private string? _tlsServerCertFile; private string? _tlsServerKeyFile; private string? _tlsCaFile; private TransportType? _transportType; + private bool _serverDisposeReturnsPorts; + private bool _enableClustering; private bool _trace; - public NatsServerOpts Build() + public NatsServerOpts Build() => new() { - return new NatsServerOpts - { - EnableWebSocket = _enableWebSocket, - EnableTls = _enableTls, - EnableJetStream = _enableJetStream, - TlsServerCertFile = _tlsServerCertFile, - TlsServerKeyFile = _tlsServerKeyFile, - TlsCaFile = _tlsCaFile, - ExtraConfigs = _extraConfigs, - TransportType = _transportType ?? TransportType.Tcp, - Trace = _trace, - }; + EnableWebSocket = _enableWebSocket, + EnableTls = _enableTls, + TlsFirst = _tlsFirst, + EnableJetStream = _enableJetStream, + TlsServerCertFile = _tlsServerCertFile, + TlsServerKeyFile = _tlsServerKeyFile, + TlsCaFile = _tlsCaFile, + ExtraConfigs = _extraConfigs, + TransportType = _transportType ?? TransportType.Tcp, + ServerDisposeReturnsPorts = _serverDisposeReturnsPorts, + EnableClustering = _enableClustering, + Trace = _trace, + }; + + public NatsServerOptsBuilder EnableClustering() + { + _enableClustering = true; + return this; + } + + public NatsServerOptsBuilder WithServerDisposeReturnsPorts() + { + _serverDisposeReturnsPorts = true; + return this; } public NatsServerOptsBuilder Trace() @@ -45,16 +60,22 @@ public NatsServerOptsBuilder Trace() return this; } - public NatsServerOptsBuilder UseTransport(TransportType transportType) + public NatsServerOptsBuilder UseTransport(TransportType transportType, bool tlsFirst = false) { _transportType = transportType; + if (transportType != TransportType.Tls && tlsFirst) + { + throw new Exception("tlsFirst is only valid for TLS transport"); + } + if (transportType == TransportType.Tls) { _enableTls = true; _tlsServerCertFile = "resources/certs/server-cert.pem"; _tlsServerKeyFile = "resources/certs/server-key.pem"; _tlsCaFile = "resources/certs/ca-cert.pem"; + _tlsFirst = tlsFirst; } else if (transportType == TransportType.WebSocket) { @@ -130,6 +151,8 @@ public NatsServerOpts() public string? TlsCaFile { get; init; } + public bool TlsFirst { get; init; } = false; + public TransportType TransportType { get; init; } public bool Trace { get; init; } @@ -152,6 +175,7 @@ public string ConfigFileContents if (Trace) { sb.AppendLine($"trace: true"); + sb.AppendLine($"debug: true"); } if (EnableWebSocket) @@ -186,6 +210,11 @@ public string ConfigFileContents sb.AppendLine($" ca_file: {TlsCaFile}"); } + if (TlsFirst) + { + sb.AppendLine($" handshake_first: true"); + } + sb.AppendLine("}"); }