From 4cdcd95afed68f1fbc1eea0e0dd8b88b2c717d92 Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sat, 16 Feb 2019 20:40:59 +0200 Subject: [PATCH 01/13] Refactor duplicate code --- .../Core/ClientRateLimitProcessor.cs | 118 +---------- src/AspNetCoreRateLimit/Core/Extensions.cs | 27 ++- .../Core/IRateLimitProcessor.cs | 15 ++ .../Core/IpRateLimitProcessor.cs | 126 ++--------- src/AspNetCoreRateLimit/Core/RateLimitCore.cs | 132 ------------ .../Core/RateLimitProcessor.cs | 200 ++++++++++++++++++ .../Middleware/ClientRateLimitMiddleware.cs | 134 +----------- .../Middleware/IpRateLimitMiddleware.cs | 156 +------------- .../Middleware/RateLimitMiddleware.cs | 162 ++++++++++++++ .../Models/ClientRateLimitOptions.cs | 2 +- .../Models/ClientRateLimitPolicy.cs | 7 +- .../Models/IpRateLimitOptions.cs | 2 +- .../Models/IpRateLimitPolicy.cs | 7 +- ...imitCoreOptions.cs => RateLimitOptions.cs} | 2 +- .../Models/RateLimitPolicy.cs | 9 + src/AspNetCoreRateLimit/Net/IpParser.cs | 28 +++ src/AspNetCoreRateLimit/Net/RemoteIpParser.cs | 34 --- .../Net/ReversProxyIpParser.cs | 37 ++-- .../DistributedCacheClientPolicyStore.cs | 19 +- .../Store/DistributedCacheIpPolicyStore.cs | 13 +- .../DistributedCacheRateLimitCounterStore.cs | 3 + .../Store/IClientPolicyStore.cs | 6 +- .../Store/IIpPolicyStore.cs | 6 +- src/AspNetCoreRateLimit/Store/IPolicyStore.cs | 10 + .../Store/MemoryCacheClientPolicyStore.cs | 16 +- .../Store/MemoryCacheIpPolicyStore.cs | 12 +- 26 files changed, 560 insertions(+), 723 deletions(-) create mode 100644 src/AspNetCoreRateLimit/Core/IRateLimitProcessor.cs delete mode 100644 src/AspNetCoreRateLimit/Core/RateLimitCore.cs create mode 100644 src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs create mode 100644 src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs rename src/AspNetCoreRateLimit/Models/{RateLimitCoreOptions.cs => RateLimitOptions.cs} (97%) create mode 100644 src/AspNetCoreRateLimit/Models/RateLimitPolicy.cs create mode 100644 src/AspNetCoreRateLimit/Net/IpParser.cs delete mode 100644 src/AspNetCoreRateLimit/Net/RemoteIpParser.cs create mode 100644 src/AspNetCoreRateLimit/Store/IPolicyStore.cs diff --git a/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs index 8a296eeb..e7209751 100644 --- a/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs @@ -1,134 +1,38 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Linq; namespace AspNetCoreRateLimit { - public class ClientRateLimitProcessor + public class ClientRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor { private readonly ClientRateLimitOptions _options; - private readonly IRateLimitCounterStore _counterStore; - private readonly IClientPolicyStore _policyStore; - private readonly RateLimitCore _core; + private readonly IPolicyStore _policyStore; - public ClientRateLimitProcessor(ClientRateLimitOptions options, + public ClientRateLimitProcessor( + ClientRateLimitOptions options, IRateLimitCounterStore counterStore, IClientPolicyStore policyStore) + : base(options, counterStore) { _options = options; - _counterStore = counterStore; _policyStore = policyStore; - - _core = new RateLimitCore(false, options, _counterStore); } - public List GetMatchingRules(ClientRequestIdentity identity) + public IEnumerable GetMatchingRules(ClientRequestIdentity identity) { - var limits = new List(); var policy = _policyStore.Get($"{_options.ClientPolicyPrefix}_{identity.ClientId}"); if (policy != null) { - if (_options.EnableEndpointRateLimiting) - { - // search for rules with endpoints like "*" and "*:/matching_path" - var pathLimits = policy.Rules.Where(l => $"*:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); - limits.AddRange(pathLimits); - - // search for rules with endpoints like "matching_verb:/matching_path" - var verbLimits = policy.Rules.Where(l => $"{identity.HttpVerb}:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); - limits.AddRange(verbLimits); - } - else - { - //ignore endpoint rules and search for global rules only - var genericLimits = policy.Rules.Where(l => l.Endpoint == "*").AsEnumerable(); - limits.AddRange(genericLimits); - } - } - - // get the most restrictive limit for each period - limits = limits.GroupBy(l => l.Period).Select(l => l.OrderBy(x => x.Limit)).Select(l => l.First()).ToList(); - - // search for matching general rules - if (_options.GeneralRules != null) - { - var matchingGeneralLimits = new List(); - if (_options.EnableEndpointRateLimiting) - { - // search for rules with endpoints like "*" and "*:/matching_path" in general rules - var pathLimits = _options.GeneralRules.Where(l => $"*:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); - matchingGeneralLimits.AddRange(pathLimits); - - // search for rules with endpoints like "matching_verb:/matching_path" in general rules - var verbLimits = _options.GeneralRules.Where(l => $"{identity.HttpVerb}:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); - matchingGeneralLimits.AddRange(verbLimits); - } - else - { - //ignore endpoint rules and search for global rules in general rules - var genericLimits = _options.GeneralRules.Where(l => l.Endpoint == "*").AsEnumerable(); - matchingGeneralLimits.AddRange(genericLimits); - } - - // get the most restrictive general limit for each period - var generalLimits = matchingGeneralLimits.GroupBy(l => l.Period).Select(l => l.OrderBy(x => x.Limit)).Select(l => l.First()).ToList(); - - foreach (var generalLimit in generalLimits) - { - // add general rule if no specific rule is declared for the specified period - if(!limits.Exists(l => l.Period == generalLimit.Period)) - { - limits.Add(generalLimit); - } - } - } - - foreach (var item in limits) - { - //parse period text into time spans - item.PeriodTimespan = _core.ConvertToTimeSpan(item.Period); - } - - limits = limits.OrderBy(l => l.PeriodTimespan).ToList(); - if(_options.StackBlockedRequests) - { - limits.Reverse(); + return GetMatchingRules(identity, policy.Rules); } - return limits; - } - - public bool IsWhitelisted(ClientRequestIdentity requestIdentity) - { - if (_options.ClientWhitelist != null && _options.ClientWhitelist.Contains(requestIdentity.ClientId)) - { - return true; - } - - if (_options.EndpointWhitelist != null && _options.EndpointWhitelist.Any()) - { - if (_options.EndpointWhitelist.Any(x => $"{requestIdentity.HttpVerb}:{requestIdentity.Path}".ContainsIgnoreCase(x)) || - _options.EndpointWhitelist.Any(x => $"*:{requestIdentity.Path}".ContainsIgnoreCase(x))) - return true; - } - - return false; - } - - public RateLimitCounter ProcessRequest(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - return _core.ProcessRequest(requestIdentity, rule); - } - - public RateLimitHeaders GetRateLimitHeaders(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - return _core.GetRateLimitHeaders(requestIdentity, rule); + return Enumerable.Empty(); } - public string RetryAfterFrom(DateTime timestamp, RateLimitRule rule) + protected override string GetCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule) { - return _core.RetryAfterFrom(timestamp, rule); + return $"{_options.RateLimitCounterPrefix}_{requestIdentity.ClientId}_{rule.Period}"; } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Core/Extensions.cs b/src/AspNetCoreRateLimit/Core/Extensions.cs index b91ab9cb..3ee1d46c 100644 --- a/src/AspNetCoreRateLimit/Core/Extensions.cs +++ b/src/AspNetCoreRateLimit/Core/Extensions.cs @@ -1,4 +1,5 @@ using System; +using System.Globalization; namespace AspNetCoreRateLimit { @@ -8,5 +9,29 @@ public static bool ContainsIgnoreCase(this string source, string value, StringCo { return source != null && value != null && source.IndexOf(value, stringComparison) >= 0; } + + public static string RetryAfterFrom(this DateTime timestamp, RateLimitRule rule) + { + var secondsPast = Convert.ToInt32((DateTime.UtcNow - timestamp).TotalSeconds); + var retryAfter = Convert.ToInt32(rule.PeriodTimespan.Value.TotalSeconds); + retryAfter = retryAfter > 1 ? retryAfter - secondsPast : 1; + return retryAfter.ToString(CultureInfo.InvariantCulture); + } + + public static TimeSpan ToTimeSpan(this string timeSpan) + { + var l = timeSpan.Length - 1; + var value = timeSpan.Substring(0, l); + var type = timeSpan.Substring(l, 1); + + switch (type) + { + case "d": return TimeSpan.FromDays(double.Parse(value)); + case "h": return TimeSpan.FromHours(double.Parse(value)); + case "m": return TimeSpan.FromMinutes(double.Parse(value)); + case "s": return TimeSpan.FromSeconds(double.Parse(value)); + default: throw new FormatException($"{timeSpan} can't be converted to TimeSpan, unknown type {type}"); + } + } } -} +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Core/IRateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/IRateLimitProcessor.cs new file mode 100644 index 00000000..83d90754 --- /dev/null +++ b/src/AspNetCoreRateLimit/Core/IRateLimitProcessor.cs @@ -0,0 +1,15 @@ +using System.Collections.Generic; + +namespace AspNetCoreRateLimit +{ + public interface IRateLimitProcessor + { + IEnumerable GetMatchingRules(ClientRequestIdentity identity); + + RateLimitHeaders GetRateLimitHeaders(ClientRequestIdentity requestIdentity, RateLimitRule rule); + + bool IsWhitelisted(ClientRequestIdentity requestIdentity); + + RateLimitCounter ProcessRequest(ClientRequestIdentity requestIdentity, RateLimitRule rule); + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs index 077dee2e..c7256cc8 100644 --- a/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs @@ -1,32 +1,23 @@ -using AspNetCoreRateLimit.Core; -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Linq; namespace AspNetCoreRateLimit { - public class IpRateLimitProcessor + public class IpRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor { private readonly IpRateLimitOptions _options; - private readonly IRateLimitCounterStore _counterStore; - private readonly IIpPolicyStore _policyStore; - private readonly IIpAddressParser _ipParser; - private readonly RateLimitCore _core; + private readonly IPolicyStore _policyStore; public IpRateLimitProcessor(IpRateLimitOptions options, IRateLimitCounterStore counterStore, - IIpPolicyStore policyStore, - IIpAddressParser ipParser) + IIpPolicyStore policyStore) + : base(options, counterStore) { _options = options; - _counterStore = counterStore; _policyStore = policyStore; - _ipParser = ipParser; - - _core = new RateLimitCore(true, options, _counterStore); } - public List GetMatchingRules(ClientRequestIdentity identity) + public IEnumerable GetMatchingRules(ClientRequestIdentity identity) { var limits = new List(); var policies = _policyStore.Get($"{_options.IpPolicyPrefix}"); @@ -34,118 +25,33 @@ public List GetMatchingRules(ClientRequestIdentity identity) if (policies != null && policies.IpRules != null && policies.IpRules.Any()) { // search for rules with IP intervals containing client IP - var matchPolicies = policies.IpRules.Where(r => _ipParser.ContainsIp(r.Ip, identity.ClientIp)).AsEnumerable(); + var matchPolicies = policies.IpRules.Where(r => IpParser.ContainsIp(r.Ip, identity.ClientIp)).AsEnumerable(); var rules = new List(); + foreach (var item in matchPolicies) { rules.AddRange(item.Rules); } - if (_options.EnableEndpointRateLimiting) - { - // search for rules with endpoints like "*" and "*:/matching_path" - var pathLimits = rules.Where(l => $"*:{identity.Path}".ToLowerInvariant().Contains(l.Endpoint.ToLowerInvariant())).AsEnumerable(); - limits.AddRange(pathLimits); - - // search for rules with endpoints like "matching_verb:/matching_path" - var verbLimits = rules.Where(l => $"{identity.HttpVerb}:{identity.Path}".ToLowerInvariant().Contains(l.Endpoint.ToLowerInvariant())).AsEnumerable(); - limits.AddRange(verbLimits); - } - else - { - //ignore endpoint rules and search for global rules only - var genericLimits = rules.Where(l => l.Endpoint == "*").AsEnumerable(); - limits.AddRange(genericLimits); - } - } - - // get the most restrictive limit for each period - limits = limits.GroupBy(l => l.Period).Select(l => l.OrderBy(x => x.Limit)).Select(l => l.First()).ToList(); - - // search for matching general rules - if (_options.GeneralRules != null) - { - var matchingGeneralLimits = new List(); - if (_options.EnableEndpointRateLimiting) - { - // search for rules with endpoints like "*" and "*:/matching_path" in general rules - var pathLimits = _options.GeneralRules.Where(l => $"*:{identity.Path}".ToLowerInvariant().Contains(l.Endpoint.ToLowerInvariant())).AsEnumerable(); - matchingGeneralLimits.AddRange(pathLimits); - - // search for rules with endpoints like "matching_verb:/matching_path" in general rules - var verbLimits = _options.GeneralRules.Where(l => $"{identity.HttpVerb}:{identity.Path}".ToLowerInvariant().IsMatch(l.Endpoint.ToLowerInvariant())).AsEnumerable(); - matchingGeneralLimits.AddRange(verbLimits); - } - else - { - //ignore endpoint rules and search for global rules in general rules - var genericLimits = _options.GeneralRules.Where(l => l.Endpoint == "*").AsEnumerable(); - matchingGeneralLimits.AddRange(genericLimits); - } - - // get the most restrictive general limit for each period - var generalLimits = matchingGeneralLimits.GroupBy(l => l.Period).Select(l => l.OrderBy(x => x.Limit)).Select(l => l.First()).ToList(); - - foreach (var generalLimit in generalLimits) - { - // add general rule if no specific rule is declared for the specified period - if(!limits.Exists(l => l.Period == generalLimit.Period)) - { - limits.Add(generalLimit); - } - } + return GetMatchingRules(identity, rules); } - foreach (var item in limits) - { - //parse period text into time spans - item.PeriodTimespan = _core.ConvertToTimeSpan(item.Period); - } - - limits = limits.OrderBy(l => l.PeriodTimespan).ToList(); - if(_options.StackBlockedRequests) - { - limits.Reverse(); - } - - return limits; + return Enumerable.Empty(); } - public bool IsWhitelisted(ClientRequestIdentity requestIdentity) + public override bool IsWhitelisted(ClientRequestIdentity requestIdentity) { - if (_options.IpWhitelist != null && _ipParser.ContainsIp(_options.IpWhitelist, requestIdentity.ClientIp)) - { - return true; - } - - if (_options.ClientWhitelist != null && _options.ClientWhitelist.Contains(requestIdentity.ClientId)) + if (_options.IpWhitelist != null && IpParser.ContainsIp(_options.IpWhitelist, requestIdentity.ClientIp)) { return true; } - if (_options.EndpointWhitelist != null && _options.EndpointWhitelist.Any()) - { - if (_options.EndpointWhitelist.Any(x => $"{requestIdentity.HttpVerb}:{requestIdentity.Path}".ToLowerInvariant().Contains(x.ToLowerInvariant())) || - _options.EndpointWhitelist.Any(x => $"*:{requestIdentity.Path}".ToLowerInvariant().Contains(x.ToLowerInvariant()))) - return true; - } - - return false; - } - - public RateLimitCounter ProcessRequest(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - return _core.ProcessRequest(requestIdentity, rule); - } - - public RateLimitHeaders GetRateLimitHeaders(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - return _core.GetRateLimitHeaders(requestIdentity, rule); + return base.IsWhitelisted(requestIdentity); } - public string RetryAfterFrom(DateTime timestamp, RateLimitRule rule) + protected override string GetCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule) { - return _core.RetryAfterFrom(timestamp, rule); + return $"{_options.RateLimitCounterPrefix}_{requestIdentity.ClientIp}_{rule.Period}"; } } -} +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Core/RateLimitCore.cs b/src/AspNetCoreRateLimit/Core/RateLimitCore.cs deleted file mode 100644 index 36c25275..00000000 --- a/src/AspNetCoreRateLimit/Core/RateLimitCore.cs +++ /dev/null @@ -1,132 +0,0 @@ -using System; -using System.Globalization; - -namespace AspNetCoreRateLimit -{ - public class RateLimitCore - { - private readonly RateLimitCoreOptions _options; - private readonly IRateLimitCounterStore _counterStore; - private readonly bool _ipRateLimiting; - - private static readonly object _processLocker = new object(); - - public RateLimitCore(bool ipRateLimiting, - RateLimitCoreOptions options, - IRateLimitCounterStore counterStore) - { - _ipRateLimiting = ipRateLimiting; - _options = options; - _counterStore = counterStore; - } - - public string ComputeCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - var key = _ipRateLimiting ? - $"{_options.RateLimitCounterPrefix}_{requestIdentity.ClientIp}_{rule.Period}" : - $"{_options.RateLimitCounterPrefix}_{requestIdentity.ClientId}_{rule.Period}"; - - if(_options.EnableEndpointRateLimiting) - { - key += $"_{requestIdentity.HttpVerb}_{requestIdentity.Path}"; - - // TODO: consider using the rule endpoint as key, this will allow to rate limit /api/values/1 and api/values/2 under same counter - //key += $"_{rule.Endpoint}"; - } - - var idBytes = System.Text.Encoding.UTF8.GetBytes(key); - - byte[] hashBytes; - - using (var algorithm = System.Security.Cryptography.SHA1.Create()) - { - hashBytes = algorithm.ComputeHash(idBytes); - } - - return BitConverter.ToString(hashBytes).Replace("-", string.Empty); - } - - public RateLimitCounter ProcessRequest(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - var counter = new RateLimitCounter - { - Timestamp = DateTime.UtcNow, - TotalRequests = 1 - }; - - var counterId = ComputeCounterKey(requestIdentity, rule); - - // serial reads and writes - lock (_processLocker) - { - var entry = _counterStore.Get(counterId); - if (entry.HasValue) - { - // entry has not expired - if (entry.Value.Timestamp + rule.PeriodTimespan.Value >= DateTime.UtcNow) - { - // increment request count - var totalRequests = entry.Value.TotalRequests + 1; - - // deep copy - counter = new RateLimitCounter - { - Timestamp = entry.Value.Timestamp, - TotalRequests = totalRequests - }; - } - } - - // stores: id (string) - timestamp (datetime) - total_requests (long) - _counterStore.Set(counterId, counter, rule.PeriodTimespan.Value); - } - - return counter; - } - - public RateLimitHeaders GetRateLimitHeaders(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - var headers = new RateLimitHeaders(); - var counterId = ComputeCounterKey(requestIdentity, rule); - var entry = _counterStore.Get(counterId); - if (entry.HasValue) - { - headers.Reset = (entry.Value.Timestamp + ConvertToTimeSpan(rule.Period)).ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo); - headers.Limit = rule.Period; - headers.Remaining = (rule.Limit - entry.Value.TotalRequests).ToString(); - } - else - { - headers.Reset = (DateTime.UtcNow + ConvertToTimeSpan(rule.Period)).ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo); - headers.Limit = rule.Period; - headers.Remaining = rule.Limit .ToString(); - } - - return headers; - } - - public string RetryAfterFrom(DateTime timestamp, RateLimitRule rule) - { - var secondsPast = Convert.ToInt32((DateTime.UtcNow - timestamp).TotalSeconds); - var retryAfter = Convert.ToInt32(rule.PeriodTimespan.Value.TotalSeconds); - retryAfter = retryAfter > 1 ? retryAfter - secondsPast : 1; - return retryAfter.ToString(CultureInfo.InvariantCulture); - } - - public TimeSpan ConvertToTimeSpan(string timeSpan) - { - var l = timeSpan.Length - 1; - var value = timeSpan.Substring(0, l); - var type = timeSpan.Substring(l, 1); - - switch (type) - { - case "d": return TimeSpan.FromDays(double.Parse(value)); - case "h": return TimeSpan.FromHours(double.Parse(value)); - case "m": return TimeSpan.FromMinutes(double.Parse(value)); - case "s": return TimeSpan.FromSeconds(double.Parse(value)); - default: throw new FormatException($"{timeSpan} can't be converted to TimeSpan, unknown type {type}"); - } - } - } -} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs new file mode 100644 index 00000000..417bb741 --- /dev/null +++ b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs @@ -0,0 +1,200 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; + +namespace AspNetCoreRateLimit +{ + public abstract class RateLimitProcessor + { + private readonly RateLimitOptions _options; + private readonly IRateLimitCounterStore _counterStore; + + protected RateLimitProcessor( + RateLimitOptions options, + IRateLimitCounterStore counterStore) + { + _options = options; + _counterStore = counterStore; + } + + private static readonly object _processLocker = new object(); + + protected abstract string GetCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule); + + protected string ComputeCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule) + { + var key = GetCounterKey(requestIdentity, rule); + + if (_options.EnableEndpointRateLimiting) + { + key += $"_{requestIdentity.HttpVerb}_{requestIdentity.Path}"; + + // TODO: consider using the rule endpoint as key, this will allow to rate limit /api/values/1 and api/values/2 under same counter + //key += $"_{rule.Endpoint}"; + } + + var idBytes = System.Text.Encoding.UTF8.GetBytes(key); + + byte[] hashBytes; + + using (var algorithm = System.Security.Cryptography.SHA1.Create()) + { + hashBytes = algorithm.ComputeHash(idBytes); + } + + return BitConverter.ToString(hashBytes).Replace("-", string.Empty); + } + + protected List GetMatchingRules(ClientRequestIdentity identity, List rules) + { + var limits = new List(); + + if (_options.EnableEndpointRateLimiting) + { + // search for rules with endpoints like "*" and "*:/matching_path" + var pathLimits = rules.Where(l => $"*:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); + limits.AddRange(pathLimits); + + // search for rules with endpoints like "matching_verb:/matching_path" + var verbLimits = rules.Where(l => $"{identity.HttpVerb}:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); + limits.AddRange(verbLimits); + } + else + { + //ignore endpoint rules and search for global rules only + var genericLimits = rules.Where(l => l.Endpoint == "*").AsEnumerable(); + limits.AddRange(genericLimits); + } + + // get the most restrictive limit for each period + limits = limits.GroupBy(l => l.Period).Select(l => l.OrderBy(x => x.Limit)).Select(l => l.First()).ToList(); + + // search for matching general rules + if (_options.GeneralRules != null) + { + var matchingGeneralLimits = new List(); + if (_options.EnableEndpointRateLimiting) + { + // search for rules with endpoints like "*" and "*:/matching_path" in general rules + var pathLimits = _options.GeneralRules.Where(l => $"*:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); + matchingGeneralLimits.AddRange(pathLimits); + + // search for rules with endpoints like "matching_verb:/matching_path" in general rules + var verbLimits = _options.GeneralRules.Where(l => $"{identity.HttpVerb}:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); + matchingGeneralLimits.AddRange(verbLimits); + } + else + { + //ignore endpoint rules and search for global rules in general rules + var genericLimits = _options.GeneralRules.Where(l => l.Endpoint == "*").AsEnumerable(); + matchingGeneralLimits.AddRange(genericLimits); + } + + // get the most restrictive general limit for each period + var generalLimits = matchingGeneralLimits.GroupBy(l => l.Period).Select(l => l.OrderBy(x => x.Limit)).Select(l => l.First()).ToList(); + + foreach (var generalLimit in generalLimits) + { + // add general rule if no specific rule is declared for the specified period + if (!limits.Exists(l => l.Period == generalLimit.Period)) + { + limits.Add(generalLimit); + } + } + } + + foreach (var item in limits) + { + //parse period text into time spans + item.PeriodTimespan = item.Period.ToTimeSpan(); + } + + limits = limits.OrderBy(l => l.PeriodTimespan).ToList(); + + if (_options.StackBlockedRequests) + { + limits.Reverse(); + } + + return limits; + } + + public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity) + { + if (_options.ClientWhitelist != null && _options.ClientWhitelist.Contains(requestIdentity.ClientId)) + { + return true; + } + + if (_options.EndpointWhitelist != null && _options.EndpointWhitelist.Any()) + { + if (_options.EndpointWhitelist.Any(x => $"{requestIdentity.HttpVerb}:{requestIdentity.Path}".ContainsIgnoreCase(x)) || + _options.EndpointWhitelist.Any(x => $"*:{requestIdentity.Path}".ContainsIgnoreCase(x))) + return true; + } + + return false; + } + + public RateLimitCounter ProcessRequest(ClientRequestIdentity requestIdentity, RateLimitRule rule) + { + var counter = new RateLimitCounter + { + Timestamp = DateTime.UtcNow, + TotalRequests = 1 + }; + + var counterId = ComputeCounterKey(requestIdentity, rule); + + // serial reads and writes + lock (_processLocker) + { + var entry = _counterStore.Get(counterId); + + if (entry.HasValue) + { + // entry has not expired + if (entry.Value.Timestamp + rule.PeriodTimespan.Value >= DateTime.UtcNow) + { + // increment request count + var totalRequests = entry.Value.TotalRequests + 1; + + // deep copy + counter = new RateLimitCounter + { + Timestamp = entry.Value.Timestamp, + TotalRequests = totalRequests + }; + } + } + + // stores: id (string) - timestamp (datetime) - total_requests (long) + _counterStore.Set(counterId, counter, rule.PeriodTimespan.Value); + } + + return counter; + } + + public RateLimitHeaders GetRateLimitHeaders(ClientRequestIdentity requestIdentity, RateLimitRule rule) + { + var headers = new RateLimitHeaders(); + var counterId = ComputeCounterKey(requestIdentity, rule); + var entry = _counterStore.Get(counterId); + if (entry.HasValue) + { + headers.Reset = (entry.Value.Timestamp + rule.Period.ToTimeSpan()).ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo); + headers.Limit = rule.Period; + headers.Remaining = (rule.Limit - entry.Value.TotalRequests).ToString(); + } + else + { + headers.Reset = (DateTime.UtcNow + rule.Period.ToTimeSpan()).ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo); + headers.Limit = rule.Period; + headers.Remaining = rule.Limit.ToString(); + } + + return headers; + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs b/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs index ccd8566a..37f027bb 100644 --- a/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs +++ b/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs @@ -1,152 +1,26 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using System; -using System.Linq; -using System.Threading.Tasks; namespace AspNetCoreRateLimit { - public class ClientRateLimitMiddleware + public class ClientRateLimitMiddleware : RateLimitMiddleware { - private readonly RequestDelegate _next; private readonly ILogger _logger; - private readonly ClientRateLimitProcessor _processor; - private readonly ClientRateLimitOptions _options; public ClientRateLimitMiddleware(RequestDelegate next, IOptions options, IRateLimitCounterStore counterStore, IClientPolicyStore policyStore, - ILogger logger - ) + ILogger logger) + : base(next, options.Value, new ClientRateLimitProcessor(options.Value, counterStore, policyStore)) { - _next = next; - _options = options.Value; _logger = logger; - - _processor = new ClientRateLimitProcessor(_options, counterStore, policyStore); - } - - public async Task Invoke(HttpContext httpContext) - { - // check if rate limiting is enabled - if (_options == null) - { - await _next.Invoke(httpContext); - return; - } - - // compute identity from request - var identity = SetIdentity(httpContext); - - // check white list - if (_processor.IsWhitelisted(identity)) - { - await _next.Invoke(httpContext); - return; - } - - var rules = _processor.GetMatchingRules(identity); - - foreach (var rule in rules) - { - if(rule.Limit > 0) - { - // increment counter - var counter = _processor.ProcessRequest(identity, rule); - - // check if key expired - if (counter.Timestamp + rule.PeriodTimespan.Value < DateTime.UtcNow) - { - continue; - } - - // check if limit is reached - if (counter.TotalRequests > rule.Limit) - { - //compute retry after value - var retryAfter = _processor.RetryAfterFrom(counter.Timestamp, rule); - - // log blocked request - LogBlockedRequest(httpContext, identity, counter, rule); - - // break execution - await ReturnQuotaExceededResponse(httpContext, rule, retryAfter); - return; - } - } - // if limit is zero or less, block the request. - else - { - // process request count - var counter = _processor.ProcessRequest(identity, rule); - - // log blocked request - LogBlockedRequest(httpContext, identity, counter, rule); - - // break execution (Int32 max used to represent infinity) - await ReturnQuotaExceededResponse(httpContext, rule, int.MaxValue.ToString(System.Globalization.CultureInfo.InvariantCulture)); - return; - } - } - - //set X-Rate-Limit headers for the longest period - if(rules.Any() && !_options.DisableRateLimitHeaders) - { - var rule = rules.OrderByDescending(x => x.PeriodTimespan.Value).First(); - var headers = _processor.GetRateLimitHeaders(identity, rule); - headers.Context = httpContext; - - httpContext.Response.OnStarting(SetRateLimitHeaders, state: headers); - } - - await _next.Invoke(httpContext); - } - - public virtual ClientRequestIdentity SetIdentity(HttpContext httpContext) - { - var clientId = "anon"; - if (httpContext.Request.Headers.Keys.Contains(_options.ClientIdHeader,StringComparer.CurrentCultureIgnoreCase)) - { - clientId = httpContext.Request.Headers[_options.ClientIdHeader].First(); - } - - return new ClientRequestIdentity - { - Path = httpContext.Request.Path.ToString().ToLowerInvariant(), - HttpVerb = httpContext.Request.Method.ToLowerInvariant(), - ClientId = clientId - }; } - public virtual Task ReturnQuotaExceededResponse(HttpContext httpContext, RateLimitRule rule, string retryAfter) - { - var message = string.Format(_options.QuotaExceededMessage ?? "API calls quota exceeded! maximum admitted {0} per {1}.", rule.Limit, rule.Period); - - if (!_options.DisableRateLimitHeaders) - { - httpContext.Response.Headers["Retry-After"] = retryAfter; - } - - httpContext.Response.StatusCode = _options.HttpStatusCode; - return httpContext.Response.WriteAsync(message); - } - - public virtual void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule) + protected override void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule) { _logger.LogInformation($"Request {identity.HttpVerb}:{identity.Path} from ClientId {identity.ClientId} has been blocked, quota {rule.Limit}/{rule.Period} exceeded by {counter.TotalRequests}. Blocked by rule {rule.Endpoint}, TraceIdentifier {httpContext.TraceIdentifier}."); } - - private Task SetRateLimitHeaders(object rateLimitHeaders) - { - var headers = (RateLimitHeaders)rateLimitHeaders; - - headers.Context.Response.Headers["X-Rate-Limit-Limit"] = headers.Limit; - headers.Context.Response.Headers["X-Rate-Limit-Remaining"] = headers.Remaining; - headers.Context.Response.Headers["X-Rate-Limit-Reset"] = headers.Reset; - - return Task.CompletedTask; - } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs b/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs index bf55f927..52bb418a 100644 --- a/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs +++ b/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs @@ -1,171 +1,27 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using System; -using System.Linq; -using System.Threading.Tasks; namespace AspNetCoreRateLimit { - public class IpRateLimitMiddleware + public class IpRateLimitMiddleware : RateLimitMiddleware { - private readonly RequestDelegate _next; private readonly ILogger _logger; - private readonly IIpAddressParser _ipParser; - private readonly IpRateLimitProcessor _processor; - private readonly IpRateLimitOptions _options; public IpRateLimitMiddleware(RequestDelegate next, IOptions options, IRateLimitCounterStore counterStore, IIpPolicyStore policyStore, - ILogger logger, - IIpAddressParser ipParser = null) - { - _next = next; - _options = options.Value; - _logger = logger; - _ipParser = ipParser ?? new ReversProxyIpParser(_options.RealIpHeader); - - _processor = new IpRateLimitProcessor(_options, counterStore, policyStore, _ipParser); - } - - public async Task Invoke(HttpContext httpContext) - { - // check if rate limiting is enabled - if (_options == null) - { - await _next.Invoke(httpContext); - return; - } - - // compute identity from request - var identity = SetIdentity(httpContext); - - // check white list - if (_processor.IsWhitelisted(identity)) - { - await _next.Invoke(httpContext); - return; - } - - var rules = _processor.GetMatchingRules(identity); - - foreach (var rule in rules) - { - if (rule.Limit > 0) - { - // increment counter - var counter = _processor.ProcessRequest(identity, rule); - - // check if key expired - if (counter.Timestamp + rule.PeriodTimespan.Value < DateTime.UtcNow) - { - continue; - } - - // check if limit is reached - if (counter.TotalRequests > rule.Limit) - { - //compute retry after value - var retryAfter = _processor.RetryAfterFrom(counter.Timestamp, rule); - - // log blocked request - LogBlockedRequest(httpContext, identity, counter, rule); - - // break execution - await ReturnQuotaExceededResponse(httpContext, rule, retryAfter); - return; - } - } - // if limit is zero or less, block the request. - else - { - // process request count - var counter = _processor.ProcessRequest(identity, rule); - - // log blocked request - LogBlockedRequest(httpContext, identity, counter, rule); - - // break execution (Int32 max used to represent infinity) - await ReturnQuotaExceededResponse(httpContext, rule, int.MaxValue.ToString(System.Globalization.CultureInfo.InvariantCulture)); - return; - } - } - - //set X-Rate-Limit headers for the longest period - if (rules.Any() && !_options.DisableRateLimitHeaders) - { - var rule = rules.OrderByDescending(x => x.PeriodTimespan.Value).First(); - var headers = _processor.GetRateLimitHeaders(identity, rule); - headers.Context = httpContext; + ILogger logger) + : base(next, options.Value, new IpRateLimitProcessor(options.Value, counterStore, policyStore)) - httpContext.Response.OnStarting(SetRateLimitHeaders, state: headers); - } - - await _next.Invoke(httpContext); - } - - public virtual ClientRequestIdentity SetIdentity(HttpContext httpContext) - { - var clientId = "anon"; - if (httpContext.Request.Headers.Keys.Contains(_options.ClientIdHeader,StringComparer.CurrentCultureIgnoreCase)) - { - clientId = httpContext.Request.Headers[_options.ClientIdHeader].First(); - } - - string clientIp; - try - { - var ip = _ipParser.GetClientIp(httpContext); - if(ip == null) - { - throw new Exception("IpRateLimitMiddleware can't parse caller IP"); - } - - clientIp = ip.ToString(); - } - catch (Exception ex) - { - throw new Exception("IpRateLimitMiddleware can't parse caller IP", ex); - } - - return new ClientRequestIdentity - { - ClientIp = clientIp, - Path = httpContext.Request.Path.ToString().ToLowerInvariant(), - HttpVerb = httpContext.Request.Method.ToLowerInvariant(), - ClientId = clientId - }; - } - - public virtual Task ReturnQuotaExceededResponse(HttpContext httpContext, RateLimitRule rule, string retryAfter) { - var message = string.Format(_options.QuotaExceededMessage ?? "API calls quota exceeded! maximum admitted {0} per {1}.", rule.Limit, rule.Period); - - if (!_options.DisableRateLimitHeaders) - { - httpContext.Response.Headers["Retry-After"] = retryAfter; - } - - httpContext.Response.StatusCode = _options.HttpStatusCode; - return httpContext.Response.WriteAsync(message); + _logger = logger; } - public virtual void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule) + protected override void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule) { _logger.LogInformation($"Request {identity.HttpVerb}:{identity.Path} from IP {identity.ClientIp} has been blocked, quota {rule.Limit}/{rule.Period} exceeded by {counter.TotalRequests}. Blocked by rule {rule.Endpoint}, TraceIdentifier {httpContext.TraceIdentifier}."); } - - private Task SetRateLimitHeaders(object rateLimitHeaders) - { - var headers = (RateLimitHeaders)rateLimitHeaders; - - headers.Context.Response.Headers["X-Rate-Limit-Limit"] = headers.Limit; - headers.Context.Response.Headers["X-Rate-Limit-Remaining"] = headers.Remaining; - headers.Context.Response.Headers["X-Rate-Limit-Reset"] = headers.Reset; - - return Task.CompletedTask; - } } -} +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs b/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs new file mode 100644 index 00000000..1c551418 --- /dev/null +++ b/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs @@ -0,0 +1,162 @@ +using Microsoft.AspNetCore.Http; +using System; +using System.Linq; +using System.Threading.Tasks; + +namespace AspNetCoreRateLimit +{ + public abstract class RateLimitMiddleware + where TProcessor : IRateLimitProcessor + { + private readonly RequestDelegate _next; + private readonly TProcessor _processor; + private readonly RateLimitOptions _options; + + protected RateLimitMiddleware( + RequestDelegate next, + RateLimitOptions options, + TProcessor processor) + { + _next = next; + _options = options; + _processor = processor; + } + + public async Task Invoke(HttpContext httpContext) + { + // check if rate limiting is enabled + if (_options == null) + { + await _next.Invoke(httpContext); + return; + } + + // compute identity from request + var identity = SetIdentity(httpContext); + + // check white list + if (_processor.IsWhitelisted(identity)) + { + await _next.Invoke(httpContext); + return; + } + + var rules = _processor.GetMatchingRules(identity); + + foreach (var rule in rules) + { + if (rule.Limit > 0) + { + // increment counter + var counter = _processor.ProcessRequest(identity, rule); + + // check if key expired + if (counter.Timestamp + rule.PeriodTimespan.Value < DateTime.UtcNow) + { + continue; + } + + // check if limit is reached + if (counter.TotalRequests > rule.Limit) + { + //compute retry after value + var retryAfter = counter.Timestamp.RetryAfterFrom(rule); + + // log blocked request + LogBlockedRequest(httpContext, identity, counter, rule); + + // break execution + await ReturnQuotaExceededResponse(httpContext, rule, retryAfter); + return; + } + } + // if limit is zero or less, block the request. + else + { + // process request count + var counter = _processor.ProcessRequest(identity, rule); + + // log blocked request + LogBlockedRequest(httpContext, identity, counter, rule); + + // break execution (Int32 max used to represent infinity) + await ReturnQuotaExceededResponse(httpContext, rule, int.MaxValue.ToString(System.Globalization.CultureInfo.InvariantCulture)); + return; + } + } + + //set X-Rate-Limit headers for the longest period + if (rules.Any() && !_options.DisableRateLimitHeaders) + { + var rule = rules.OrderByDescending(x => x.PeriodTimespan.Value).First(); + var headers = _processor.GetRateLimitHeaders(identity, rule); + headers.Context = httpContext; + + httpContext.Response.OnStarting(SetRateLimitHeaders, state: headers); + } + + await _next.Invoke(httpContext); + } + + public virtual ClientRequestIdentity SetIdentity(HttpContext httpContext) + { + //var clientId = "anon"; + //if (httpContext.Request.Headers.Keys.Contains(_options.ClientIdHeader, StringComparer.CurrentCultureIgnoreCase)) + //{ + // clientId = httpContext.Request.Headers[_options.ClientIdHeader].First(); + //} + + //string clientIp; + //try + //{ + // var ip = _ipParser.GetClientIp(httpContext); + + // if (ip == null) + // { + // throw new Exception("IpRateLimitMiddleware can't parse caller IP"); + // } + + // clientIp = ip.ToString(); + //} + //catch (Exception ex) + //{ + // throw new Exception("IpRateLimitMiddleware can't parse caller IP", ex); + //} + + return new ClientRequestIdentity + { + //ClientIp = clientIp, + Path = httpContext.Request.Path.ToString().ToLowerInvariant(), + HttpVerb = httpContext.Request.Method.ToLowerInvariant(), + //ClientId = clientId + }; + } + + public virtual Task ReturnQuotaExceededResponse(HttpContext httpContext, RateLimitRule rule, string retryAfter) + { + var message = string.Format(_options.QuotaExceededMessage ?? "API calls quota exceeded! maximum admitted {0} per {1}.", rule.Limit, rule.Period); + + if (!_options.DisableRateLimitHeaders) + { + httpContext.Response.Headers["Retry-After"] = retryAfter; + } + + httpContext.Response.StatusCode = _options.HttpStatusCode; + + return httpContext.Response.WriteAsync(message); + } + + protected abstract void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule); + + private Task SetRateLimitHeaders(object rateLimitHeaders) + { + var headers = (RateLimitHeaders)rateLimitHeaders; + + headers.Context.Response.Headers["X-Rate-Limit-Limit"] = headers.Limit; + headers.Context.Response.Headers["X-Rate-Limit-Remaining"] = headers.Remaining; + headers.Context.Response.Headers["X-Rate-Limit-Reset"] = headers.Reset; + + return Task.CompletedTask; + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Models/ClientRateLimitOptions.cs b/src/AspNetCoreRateLimit/Models/ClientRateLimitOptions.cs index 3af967f2..f7269f7b 100644 --- a/src/AspNetCoreRateLimit/Models/ClientRateLimitOptions.cs +++ b/src/AspNetCoreRateLimit/Models/ClientRateLimitOptions.cs @@ -1,6 +1,6 @@ namespace AspNetCoreRateLimit { - public class ClientRateLimitOptions : RateLimitCoreOptions + public class ClientRateLimitOptions : RateLimitOptions { /// /// Gets or sets the HTTP header that holds the client identifier, by default is X-ClientId diff --git a/src/AspNetCoreRateLimit/Models/ClientRateLimitPolicy.cs b/src/AspNetCoreRateLimit/Models/ClientRateLimitPolicy.cs index 5920bf0c..dbe5158e 100644 --- a/src/AspNetCoreRateLimit/Models/ClientRateLimitPolicy.cs +++ b/src/AspNetCoreRateLimit/Models/ClientRateLimitPolicy.cs @@ -1,10 +1,7 @@ -using System.Collections.Generic; - -namespace AspNetCoreRateLimit +namespace AspNetCoreRateLimit { - public class ClientRateLimitPolicy + public class ClientRateLimitPolicy : RateLimitPolicy { public string ClientId { get; set; } - public List Rules { get; set; } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Models/IpRateLimitOptions.cs b/src/AspNetCoreRateLimit/Models/IpRateLimitOptions.cs index 410b6068..9cf503df 100644 --- a/src/AspNetCoreRateLimit/Models/IpRateLimitOptions.cs +++ b/src/AspNetCoreRateLimit/Models/IpRateLimitOptions.cs @@ -2,7 +2,7 @@ namespace AspNetCoreRateLimit { - public class IpRateLimitOptions : RateLimitCoreOptions + public class IpRateLimitOptions : RateLimitOptions { /// /// Gets or sets the HTTP header of the real ip header injected by reverse proxy, by default is X-Real-IP diff --git a/src/AspNetCoreRateLimit/Models/IpRateLimitPolicy.cs b/src/AspNetCoreRateLimit/Models/IpRateLimitPolicy.cs index 15b594de..699e1e86 100644 --- a/src/AspNetCoreRateLimit/Models/IpRateLimitPolicy.cs +++ b/src/AspNetCoreRateLimit/Models/IpRateLimitPolicy.cs @@ -1,10 +1,7 @@ -using System.Collections.Generic; - -namespace AspNetCoreRateLimit +namespace AspNetCoreRateLimit { - public class IpRateLimitPolicy + public class IpRateLimitPolicy : RateLimitPolicy { public string Ip { get; set; } - public List Rules { get; set; } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Models/RateLimitCoreOptions.cs b/src/AspNetCoreRateLimit/Models/RateLimitOptions.cs similarity index 97% rename from src/AspNetCoreRateLimit/Models/RateLimitCoreOptions.cs rename to src/AspNetCoreRateLimit/Models/RateLimitOptions.cs index eda19cbc..00d6a93e 100644 --- a/src/AspNetCoreRateLimit/Models/RateLimitCoreOptions.cs +++ b/src/AspNetCoreRateLimit/Models/RateLimitOptions.cs @@ -2,7 +2,7 @@ namespace AspNetCoreRateLimit { - public class RateLimitCoreOptions + public class RateLimitOptions { public List GeneralRules { get; set; } diff --git a/src/AspNetCoreRateLimit/Models/RateLimitPolicy.cs b/src/AspNetCoreRateLimit/Models/RateLimitPolicy.cs new file mode 100644 index 00000000..ba1d40dd --- /dev/null +++ b/src/AspNetCoreRateLimit/Models/RateLimitPolicy.cs @@ -0,0 +1,9 @@ +using System.Collections.Generic; + +namespace AspNetCoreRateLimit +{ + public class RateLimitPolicy + { + public List Rules { get; set; } + } +} diff --git a/src/AspNetCoreRateLimit/Net/IpParser.cs b/src/AspNetCoreRateLimit/Net/IpParser.cs new file mode 100644 index 00000000..6c5fc65a --- /dev/null +++ b/src/AspNetCoreRateLimit/Net/IpParser.cs @@ -0,0 +1,28 @@ +using System.Collections.Generic; +using System.Net; + +namespace AspNetCoreRateLimit +{ + public static class IpParser + { + public static bool ContainsIp(string ipRule, string clientIp) + { + return IpAddressUtil.ContainsIp(ipRule, clientIp); + } + + public static bool ContainsIp(List ipRules, string clientIp) + { + return IpAddressUtil.ContainsIp(ipRules, clientIp); + } + + public static bool ContainsIp(List ipRules, string clientIp, out string rule) + { + return IpAddressUtil.ContainsIp(ipRules, clientIp, out rule); + } + + public static IPAddress ParseIp(string ipAddress) + { + return IpAddressUtil.ParseIp(ipAddress); + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Net/RemoteIpParser.cs b/src/AspNetCoreRateLimit/Net/RemoteIpParser.cs deleted file mode 100644 index 44418e8e..00000000 --- a/src/AspNetCoreRateLimit/Net/RemoteIpParser.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System.Collections.Generic; -using System.Net; -using Microsoft.AspNetCore.Http; - -namespace AspNetCoreRateLimit -{ - public class RemoteIpParser : IIpAddressParser - { - public bool ContainsIp(string ipRule, string clientIp) - { - return IpAddressUtil.ContainsIp(ipRule, clientIp); - } - - public bool ContainsIp(List ipRules, string clientIp) - { - return IpAddressUtil.ContainsIp(ipRules, clientIp); - } - - public bool ContainsIp(List ipRules, string clientIp, out string rule) - { - return IpAddressUtil.ContainsIp(ipRules, clientIp, out rule); - } - - public virtual IPAddress GetClientIp(HttpContext context) - { - return context.Connection.RemoteIpAddress; - } - - public IPAddress ParseIp(string ipAddress) - { - return IpAddressUtil.ParseIp(ipAddress); - } - } -} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Net/ReversProxyIpParser.cs b/src/AspNetCoreRateLimit/Net/ReversProxyIpParser.cs index 375d99cf..a69fb24c 100644 --- a/src/AspNetCoreRateLimit/Net/ReversProxyIpParser.cs +++ b/src/AspNetCoreRateLimit/Net/ReversProxyIpParser.cs @@ -5,23 +5,28 @@ namespace AspNetCoreRateLimit { - public class ReversProxyIpParser : RemoteIpParser - { - private readonly string _realIpHeader; +// public class ReversProxyIpParser : RemoteIpParser +// { +// private readonly string _realIpHeader; - public ReversProxyIpParser(string realIpHeader) - { - _realIpHeader = realIpHeader; - } +// public ReversProxyIpParser(string realIpHeader) +// { +// _realIpHeader = realIpHeader; +// } - public override IPAddress GetClientIp(HttpContext context) - { - if (context.Request.Headers.Keys.Contains(_realIpHeader, StringComparer.CurrentCultureIgnoreCase)) - { - return ParseIp(context.Request.Headers[_realIpHeader].Last()); - } +// public override IPAddress GetClientIp(HttpContext context) +// { +// if (context.Request.Headers.Keys.Contains(_realIpHeader, StringComparer.CurrentCultureIgnoreCase)) +// { +// return ParseIp(context.Request.Headers[_realIpHeader].Last()); +// } - return base.GetClientIp(context); - } - } +// return base.GetClientIp(context); +// } + +// public virtual IPAddress GetClientIp(HttpContext context) +// { +// return context.Connection.RemoteIpAddress; +// } +// } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs b/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs index 14879d20..c295e768 100644 --- a/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs @@ -8,18 +8,22 @@ public class DistributedCacheClientPolicyStore : IClientPolicyStore { private readonly IDistributedCache _memoryCache; - public DistributedCacheClientPolicyStore(IDistributedCache memoryCache, + public DistributedCacheClientPolicyStore( + IDistributedCache memoryCache, IOptions options = null, IOptions policies = null) { _memoryCache = memoryCache; - //save client rules defined in appsettings in distributed cache on startup - if (options != null && options.Value != null && policies != null && policies.Value != null && policies.Value.ClientRules != null) + var clientOptions = options?.Value; + var clientPolicyRules = policies?.Value?.ClientRules; + + //save client rules defined in appsettings in cache on startup + if (clientOptions != null && clientPolicyRules != null) { - foreach (var rule in policies.Value.ClientRules) + foreach (var rule in clientPolicyRules) { - Set($"{options.Value.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }); + Set($"{clientOptions.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }); } } } @@ -32,16 +36,19 @@ public void Set(string id, ClientRateLimitPolicy policy) public bool Exists(string id) { var stored = _memoryCache.GetString(id); + return !string.IsNullOrEmpty(stored); } public ClientRateLimitPolicy Get(string id) { var stored = _memoryCache.GetString(id); + if (!string.IsNullOrEmpty(stored)) { return JsonConvert.DeserializeObject(stored); } + return null; } @@ -50,4 +57,4 @@ public void Remove(string id) _memoryCache.Remove(id); } } -} +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs b/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs index 6d044c8c..4ee20315 100644 --- a/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs @@ -8,16 +8,21 @@ public class DistributedCacheIpPolicyStore : IIpPolicyStore { private readonly IDistributedCache _memoryCache; - public DistributedCacheIpPolicyStore(IDistributedCache memoryCache, + public DistributedCacheIpPolicyStore( + IDistributedCache memoryCache, IOptions options = null, IOptions policies = null) { _memoryCache = memoryCache; - //save ip rules defined in appsettings in distributed cache on startup - if (options != null && options.Value != null && policies != null && policies.Value != null && policies.Value.IpRules != null) + var ipOptions = options?.Value; + var ipPolicyRules = policies?.Value; + + //save IP rules defined in appsettings in cache on startup + if (ipOptions != null && ipPolicyRules != null) { - Set($"{options.Value.IpPolicyPrefix}", policies.Value); + Set($"{ipOptions.IpPolicyPrefix}", ipPolicyRules); + } } diff --git a/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitCounterStore.cs b/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitCounterStore.cs index bba9c2f4..7d164b5d 100644 --- a/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitCounterStore.cs +++ b/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitCounterStore.cs @@ -21,16 +21,19 @@ public void Set(string id, RateLimitCounter counter, TimeSpan expirationTime) public bool Exists(string id) { var stored = _memoryCache.GetString(id); + return !string.IsNullOrEmpty(stored); } public RateLimitCounter? Get(string id) { var stored = _memoryCache.GetString(id); + if(!string.IsNullOrEmpty(stored)) { return JsonConvert.DeserializeObject(stored); } + return null; } diff --git a/src/AspNetCoreRateLimit/Store/IClientPolicyStore.cs b/src/AspNetCoreRateLimit/Store/IClientPolicyStore.cs index 90bbc0f1..f3a998eb 100644 --- a/src/AspNetCoreRateLimit/Store/IClientPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/IClientPolicyStore.cs @@ -1,10 +1,6 @@ namespace AspNetCoreRateLimit { - public interface IClientPolicyStore + public interface IClientPolicyStore : IPolicyStore { - bool Exists(string id); - ClientRateLimitPolicy Get(string id); - void Remove(string id); - void Set(string id, ClientRateLimitPolicy policy); } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/IIpPolicyStore.cs b/src/AspNetCoreRateLimit/Store/IIpPolicyStore.cs index fb22c1d5..ab6d7c05 100644 --- a/src/AspNetCoreRateLimit/Store/IIpPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/IIpPolicyStore.cs @@ -1,10 +1,6 @@ namespace AspNetCoreRateLimit { - public interface IIpPolicyStore + public interface IIpPolicyStore : IPolicyStore { - bool Exists(string id); - IpRateLimitPolicies Get(string id); - void Remove(string id); - void Set(string id, IpRateLimitPolicies policy); } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/IPolicyStore.cs b/src/AspNetCoreRateLimit/Store/IPolicyStore.cs new file mode 100644 index 00000000..bcbe42f4 --- /dev/null +++ b/src/AspNetCoreRateLimit/Store/IPolicyStore.cs @@ -0,0 +1,10 @@ +namespace AspNetCoreRateLimit +{ + public interface IPolicyStore + { + bool Exists(string id); + TPolicy Get(string id); + void Remove(string id); + void Set(string id, TPolicy policy); + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs b/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs index 07dd6608..ecaea11c 100644 --- a/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs @@ -3,22 +3,26 @@ namespace AspNetCoreRateLimit { - public class MemoryCacheClientPolicyStore: IClientPolicyStore + public class MemoryCacheClientPolicyStore : IClientPolicyStore { private readonly IMemoryCache _memoryCache; - public MemoryCacheClientPolicyStore(IMemoryCache memoryCache, + public MemoryCacheClientPolicyStore( + IMemoryCache memoryCache, IOptions options = null, IOptions policies = null) { _memoryCache = memoryCache; + var clientOptions = options?.Value; + var clientPolicyRules = policies?.Value?.ClientRules; + //save client rules defined in appsettings in cache on startup - if(options != null && options.Value != null && policies != null && policies.Value != null && policies.Value.ClientRules != null) + if (clientOptions != null && clientPolicyRules != null) { - foreach (var rule in policies.Value.ClientRules) + foreach (var rule in clientPolicyRules) { - Set($"{options.Value.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }); + Set($"{clientOptions.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }); } } } @@ -48,4 +52,4 @@ public void Remove(string id) _memoryCache.Remove(id); } } -} +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs b/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs index 65c94855..aa18264f 100644 --- a/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs @@ -7,16 +7,20 @@ public class MemoryCacheIpPolicyStore : IIpPolicyStore { private readonly IMemoryCache _memoryCache; - public MemoryCacheIpPolicyStore(IMemoryCache memoryCache, + public MemoryCacheIpPolicyStore( + IMemoryCache memoryCache, IOptions options = null, IOptions policies = null) { _memoryCache = memoryCache; - //save ip rules defined in appsettings in cache on startup - if (options != null && options.Value != null && policies != null && policies.Value != null && policies.Value.IpRules != null) + var ipOptions = options?.Value; + var ipPolicyRules = policies?.Value; + + //save IP rules defined in appsettings in cache on startup + if (ipOptions != null && ipPolicyRules != null) { - Set($"{options.Value.IpPolicyPrefix}", policies.Value); + Set($"{ipOptions.IpPolicyPrefix}", ipPolicyRules); } } From 63c12eb40dfbd1de80d11c9efdf1d61058d19db0 Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sat, 16 Feb 2019 22:20:36 +0200 Subject: [PATCH 02/13] Go async with the client/ip/counter stores --- .../AspNetCoreRateLimit.csproj | 1 + .../Core/ClientRateLimitProcessor.cs | 8 +- .../Core/IRateLimitProcessor.cs | 10 +- .../Core/IpRateLimitProcessor.cs | 8 +- .../Core/RateLimitProcessor.cs | 170 ++++++++++-------- .../Middleware/RateLimitMiddleware.cs | 35 ++-- .../DistributedCacheClientPolicyStore.cs | 61 ++----- .../Store/DistributedCacheIpPolicyStore.cs | 55 ++---- .../DistributedCacheRateLimitCounterStore.cs | 40 +---- .../Store/DistributedCacheRateLimitStore.cs | 50 ++++++ .../Store/IClientPolicyStore.cs | 7 +- .../Store/IIpPolicyStore.cs | 7 +- src/AspNetCoreRateLimit/Store/IPolicyStore.cs | 10 -- .../Store/IRateLimitCounterStore.cs | 10 +- .../Store/IRateLimitStore.cs | 14 ++ .../Store/MemoryCacheClientPolicyStore.cs | 56 ++---- .../Store/MemoryCacheIpPolicyStore.cs | 51 ++---- .../Store/MemoryCacheRateLimitCounterStore.cs | 33 +--- .../Store/MemoryCacheRateLimitStore.cs | 46 +++++ .../AspNetCoreRateLimit.Demo.csproj | 1 + .../Controllers/ClientRateLimitController.cs | 13 +- .../Controllers/IpRateLimitController.cs | 11 +- test/AspNetCoreRateLimit.Demo/Program.cs | 23 ++- .../AspNetCoreRateLimit.Tests.csproj | 1 + 24 files changed, 355 insertions(+), 366 deletions(-) create mode 100644 src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitStore.cs delete mode 100644 src/AspNetCoreRateLimit/Store/IPolicyStore.cs create mode 100644 src/AspNetCoreRateLimit/Store/IRateLimitStore.cs create mode 100644 src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitStore.cs diff --git a/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj b/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj index a32c5011..d1aa1b36 100644 --- a/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj +++ b/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj @@ -12,6 +12,7 @@ http://opensource.org/licenses/MIT git https://github.com/stefanprodan/AspNetCoreRateLimit + 7.1 diff --git a/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs index e7209751..8fb17a49 100644 --- a/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs @@ -1,12 +1,14 @@ using System.Collections.Generic; using System.Linq; +using System.Threading; +using System.Threading.Tasks; namespace AspNetCoreRateLimit { public class ClientRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor { private readonly ClientRateLimitOptions _options; - private readonly IPolicyStore _policyStore; + private readonly IRateLimitStore _policyStore; public ClientRateLimitProcessor( ClientRateLimitOptions options, @@ -18,9 +20,9 @@ public ClientRateLimitProcessor( _policyStore = policyStore; } - public IEnumerable GetMatchingRules(ClientRequestIdentity identity) + public async Task> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default) { - var policy = _policyStore.Get($"{_options.ClientPolicyPrefix}_{identity.ClientId}"); + var policy = await _policyStore.GetAsync($"{_options.ClientPolicyPrefix}_{identity.ClientId}", cancellationToken); if (policy != null) { diff --git a/src/AspNetCoreRateLimit/Core/IRateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/IRateLimitProcessor.cs index 83d90754..788a80b0 100644 --- a/src/AspNetCoreRateLimit/Core/IRateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/IRateLimitProcessor.cs @@ -1,15 +1,17 @@ using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; namespace AspNetCoreRateLimit { public interface IRateLimitProcessor { - IEnumerable GetMatchingRules(ClientRequestIdentity identity); + Task> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default); - RateLimitHeaders GetRateLimitHeaders(ClientRequestIdentity requestIdentity, RateLimitRule rule); + Task GetRateLimitHeadersAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default); - bool IsWhitelisted(ClientRequestIdentity requestIdentity); + Task ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default); - RateLimitCounter ProcessRequest(ClientRequestIdentity requestIdentity, RateLimitRule rule); + bool IsWhitelisted(ClientRequestIdentity requestIdentity); } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs index c7256cc8..a6522e12 100644 --- a/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs @@ -1,12 +1,14 @@ using System.Collections.Generic; using System.Linq; +using System.Threading; +using System.Threading.Tasks; namespace AspNetCoreRateLimit { public class IpRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor { private readonly IpRateLimitOptions _options; - private readonly IPolicyStore _policyStore; + private readonly IRateLimitStore _policyStore; public IpRateLimitProcessor(IpRateLimitOptions options, IRateLimitCounterStore counterStore, @@ -17,10 +19,10 @@ public IpRateLimitProcessor(IpRateLimitOptions options, _policyStore = policyStore; } - public IEnumerable GetMatchingRules(ClientRequestIdentity identity) + public async Task> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default) { var limits = new List(); - var policies = _policyStore.Get($"{_options.IpPolicyPrefix}"); + var policies = await _policyStore.GetAsync($"{_options.IpPolicyPrefix}", cancellationToken); if (policies != null && policies.IpRules != null && policies.IpRules.Any()) { diff --git a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs index 417bb741..9a2c88ba 100644 --- a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs @@ -2,6 +2,8 @@ using System.Collections.Generic; using System.Globalization; using System.Linq; +using System.Threading; +using System.Threading.Tasks; namespace AspNetCoreRateLimit { @@ -18,7 +20,96 @@ protected RateLimitProcessor( _counterStore = counterStore; } - private static readonly object _processLocker = new object(); + private static readonly SemaphoreSlim Semaphore = new SemaphoreSlim(1); + + public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity) + { + if (_options.ClientWhitelist != null && _options.ClientWhitelist.Contains(requestIdentity.ClientId)) + { + return true; + } + + if (_options.EndpointWhitelist != null && _options.EndpointWhitelist.Any()) + { + if (_options.EndpointWhitelist.Any(x => $"{requestIdentity.HttpVerb}:{requestIdentity.Path}".ContainsIgnoreCase(x)) || + _options.EndpointWhitelist.Any(x => $"*:{requestIdentity.Path}".ContainsIgnoreCase(x))) + return true; + } + + return false; + } + + public async Task ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default) + { + var counter = new RateLimitCounter + { + Timestamp = DateTime.UtcNow, + TotalRequests = 1 + }; + + var counterId = ComputeCounterKey(requestIdentity, rule); + + // serial reads and writes + await Semaphore.WaitAsync(cancellationToken); + + try + { + var entry = await _counterStore.GetAsync(counterId, cancellationToken); + + if (entry.HasValue) + { + // entry has not expired + if (entry.Value.Timestamp + rule.PeriodTimespan.Value >= DateTime.UtcNow) + { + // increment request count + var totalRequests = entry.Value.TotalRequests + 1; + + // deep copy + counter = new RateLimitCounter + { + Timestamp = entry.Value.Timestamp, + TotalRequests = totalRequests + }; + } + } + + // stores: id (string) - timestamp (datetime) - total_requests (long) + await _counterStore.SetAsync(counterId, counter, rule.PeriodTimespan.Value, cancellationToken); + } + finally + { + Semaphore.Release(); + } + + return counter; + } + + public async Task GetRateLimitHeadersAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default) + { + var headers = new RateLimitHeaders(); + var counterId = ComputeCounterKey(requestIdentity, rule); + var entry = await _counterStore.GetAsync(counterId, cancellationToken); + + long remaining; + DateTime reset; + + if (entry.HasValue) + { + reset = entry.Value.Timestamp + rule.Period.ToTimeSpan(); + remaining = rule.Limit - entry.Value.TotalRequests; + } + else + { + reset = DateTime.UtcNow + rule.Period.ToTimeSpan(); + remaining = rule.Limit; + } + + headers.Reset = reset.ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo); + headers.Limit = rule.Period; + headers.Remaining = remaining.ToString(); + + return headers; + } protected abstract string GetCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule); @@ -119,82 +210,5 @@ protected List GetMatchingRules(ClientRequestIdentity identity, L return limits; } - - public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity) - { - if (_options.ClientWhitelist != null && _options.ClientWhitelist.Contains(requestIdentity.ClientId)) - { - return true; - } - - if (_options.EndpointWhitelist != null && _options.EndpointWhitelist.Any()) - { - if (_options.EndpointWhitelist.Any(x => $"{requestIdentity.HttpVerb}:{requestIdentity.Path}".ContainsIgnoreCase(x)) || - _options.EndpointWhitelist.Any(x => $"*:{requestIdentity.Path}".ContainsIgnoreCase(x))) - return true; - } - - return false; - } - - public RateLimitCounter ProcessRequest(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - var counter = new RateLimitCounter - { - Timestamp = DateTime.UtcNow, - TotalRequests = 1 - }; - - var counterId = ComputeCounterKey(requestIdentity, rule); - - // serial reads and writes - lock (_processLocker) - { - var entry = _counterStore.Get(counterId); - - if (entry.HasValue) - { - // entry has not expired - if (entry.Value.Timestamp + rule.PeriodTimespan.Value >= DateTime.UtcNow) - { - // increment request count - var totalRequests = entry.Value.TotalRequests + 1; - - // deep copy - counter = new RateLimitCounter - { - Timestamp = entry.Value.Timestamp, - TotalRequests = totalRequests - }; - } - } - - // stores: id (string) - timestamp (datetime) - total_requests (long) - _counterStore.Set(counterId, counter, rule.PeriodTimespan.Value); - } - - return counter; - } - - public RateLimitHeaders GetRateLimitHeaders(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - var headers = new RateLimitHeaders(); - var counterId = ComputeCounterKey(requestIdentity, rule); - var entry = _counterStore.Get(counterId); - if (entry.HasValue) - { - headers.Reset = (entry.Value.Timestamp + rule.Period.ToTimeSpan()).ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo); - headers.Limit = rule.Period; - headers.Remaining = (rule.Limit - entry.Value.TotalRequests).ToString(); - } - else - { - headers.Reset = (DateTime.UtcNow + rule.Period.ToTimeSpan()).ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo); - headers.Limit = rule.Period; - headers.Remaining = rule.Limit.ToString(); - } - - return headers; - } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs b/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs index 1c551418..0f7b5a9a 100644 --- a/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs +++ b/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs @@ -22,33 +22,33 @@ protected RateLimitMiddleware( _processor = processor; } - public async Task Invoke(HttpContext httpContext) + public async Task Invoke(HttpContext context) { // check if rate limiting is enabled if (_options == null) { - await _next.Invoke(httpContext); + await _next.Invoke(context); return; } // compute identity from request - var identity = SetIdentity(httpContext); + var identity = SetIdentity(context); // check white list if (_processor.IsWhitelisted(identity)) { - await _next.Invoke(httpContext); + await _next.Invoke(context); return; } - var rules = _processor.GetMatchingRules(identity); + var rules = await _processor.GetMatchingRulesAsync(identity, context.RequestAborted); foreach (var rule in rules) { if (rule.Limit > 0) { // increment counter - var counter = _processor.ProcessRequest(identity, rule); + var counter = await _processor.ProcessRequestAsync(identity, rule, context.RequestAborted); // check if key expired if (counter.Timestamp + rule.PeriodTimespan.Value < DateTime.UtcNow) @@ -63,10 +63,11 @@ public async Task Invoke(HttpContext httpContext) var retryAfter = counter.Timestamp.RetryAfterFrom(rule); // log blocked request - LogBlockedRequest(httpContext, identity, counter, rule); + LogBlockedRequest(context, identity, counter, rule); // break execution - await ReturnQuotaExceededResponse(httpContext, rule, retryAfter); + await ReturnQuotaExceededResponse(context, rule, retryAfter); + return; } } @@ -74,28 +75,30 @@ public async Task Invoke(HttpContext httpContext) else { // process request count - var counter = _processor.ProcessRequest(identity, rule); + var counter = await _processor.ProcessRequestAsync(identity, rule, context.RequestAborted); // log blocked request - LogBlockedRequest(httpContext, identity, counter, rule); + LogBlockedRequest(context, identity, counter, rule); // break execution (Int32 max used to represent infinity) - await ReturnQuotaExceededResponse(httpContext, rule, int.MaxValue.ToString(System.Globalization.CultureInfo.InvariantCulture)); + await ReturnQuotaExceededResponse(context, rule, int.MaxValue.ToString(System.Globalization.CultureInfo.InvariantCulture)); + return; } } - //set X-Rate-Limit headers for the longest period + // set X-Rate-Limit headers for the longest period if (rules.Any() && !_options.DisableRateLimitHeaders) { var rule = rules.OrderByDescending(x => x.PeriodTimespan.Value).First(); - var headers = _processor.GetRateLimitHeaders(identity, rule); - headers.Context = httpContext; + var headers = await _processor.GetRateLimitHeadersAsync(identity, rule, context.RequestAborted); + + headers.Context = context; - httpContext.Response.OnStarting(SetRateLimitHeaders, state: headers); + context.Response.OnStarting(SetRateLimitHeaders, state: headers); } - await _next.Invoke(httpContext); + await _next.Invoke(context); } public virtual ClientRequestIdentity SetIdentity(HttpContext httpContext) diff --git a/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs b/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs index c295e768..a5a21074 100644 --- a/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs @@ -1,60 +1,33 @@ -using Microsoft.Extensions.Caching.Distributed; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Options; -using Newtonsoft.Json; namespace AspNetCoreRateLimit { - public class DistributedCacheClientPolicyStore : IClientPolicyStore + public class DistributedCacheClientPolicyStore : DistributedCacheRateLimitStore, IClientPolicyStore { - private readonly IDistributedCache _memoryCache; + private readonly ClientRateLimitOptions _options; + private readonly ClientRateLimitPolicies _policies; public DistributedCacheClientPolicyStore( - IDistributedCache memoryCache, - IOptions options = null, - IOptions policies = null) + IDistributedCache cache, + IOptions options = null, + IOptions policies = null) : base(cache) { - _memoryCache = memoryCache; - - var clientOptions = options?.Value; - var clientPolicyRules = policies?.Value?.ClientRules; - - //save client rules defined in appsettings in cache on startup - if (clientOptions != null && clientPolicyRules != null) - { - foreach (var rule in clientPolicyRules) - { - Set($"{clientOptions.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }); - } - } - } - - public void Set(string id, ClientRateLimitPolicy policy) - { - _memoryCache.SetString(id, JsonConvert.SerializeObject(policy)); - } - - public bool Exists(string id) - { - var stored = _memoryCache.GetString(id); - - return !string.IsNullOrEmpty(stored); + _options = options?.Value; + _policies = policies?.Value; } - public ClientRateLimitPolicy Get(string id) + public async Task SeedAsync() { - var stored = _memoryCache.GetString(id); - - if (!string.IsNullOrEmpty(stored)) + // on startup, save the IP rules defined in appsettings + if (_options != null && _policies?.ClientRules != null) { - return JsonConvert.DeserializeObject(stored); + foreach (var rule in _policies.ClientRules) + { + await SetAsync($"{_options.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }); + } } - - return null; - } - - public void Remove(string id) - { - _memoryCache.Remove(id); } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs b/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs index 4ee20315..51094fe7 100644 --- a/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs @@ -1,55 +1,30 @@ -using Microsoft.Extensions.Caching.Distributed; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Options; -using Newtonsoft.Json; namespace AspNetCoreRateLimit { - public class DistributedCacheIpPolicyStore : IIpPolicyStore + public class DistributedCacheIpPolicyStore : DistributedCacheRateLimitStore, IIpPolicyStore { - private readonly IDistributedCache _memoryCache; + private readonly IpRateLimitOptions _options; + private readonly IpRateLimitPolicies _policies; public DistributedCacheIpPolicyStore( - IDistributedCache memoryCache, - IOptions options = null, - IOptions policies = null) + IDistributedCache cache, + IOptions options = null, + IOptions policies = null) : base(cache) { - _memoryCache = memoryCache; - - var ipOptions = options?.Value; - var ipPolicyRules = policies?.Value; - - //save IP rules defined in appsettings in cache on startup - if (ipOptions != null && ipPolicyRules != null) - { - Set($"{ipOptions.IpPolicyPrefix}", ipPolicyRules); - - } - } - - public void Set(string id, IpRateLimitPolicies policy) - { - _memoryCache.SetString(id, JsonConvert.SerializeObject(policy)); - } - - public bool Exists(string id) - { - var stored = _memoryCache.GetString(id); - return !string.IsNullOrEmpty(stored); + _options = options?.Value; + _policies = policies?.Value; } - public IpRateLimitPolicies Get(string id) + public async Task SeedAsync() { - var stored = _memoryCache.GetString(id); - if (!string.IsNullOrEmpty(stored)) + // on startup, save the IP rules defined in appsettings + if (_options != null && _policies != null) { - return JsonConvert.DeserializeObject(stored); + await SetAsync($"{_options.IpPolicyPrefix}", _policies); } - return null; - } - - public void Remove(string id) - { - _memoryCache.Remove(id); } } -} +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitCounterStore.cs b/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitCounterStore.cs index 7d164b5d..abe82410 100644 --- a/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitCounterStore.cs +++ b/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitCounterStore.cs @@ -1,45 +1,11 @@ using Microsoft.Extensions.Caching.Distributed; -using Newtonsoft.Json; -using System; namespace AspNetCoreRateLimit { - public class DistributedCacheRateLimitCounterStore : IRateLimitCounterStore + public class DistributedCacheRateLimitCounterStore : DistributedCacheRateLimitStore, IRateLimitCounterStore { - private readonly IDistributedCache _memoryCache; - - public DistributedCacheRateLimitCounterStore(IDistributedCache memoryCache) - { - _memoryCache = memoryCache; - } - - public void Set(string id, RateLimitCounter counter, TimeSpan expirationTime) - { - _memoryCache.SetString(id, JsonConvert.SerializeObject(counter), new DistributedCacheEntryOptions().SetAbsoluteExpiration(expirationTime)); - } - - public bool Exists(string id) - { - var stored = _memoryCache.GetString(id); - - return !string.IsNullOrEmpty(stored); - } - - public RateLimitCounter? Get(string id) - { - var stored = _memoryCache.GetString(id); - - if(!string.IsNullOrEmpty(stored)) - { - return JsonConvert.DeserializeObject(stored); - } - - return null; - } - - public void Remove(string id) + public DistributedCacheRateLimitCounterStore(IDistributedCache cache) : base(cache) { - _memoryCache.Remove(id); } } -} +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitStore.cs b/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitStore.cs new file mode 100644 index 00000000..6ea6683c --- /dev/null +++ b/src/AspNetCoreRateLimit/Store/DistributedCacheRateLimitStore.cs @@ -0,0 +1,50 @@ +using Microsoft.Extensions.Caching.Distributed; +using Newtonsoft.Json; +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace AspNetCoreRateLimit +{ + public class DistributedCacheRateLimitStore : IRateLimitStore + { + private readonly IDistributedCache _cache; + + public DistributedCacheRateLimitStore(IDistributedCache cache) + { + _cache = cache; + } + + public Task SetAsync(string id, T entry, TimeSpan? expirationTime = null, CancellationToken cancellationToken = default) + { + return _cache.SetStringAsync(id, + JsonConvert.SerializeObject(entry), + expirationTime.HasValue ? new DistributedCacheEntryOptions().SetAbsoluteExpiration(expirationTime.Value) : null, + cancellationToken); + } + + public async Task ExistsAsync(string id, CancellationToken cancellationToken = default) + { + var stored = await _cache.GetStringAsync(id, cancellationToken); + + return !string.IsNullOrEmpty(stored); + } + + public async Task GetAsync(string id, CancellationToken cancellationToken = default) + { + var stored = await _cache.GetStringAsync(id, cancellationToken); + + if (!string.IsNullOrEmpty(stored)) + { + return JsonConvert.DeserializeObject(stored); + } + + return default; + } + + public Task RemoveAsync(string id, CancellationToken cancellationToken = default) + { + return _cache.RemoveAsync(id, cancellationToken); + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/IClientPolicyStore.cs b/src/AspNetCoreRateLimit/Store/IClientPolicyStore.cs index f3a998eb..b6f32775 100644 --- a/src/AspNetCoreRateLimit/Store/IClientPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/IClientPolicyStore.cs @@ -1,6 +1,9 @@ -namespace AspNetCoreRateLimit +using System.Threading.Tasks; + +namespace AspNetCoreRateLimit { - public interface IClientPolicyStore : IPolicyStore + public interface IClientPolicyStore : IRateLimitStore { + Task SeedAsync(); } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/IIpPolicyStore.cs b/src/AspNetCoreRateLimit/Store/IIpPolicyStore.cs index ab6d7c05..6a2b8fd0 100644 --- a/src/AspNetCoreRateLimit/Store/IIpPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/IIpPolicyStore.cs @@ -1,6 +1,9 @@ -namespace AspNetCoreRateLimit +using System.Threading.Tasks; + +namespace AspNetCoreRateLimit { - public interface IIpPolicyStore : IPolicyStore + public interface IIpPolicyStore : IRateLimitStore { + Task SeedAsync(); } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/IPolicyStore.cs b/src/AspNetCoreRateLimit/Store/IPolicyStore.cs deleted file mode 100644 index bcbe42f4..00000000 --- a/src/AspNetCoreRateLimit/Store/IPolicyStore.cs +++ /dev/null @@ -1,10 +0,0 @@ -namespace AspNetCoreRateLimit -{ - public interface IPolicyStore - { - bool Exists(string id); - TPolicy Get(string id); - void Remove(string id); - void Set(string id, TPolicy policy); - } -} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/IRateLimitCounterStore.cs b/src/AspNetCoreRateLimit/Store/IRateLimitCounterStore.cs index d9c9aad1..cd55ee6f 100644 --- a/src/AspNetCoreRateLimit/Store/IRateLimitCounterStore.cs +++ b/src/AspNetCoreRateLimit/Store/IRateLimitCounterStore.cs @@ -1,12 +1,6 @@ -using System; - -namespace AspNetCoreRateLimit +namespace AspNetCoreRateLimit { - public interface IRateLimitCounterStore + public interface IRateLimitCounterStore : IRateLimitStore { - bool Exists(string id); - RateLimitCounter? Get(string id); - void Remove(string id); - void Set(string id, RateLimitCounter counter, TimeSpan expirationTime); } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/IRateLimitStore.cs b/src/AspNetCoreRateLimit/Store/IRateLimitStore.cs new file mode 100644 index 00000000..9c02fe32 --- /dev/null +++ b/src/AspNetCoreRateLimit/Store/IRateLimitStore.cs @@ -0,0 +1,14 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace AspNetCoreRateLimit +{ + public interface IRateLimitStore + { + Task ExistsAsync(string id, CancellationToken cancellationToken = default); + Task GetAsync(string id, CancellationToken cancellationToken = default); + Task RemoveAsync(string id, CancellationToken cancellationToken = default); + Task SetAsync(string id, T entry, TimeSpan? expirationTime = null, CancellationToken cancellationToken = default); + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs b/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs index ecaea11c..da20685d 100644 --- a/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs @@ -1,55 +1,33 @@ -using Microsoft.Extensions.Caching.Memory; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; namespace AspNetCoreRateLimit { - public class MemoryCacheClientPolicyStore : IClientPolicyStore + public class MemoryCacheClientPolicyStore : MemoryCacheRateLimitStore, IClientPolicyStore { - private readonly IMemoryCache _memoryCache; + private readonly ClientRateLimitOptions _options; + private readonly ClientRateLimitPolicies _policies; public MemoryCacheClientPolicyStore( - IMemoryCache memoryCache, - IOptions options = null, - IOptions policies = null) + IMemoryCache cache, + IOptions options = null, + IOptions policies = null) : base(cache) { - _memoryCache = memoryCache; - - var clientOptions = options?.Value; - var clientPolicyRules = policies?.Value?.ClientRules; - - //save client rules defined in appsettings in cache on startup - if (clientOptions != null && clientPolicyRules != null) - { - foreach (var rule in clientPolicyRules) - { - Set($"{clientOptions.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }); - } - } - } - - public void Set(string id, ClientRateLimitPolicy policy) - { - _memoryCache.Set(id, policy); + _options = options?.Value; + _policies = policies?.Value; } - public bool Exists(string id) + public async Task SeedAsync() { - return _memoryCache.TryGetValue(id, out _); - } - - public ClientRateLimitPolicy Get(string id) - { - if (_memoryCache.TryGetValue(id, out ClientRateLimitPolicy stored)) + // on startup, save the IP rules defined in appsettings + if (_options != null && _policies?.ClientRules != null) { - return stored; + foreach (var rule in _policies.ClientRules) + { + await SetAsync($"{_options.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }); + } } - - return null; - } - - public void Remove(string id) - { - _memoryCache.Remove(id); } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs b/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs index aa18264f..da12bafe 100644 --- a/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs @@ -1,53 +1,30 @@ -using Microsoft.Extensions.Caching.Memory; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; namespace AspNetCoreRateLimit { - public class MemoryCacheIpPolicyStore : IIpPolicyStore + public class MemoryCacheIpPolicyStore : MemoryCacheRateLimitStore, IIpPolicyStore { - private readonly IMemoryCache _memoryCache; + private readonly IpRateLimitOptions _options; + private readonly IpRateLimitPolicies _policies; public MemoryCacheIpPolicyStore( - IMemoryCache memoryCache, - IOptions options = null, - IOptions policies = null) + IMemoryCache cache, + IOptions options = null, + IOptions policies = null) : base(cache) { - _memoryCache = memoryCache; - - var ipOptions = options?.Value; - var ipPolicyRules = policies?.Value; - - //save IP rules defined in appsettings in cache on startup - if (ipOptions != null && ipPolicyRules != null) - { - Set($"{ipOptions.IpPolicyPrefix}", ipPolicyRules); - - } - } - - public void Set(string id, IpRateLimitPolicies policy) - { - _memoryCache.Set(id, policy); + _options = options?.Value; + _policies = policies?.Value; } - public bool Exists(string id) + public async Task SeedAsync() { - return _memoryCache.TryGetValue(id, out _); - } - - public IpRateLimitPolicies Get(string id) - { - if (_memoryCache.TryGetValue(id, out IpRateLimitPolicies stored)) + // on startup, save the IP rules defined in appsettings + if (_options != null && _policies != null) { - return stored; + await SetAsync($"{_options.IpPolicyPrefix}", _policies); } - - return null; - } - - public void Remove(string id) - { - _memoryCache.Remove(id); } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitCounterStore.cs b/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitCounterStore.cs index 00483be9..8f8faa63 100644 --- a/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitCounterStore.cs +++ b/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitCounterStore.cs @@ -1,40 +1,11 @@ using Microsoft.Extensions.Caching.Memory; -using System; namespace AspNetCoreRateLimit { - public class MemoryCacheRateLimitCounterStore: IRateLimitCounterStore + public class MemoryCacheRateLimitCounterStore : MemoryCacheRateLimitStore, IRateLimitCounterStore { - private readonly IMemoryCache _memoryCache; - - public MemoryCacheRateLimitCounterStore(IMemoryCache memoryCache) - { - _memoryCache = memoryCache; - } - - public void Set(string id, RateLimitCounter counter, TimeSpan expirationTime) - { - _memoryCache.Set(id, counter, new MemoryCacheEntryOptions().SetAbsoluteExpiration(expirationTime)); - } - - public bool Exists(string id) - { - return _memoryCache.TryGetValue(id, out _); - } - - public RateLimitCounter? Get(string id) - { - if (_memoryCache.TryGetValue(id, out RateLimitCounter stored)) - { - return stored; - } - - return null; - } - - public void Remove(string id) + public MemoryCacheRateLimitCounterStore(IMemoryCache cache) : base(cache) { - _memoryCache.Remove(id); } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitStore.cs b/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitStore.cs new file mode 100644 index 00000000..3b812f15 --- /dev/null +++ b/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitStore.cs @@ -0,0 +1,46 @@ +using Microsoft.Extensions.Caching.Memory; +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace AspNetCoreRateLimit +{ + public class MemoryCacheRateLimitStore : IRateLimitStore + { + private readonly IMemoryCache _cache; + + public MemoryCacheRateLimitStore(IMemoryCache cache) + { + _cache = cache; + } + + public Task ExistsAsync(string id, CancellationToken cancellationToken = default) + { + return Task.FromResult(_cache.TryGetValue(id, out _)); + } + + public Task GetAsync(string id, CancellationToken cancellationToken = default) + { + if (_cache.TryGetValue(id, out T stored)) + { + return Task.FromResult(stored); + } + + return Task.FromResult(default(T)); + } + + public Task RemoveAsync(string id, CancellationToken cancellationToken = default) + { + _cache.Remove(id); + + return Task.CompletedTask; + } + + public Task SetAsync(string id, T entry, TimeSpan? expirationTime = null, CancellationToken cancellationToken = default) + { + _cache.Set(id, entry, expirationTime.HasValue ? new MemoryCacheEntryOptions().SetAbsoluteExpiration(expirationTime.Value) : null); + + return Task.CompletedTask; + } + } +} \ No newline at end of file diff --git a/test/AspNetCoreRateLimit.Demo/AspNetCoreRateLimit.Demo.csproj b/test/AspNetCoreRateLimit.Demo/AspNetCoreRateLimit.Demo.csproj index 1ca804a9..b445d1d4 100644 --- a/test/AspNetCoreRateLimit.Demo/AspNetCoreRateLimit.Demo.csproj +++ b/test/AspNetCoreRateLimit.Demo/AspNetCoreRateLimit.Demo.csproj @@ -4,6 +4,7 @@ netcoreapp3.0 true true + 7.1 diff --git a/test/AspNetCoreRateLimit.Demo/Controllers/ClientRateLimitController.cs b/test/AspNetCoreRateLimit.Demo/Controllers/ClientRateLimitController.cs index 8b3b3504..9fc13d31 100644 --- a/test/AspNetCoreRateLimit.Demo/Controllers/ClientRateLimitController.cs +++ b/test/AspNetCoreRateLimit.Demo/Controllers/ClientRateLimitController.cs @@ -1,5 +1,6 @@ using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; +using System.Threading.Tasks; namespace AspNetCoreRateLimit.Demo.Controllers { @@ -16,23 +17,25 @@ public ClientRateLimitController(IOptions optionsAccesso } [HttpGet] - public ClientRateLimitPolicy Get() + public async Task Get() { - return _clientPolicyStore.Get($"{_options.ClientPolicyPrefix}_cl-key-1"); + return await _clientPolicyStore.GetAsync($"{_options.ClientPolicyPrefix}_cl-key-1", HttpContext.RequestAborted); } [HttpPost] - public void Post() + public async Task Post() { var id = $"{_options.ClientPolicyPrefix}_cl-key-1"; - var policy = _clientPolicyStore.Get(id); + var policy = await _clientPolicyStore.GetAsync(id, HttpContext.RequestAborted); + policy.Rules.Add(new RateLimitRule { Endpoint = "*/api/testpolicyupdate", Period = "1h", Limit = 100 }); - _clientPolicyStore.Set(id, policy); + + await _clientPolicyStore.SetAsync(id, policy, cancellationToken: HttpContext.RequestAborted); } } } \ No newline at end of file diff --git a/test/AspNetCoreRateLimit.Demo/Controllers/IpRateLimitController.cs b/test/AspNetCoreRateLimit.Demo/Controllers/IpRateLimitController.cs index 2f8e3fbc..878e5eef 100644 --- a/test/AspNetCoreRateLimit.Demo/Controllers/IpRateLimitController.cs +++ b/test/AspNetCoreRateLimit.Demo/Controllers/IpRateLimitController.cs @@ -1,6 +1,7 @@ using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; using System.Collections.Generic; +using System.Threading.Tasks; namespace AspNetCoreRateLimit.Demo.Controllers { @@ -17,15 +18,15 @@ public IpRateLimitController(IOptions optionsAccessor, IIpPo } [HttpGet] - public IpRateLimitPolicies Get() + public async Task Get() { - return _ipPolicyStore.Get(_options.IpPolicyPrefix); + return await _ipPolicyStore.GetAsync(_options.IpPolicyPrefix, HttpContext.RequestAborted); } [HttpPost] - public void Post() + public async Task Post() { - var policy = _ipPolicyStore.Get(_options.IpPolicyPrefix); + var policy = await _ipPolicyStore.GetAsync(_options.IpPolicyPrefix, HttpContext.RequestAborted); policy.IpRules.Add(new IpRateLimitPolicy { @@ -38,7 +39,7 @@ public void Post() }) }); - _ipPolicyStore.Set(_options.IpPolicyPrefix, policy); + await _ipPolicyStore.SetAsync(_options.IpPolicyPrefix, policy, cancellationToken: HttpContext.RequestAborted); } } } \ No newline at end of file diff --git a/test/AspNetCoreRateLimit.Demo/Program.cs b/test/AspNetCoreRateLimit.Demo/Program.cs index 90b4fff2..edfa6d7b 100644 --- a/test/AspNetCoreRateLimit.Demo/Program.cs +++ b/test/AspNetCoreRateLimit.Demo/Program.cs @@ -1,15 +1,34 @@ using Microsoft.AspNetCore; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using System.Threading.Tasks; namespace AspNetCoreRateLimit.Demo { public class Program { - public static void Main(string[] args) + public static async Task Main(string[] args) { - CreateWebHostBuilder(args).Build().Run(); + IWebHost webHost = CreateWebHostBuilder(args).Build(); + + using (var scope = webHost.Services.CreateScope()) + { + // get the ClientPolicyStore instance + var clientPolicyStore = scope.ServiceProvider.GetRequiredService(); + + // seed Client data from appsettings + await clientPolicyStore.SeedAsync(); + + // get the IpPolicyStore instance + var ipPolicyStore = scope.ServiceProvider.GetRequiredService(); + + // seed IP data from appsettings + await ipPolicyStore.SeedAsync(); + } + + await webHost.RunAsync(); } public static IWebHostBuilder CreateWebHostBuilder(string[] args) => diff --git a/test/AspNetCoreRateLimit.Tests/AspNetCoreRateLimit.Tests.csproj b/test/AspNetCoreRateLimit.Tests/AspNetCoreRateLimit.Tests.csproj index 1ff4436e..d590625b 100644 --- a/test/AspNetCoreRateLimit.Tests/AspNetCoreRateLimit.Tests.csproj +++ b/test/AspNetCoreRateLimit.Tests/AspNetCoreRateLimit.Tests.csproj @@ -5,6 +5,7 @@ netcoreapp3.0 true true + 7.1 From 09c53042b3d3396296c764730b1d7ff52f31e429 Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sat, 16 Feb 2019 22:49:52 +0200 Subject: [PATCH 03/13] Fix wildcard match #29 #41 --- src/AspNetCoreRateLimit/Core/Extensions.cs | 6 ++++-- .../Core/RateLimitProcessor.cs | 16 ++++++++-------- src/AspNetCoreRateLimit/Core/WildcardMatcher.cs | 2 +- src/AspNetCoreRateLimit/Net/IpAddressUtil.cs | 7 +++++++ 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/AspNetCoreRateLimit/Core/Extensions.cs b/src/AspNetCoreRateLimit/Core/Extensions.cs index 3ee1d46c..afe6e4b1 100644 --- a/src/AspNetCoreRateLimit/Core/Extensions.cs +++ b/src/AspNetCoreRateLimit/Core/Extensions.cs @@ -5,16 +5,18 @@ namespace AspNetCoreRateLimit { public static class Extensions { - public static bool ContainsIgnoreCase(this string source, string value, StringComparison stringComparison = StringComparison.CurrentCultureIgnoreCase) + public static bool IsWildcardMatch(this string source, string value) { - return source != null && value != null && source.IndexOf(value, stringComparison) >= 0; + return source != null && value != null && source.ToLowerInvariant().IsMatch(value.ToLowerInvariant()); } public static string RetryAfterFrom(this DateTime timestamp, RateLimitRule rule) { var secondsPast = Convert.ToInt32((DateTime.UtcNow - timestamp).TotalSeconds); var retryAfter = Convert.ToInt32(rule.PeriodTimespan.Value.TotalSeconds); + retryAfter = retryAfter > 1 ? retryAfter - secondsPast : 1; + return retryAfter.ToString(CultureInfo.InvariantCulture); } diff --git a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs index 9a2c88ba..70117fda 100644 --- a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs @@ -31,8 +31,8 @@ public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity) if (_options.EndpointWhitelist != null && _options.EndpointWhitelist.Any()) { - if (_options.EndpointWhitelist.Any(x => $"{requestIdentity.HttpVerb}:{requestIdentity.Path}".ContainsIgnoreCase(x)) || - _options.EndpointWhitelist.Any(x => $"*:{requestIdentity.Path}".ContainsIgnoreCase(x))) + if (_options.EndpointWhitelist.Any(x => $"{requestIdentity.HttpVerb}:{requestIdentity.Path}".IsWildcardMatch(x)) || + _options.EndpointWhitelist.Any(x => $"*:{requestIdentity.Path}".IsWildcardMatch(x))) return true; } @@ -144,17 +144,17 @@ protected List GetMatchingRules(ClientRequestIdentity identity, L if (_options.EnableEndpointRateLimiting) { // search for rules with endpoints like "*" and "*:/matching_path" - var pathLimits = rules.Where(l => $"*:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); + var pathLimits = rules.Where(r => $"*:{identity.Path}".IsWildcardMatch(r.Endpoint)).AsEnumerable(); limits.AddRange(pathLimits); // search for rules with endpoints like "matching_verb:/matching_path" - var verbLimits = rules.Where(l => $"{identity.HttpVerb}:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); + var verbLimits = rules.Where(r => $"{identity.HttpVerb}:{identity.Path}".IsWildcardMatch(r.Endpoint)).AsEnumerable(); limits.AddRange(verbLimits); } else { //ignore endpoint rules and search for global rules only - var genericLimits = rules.Where(l => l.Endpoint == "*").AsEnumerable(); + var genericLimits = rules.Where(r => r.Endpoint == "*").AsEnumerable(); limits.AddRange(genericLimits); } @@ -168,17 +168,17 @@ protected List GetMatchingRules(ClientRequestIdentity identity, L if (_options.EnableEndpointRateLimiting) { // search for rules with endpoints like "*" and "*:/matching_path" in general rules - var pathLimits = _options.GeneralRules.Where(l => $"*:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); + var pathLimits = _options.GeneralRules.Where(r => $"*:{identity.Path}".IsWildcardMatch(r.Endpoint)).AsEnumerable(); matchingGeneralLimits.AddRange(pathLimits); // search for rules with endpoints like "matching_verb:/matching_path" in general rules - var verbLimits = _options.GeneralRules.Where(l => $"{identity.HttpVerb}:{identity.Path}".ContainsIgnoreCase(l.Endpoint)).AsEnumerable(); + var verbLimits = _options.GeneralRules.Where(r => $"{identity.HttpVerb}:{identity.Path}".IsWildcardMatch(r.Endpoint)).AsEnumerable(); matchingGeneralLimits.AddRange(verbLimits); } else { //ignore endpoint rules and search for global rules in general rules - var genericLimits = _options.GeneralRules.Where(l => l.Endpoint == "*").AsEnumerable(); + var genericLimits = _options.GeneralRules.Where(r => r.Endpoint == "*").AsEnumerable(); matchingGeneralLimits.AddRange(genericLimits); } diff --git a/src/AspNetCoreRateLimit/Core/WildcardMatcher.cs b/src/AspNetCoreRateLimit/Core/WildcardMatcher.cs index 43788515..02d0e403 100644 --- a/src/AspNetCoreRateLimit/Core/WildcardMatcher.cs +++ b/src/AspNetCoreRateLimit/Core/WildcardMatcher.cs @@ -1,4 +1,4 @@ -namespace AspNetCoreRateLimit.Core +namespace AspNetCoreRateLimit { public static class WildcardMatcher { diff --git a/src/AspNetCoreRateLimit/Net/IpAddressUtil.cs b/src/AspNetCoreRateLimit/Net/IpAddressUtil.cs index d60b16b9..fa7bc85b 100644 --- a/src/AspNetCoreRateLimit/Net/IpAddressUtil.cs +++ b/src/AspNetCoreRateLimit/Net/IpAddressUtil.cs @@ -12,6 +12,7 @@ public static bool ContainsIp(string rule, string clientIp) var ip = ParseIp(clientIp); var range = new IpAddressRange(rule); + if (range.Contains(ip)) { return true; @@ -23,11 +24,13 @@ public static bool ContainsIp(string rule, string clientIp) public static bool ContainsIp(List ipRules, string clientIp) { var ip = ParseIp(clientIp); + if (ipRules != null && ipRules.Any()) { foreach (var rule in ipRules) { var range = new IpAddressRange(rule); + if (range.Contains(ip)) { return true; @@ -42,11 +45,13 @@ public static bool ContainsIp(List ipRules, string clientIp, out string { rule = null; var ip = ParseIp(clientIp); + if (ipRules != null && ipRules.Any()) { foreach (var r in ipRules) { var range = new IpAddressRange(r); + if (range.Contains(ip)) { rule = r; @@ -62,9 +67,11 @@ public static IPAddress ParseIp(string ipAddress) { //remove port number from ip address if any ipAddress = ipAddress.Split(',').First().Trim(); + var portDelimiterPos = ipAddress.LastIndexOf(":", StringComparison.CurrentCultureIgnoreCase); var ipv6WithPortStart = ipAddress.StartsWith("["); var ipv6End = ipAddress.IndexOf("]"); + if (portDelimiterPos != -1 && portDelimiterPos == ipAddress.IndexOf(":", StringComparison.CurrentCultureIgnoreCase) || ipv6WithPortStart && ipv6End != -1 && ipv6End < portDelimiterPos) From 47e2722401ff6a09c6006ff5aedcdb33cb05771d Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sun, 17 Feb 2019 11:36:23 +0200 Subject: [PATCH 04/13] Add ip/client resolvers, fix tests seed #47 #54 --- .../Core/IpRateLimitProcessor.cs | 3 +- .../Core/RateLimitProcessor.cs | 1 + .../Middleware/ClientRateLimitMiddleware.cs | 3 +- .../Middleware/IRateLimitConfiguration.cs | 13 ++++ .../Middleware/IpRateLimitMiddleware.cs | 3 +- .../Middleware/RateLimitConfiguration.cs | 38 +++++++++++ .../Middleware/RateLimitMiddleware.cs | 63 +++++++++++-------- .../ClientHeaderResolveContributor.cs | 31 +++++++++ .../Resolvers/IClientResolveContributor.cs | 7 +++ .../Resolvers/IIpResolveContributor.cs | 7 +++ .../IpConnectionResolveContributor.cs | 19 ++++++ .../Resolvers/IpHeaderResolveContributor.cs | 34 ++++++++++ test/AspNetCoreRateLimit.Demo/Program.cs | 2 + test/AspNetCoreRateLimit.Demo/Startup.cs | 13 +++- .../ClientRateLimitTests.cs | 4 +- .../IpRateLimitTests.cs | 4 +- .../RateLimitWebApplicationFactory.cs | 35 +++++++++++ 17 files changed, 244 insertions(+), 36 deletions(-) create mode 100644 src/AspNetCoreRateLimit/Middleware/IRateLimitConfiguration.cs create mode 100644 src/AspNetCoreRateLimit/Middleware/RateLimitConfiguration.cs create mode 100644 src/AspNetCoreRateLimit/Resolvers/ClientHeaderResolveContributor.cs create mode 100644 src/AspNetCoreRateLimit/Resolvers/IClientResolveContributor.cs create mode 100644 src/AspNetCoreRateLimit/Resolvers/IIpResolveContributor.cs create mode 100644 src/AspNetCoreRateLimit/Resolvers/IpConnectionResolveContributor.cs create mode 100644 src/AspNetCoreRateLimit/Resolvers/IpHeaderResolveContributor.cs create mode 100644 test/AspNetCoreRateLimit.Tests/RateLimitWebApplicationFactory.cs diff --git a/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs index a6522e12..a22294a4 100644 --- a/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs @@ -10,7 +10,8 @@ public class IpRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor private readonly IpRateLimitOptions _options; private readonly IRateLimitStore _policyStore; - public IpRateLimitProcessor(IpRateLimitOptions options, + public IpRateLimitProcessor( + IpRateLimitOptions options, IRateLimitCounterStore counterStore, IIpPolicyStore policyStore) : base(options, counterStore) diff --git a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs index 70117fda..299b9bdb 100644 --- a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs @@ -165,6 +165,7 @@ protected List GetMatchingRules(ClientRequestIdentity identity, L if (_options.GeneralRules != null) { var matchingGeneralLimits = new List(); + if (_options.EnableEndpointRateLimiting) { // search for rules with endpoints like "*" and "*:/matching_path" in general rules diff --git a/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs b/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs index 37f027bb..87ba0180 100644 --- a/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs +++ b/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs @@ -12,8 +12,9 @@ public ClientRateLimitMiddleware(RequestDelegate next, IOptions options, IRateLimitCounterStore counterStore, IClientPolicyStore policyStore, + IRateLimitConfiguration config, ILogger logger) - : base(next, options.Value, new ClientRateLimitProcessor(options.Value, counterStore, policyStore)) + : base(next, options?.Value, new ClientRateLimitProcessor(options?.Value, counterStore, policyStore), config) { _logger = logger; } diff --git a/src/AspNetCoreRateLimit/Middleware/IRateLimitConfiguration.cs b/src/AspNetCoreRateLimit/Middleware/IRateLimitConfiguration.cs new file mode 100644 index 00000000..8f72851e --- /dev/null +++ b/src/AspNetCoreRateLimit/Middleware/IRateLimitConfiguration.cs @@ -0,0 +1,13 @@ +using System.Collections.Generic; + +namespace AspNetCoreRateLimit +{ + public interface IRateLimitConfiguration + { + //bool Enabled { get; set; } + + IList ClientResolvers { get; } + + IList IpResolvers { get; } + } +} diff --git a/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs b/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs index 52bb418a..bad8f63e 100644 --- a/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs +++ b/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs @@ -12,8 +12,9 @@ public IpRateLimitMiddleware(RequestDelegate next, IOptions options, IRateLimitCounterStore counterStore, IIpPolicyStore policyStore, + IRateLimitConfiguration config, ILogger logger) - : base(next, options.Value, new IpRateLimitProcessor(options.Value, counterStore, policyStore)) + : base(next, options?.Value, new IpRateLimitProcessor(options?.Value, counterStore, policyStore), config) { _logger = logger; diff --git a/src/AspNetCoreRateLimit/Middleware/RateLimitConfiguration.cs b/src/AspNetCoreRateLimit/Middleware/RateLimitConfiguration.cs new file mode 100644 index 00000000..4e9b554c --- /dev/null +++ b/src/AspNetCoreRateLimit/Middleware/RateLimitConfiguration.cs @@ -0,0 +1,38 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Options; +using System.Collections.Generic; + +namespace AspNetCoreRateLimit +{ + public class RateLimitConfiguration : IRateLimitConfiguration + { + //public bool Enabled { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + + public IList ClientResolvers { get; } + + public IList IpResolvers { get; } + + public RateLimitConfiguration( + IHttpContextAccessor httpContextAccessor, + IOptions ipOptions, + IOptions clientOptions) + { + ClientResolvers = new List(); + + if (!string.IsNullOrEmpty(clientOptions?.Value.ClientIdHeader)) + { + ClientResolvers.Add(new ClientHeaderResolveContributor(httpContextAccessor, clientOptions.Value.ClientIdHeader)); + } + + IpResolvers = new List(); + + // the contributors are resolved in the order of their collection index + if (!string.IsNullOrEmpty(ipOptions?.Value.RealIpHeader)) + { + IpResolvers.Add(new IpHeaderResolveContributor(httpContextAccessor, ipOptions.Value.RealIpHeader)); + } + + IpResolvers.Add(new IpConnectionResolveContributor(httpContextAccessor)); + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs b/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs index 0f7b5a9a..dae95a02 100644 --- a/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs +++ b/src/AspNetCoreRateLimit/Middleware/RateLimitMiddleware.cs @@ -11,15 +11,18 @@ public abstract class RateLimitMiddleware private readonly RequestDelegate _next; private readonly TProcessor _processor; private readonly RateLimitOptions _options; + private readonly IRateLimitConfiguration _config; protected RateLimitMiddleware( RequestDelegate next, RateLimitOptions options, - TProcessor processor) + TProcessor processor, + IRateLimitConfiguration config) { _next = next; _options = options; _processor = processor; + _config = config; } public async Task Invoke(HttpContext context) @@ -32,7 +35,7 @@ public async Task Invoke(HttpContext context) } // compute identity from request - var identity = SetIdentity(context); + var identity = ResolveIdentity(context); // check white list if (_processor.IsWhitelisted(identity)) @@ -101,37 +104,43 @@ public async Task Invoke(HttpContext context) await _next.Invoke(context); } - public virtual ClientRequestIdentity SetIdentity(HttpContext httpContext) + public virtual ClientRequestIdentity ResolveIdentity(HttpContext httpContext) { - //var clientId = "anon"; - //if (httpContext.Request.Headers.Keys.Contains(_options.ClientIdHeader, StringComparer.CurrentCultureIgnoreCase)) - //{ - // clientId = httpContext.Request.Headers[_options.ClientIdHeader].First(); - //} - - //string clientIp; - //try - //{ - // var ip = _ipParser.GetClientIp(httpContext); - - // if (ip == null) - // { - // throw new Exception("IpRateLimitMiddleware can't parse caller IP"); - // } - - // clientIp = ip.ToString(); - //} - //catch (Exception ex) - //{ - // throw new Exception("IpRateLimitMiddleware can't parse caller IP", ex); - //} + string clientIp = null; + string clientId = null; + + if (_config.ClientResolvers?.Any() == true) + { + foreach(var resolver in _config.ClientResolvers) + { + clientId = resolver.ResolveClient(); + + if (!string.IsNullOrEmpty(clientId)) + { + break; + } + } + } + + if (_config.IpResolvers?.Any() == true) + { + foreach (var resolver in _config.IpResolvers) + { + clientIp = resolver.ResolveIp(); + + if (!string.IsNullOrEmpty(clientIp)) + { + break; + } + } + } return new ClientRequestIdentity { - //ClientIp = clientIp, + ClientIp = clientIp, Path = httpContext.Request.Path.ToString().ToLowerInvariant(), HttpVerb = httpContext.Request.Method.ToLowerInvariant(), - //ClientId = clientId + ClientId = clientId }; } diff --git a/src/AspNetCoreRateLimit/Resolvers/ClientHeaderResolveContributor.cs b/src/AspNetCoreRateLimit/Resolvers/ClientHeaderResolveContributor.cs new file mode 100644 index 00000000..7df1d738 --- /dev/null +++ b/src/AspNetCoreRateLimit/Resolvers/ClientHeaderResolveContributor.cs @@ -0,0 +1,31 @@ +using Microsoft.AspNetCore.Http; +using System.Linq; + +namespace AspNetCoreRateLimit +{ + public class ClientHeaderResolveContributor : IClientResolveContributor + { + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly string _headerName; + + public ClientHeaderResolveContributor( + IHttpContextAccessor httpContextAccessor, + string headerName) + { + _httpContextAccessor = httpContextAccessor; + _headerName = headerName; + } + public string ResolveClient() + { + var clientId = "anon"; + var httpContext = _httpContextAccessor.HttpContext; + + if (httpContext.Request.Headers.TryGetValue(_headerName, out var values)) + { + clientId = values.First(); + } + + return clientId; + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Resolvers/IClientResolveContributor.cs b/src/AspNetCoreRateLimit/Resolvers/IClientResolveContributor.cs new file mode 100644 index 00000000..a92a72df --- /dev/null +++ b/src/AspNetCoreRateLimit/Resolvers/IClientResolveContributor.cs @@ -0,0 +1,7 @@ +namespace AspNetCoreRateLimit +{ + public interface IClientResolveContributor + { + string ResolveClient(); + } +} diff --git a/src/AspNetCoreRateLimit/Resolvers/IIpResolveContributor.cs b/src/AspNetCoreRateLimit/Resolvers/IIpResolveContributor.cs new file mode 100644 index 00000000..33a3722a --- /dev/null +++ b/src/AspNetCoreRateLimit/Resolvers/IIpResolveContributor.cs @@ -0,0 +1,7 @@ +namespace AspNetCoreRateLimit +{ + public interface IIpResolveContributor + { + string ResolveIp(); + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Resolvers/IpConnectionResolveContributor.cs b/src/AspNetCoreRateLimit/Resolvers/IpConnectionResolveContributor.cs new file mode 100644 index 00000000..e7bdb662 --- /dev/null +++ b/src/AspNetCoreRateLimit/Resolvers/IpConnectionResolveContributor.cs @@ -0,0 +1,19 @@ +using Microsoft.AspNetCore.Http; + +namespace AspNetCoreRateLimit +{ + public class IpConnectionResolveContributor : IIpResolveContributor + { + private readonly IHttpContextAccessor _httpContextAccessor; + + public IpConnectionResolveContributor(IHttpContextAccessor httpContextAccessor) + { + _httpContextAccessor = httpContextAccessor; + } + + public string ResolveIp() + { + return _httpContextAccessor.HttpContext.Connection.RemoteIpAddress?.ToString(); + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Resolvers/IpHeaderResolveContributor.cs b/src/AspNetCoreRateLimit/Resolvers/IpHeaderResolveContributor.cs new file mode 100644 index 00000000..68a64d93 --- /dev/null +++ b/src/AspNetCoreRateLimit/Resolvers/IpHeaderResolveContributor.cs @@ -0,0 +1,34 @@ +using Microsoft.AspNetCore.Http; +using System.Linq; +using System.Net; + +namespace AspNetCoreRateLimit +{ + public class IpHeaderResolveContributor : IIpResolveContributor + { + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly string _headerName; + + public IpHeaderResolveContributor( + IHttpContextAccessor httpContextAccessor, + string headerName) + { + _httpContextAccessor = httpContextAccessor; + _headerName = headerName; + } + + public string ResolveIp() + { + IPAddress clientIp = null; + + var httpContent = _httpContextAccessor.HttpContext; + + if (httpContent.Request.Headers.TryGetValue(_headerName, out var values)) + { + clientIp = IpAddressUtil.ParseIp(values.Last()); + } + + return clientIp?.ToString(); + } + } +} \ No newline at end of file diff --git a/test/AspNetCoreRateLimit.Demo/Program.cs b/test/AspNetCoreRateLimit.Demo/Program.cs index edfa6d7b..001fb75c 100644 --- a/test/AspNetCoreRateLimit.Demo/Program.cs +++ b/test/AspNetCoreRateLimit.Demo/Program.cs @@ -7,6 +7,8 @@ namespace AspNetCoreRateLimit.Demo { + // https://andrewlock.net/running-async-tasks-on-app-startup-in-asp-net-core-part-1/ + // https://andrewlock.net/running-async-tasks-on-app-startup-in-asp-net-core-part-2/ public class Program { public static async Task Main(string[] args) diff --git a/test/AspNetCoreRateLimit.Demo/Startup.cs b/test/AspNetCoreRateLimit.Demo/Startup.cs index ae8b9aa3..87a141c9 100644 --- a/test/AspNetCoreRateLimit.Demo/Startup.cs +++ b/test/AspNetCoreRateLimit.Demo/Startup.cs @@ -1,5 +1,6 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; @@ -23,13 +24,13 @@ public void ConfigureServices(IServiceCollection services) // needed to store rate limit counters and ip rules services.AddMemoryCache(); - //configure ip rate limiting middle-ware + // configure ip rate limiting middle-ware services.Configure(Configuration.GetSection("IpRateLimiting")); services.Configure(Configuration.GetSection("IpRateLimitPolicies")); services.AddSingleton(); services.AddSingleton(); - //configure client rate limiting middleware + // configure client rate limiting middleware services.Configure(Configuration.GetSection("ClientRateLimiting")); services.Configure(Configuration.GetSection("ClientRateLimitPolicies")); services.AddSingleton(); @@ -39,6 +40,14 @@ public void ConfigureServices(IServiceCollection services) ConfigurationBinder.Bind(Configuration.GetSection("ClientRateLimiting"), opt); services.AddMvc().AddNewtonsoftJson(); + + // https://github.com/aspnet/Hosting/issues/793 + // the IHttpContextAccessor service is not registered by default. + // the clientId/clientIp resolvers use it. + services.AddSingleton(); + + // configure the resolvers + services.AddSingleton(); } // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. diff --git a/test/AspNetCoreRateLimit.Tests/ClientRateLimitTests.cs b/test/AspNetCoreRateLimit.Tests/ClientRateLimitTests.cs index 2b62968c..ed18b020 100644 --- a/test/AspNetCoreRateLimit.Tests/ClientRateLimitTests.cs +++ b/test/AspNetCoreRateLimit.Tests/ClientRateLimitTests.cs @@ -5,7 +5,7 @@ namespace AspNetCoreRateLimit.Tests { - public class ClientRateLimitTests : IClassFixture> + public class ClientRateLimitTests : IClassFixture { private const string apiPath = "/api/clients"; private const string apiRateLimitPath = "/api/clientratelimit"; @@ -13,7 +13,7 @@ public class ClientRateLimitTests : IClassFixture factory) + public ClientRateLimitTests(RateLimitWebApplicationFactory factory) { _client = factory.CreateClient(options: new WebApplicationFactoryClientOptions { diff --git a/test/AspNetCoreRateLimit.Tests/IpRateLimitTests.cs b/test/AspNetCoreRateLimit.Tests/IpRateLimitTests.cs index 23538ab8..3d22f803 100644 --- a/test/AspNetCoreRateLimit.Tests/IpRateLimitTests.cs +++ b/test/AspNetCoreRateLimit.Tests/IpRateLimitTests.cs @@ -5,14 +5,14 @@ namespace AspNetCoreRateLimit.Tests { - public class IpRateLimitTests: IClassFixture> + public class IpRateLimitTests: IClassFixture { private const string apiValuesPath = "/api/values"; private const string apiRateLimitPath = "/api/ipratelimit"; private readonly HttpClient _client; - public IpRateLimitTests(WebApplicationFactory factory) + public IpRateLimitTests(RateLimitWebApplicationFactory factory) { _client = factory.CreateClient(options: new WebApplicationFactoryClientOptions { diff --git a/test/AspNetCoreRateLimit.Tests/RateLimitWebApplicationFactory.cs b/test/AspNetCoreRateLimit.Tests/RateLimitWebApplicationFactory.cs new file mode 100644 index 00000000..cb985a31 --- /dev/null +++ b/test/AspNetCoreRateLimit.Tests/RateLimitWebApplicationFactory.cs @@ -0,0 +1,35 @@ +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Mvc.Testing; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; + +namespace AspNetCoreRateLimit.Tests +{ + // https://docs.microsoft.com/en-us/aspnet/core/test/integration-tests?view=aspnetcore-2.2 + // TestServer.cs https://github.com/aspnet/AspNetCore/blob/93a24b03bbda1aa0ab9b553a50b70dd36d554934/src/Hosting/TestHost/src/TestServer.cs + // WebApplicationFactory.cs https://github.com/aspnet/AspNetCore/blob/c565386a3ed135560bc2e9017aa54a950b4e35dd/src/Mvc/Mvc.Testing/src/WebApplicationFactory.cs + public class RateLimitWebApplicationFactory : WebApplicationFactory + { + protected override TestServer CreateServer(IWebHostBuilder builder) + { + var server = base.CreateServer(builder); + + using (var scope = server.Host.Services.CreateScope()) + { + // get the ClientPolicyStore instance + var clientPolicyStore = scope.ServiceProvider.GetRequiredService(); + + // seed Client data from appsettings + clientPolicyStore.SeedAsync().Wait(); + + // get the IpPolicyStore instance + var ipPolicyStore = scope.ServiceProvider.GetRequiredService(); + + // seed IP data from appsettings + ipPolicyStore.SeedAsync().Wait(); + } + + return server; + } + } +} \ No newline at end of file From 481bc70feb21865e9a2b3908b0f3adab8d495b98 Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sun, 17 Feb 2019 12:44:02 +0200 Subject: [PATCH 05/13] Add counter key builders #13 #29 #41 --- .../Core/ClientRateLimitProcessor.cs | 10 ++---- .../Core/IpRateLimitProcessor.cs | 10 ++---- .../Core/RateLimitProcessor.cs | 27 +++++++-------- .../ClientCounterKeyBuilder.cs | 17 ++++++++++ .../EndpointCounterKeyBuilder.cs | 11 +++++++ .../CounterKeyBuilders/ICounterKeyBuilder.cs | 7 ++++ .../CounterKeyBuilders/IpCounterKeyBuilder.cs | 17 ++++++++++ .../PathCounterKeyBuilder.cs | 10 ++++++ .../Middleware/ClientRateLimitMiddleware.cs | 2 +- .../Middleware/IRateLimitConfiguration.cs | 6 ++-- .../Middleware/IpRateLimitMiddleware.cs | 2 +- .../Middleware/RateLimitConfiguration.cs | 33 ++++++++++++------- 12 files changed, 109 insertions(+), 43 deletions(-) create mode 100644 src/AspNetCoreRateLimit/CounterKeyBuilders/ClientCounterKeyBuilder.cs create mode 100644 src/AspNetCoreRateLimit/CounterKeyBuilders/EndpointCounterKeyBuilder.cs create mode 100644 src/AspNetCoreRateLimit/CounterKeyBuilders/ICounterKeyBuilder.cs create mode 100644 src/AspNetCoreRateLimit/CounterKeyBuilders/IpCounterKeyBuilder.cs create mode 100644 src/AspNetCoreRateLimit/CounterKeyBuilders/PathCounterKeyBuilder.cs diff --git a/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs index 8fb17a49..503ffa4b 100644 --- a/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs @@ -13,8 +13,9 @@ public class ClientRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor public ClientRateLimitProcessor( ClientRateLimitOptions options, IRateLimitCounterStore counterStore, - IClientPolicyStore policyStore) - : base(options, counterStore) + IClientPolicyStore policyStore, + IRateLimitConfiguration config) + : base(options, counterStore, new ClientCounterKeyBuilder(options), config) { _options = options; _policyStore = policyStore; @@ -31,10 +32,5 @@ public async Task> GetMatchingRulesAsync(ClientReques return Enumerable.Empty(); } - - protected override string GetCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - return $"{_options.RateLimitCounterPrefix}_{requestIdentity.ClientId}_{rule.Period}"; - } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs index a22294a4..19990a82 100644 --- a/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs @@ -13,8 +13,9 @@ public class IpRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor public IpRateLimitProcessor( IpRateLimitOptions options, IRateLimitCounterStore counterStore, - IIpPolicyStore policyStore) - : base(options, counterStore) + IIpPolicyStore policyStore, + IRateLimitConfiguration config) + : base(options, counterStore, new IpCounterKeyBuilder(options), config) { _options = options; _policyStore = policyStore; @@ -51,10 +52,5 @@ public override bool IsWhitelisted(ClientRequestIdentity requestIdentity) return base.IsWhitelisted(requestIdentity); } - - protected override string GetCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule) - { - return $"{_options.RateLimitCounterPrefix}_{requestIdentity.ClientIp}_{rule.Period}"; - } } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs index 299b9bdb..072f5971 100644 --- a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs @@ -11,13 +11,19 @@ public abstract class RateLimitProcessor { private readonly RateLimitOptions _options; private readonly IRateLimitCounterStore _counterStore; + private readonly ICounterKeyBuilder _counterKeyBuilder; + private readonly IRateLimitConfiguration _config; protected RateLimitProcessor( RateLimitOptions options, - IRateLimitCounterStore counterStore) + IRateLimitCounterStore counterStore, + ICounterKeyBuilder counterKeyBuilder, + IRateLimitConfiguration config) { _options = options; _counterStore = counterStore; + _counterKeyBuilder = counterKeyBuilder; + _config = config; } private static readonly SemaphoreSlim Semaphore = new SemaphoreSlim(1); @@ -47,7 +53,7 @@ public async Task ProcessRequestAsync(ClientRequestIdentity re TotalRequests = 1 }; - var counterId = ComputeCounterKey(requestIdentity, rule); + var counterId = BuildCounterKey(requestIdentity, rule); // serial reads and writes await Semaphore.WaitAsync(cancellationToken); @@ -87,7 +93,7 @@ public async Task ProcessRequestAsync(ClientRequestIdentity re public async Task GetRateLimitHeadersAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default) { var headers = new RateLimitHeaders(); - var counterId = ComputeCounterKey(requestIdentity, rule); + var counterId = BuildCounterKey(requestIdentity, rule); var entry = await _counterStore.GetAsync(counterId, cancellationToken); long remaining; @@ -111,18 +117,13 @@ public async Task GetRateLimitHeadersAsync(ClientRequestIdenti return headers; } - protected abstract string GetCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule); - - protected string ComputeCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule) + protected virtual string BuildCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule) { - var key = GetCounterKey(requestIdentity, rule); + var key = _counterKeyBuilder.Build(requestIdentity, rule); - if (_options.EnableEndpointRateLimiting) + if (_options.EnableEndpointRateLimiting && _config.EndpointCounterKeyBuilder != null) { - key += $"_{requestIdentity.HttpVerb}_{requestIdentity.Path}"; - - // TODO: consider using the rule endpoint as key, this will allow to rate limit /api/values/1 and api/values/2 under same counter - //key += $"_{rule.Endpoint}"; + key += _config.EndpointCounterKeyBuilder.Build(requestIdentity, rule); } var idBytes = System.Text.Encoding.UTF8.GetBytes(key); @@ -137,7 +138,7 @@ protected string ComputeCounterKey(ClientRequestIdentity requestIdentity, RateLi return BitConverter.ToString(hashBytes).Replace("-", string.Empty); } - protected List GetMatchingRules(ClientRequestIdentity identity, List rules) + protected virtual List GetMatchingRules(ClientRequestIdentity identity, List rules) { var limits = new List(); diff --git a/src/AspNetCoreRateLimit/CounterKeyBuilders/ClientCounterKeyBuilder.cs b/src/AspNetCoreRateLimit/CounterKeyBuilders/ClientCounterKeyBuilder.cs new file mode 100644 index 00000000..32aca93f --- /dev/null +++ b/src/AspNetCoreRateLimit/CounterKeyBuilders/ClientCounterKeyBuilder.cs @@ -0,0 +1,17 @@ +namespace AspNetCoreRateLimit +{ + public class ClientCounterKeyBuilder : ICounterKeyBuilder + { + private readonly ClientRateLimitOptions _options; + + public ClientCounterKeyBuilder(ClientRateLimitOptions options) + { + _options = options; + } + + public string Build(ClientRequestIdentity requestIdentity, RateLimitRule rule) + { + return $"{_options.RateLimitCounterPrefix}_{requestIdentity.ClientId}_{rule.Period}"; + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/CounterKeyBuilders/EndpointCounterKeyBuilder.cs b/src/AspNetCoreRateLimit/CounterKeyBuilders/EndpointCounterKeyBuilder.cs new file mode 100644 index 00000000..5a25b521 --- /dev/null +++ b/src/AspNetCoreRateLimit/CounterKeyBuilders/EndpointCounterKeyBuilder.cs @@ -0,0 +1,11 @@ +namespace AspNetCoreRateLimit +{ + public class EndpointCounterKeyBuilder : ICounterKeyBuilder + { + public string Build(ClientRequestIdentity requestIdentity, RateLimitRule rule) + { + // This will allow to rate limit /api/values/1 and api/values/2 under same counter + return $"_{requestIdentity.HttpVerb}_{requestIdentity.Path}_{rule.Endpoint}"; + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/CounterKeyBuilders/ICounterKeyBuilder.cs b/src/AspNetCoreRateLimit/CounterKeyBuilders/ICounterKeyBuilder.cs new file mode 100644 index 00000000..136f301e --- /dev/null +++ b/src/AspNetCoreRateLimit/CounterKeyBuilders/ICounterKeyBuilder.cs @@ -0,0 +1,7 @@ +namespace AspNetCoreRateLimit +{ + public interface ICounterKeyBuilder + { + string Build(ClientRequestIdentity requestIdentity, RateLimitRule rule); + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/CounterKeyBuilders/IpCounterKeyBuilder.cs b/src/AspNetCoreRateLimit/CounterKeyBuilders/IpCounterKeyBuilder.cs new file mode 100644 index 00000000..6c610d06 --- /dev/null +++ b/src/AspNetCoreRateLimit/CounterKeyBuilders/IpCounterKeyBuilder.cs @@ -0,0 +1,17 @@ +namespace AspNetCoreRateLimit +{ + public class IpCounterKeyBuilder : ICounterKeyBuilder + { + private readonly IpRateLimitOptions _options; + + public IpCounterKeyBuilder(IpRateLimitOptions options) + { + _options = options; + } + + public string Build(ClientRequestIdentity requestIdentity, RateLimitRule rule) + { + return $"{_options.RateLimitCounterPrefix}_{requestIdentity.ClientIp}_{rule.Period}"; + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/CounterKeyBuilders/PathCounterKeyBuilder.cs b/src/AspNetCoreRateLimit/CounterKeyBuilders/PathCounterKeyBuilder.cs new file mode 100644 index 00000000..7c194cdd --- /dev/null +++ b/src/AspNetCoreRateLimit/CounterKeyBuilders/PathCounterKeyBuilder.cs @@ -0,0 +1,10 @@ +namespace AspNetCoreRateLimit +{ + public class PathCounterKeyBuilder : ICounterKeyBuilder + { + public string Build(ClientRequestIdentity requestIdentity, RateLimitRule rule) + { + return $"_{requestIdentity.HttpVerb}_{requestIdentity.Path}"; + } + } +} diff --git a/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs b/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs index 87ba0180..aa3385b9 100644 --- a/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs +++ b/src/AspNetCoreRateLimit/Middleware/ClientRateLimitMiddleware.cs @@ -14,7 +14,7 @@ public ClientRateLimitMiddleware(RequestDelegate next, IClientPolicyStore policyStore, IRateLimitConfiguration config, ILogger logger) - : base(next, options?.Value, new ClientRateLimitProcessor(options?.Value, counterStore, policyStore), config) + : base(next, options?.Value, new ClientRateLimitProcessor(options?.Value, counterStore, policyStore, config), config) { _logger = logger; } diff --git a/src/AspNetCoreRateLimit/Middleware/IRateLimitConfiguration.cs b/src/AspNetCoreRateLimit/Middleware/IRateLimitConfiguration.cs index 8f72851e..0de75dce 100644 --- a/src/AspNetCoreRateLimit/Middleware/IRateLimitConfiguration.cs +++ b/src/AspNetCoreRateLimit/Middleware/IRateLimitConfiguration.cs @@ -4,10 +4,10 @@ namespace AspNetCoreRateLimit { public interface IRateLimitConfiguration { - //bool Enabled { get; set; } - IList ClientResolvers { get; } IList IpResolvers { get; } + + ICounterKeyBuilder EndpointCounterKeyBuilder { get; } } -} +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs b/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs index bad8f63e..c2c95ae7 100644 --- a/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs +++ b/src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs @@ -14,7 +14,7 @@ public IpRateLimitMiddleware(RequestDelegate next, IIpPolicyStore policyStore, IRateLimitConfiguration config, ILogger logger) - : base(next, options?.Value, new IpRateLimitProcessor(options?.Value, counterStore, policyStore), config) + : base(next, options?.Value, new IpRateLimitProcessor(options?.Value, counterStore, policyStore, config), config) { _logger = logger; diff --git a/src/AspNetCoreRateLimit/Middleware/RateLimitConfiguration.cs b/src/AspNetCoreRateLimit/Middleware/RateLimitConfiguration.cs index 4e9b554c..ecb98dc8 100644 --- a/src/AspNetCoreRateLimit/Middleware/RateLimitConfiguration.cs +++ b/src/AspNetCoreRateLimit/Middleware/RateLimitConfiguration.cs @@ -6,33 +6,44 @@ namespace AspNetCoreRateLimit { public class RateLimitConfiguration : IRateLimitConfiguration { - //public bool Enabled { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + public IList ClientResolvers { get; } = new List(); + public IList IpResolvers { get; } = new List(); - public IList ClientResolvers { get; } - - public IList IpResolvers { get; } + public virtual ICounterKeyBuilder EndpointCounterKeyBuilder { get; } = new PathCounterKeyBuilder(); public RateLimitConfiguration( IHttpContextAccessor httpContextAccessor, IOptions ipOptions, IOptions clientOptions) { + IpRateLimitOptions = ipOptions?.Value; + ClientRateLimitOptions = clientOptions?.Value; + HttpContextAccessor = httpContextAccessor; + ClientResolvers = new List(); + IpResolvers = new List(); + + RegisterResolvers(); + } - if (!string.IsNullOrEmpty(clientOptions?.Value.ClientIdHeader)) + protected readonly IpRateLimitOptions IpRateLimitOptions; + protected readonly ClientRateLimitOptions ClientRateLimitOptions; + protected readonly IHttpContextAccessor HttpContextAccessor; + + protected virtual void RegisterResolvers() + { + if (!string.IsNullOrEmpty(ClientRateLimitOptions?.ClientIdHeader)) { - ClientResolvers.Add(new ClientHeaderResolveContributor(httpContextAccessor, clientOptions.Value.ClientIdHeader)); + ClientResolvers.Add(new ClientHeaderResolveContributor(HttpContextAccessor, ClientRateLimitOptions.ClientIdHeader)); } - IpResolvers = new List(); - // the contributors are resolved in the order of their collection index - if (!string.IsNullOrEmpty(ipOptions?.Value.RealIpHeader)) + if (!string.IsNullOrEmpty(IpRateLimitOptions?.RealIpHeader)) { - IpResolvers.Add(new IpHeaderResolveContributor(httpContextAccessor, ipOptions.Value.RealIpHeader)); + IpResolvers.Add(new IpHeaderResolveContributor(HttpContextAccessor, IpRateLimitOptions.RealIpHeader)); } - IpResolvers.Add(new IpConnectionResolveContributor(httpContextAccessor)); + IpResolvers.Add(new IpConnectionResolveContributor(HttpContextAccessor)); } } } \ No newline at end of file From 63d44aa7611900d88528c874218c980082b649e8 Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sun, 17 Feb 2019 12:46:52 +0200 Subject: [PATCH 06/13] Set correct version, change authors --- src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj b/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj index d1aa1b36..c9161a0d 100644 --- a/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj +++ b/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj @@ -3,8 +3,7 @@ netstandard2.0 ASP.NET Core rate limiting middleware - 1.0.6 - Stefan Prodan + Stefan Prodan, Cristi Pufu AspNetCoreRateLimit AspNetCoreRateLimit aspnetcore;rate-limit;throttle @@ -13,6 +12,7 @@ git https://github.com/stefanprodan/AspNetCoreRateLimit 7.1 + 3.0.0 From 1642cdbd38689022a1c128602b4b732a2a6b5e03 Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sun, 17 Feb 2019 13:58:39 +0200 Subject: [PATCH 07/13] Extensions cosmetic --- src/AspNetCoreRateLimit/Core/Extensions.cs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/AspNetCoreRateLimit/Core/Extensions.cs b/src/AspNetCoreRateLimit/Core/Extensions.cs index afe6e4b1..1443ee21 100644 --- a/src/AspNetCoreRateLimit/Core/Extensions.cs +++ b/src/AspNetCoreRateLimit/Core/Extensions.cs @@ -1,5 +1,4 @@ using System; -using System.Globalization; namespace AspNetCoreRateLimit { @@ -12,12 +11,10 @@ public static bool IsWildcardMatch(this string source, string value) public static string RetryAfterFrom(this DateTime timestamp, RateLimitRule rule) { - var secondsPast = Convert.ToInt32((DateTime.UtcNow - timestamp).TotalSeconds); - var retryAfter = Convert.ToInt32(rule.PeriodTimespan.Value.TotalSeconds); + var diff = timestamp + rule.PeriodTimespan.Value - DateTime.UtcNow; + var seconds = Math.Max(diff.TotalSeconds, 1); - retryAfter = retryAfter > 1 ? retryAfter - secondsPast : 1; - - return retryAfter.ToString(CultureInfo.InvariantCulture); + return $"{seconds:F0}"; } public static TimeSpan ToTimeSpan(this string timeSpan) From 0118aa7ecc0dfeef46d3f6a1d3de4be02b1c83e8 Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sun, 17 Feb 2019 14:10:06 +0200 Subject: [PATCH 08/13] Memory cache options - NeverRemove, cosmetics --- src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs | 2 +- src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs | 12 ++++++------ .../Store/MemoryCacheRateLimitStore.cs | 12 +++++++++++- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs index 19990a82..657c958e 100644 --- a/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs @@ -29,7 +29,7 @@ public async Task> GetMatchingRulesAsync(ClientReques if (policies != null && policies.IpRules != null && policies.IpRules.Any()) { // search for rules with IP intervals containing client IP - var matchPolicies = policies.IpRules.Where(r => IpParser.ContainsIp(r.Ip, identity.ClientIp)).AsEnumerable(); + var matchPolicies = policies.IpRules.Where(r => IpParser.ContainsIp(r.Ip, identity.ClientIp)); var rules = new List(); foreach (var item in matchPolicies) diff --git a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs index 072f5971..88e3503d 100644 --- a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs @@ -145,17 +145,17 @@ protected virtual List GetMatchingRules(ClientRequestIdentity ide if (_options.EnableEndpointRateLimiting) { // search for rules with endpoints like "*" and "*:/matching_path" - var pathLimits = rules.Where(r => $"*:{identity.Path}".IsWildcardMatch(r.Endpoint)).AsEnumerable(); + var pathLimits = rules.Where(r => $"*:{identity.Path}".IsWildcardMatch(r.Endpoint)); limits.AddRange(pathLimits); // search for rules with endpoints like "matching_verb:/matching_path" - var verbLimits = rules.Where(r => $"{identity.HttpVerb}:{identity.Path}".IsWildcardMatch(r.Endpoint)).AsEnumerable(); + var verbLimits = rules.Where(r => $"{identity.HttpVerb}:{identity.Path}".IsWildcardMatch(r.Endpoint)); limits.AddRange(verbLimits); } else { //ignore endpoint rules and search for global rules only - var genericLimits = rules.Where(r => r.Endpoint == "*").AsEnumerable(); + var genericLimits = rules.Where(r => r.Endpoint == "*"); limits.AddRange(genericLimits); } @@ -170,17 +170,17 @@ protected virtual List GetMatchingRules(ClientRequestIdentity ide if (_options.EnableEndpointRateLimiting) { // search for rules with endpoints like "*" and "*:/matching_path" in general rules - var pathLimits = _options.GeneralRules.Where(r => $"*:{identity.Path}".IsWildcardMatch(r.Endpoint)).AsEnumerable(); + var pathLimits = _options.GeneralRules.Where(r => $"*:{identity.Path}".IsWildcardMatch(r.Endpoint)); matchingGeneralLimits.AddRange(pathLimits); // search for rules with endpoints like "matching_verb:/matching_path" in general rules - var verbLimits = _options.GeneralRules.Where(r => $"{identity.HttpVerb}:{identity.Path}".IsWildcardMatch(r.Endpoint)).AsEnumerable(); + var verbLimits = _options.GeneralRules.Where(r => $"{identity.HttpVerb}:{identity.Path}".IsWildcardMatch(r.Endpoint)); matchingGeneralLimits.AddRange(verbLimits); } else { //ignore endpoint rules and search for global rules in general rules - var genericLimits = _options.GeneralRules.Where(r => r.Endpoint == "*").AsEnumerable(); + var genericLimits = _options.GeneralRules.Where(r => r.Endpoint == "*"); matchingGeneralLimits.AddRange(genericLimits); } diff --git a/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitStore.cs b/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitStore.cs index 3b812f15..619c14db 100644 --- a/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitStore.cs +++ b/src/AspNetCoreRateLimit/Store/MemoryCacheRateLimitStore.cs @@ -38,7 +38,17 @@ public Task RemoveAsync(string id, CancellationToken cancellationToken = default public Task SetAsync(string id, T entry, TimeSpan? expirationTime = null, CancellationToken cancellationToken = default) { - _cache.Set(id, entry, expirationTime.HasValue ? new MemoryCacheEntryOptions().SetAbsoluteExpiration(expirationTime.Value) : null); + MemoryCacheEntryOptions options = new MemoryCacheEntryOptions + { + Priority = CacheItemPriority.NeverRemove + }; + + if (expirationTime.HasValue) + { + options.SetAbsoluteExpiration(expirationTime.Value); + } + + _cache.Set(id, entry, options); return Task.CompletedTask; } From 8ab12be5bdb84197f4db9a8561382d12cab004a4 Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sun, 17 Feb 2019 14:20:00 +0200 Subject: [PATCH 09/13] Consistent order for limits --- src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs index 88e3503d..820d01cc 100644 --- a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs @@ -185,7 +185,11 @@ protected virtual List GetMatchingRules(ClientRequestIdentity ide } // get the most restrictive general limit for each period - var generalLimits = matchingGeneralLimits.GroupBy(l => l.Period).Select(l => l.OrderBy(x => x.Limit)).Select(l => l.First()).ToList(); + var generalLimits = matchingGeneralLimits + .GroupBy(l => l.Period) + .Select(l => l.OrderBy(x => x.Limit).ThenBy(x => x.Endpoint)) + .Select(l => l.First()) + .ToList(); foreach (var generalLimit in generalLimits) { From 4d61dd86332ef0b6addadabc0b979fc0534faa8f Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sun, 17 Feb 2019 14:22:50 +0200 Subject: [PATCH 10/13] Empty lists for policies & rules --- src/AspNetCoreRateLimit/Models/IpRateLimitPolicies.cs | 2 +- src/AspNetCoreRateLimit/Models/RateLimitPolicy.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/AspNetCoreRateLimit/Models/IpRateLimitPolicies.cs b/src/AspNetCoreRateLimit/Models/IpRateLimitPolicies.cs index 044907d1..20ad6713 100644 --- a/src/AspNetCoreRateLimit/Models/IpRateLimitPolicies.cs +++ b/src/AspNetCoreRateLimit/Models/IpRateLimitPolicies.cs @@ -4,6 +4,6 @@ namespace AspNetCoreRateLimit { public class IpRateLimitPolicies { - public List IpRules { get; set; } + public List IpRules { get; set; } = new List(); } } \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Models/RateLimitPolicy.cs b/src/AspNetCoreRateLimit/Models/RateLimitPolicy.cs index ba1d40dd..7403efdc 100644 --- a/src/AspNetCoreRateLimit/Models/RateLimitPolicy.cs +++ b/src/AspNetCoreRateLimit/Models/RateLimitPolicy.cs @@ -4,6 +4,6 @@ namespace AspNetCoreRateLimit { public class RateLimitPolicy { - public List Rules { get; set; } + public List Rules { get; set; } = new List(); } } From 2bf430f0b5650b4ca3d50fe7990e345c20258d8b Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sun, 17 Feb 2019 14:48:02 +0200 Subject: [PATCH 11/13] Minor fix --- src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs index 820d01cc..695d536a 100644 --- a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs @@ -101,12 +101,12 @@ public async Task GetRateLimitHeadersAsync(ClientRequestIdenti if (entry.HasValue) { - reset = entry.Value.Timestamp + rule.Period.ToTimeSpan(); + reset = entry.Value.Timestamp + (rule.PeriodTimespan ?? rule.Period.ToTimeSpan()); remaining = rule.Limit - entry.Value.TotalRequests; } else { - reset = DateTime.UtcNow + rule.Period.ToTimeSpan(); + reset = DateTime.UtcNow + (rule.PeriodTimespan ?? rule.Period.ToTimeSpan()); remaining = rule.Limit; } @@ -203,7 +203,7 @@ protected virtual List GetMatchingRules(ClientRequestIdentity ide foreach (var item in limits) { - //parse period text into time spans + // parse period text into time spans item.PeriodTimespan = item.Period.ToTimeSpan(); } From 0a54eacfde51bd7088a3d94b9b56c9c401b70ca4 Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Sun, 17 Feb 2019 15:38:31 +0200 Subject: [PATCH 12/13] Implement named locker by counter Id #12 --- .../AspNetCoreRateLimit.csproj | 2 +- .../AsyncKeyLock/AsyncKeyLock.cs | 133 ++++++++++++++ .../AsyncKeyLock/AsyncKeyLockDoorman.cs | 171 ++++++++++++++++++ .../Core/RateLimitProcessor.cs | 14 +- .../DistributedCacheClientPolicyStore.cs | 2 +- .../Store/DistributedCacheIpPolicyStore.cs | 2 +- .../Store/MemoryCacheClientPolicyStore.cs | 2 +- .../Store/MemoryCacheIpPolicyStore.cs | 2 +- .../AspNetCoreRateLimit.Demo.csproj | 3 +- test/AspNetCoreRateLimit.Demo/Startup.cs | 3 + .../AspNetCoreRateLimit.Tests.csproj | 2 +- 11 files changed, 320 insertions(+), 16 deletions(-) create mode 100644 src/AspNetCoreRateLimit/AsyncKeyLock/AsyncKeyLock.cs create mode 100644 src/AspNetCoreRateLimit/AsyncKeyLock/AsyncKeyLockDoorman.cs diff --git a/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj b/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj index c9161a0d..882dedae 100644 --- a/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj +++ b/src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj @@ -11,7 +11,7 @@ http://opensource.org/licenses/MIT git https://github.com/stefanprodan/AspNetCoreRateLimit - 7.1 + 7.3 3.0.0 diff --git a/src/AspNetCoreRateLimit/AsyncKeyLock/AsyncKeyLock.cs b/src/AspNetCoreRateLimit/AsyncKeyLock/AsyncKeyLock.cs new file mode 100644 index 00000000..fb9eb367 --- /dev/null +++ b/src/AspNetCoreRateLimit/AsyncKeyLock/AsyncKeyLock.cs @@ -0,0 +1,133 @@ +// Copyright (c) Six Labors and contributors. +// Licensed under the Apache License, Version 2.0. +// Thanks to https://github.com/SixLabors/ImageSharp.Web/ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace AspNetCoreRateLimit +{ + /// + /// The async key lock prevents multiple asynchronous threads acting upon the same object with the given key at the same time. + /// It is designed so that it does not block unique requests allowing a high throughput. + /// + internal sealed class AsyncKeyLock + { + /// + /// A collection of doorman counters used for tracking references to the same key. + /// + private static readonly Dictionary Keys = new Dictionary(); + + /// + /// A pool of unused doorman counters that can be re-used to avoid allocations. + /// + private static readonly Stack Pool = new Stack(MaxPoolSize); + + /// + /// Maximum size of the doorman pool. If the pool is already full when releasing + /// a doorman, it is simply left for garbage collection. + /// + private const int MaxPoolSize = 20; + + /// + /// SpinLock used to protect access to the Keys and Pool collections. + /// + private static SpinLock _spinLock = new SpinLock(false); + + /// + /// Locks the current thread in read mode asynchronously. + /// + /// The key identifying the specific object to lock against. + /// + /// The that will release the lock. + /// + public async Task ReaderLockAsync(string key) + { + AsyncKeyLockDoorman doorman = GetDoorman(key); + + return await doorman.ReaderLockAsync().ConfigureAwait(false); + } + + /// + /// Locks the current thread in write mode asynchronously. + /// + /// The key identifying the specific object to lock against. + /// + /// The that will release the lock. + /// + public async Task WriterLockAsync(string key) + { + AsyncKeyLockDoorman doorman = GetDoorman(key); + + return await doorman.WriterLockAsync().ConfigureAwait(false); + } + + /// + /// Gets the doorman for the specified key. If no such doorman exists, an unused doorman + /// is obtained from the pool (or a new one is allocated if the pool is empty), and it's + /// assigned to the requested key. + /// + /// The key for the desired doorman. + /// The . + private static AsyncKeyLockDoorman GetDoorman(string key) + { + AsyncKeyLockDoorman doorman; + bool lockTaken = false; + try + { + _spinLock.Enter(ref lockTaken); + + if (!Keys.TryGetValue(key, out doorman)) + { + doorman = (Pool.Count > 0) ? Pool.Pop() : new AsyncKeyLockDoorman(ReleaseDoorman); + doorman.Key = key; + Keys.Add(key, doorman); + } + + doorman.RefCount++; + } + finally + { + if (lockTaken) + { + _spinLock.Exit(); + } + } + + return doorman; + } + + /// + /// Releases a reference to a doorman. If the ref-count hits zero, then the doorman is + /// returned to the pool (or is simply left for the garbage collector to cleanup if the + /// pool is already full). + /// + /// The . + private static void ReleaseDoorman(AsyncKeyLockDoorman doorman) + { + bool lockTaken = false; + try + { + _spinLock.Enter(ref lockTaken); + + if (--doorman.RefCount == 0) + { + Keys.Remove(doorman.Key); + if (Pool.Count < MaxPoolSize) + { + doorman.Key = null; + Pool.Push(doorman); + } + } + } + finally + { + if (lockTaken) + { + _spinLock.Exit(); + } + } + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/AsyncKeyLock/AsyncKeyLockDoorman.cs b/src/AspNetCoreRateLimit/AsyncKeyLock/AsyncKeyLockDoorman.cs new file mode 100644 index 00000000..84a8040e --- /dev/null +++ b/src/AspNetCoreRateLimit/AsyncKeyLock/AsyncKeyLockDoorman.cs @@ -0,0 +1,171 @@ +// Copyright (c) Six Labors and contributors. +// Licensed under the Apache License, Version 2.0. +// Thanks to https://github.com/SixLabors/ImageSharp.Web/ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace AspNetCoreRateLimit +{ + /// + /// An asynchronous locker that provides read and write locking policies. + /// + internal sealed class AsyncKeyLockDoorman + { + private readonly Queue> _waitingWriters; + private readonly Task _readerReleaser; + private readonly Task _writerReleaser; + private readonly Action _reset; + private TaskCompletionSource _waitingReader; + private int _readersWaiting; + private int _status; + + /// + /// Initializes a new instance of the class. + /// + /// The reset action. + public AsyncKeyLockDoorman(Action reset) + { + _waitingWriters = new Queue>(); + _waitingReader = new TaskCompletionSource(); + _status = 0; + + _readerReleaser = Task.FromResult(new Releaser(this, false)); + _writerReleaser = Task.FromResult(new Releaser(this, true)); + _reset = reset; + } + + /// + /// Gets or sets the key that this doorman is mapped to. + /// + public string Key { get; set; } + + /// + /// Gets or sets the current reference count on this doorman. + /// + public int RefCount { get; set; } + + /// + /// Locks the current thread in read mode asynchronously. + /// + /// The . + public Task ReaderLockAsync() + { + lock (_waitingWriters) + { + if (_status >= 0 && _waitingWriters.Count == 0) + { + ++_status; + return _readerReleaser; + } + else + { + ++_readersWaiting; + return _waitingReader.Task.ContinueWith(t => t.Result); + } + } + } + + /// + /// Locks the current thread in write mode asynchronously. + /// + /// The . + public Task WriterLockAsync() + { + lock (_waitingWriters) + { + if (_status == 0) + { + _status = -1; + return _writerReleaser; + } + else + { + var waiter = new TaskCompletionSource(); + _waitingWriters.Enqueue(waiter); + return waiter.Task; + } + } + } + + private void ReaderRelease() + { + TaskCompletionSource toWake = null; + + lock (_waitingWriters) + { + --_status; + + if (_status == 0) + { + if (_waitingWriters.Count > 0) + { + _status = -1; + toWake = _waitingWriters.Dequeue(); + } + } + } + + _reset(this); + + toWake?.SetResult(new Releaser(this, true)); + } + + private void WriterRelease() + { + TaskCompletionSource toWake = null; + bool toWakeIsWriter = false; + + lock (_waitingWriters) + { + if (_waitingWriters.Count > 0) + { + toWake = _waitingWriters.Dequeue(); + toWakeIsWriter = true; + } + else if (_readersWaiting > 0) + { + toWake = _waitingReader; + _status = _readersWaiting; + _readersWaiting = 0; + _waitingReader = new TaskCompletionSource(); + } + else + { + _status = 0; + } + } + + _reset(this); + + toWake?.SetResult(new Releaser(this, toWakeIsWriter)); + } + + public readonly struct Releaser : IDisposable + { + private readonly AsyncKeyLockDoorman toRelease; + private readonly bool writer; + + internal Releaser(AsyncKeyLockDoorman toRelease, bool writer) + { + this.toRelease = toRelease; + this.writer = writer; + } + + public void Dispose() + { + if (toRelease != null) + { + if (writer) + { + toRelease.WriterRelease(); + } + else + { + toRelease.ReaderRelease(); + } + } + } + } + } +} \ No newline at end of file diff --git a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs index 695d536a..453e63fe 100644 --- a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs @@ -26,7 +26,9 @@ protected RateLimitProcessor( _config = config; } - private static readonly SemaphoreSlim Semaphore = new SemaphoreSlim(1); + /// The key-lock used for limiting requests. + /// + private static readonly AsyncKeyLock AsyncLock = new AsyncKeyLock(); public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity) { @@ -55,10 +57,8 @@ public async Task ProcessRequestAsync(ClientRequestIdentity re var counterId = BuildCounterKey(requestIdentity, rule); - // serial reads and writes - await Semaphore.WaitAsync(cancellationToken); - - try + // serial reads and writes on same key + using (await AsyncLock.WriterLockAsync(counterId).ConfigureAwait(false)) { var entry = await _counterStore.GetAsync(counterId, cancellationToken); @@ -82,10 +82,6 @@ public async Task ProcessRequestAsync(ClientRequestIdentity re // stores: id (string) - timestamp (datetime) - total_requests (long) await _counterStore.SetAsync(counterId, counter, rule.PeriodTimespan.Value, cancellationToken); } - finally - { - Semaphore.Release(); - } return counter; } diff --git a/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs b/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs index a5a21074..f8ff4790 100644 --- a/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/DistributedCacheClientPolicyStore.cs @@ -25,7 +25,7 @@ public async Task SeedAsync() { foreach (var rule in _policies.ClientRules) { - await SetAsync($"{_options.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }); + await SetAsync($"{_options.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }).ConfigureAwait(false); } } } diff --git a/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs b/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs index 51094fe7..565a7d72 100644 --- a/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/DistributedCacheIpPolicyStore.cs @@ -23,7 +23,7 @@ public async Task SeedAsync() // on startup, save the IP rules defined in appsettings if (_options != null && _policies != null) { - await SetAsync($"{_options.IpPolicyPrefix}", _policies); + await SetAsync($"{_options.IpPolicyPrefix}", _policies).ConfigureAwait(false); } } } diff --git a/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs b/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs index da20685d..37f87284 100644 --- a/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/MemoryCacheClientPolicyStore.cs @@ -25,7 +25,7 @@ public async Task SeedAsync() { foreach (var rule in _policies.ClientRules) { - await SetAsync($"{_options.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }); + await SetAsync($"{_options.ClientPolicyPrefix}_{rule.ClientId}", new ClientRateLimitPolicy { ClientId = rule.ClientId, Rules = rule.Rules }).ConfigureAwait(false); } } } diff --git a/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs b/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs index da12bafe..69972f62 100644 --- a/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs +++ b/src/AspNetCoreRateLimit/Store/MemoryCacheIpPolicyStore.cs @@ -23,7 +23,7 @@ public async Task SeedAsync() // on startup, save the IP rules defined in appsettings if (_options != null && _policies != null) { - await SetAsync($"{_options.IpPolicyPrefix}", _policies); + await SetAsync($"{_options.IpPolicyPrefix}", _policies).ConfigureAwait(false); } } } diff --git a/test/AspNetCoreRateLimit.Demo/AspNetCoreRateLimit.Demo.csproj b/test/AspNetCoreRateLimit.Demo/AspNetCoreRateLimit.Demo.csproj index b445d1d4..9e922fdc 100644 --- a/test/AspNetCoreRateLimit.Demo/AspNetCoreRateLimit.Demo.csproj +++ b/test/AspNetCoreRateLimit.Demo/AspNetCoreRateLimit.Demo.csproj @@ -4,10 +4,11 @@ netcoreapp3.0 true true - 7.1 + 7.3 + diff --git a/test/AspNetCoreRateLimit.Demo/Startup.cs b/test/AspNetCoreRateLimit.Demo/Startup.cs index 87a141c9..1ba70f8b 100644 --- a/test/AspNetCoreRateLimit.Demo/Startup.cs +++ b/test/AspNetCoreRateLimit.Demo/Startup.cs @@ -1,3 +1,4 @@ +using Ben.Diagnostics; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; @@ -53,6 +54,8 @@ public void ConfigureServices(IServiceCollection services) // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. public void Configure(IApplicationBuilder app, IHostingEnvironment env) { + app.UseBlockingDetection(); + app.UseIpRateLimiting(); app.UseClientRateLimiting(); diff --git a/test/AspNetCoreRateLimit.Tests/AspNetCoreRateLimit.Tests.csproj b/test/AspNetCoreRateLimit.Tests/AspNetCoreRateLimit.Tests.csproj index d590625b..a795ed00 100644 --- a/test/AspNetCoreRateLimit.Tests/AspNetCoreRateLimit.Tests.csproj +++ b/test/AspNetCoreRateLimit.Tests/AspNetCoreRateLimit.Tests.csproj @@ -5,7 +5,7 @@ netcoreapp3.0 true true - 7.1 + 7.3 From 40e0d0d99b1e87ce8e605844608d5acb20e02ce4 Mon Sep 17 00:00:00 2001 From: Cristi Pufu Date: Thu, 21 Feb 2019 20:40:29 +0200 Subject: [PATCH 13/13] Fix endpoint counter key builder --- src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs | 4 ++-- .../CounterKeyBuilders/EndpointCounterKeyBuilder.cs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs index 453e63fe..cb778447 100644 --- a/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs +++ b/src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs @@ -47,7 +47,7 @@ public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity) return false; } - public async Task ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default) + public virtual async Task ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default) { var counter = new RateLimitCounter { @@ -86,7 +86,7 @@ public async Task ProcessRequestAsync(ClientRequestIdentity re return counter; } - public async Task GetRateLimitHeadersAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default) + public virtual async Task GetRateLimitHeadersAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default) { var headers = new RateLimitHeaders(); var counterId = BuildCounterKey(requestIdentity, rule); diff --git a/src/AspNetCoreRateLimit/CounterKeyBuilders/EndpointCounterKeyBuilder.cs b/src/AspNetCoreRateLimit/CounterKeyBuilders/EndpointCounterKeyBuilder.cs index 5a25b521..895d7209 100644 --- a/src/AspNetCoreRateLimit/CounterKeyBuilders/EndpointCounterKeyBuilder.cs +++ b/src/AspNetCoreRateLimit/CounterKeyBuilders/EndpointCounterKeyBuilder.cs @@ -5,7 +5,7 @@ public class EndpointCounterKeyBuilder : ICounterKeyBuilder public string Build(ClientRequestIdentity requestIdentity, RateLimitRule rule) { // This will allow to rate limit /api/values/1 and api/values/2 under same counter - return $"_{requestIdentity.HttpVerb}_{requestIdentity.Path}_{rule.Endpoint}"; + return $"_{rule.Endpoint}"; } } } \ No newline at end of file