Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #2785: Support optionally disabling pipelining on AUTH flow #2787

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/Configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ The `ConfigurationOptions` object has a wide range of properties, all of which a
| setlib={bool} | `SetClientLibrary` | `true` | Whether to attempt to use `CLIENT SETINFO` to set the library name/version on the connection |
| protocol={string} | `Protocol` | `null` | Redis protocol to use; see section below |
| highIntegrity={bool} | `HighIntegrity` | `false` | High integrity (incurs overhead) sequence checking on every command; see section below |
| waitForAuth={bool} | `WaitForAuth` | `false` | Wait before the result of the `AUTH` command is returned before trying to send any other commands to the server |

Additional code-only options:
- LoggerFactory (`ILoggerFactory`) - Default: `null`
Expand Down
11 changes: 11 additions & 0 deletions src/StackExchange.Redis/Configuration/DefaultOptionsProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ public static DefaultOptionsProvider GetProvider(EndPoint endpoint)
/// </remarks>
public virtual bool HighIntegrity => false;

/// <summary>
/// A Boolean value that specifies whether the client should wait for the server to return
/// response for the initial AUTH command before trying any further commands.
/// </summary>
/// <remarks>
/// This is especially useful when connecting to Envoy proxies with external authentication
/// providers.
/// The default and recommended value is false.
/// </remarks>
public virtual bool WaitForAuth => false;

/// <summary>
/// The number of times to repeat the initial connect cycle if no servers respond promptly.
/// </summary>
Expand Down
28 changes: 25 additions & 3 deletions src/StackExchange.Redis/ConfigurationOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ internal const string
Tunnel = "tunnel",
SetClientLibrary = "setlib",
Protocol = "protocol",
HighIntegrity = "highIntegrity";
HighIntegrity = "highIntegrity",
WaitForAuth = "waitForAuth";

private static readonly Dictionary<string, string> normalizedOptions = new[]
{
Expand Down Expand Up @@ -143,6 +144,7 @@ internal const string
CheckCertificateRevocation,
Protocol,
HighIntegrity,
WaitForAuth,
}.ToDictionary(x => x, StringComparer.OrdinalIgnoreCase);

public static string TryNormalize(string value)
Expand All @@ -158,7 +160,7 @@ public static string TryNormalize(string value)
private DefaultOptionsProvider? defaultOptions;

private bool? allowAdmin, abortOnConnectFail, resolveDns, ssl, checkCertificateRevocation, heartbeatConsistencyChecks,
includeDetailInExceptions, includePerformanceCountersInExceptions, setClientLibrary, highIntegrity;
includeDetailInExceptions, includePerformanceCountersInExceptions, setClientLibrary, highIntegrity, waitForAuth;

private string? tieBreaker, sslHost, configChannel, user, password;

Expand Down Expand Up @@ -295,6 +297,21 @@ public bool HighIntegrity
set => highIntegrity = value;
}

/// <summary>
/// A Boolean value that specifies whether the client should wait for the server to return
/// response for the initial AUTH command before trying any further commands.
/// </summary>
/// <remarks>
/// This is especially useful when connecting to Envoy proxies with external authentication
/// providers.
/// The default and recommended value is false.
/// </remarks>
public bool WaitForAuth
{
get => waitForAuth ?? Defaults.WaitForAuth;
set => waitForAuth = value;
}

/// <summary>
/// Create a certificate validation check that checks against the supplied issuer even when not known by the machine.
/// </summary>
Expand Down Expand Up @@ -786,6 +803,7 @@ public static ConfigurationOptions Parse(string configuration, bool ignoreUnknow
heartbeatInterval = heartbeatInterval,
heartbeatConsistencyChecks = heartbeatConsistencyChecks,
highIntegrity = highIntegrity,
waitForAuth = waitForAuth,
};

/// <summary>
Expand Down Expand Up @@ -867,6 +885,7 @@ public string ToString(bool includePassword)
Append(sb, OptionKeys.DefaultDatabase, DefaultDatabase);
Append(sb, OptionKeys.SetClientLibrary, setClientLibrary);
Append(sb, OptionKeys.HighIntegrity, highIntegrity);
Append(sb, OptionKeys.WaitForAuth, waitForAuth);
Append(sb, OptionKeys.Protocol, FormatProtocol(Protocol));
if (Tunnel is { IsInbuilt: true } tunnel)
{
Expand Down Expand Up @@ -912,7 +931,7 @@ private void Clear()
{
ClientName = ServiceName = user = password = tieBreaker = sslHost = configChannel = null;
keepAlive = syncTimeout = asyncTimeout = connectTimeout = connectRetry = configCheckSeconds = DefaultDatabase = null;
allowAdmin = abortOnConnectFail = resolveDns = ssl = setClientLibrary = highIntegrity = null;
allowAdmin = abortOnConnectFail = resolveDns = ssl = setClientLibrary = highIntegrity = waitForAuth = null;
SslProtocols = null;
defaultVersion = null;
EndPoints.Clear();
Expand Down Expand Up @@ -1034,6 +1053,9 @@ private ConfigurationOptions DoParse(string configuration, bool ignoreUnknown)
case OptionKeys.HighIntegrity:
HighIntegrity = OptionKeys.ParseBoolean(key, value);
break;
case OptionKeys.WaitForAuth:
WaitForAuth = OptionKeys.ParseBoolean(key, value);
break;
case OptionKeys.Tunnel:
if (value.IsNullOrWhiteSpace())
{
Expand Down
3 changes: 3 additions & 0 deletions src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ StackExchange.Redis.ConfigurationOptions.HeartbeatInterval.get -> System.TimeSpa
StackExchange.Redis.ConfigurationOptions.HeartbeatInterval.set -> void
StackExchange.Redis.ConfigurationOptions.HighIntegrity.get -> bool
StackExchange.Redis.ConfigurationOptions.HighIntegrity.set -> void
StackExchange.Redis.ConfigurationOptions.WaitForAuth.get -> bool
StackExchange.Redis.ConfigurationOptions.WaitForAuth.set -> void
StackExchange.Redis.ConfigurationOptions.HighPrioritySocketThreads.get -> bool
StackExchange.Redis.ConfigurationOptions.HighPrioritySocketThreads.set -> void
StackExchange.Redis.ConfigurationOptions.IncludeDetailInExceptions.get -> bool
Expand Down Expand Up @@ -1846,6 +1848,7 @@ virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.GetSslHostFromE
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.HeartbeatConsistencyChecks.get -> bool
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.HeartbeatInterval.get -> System.TimeSpan
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.HighIntegrity.get -> bool
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.WaitForAuth.get -> bool
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.IncludeDetailInExceptions.get -> bool
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.IncludePerformanceCountersInExceptions.get -> bool
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.IsMatch(System.Net.EndPoint! endpoint) -> bool
Expand Down
24 changes: 18 additions & 6 deletions src/StackExchange.Redis/ServerEndPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -990,16 +990,14 @@ private async Task HandshakeAsync(PhysicalConnection connection, ILogger? log)
if (!string.IsNullOrWhiteSpace(user) && Multiplexer.CommandMap.IsAvailable(RedisCommand.AUTH))
{
log?.LogInformation($"{Format.ToString(this)}: Authenticating (user/password)");
msg = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.AUTH, (RedisValue)user, (RedisValue)password);
msg.SetInternalCall();
await WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK).ForAwait();
msg = Message.Create(-1, CommandFlags.None, RedisCommand.AUTH, (RedisValue)user, (RedisValue)password);
await SendAuthMessageAsync(connection, msg, ResultProcessor.DemandOK).ForAwait();
}
else if (!string.IsNullOrWhiteSpace(password) && Multiplexer.CommandMap.IsAvailable(RedisCommand.AUTH))
{
log?.LogInformation($"{Format.ToString(this)}: Authenticating (password)");
msg = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.AUTH, (RedisValue)password);
msg.SetInternalCall();
await WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK).ForAwait();
msg = Message.Create(-1, CommandFlags.None, RedisCommand.AUTH, (RedisValue)password);
await SendAuthMessageAsync(connection, msg, ResultProcessor.DemandOK).ForAwait();
}

if (Multiplexer.CommandMap.IsAvailable(RedisCommand.CLIENT))
Expand Down Expand Up @@ -1073,6 +1071,20 @@ private async Task HandshakeAsync(PhysicalConnection connection, ILogger? log)
await connection.FlushAsync().ForAwait();
}

private ValueTask SendAuthMessageAsync(PhysicalConnection connection, Message msg, ResultProcessor<bool> demandOK)
{
if (Multiplexer.RawConfig.WaitForAuth)
{
return new ValueTask(WriteDirectAsync(msg, ResultProcessor.DemandOK));
}
else
{
msg.Flags = CommandFlags.FireAndForget;
msg.SetInternalCall();
return WriteDirectOrQueueFireAndForgetAsync(connection, msg, ResultProcessor.DemandOK);
}
}

private void SetConfig<T>(ref T field, T value, [CallerMemberName] string? caller = null)
{
if (!EqualityComparer<T>.Default.Equals(field, value))
Expand Down
21 changes: 21 additions & 0 deletions tests/StackExchange.Redis.Tests/ConfigTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public void ExpectedFields()
"tieBreaker",
"Tunnel",
"user",
"waitForAuth",
},
fields);
}
Expand Down Expand Up @@ -811,4 +812,24 @@ public void CheckHighIntegrity(bool? assigned, bool expected, string cs)
var parsed = ConfigurationOptions.Parse(cs);
Assert.Equal(expected, options.HighIntegrity);
}

[Theory]
[InlineData(null, false, "dummy")]
[InlineData(false, false, "dummy,waitForAuth=False")]
[InlineData(true, true, "dummy,waitForAuth=True")]
public void CheckWaitForAuth(bool? assigned, bool expected, string cs)
{
var options = ConfigurationOptions.Parse("dummy");
if (assigned.HasValue) options.WaitForAuth = assigned.Value;

Assert.Equal(expected, options.WaitForAuth);
Assert.Equal(cs, options.ToString());

var clone = options.Clone();
Assert.Equal(expected, clone.WaitForAuth);
Assert.Equal(cs, clone.ToString());

var parsed = ConfigurationOptions.Parse(cs);
Assert.Equal(expected, options.WaitForAuth);
}
}
4 changes: 4 additions & 0 deletions tests/StackExchange.Redis.Tests/TestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ internal virtual IInternalConnectionMultiplexer Create(
BacklogPolicy? backlogPolicy = null,
Version? require = null,
RedisProtocol? protocol = null,
bool? waitForAuth = null,
[CallerMemberName] string caller = "")
{
if (Output == null)
Expand Down Expand Up @@ -314,6 +315,7 @@ internal virtual IInternalConnectionMultiplexer Create(
backlogPolicy,
protocol,
highIntegrity,
waitForAuth,
caller);

ThrowIfIncorrectProtocol(conn, protocol);
Expand Down Expand Up @@ -409,6 +411,7 @@ public static ConnectionMultiplexer CreateDefault(
BacklogPolicy? backlogPolicy = null,
RedisProtocol? protocol = null,
bool highIntegrity = false,
bool? waitForAuth = null,
[CallerMemberName] string caller = "")
{
StringWriter? localLog = null;
Expand Down Expand Up @@ -445,6 +448,7 @@ public static ConnectionMultiplexer CreateDefault(
if (backlogPolicy is not null) config.BacklogPolicy = backlogPolicy;
if (protocol is not null) config.Protocol = protocol;
if (highIntegrity) config.HighIntegrity = highIntegrity;
if (waitForAuth is not null) config.WaitForAuth = waitForAuth.Value;
var watch = Stopwatch.StartNew();
var task = ConnectionMultiplexer.ConnectAsync(config, log);
if (!task.Wait(config.ConnectTimeout >= (int.MaxValue / 2) ? int.MaxValue : config.ConnectTimeout * 2))
Expand Down
Loading