Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for AspNetCore 7 rate limiting #1967

Merged
merged 10 commits into from
Jan 13, 2023
64 changes: 32 additions & 32 deletions src/ReverseProxy/Configuration/ConfigValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ internal sealed class ConfigValidator : IConfigValidator

private readonly ITransformBuilder _transformBuilder;
private readonly IAuthorizationPolicyProvider _authorizationPolicyProvider;
#if NET7_0_OR_GREATER
private readonly IYarpRateLimiterPolicyProvider _rateLimiterPolicyProvider;
#endif
private readonly ICorsPolicyProvider _corsPolicyProvider;
private readonly IDictionary<string, ILoadBalancingPolicy> _loadBalancingPolicies;
private readonly IDictionary<string, IAffinityFailurePolicy> _affinityFailurePolicies;
Expand All @@ -40,6 +43,9 @@ internal sealed class ConfigValidator : IConfigValidator

public ConfigValidator(ITransformBuilder transformBuilder,
IAuthorizationPolicyProvider authorizationPolicyProvider,
#if NET7_0_OR_GREATER
IYarpRateLimiterPolicyProvider rateLimiterPolicyProvider,
#endif
ICorsPolicyProvider corsPolicyProvider,
IEnumerable<ILoadBalancingPolicy> loadBalancingPolicies,
IEnumerable<IAffinityFailurePolicy> affinityFailurePolicies,
Expand All @@ -50,6 +56,9 @@ public ConfigValidator(ITransformBuilder transformBuilder,
{
_transformBuilder = transformBuilder ?? throw new ArgumentNullException(nameof(transformBuilder));
_authorizationPolicyProvider = authorizationPolicyProvider ?? throw new ArgumentNullException(nameof(authorizationPolicyProvider));
#if NET7_0_OR_GREATER
_rateLimiterPolicyProvider = rateLimiterPolicyProvider ?? throw new ArgumentNullException(nameof(rateLimiterPolicyProvider));
#endif
_corsPolicyProvider = corsPolicyProvider ?? throw new ArgumentNullException(nameof(corsPolicyProvider));
_loadBalancingPolicies = loadBalancingPolicies?.ToDictionaryByUniqueId(p => p.Name) ?? throw new ArgumentNullException(nameof(loadBalancingPolicies));
_affinityFailurePolicies = affinityFailurePolicies?.ToDictionaryByUniqueId(p => p.Name) ?? throw new ArgumentNullException(nameof(affinityFailurePolicies));
Expand All @@ -72,7 +81,9 @@ public async ValueTask<IList<Exception>> ValidateRouteAsync(RouteConfig route)

errors.AddRange(_transformBuilder.ValidateRoute(route));
await ValidateAuthorizationPolicyAsync(errors, route.AuthorizationPolicy, route.RouteId);
#if NET7_0_OR_GREATER
await ValidateRateLimiterPolicyAsync(errors, route.RateLimiterPolicy, route.RouteId);
#endif
await ValidateCorsPolicyAsync(errors, route.CorsPolicy, route.RouteId);

if (route.Match is null)
Expand Down Expand Up @@ -288,59 +299,48 @@ private async ValueTask ValidateAuthorizationPolicyAsync(IList<Exception> errors
}
}

private ValueTask ValidateRateLimiterPolicyAsync(IList<Exception> errors, string? rateLimiterPolicyName, string routeId)
#if NET7_0_OR_GREATER
private async ValueTask ValidateRateLimiterPolicyAsync(IList<Exception> errors, string? rateLimiterPolicyName, string routeId)
{
if (string.IsNullOrEmpty(rateLimiterPolicyName))
mburumaxwell marked this conversation as resolved.
Show resolved Hide resolved
{
//return;
return ValueTask.CompletedTask;
return;
}

// TODO: update this once AspNetCore provides a mechanism to validate the RateLimiter policies https://github.com/dotnet/aspnetcore/issues/45684

if (string.Equals(RateLimitingConstants.Default, rateLimiterPolicyName, StringComparison.OrdinalIgnoreCase))
{
#if NET7_0_OR_GREATER
//var policy = await _rateLimiterPolicyProvider.GetPolicyAsync(rateLimiterPolicyName);
//if (policy is not null)
//{
// errors.Add(new ArgumentException($"The application has registered a RateLimiter policy named '{rateLimiterPolicyName}' that conflicts with the reserved RateLimiter policy name used on this route. The registered policy name needs to be changed for this route to function."));
//}
#endif
//return;
return ValueTask.CompletedTask;
var policy = await _rateLimiterPolicyProvider.GetPolicyAsync(rateLimiterPolicyName);
if (policy is not null)
{
errors.Add(new ArgumentException($"The application has registered a RateLimiter policy named '{rateLimiterPolicyName}' that conflicts with the reserved RateLimiter policy name used on this route. The registered policy name needs to be changed for this route to function."));
}
return;
}

if (string.Equals(RateLimitingConstants.Disable, rateLimiterPolicyName, StringComparison.OrdinalIgnoreCase))
{
#if NET7_0_OR_GREATER
//var policy = await _rateLimiterPolicyProvider.GetPolicyAsync(rateLimiterPolicyName);
//if (policy is not null)
//{
// errors.Add(new ArgumentException($"The application has registered a RateLimiter policy named '{rateLimiterPolicyName}' that conflicts with the reserved RateLimiter policy name used on this route. The registered policy name needs to be changed for this route to function."));
//}
#endif
//return;
return ValueTask.CompletedTask;
var policy = await _rateLimiterPolicyProvider.GetPolicyAsync(rateLimiterPolicyName);
if (policy is not null)
{
errors.Add(new ArgumentException($"The application has registered a RateLimiter policy named '{rateLimiterPolicyName}' that conflicts with the reserved RateLimiter policy name used on this route. The registered policy name needs to be changed for this route to function."));
}
return;
}

try
{
#if NET7_0_OR_GREATER
//var policy = await _rateLimiterPolicyProvider.GetPolicyAsync(rateLimiterPolicyName);
//if (policy is null)
//{
// errors.Add(new ArgumentException($"RateLimiter policy '{rateLimiterPolicyName}' not found for route '{routeId}'."));
//}
#endif
var policy = await _rateLimiterPolicyProvider.GetPolicyAsync(rateLimiterPolicyName);
if (policy is null)
{
errors.Add(new ArgumentException($"RateLimiter policy '{rateLimiterPolicyName}' not found for route '{routeId}'."));
}
}
catch (Exception ex)
{
errors.Add(new ArgumentException($"Unable to retrieve the RateLimiter policy '{rateLimiterPolicyName}' for route '{routeId}'.", ex));
Tratcher marked this conversation as resolved.
Show resolved Hide resolved
}

return ValueTask.CompletedTask;
}
#endif

private async ValueTask ValidateCorsPolicyAsync(IList<Exception> errors, string? corsPolicyName, string routeId)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System;
using System.Reflection;
using System.Threading.Tasks;
#if NET7_0_OR_GREATER
Tratcher marked this conversation as resolved.
Show resolved Hide resolved
using Microsoft.AspNetCore.RateLimiting;
#endif
using Microsoft.Extensions.Options;

namespace Yarp.ReverseProxy.Configuration;

// TODO: update this once AspNetCore provides a mechanism to validate the RateLimiter policies https://github.com/dotnet/aspnetcore/issues/45684

#if NET7_0_OR_GREATER

internal interface IYarpRateLimiterPolicyProvider
Tratcher marked this conversation as resolved.
Show resolved Hide resolved
{
ValueTask<object?> GetPolicyAsync(string policyName);
}

internal class YarpRateLimiterPolicyProvider : IYarpRateLimiterPolicyProvider
{
private readonly RateLimiterOptions _rateLimiterOptions;

private readonly System.Collections.IDictionary _policyMap, _unactivatedPolicyMap;
Tratcher marked this conversation as resolved.
Show resolved Hide resolved

public YarpRateLimiterPolicyProvider(IOptions<RateLimiterOptions> rateLimiterOptions)
{
_rateLimiterOptions = rateLimiterOptions?.Value ?? throw new ArgumentNullException(nameof(rateLimiterOptions));

var type = typeof(RateLimiterOptions);
var flags = BindingFlags.Instance | BindingFlags.NonPublic;
_policyMap = (System.Collections.IDictionary)type.GetProperty("PolicyMap", flags)!.GetValue(_rateLimiterOptions, null)!;
_unactivatedPolicyMap = (System.Collections.IDictionary)type.GetProperty("UnactivatedPolicyMap", flags)!.GetValue(_rateLimiterOptions, null)!;
Tratcher marked this conversation as resolved.
Show resolved Hide resolved
}

public ValueTask<object?> GetPolicyAsync(string policyName)
Tratcher marked this conversation as resolved.
Show resolved Hide resolved
{
return ValueTask.FromResult(_policyMap[policyName] ?? _unactivatedPolicyMap[policyName]);
}
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ internal static class IReverseProxyBuilderExtensions
{
public static IReverseProxyBuilder AddConfigBuilder(this IReverseProxyBuilder builder)
{
#if NET7_0_OR_GREATER
builder.Services.TryAddSingleton<IYarpRateLimiterPolicyProvider, YarpRateLimiterPolicyProvider>();
#endif
builder.Services.TryAddSingleton<IConfigValidator, ConfigValidator>();
builder.Services.TryAddSingleton<IRandomFactory, RandomFactory>();
builder.AddTransformFactory<ForwardedTransformFactory>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using System;
using System.Threading.Tasks;
#if NET7_0_OR_GREATER
Tratcher marked this conversation as resolved.
Show resolved Hide resolved
using System.Threading.RateLimiting;
#endif
using Microsoft.AspNetCore.Builder;
#if NET7_0_OR_GREATER
using Microsoft.AspNetCore.RateLimiting;
#endif
using Microsoft.Extensions.DependencyInjection;
using Xunit;

namespace Yarp.ReverseProxy.Configuration;

public class YarpRateLimiterPolicyProviderTests
{
#if NET7_0_OR_GREATER
[Fact]
public async Task GetPolicyAsync_Works()
{
var services = new ServiceCollection();

services.AddRateLimiter(options =>
{
options.AddFixedWindowLimiter("customPolicy", opt =>
{
opt.PermitLimit = 4;
opt.Window = TimeSpan.FromSeconds(12);
opt.QueueProcessingOrder = QueueProcessingOrder.OldestFirst;
opt.QueueLimit = 2;
});
});

services.AddReverseProxy();
var provider = services.BuildServiceProvider();
var rateLimiterPolicyProvider = provider.GetRequiredService<IYarpRateLimiterPolicyProvider>();
Assert.Null(await rateLimiterPolicyProvider.GetPolicyAsync("anotherPolicy"));
Assert.NotNull(await rateLimiterPolicyProvider.GetPolicyAsync("customPolicy"));
}
#endif
}