diff --git a/api/envoy/config/route/v3/route_components.proto b/api/envoy/config/route/v3/route_components.proto index 53b351b8d3aa..9538a417eb7c 100644 --- a/api/envoy/config/route/v3/route_components.proto +++ b/api/envoy/config/route/v3/route_components.proto @@ -1541,6 +1541,7 @@ message VirtualCluster { } // Global rate limiting :ref:`architecture overview `. +// Also applies to Local rate limiting :ref:`using descriptors `. message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit"; diff --git a/api/envoy/config/route/v4alpha/route_components.proto b/api/envoy/config/route/v4alpha/route_components.proto index 577282595d84..319c65c6e2ff 100644 --- a/api/envoy/config/route/v4alpha/route_components.proto +++ b/api/envoy/config/route/v4alpha/route_components.proto @@ -1490,6 +1490,7 @@ message VirtualCluster { } // Global rate limiting :ref:`architecture overview `. +// Also applies to Local rate limiting :ref:`using descriptors `. message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.config.route.v3.RateLimit"; diff --git a/api/envoy/extensions/common/ratelimit/v3/ratelimit.proto b/api/envoy/extensions/common/ratelimit/v3/ratelimit.proto index 30efa6026218..6bb771d25af9 100644 --- a/api/envoy/extensions/common/ratelimit/v3/ratelimit.proto +++ b/api/envoy/extensions/common/ratelimit/v3/ratelimit.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.extensions.common.ratelimit.v3; import "envoy/type/v3/ratelimit_unit.proto"; +import "envoy/type/v3/token_bucket.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; @@ -92,3 +93,11 @@ message RateLimitDescriptor { // Optional rate limit override to supply to the ratelimit service. RateLimitOverride limit = 2; } + +message LocalRateLimitDescriptor { + // Descriptor entries. + repeated v3.RateLimitDescriptor.Entry entries = 1 [(validate.rules).repeated = {min_items: 1}]; + + // Token Bucket algorithm for local ratelimiting. + type.v3.TokenBucket token_bucket = 2 [(validate.rules).message = {required: true}]; +} diff --git a/api/envoy/extensions/filters/http/local_ratelimit/v3/BUILD b/api/envoy/extensions/filters/http/local_ratelimit/v3/BUILD index ad2fc9a9a84f..6c58a43e4ff6 100644 --- a/api/envoy/extensions/filters/http/local_ratelimit/v3/BUILD +++ b/api/envoy/extensions/filters/http/local_ratelimit/v3/BUILD @@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2 api_proto_package( deps = [ "//envoy/config/core/v3:pkg", + "//envoy/extensions/common/ratelimit/v3:pkg", "//envoy/type/v3:pkg", "@com_github_cncf_udpa//udpa/annotations:pkg", ], diff --git a/api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto b/api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto index 94f21edd3eed..546ff26bac79 100644 --- a/api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto +++ b/api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.extensions.filters.http.local_ratelimit.v3; import "envoy/config/core/v3/base.proto"; +import "envoy/extensions/common/ratelimit/v3/ratelimit.proto"; import "envoy/type/v3/http_status.proto"; import "envoy/type/v3/token_bucket.proto"; @@ -19,7 +20,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Local Rate limit :ref:`configuration overview `. // [#extension: envoy.filters.http.local_ratelimit] -// [#next-free-field: 7] +// [#next-free-field: 10] message LocalRateLimit { // The human readable prefix to use when emitting stats. string stat_prefix = 1 [(validate.rules).string = {min_len: 1}]; @@ -67,4 +68,28 @@ message LocalRateLimit { // have been rate limited. repeated config.core.v3.HeaderValueOption response_headers_to_add = 6 [(validate.rules).repeated = {max_items: 10}]; + + // The rate limit descriptor list to use in the local rate limit to override + // on. The rate limit descriptor is selected by the first full match from the + // request descriptors. + // + // Example on how to use ::ref:`this ` + // + // .. note:: + // + // In the current implementation the descriptor's token bucket :ref:`fill_interval + // ` must be a multiple + // global :ref:`token bucket's` fill interval. + // + // The descriptors must match verbatim for rate limiting to apply. There is no partial + // match by a subset of descriptor entries in the current implementation. + repeated common.ratelimit.v3.LocalRateLimitDescriptor descriptors = 8; + + // Specifies the rate limit configurations to be applied with the same + // stage number. If not set, the default stage number is 0. + // + // .. note:: + // + // The filter supports a range of 0 - 10 inclusively for stage numbers. + uint32 stage = 9 [(validate.rules).uint32 = {lte: 10}]; } diff --git a/docs/root/api-v3/config/common/common.rst b/docs/root/api-v3/config/common/common.rst index bb6965a5f149..f286ba06c4e9 100644 --- a/docs/root/api-v3/config/common/common.rst +++ b/docs/root/api-v3/config/common/common.rst @@ -8,3 +8,4 @@ Common matcher/v3/* ../../extensions/common/dynamic_forward_proxy/v3/* ../../extensions/common/tap/v3/* + ../../extensions/common/ratelimit/v3/* diff --git a/docs/root/configuration/http/http_filters/local_rate_limit_filter.rst b/docs/root/configuration/http/http_filters/local_rate_limit_filter.rst index 78bbc806a78e..3903eefe8b33 100644 --- a/docs/root/configuration/http/http_filters/local_rate_limit_filter.rst +++ b/docs/root/configuration/http/http_filters/local_rate_limit_filter.rst @@ -103,6 +103,93 @@ The route specific configuration: Note that if this filter is configured as globally disabled and there are no virtual host or route level token buckets, no rate limiting will be applied. +.. _config_http_filters_local_rate_limit_descriptors: + +Using rate limit descriptors for local rate limiting +---------------------------------------------------- + +Rate limit descriptors can be used to override local per-route rate limiting. +A route's :ref:`rate limit action ` +is used to match up a :ref:`local descriptor +` in +the filter config descriptor list. The local descriptor's token bucket +settings are then used to decide if the request should be rate limited or not +depending on whether the local descriptor's entries match the route's rate +limit actions descriptor entries. If there is no matching descriptor entries, +the default token bucket is used. + +Example filter configuration using descriptors: + +.. validated-code-block:: yaml + :type-name: envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + + route_config: + name: local_route + virtual_hosts: + - name: local_service + domains: ["*"] + routes: + - match: { prefix: "/foo" } + route: { cluster: service_protected_by_rate_limit } + typed_per_filter_config: + envoy.filters.http.local_ratelimit: + "@type": type.googleapis.com/envoy.extensions.filters.http.local_ratelimit.v3.LocalRateLimit + stat_prefix: test + token_bucket: + max_tokens: 1000 + tokens_per_fill: 1000 + fill_interval: 60s + filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED + filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED + response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' + descriptors: + - entries: + - key: client_cluster + value: foo + - key: path + value: /foo/bar + token_bucket: + max_tokens: 10 + tokens_per_fill: 10 + fill_interval: 60s + - entries: + - key: client_cluster + value: foo + - key: path + value: /foo/bar2 + token_bucket: + max_tokens: 100 + tokens_per_fill: 100 + fill_interval: 60s + - match: { prefix: "/" } + route: { cluster: default_service } + rate_limits: + - actions: # any actions in here + - request_headers: + header_name: x-envoy-downstream-service-cluster + descriptor_key: client_cluster + - request_headers: + header_name: ":path" + descriptor_key: path + +In this example, requests are rate-limited for routes prefixed with "/foo" as +follow. If requests come from a downstream service cluster "foo" for "/foo/bar" +path, then 10 req/min are allowed. But if they come from a downstream service +cluster "foo" for "/foo/bar2" path, then 100 req/min are allowed. Otherwise, +1000 req/min are allowed. + Statistics ---------- diff --git a/generated_api_shadow/envoy/config/route/v3/route_components.proto b/generated_api_shadow/envoy/config/route/v3/route_components.proto index bfb296e47836..afdd534094e3 100644 --- a/generated_api_shadow/envoy/config/route/v3/route_components.proto +++ b/generated_api_shadow/envoy/config/route/v3/route_components.proto @@ -1553,6 +1553,7 @@ message VirtualCluster { } // Global rate limiting :ref:`architecture overview `. +// Also applies to Local rate limiting :ref:`using descriptors `. message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit"; diff --git a/generated_api_shadow/envoy/config/route/v4alpha/route_components.proto b/generated_api_shadow/envoy/config/route/v4alpha/route_components.proto index 586527865d5b..5b904936572a 100644 --- a/generated_api_shadow/envoy/config/route/v4alpha/route_components.proto +++ b/generated_api_shadow/envoy/config/route/v4alpha/route_components.proto @@ -1557,6 +1557,7 @@ message VirtualCluster { } // Global rate limiting :ref:`architecture overview `. +// Also applies to Local rate limiting :ref:`using descriptors `. message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.config.route.v3.RateLimit"; diff --git a/generated_api_shadow/envoy/extensions/common/ratelimit/v3/ratelimit.proto b/generated_api_shadow/envoy/extensions/common/ratelimit/v3/ratelimit.proto index 30efa6026218..6bb771d25af9 100644 --- a/generated_api_shadow/envoy/extensions/common/ratelimit/v3/ratelimit.proto +++ b/generated_api_shadow/envoy/extensions/common/ratelimit/v3/ratelimit.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.extensions.common.ratelimit.v3; import "envoy/type/v3/ratelimit_unit.proto"; +import "envoy/type/v3/token_bucket.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; @@ -92,3 +93,11 @@ message RateLimitDescriptor { // Optional rate limit override to supply to the ratelimit service. RateLimitOverride limit = 2; } + +message LocalRateLimitDescriptor { + // Descriptor entries. + repeated v3.RateLimitDescriptor.Entry entries = 1 [(validate.rules).repeated = {min_items: 1}]; + + // Token Bucket algorithm for local ratelimiting. + type.v3.TokenBucket token_bucket = 2 [(validate.rules).message = {required: true}]; +} diff --git a/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/BUILD b/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/BUILD index ad2fc9a9a84f..6c58a43e4ff6 100644 --- a/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/BUILD +++ b/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/BUILD @@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2 api_proto_package( deps = [ "//envoy/config/core/v3:pkg", + "//envoy/extensions/common/ratelimit/v3:pkg", "//envoy/type/v3:pkg", "@com_github_cncf_udpa//udpa/annotations:pkg", ], diff --git a/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto b/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto index 94f21edd3eed..546ff26bac79 100644 --- a/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto +++ b/generated_api_shadow/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.extensions.filters.http.local_ratelimit.v3; import "envoy/config/core/v3/base.proto"; +import "envoy/extensions/common/ratelimit/v3/ratelimit.proto"; import "envoy/type/v3/http_status.proto"; import "envoy/type/v3/token_bucket.proto"; @@ -19,7 +20,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Local Rate limit :ref:`configuration overview `. // [#extension: envoy.filters.http.local_ratelimit] -// [#next-free-field: 7] +// [#next-free-field: 10] message LocalRateLimit { // The human readable prefix to use when emitting stats. string stat_prefix = 1 [(validate.rules).string = {min_len: 1}]; @@ -67,4 +68,28 @@ message LocalRateLimit { // have been rate limited. repeated config.core.v3.HeaderValueOption response_headers_to_add = 6 [(validate.rules).repeated = {max_items: 10}]; + + // The rate limit descriptor list to use in the local rate limit to override + // on. The rate limit descriptor is selected by the first full match from the + // request descriptors. + // + // Example on how to use ::ref:`this ` + // + // .. note:: + // + // In the current implementation the descriptor's token bucket :ref:`fill_interval + // ` must be a multiple + // global :ref:`token bucket's` fill interval. + // + // The descriptors must match verbatim for rate limiting to apply. There is no partial + // match by a subset of descriptor entries in the current implementation. + repeated common.ratelimit.v3.LocalRateLimitDescriptor descriptors = 8; + + // Specifies the rate limit configurations to be applied with the same + // stage number. If not set, the default stage number is 0. + // + // .. note:: + // + // The filter supports a range of 0 - 10 inclusively for stage numbers. + uint32 stage = 9 [(validate.rules).uint32 = {lte: 10}]; } diff --git a/include/envoy/ratelimit/ratelimit.h b/include/envoy/ratelimit/ratelimit.h index e10641400310..00646eda753e 100644 --- a/include/envoy/ratelimit/ratelimit.h +++ b/include/envoy/ratelimit/ratelimit.h @@ -9,6 +9,7 @@ #include "envoy/stream_info/stream_info.h" #include "envoy/type/v3/ratelimit_unit.pb.h" +#include "absl/time/time.h" #include "absl/types/optional.h" namespace Envoy { @@ -28,6 +29,15 @@ struct RateLimitOverride { struct DescriptorEntry { std::string key_; std::string value_; + + friend bool operator==(const DescriptorEntry& lhs, const DescriptorEntry& rhs) { + return lhs.key_ == rhs.key_ && lhs.value_ == rhs.value_; + } + template + friend H AbslHashValue(H h, // NOLINT(readability-identifier-naming) + const DescriptorEntry& entry) { + return H::combine(std::move(h), entry.key_, entry.value_); + } }; /** @@ -39,6 +49,25 @@ struct Descriptor { }; /** + * A single token bucket. See token_bucket.proto. + */ +struct TokenBucket { + uint32_t max_tokens_; + uint32_t tokens_per_fill_; + absl::Duration fill_interval_; +}; + +/** + * A single rate limit request descriptor. See ratelimit.proto. + */ +struct LocalDescriptor { + std::vector entries_; + friend bool operator==(const LocalDescriptor& lhs, const LocalDescriptor& rhs) { + return lhs.entries_ == rhs.entries_; + } +}; + +/* * Base interface for generic rate limit descriptor producer. */ class DescriptorProducer { @@ -46,14 +75,15 @@ class DescriptorProducer { virtual ~DescriptorProducer() = default; /** - * Potentially append a descriptor entry to the end of descriptor. - * @param descriptor supplies the descriptor to optionally fill. + * Potentially fill a descriptor entry to the end of descriptor. + * @param descriptor_entry supplies the descriptor entry to optionally fill. * @param local_service_cluster supplies the name of the local service cluster. * @param headers supplies the header for the request. * @param info stream info associated with the request * @return true if the producer populated the descriptor. */ - virtual bool populateDescriptor(Descriptor& descriptor, const std::string& local_service_cluster, + virtual bool populateDescriptor(DescriptorEntry& descriptor_entry, + const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) const PURE; }; @@ -73,8 +103,8 @@ class DescriptorProducerFactory : public Config::TypedFactory { * * @param config supplies the configuration for the descriptor extension. * @param validator configuration validation visitor. - * @return DescriptorProducerPtr the rate limit descriptor producer which will be used to populate - * rate limit descriptors. + * @return DescriptorProducerPtr the rate limit descriptor producer which will be used to + * populate rate limit descriptors. */ virtual DescriptorProducerPtr createDescriptorProducerFromProto(const Protobuf::Message& config, diff --git a/include/envoy/router/router_ratelimit.h b/include/envoy/router/router_ratelimit.h index 8a4ef25f6d6d..199b3248130e 100644 --- a/include/envoy/router/router_ratelimit.h +++ b/include/envoy/router/router_ratelimit.h @@ -59,6 +59,18 @@ class RateLimitPolicyEntry { const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) const PURE; + + /** + * Potentially populate the local descriptor array with new descriptors to query. + * @param descriptors supplies the descriptor array to optionally fill. + * @param local_service_cluster supplies the name of the local service cluster. + * @param headers supplies the header for the request. + * @param info stream info associated with the request + */ + virtual void populateLocalDescriptors(std::vector& descriptors, + const std::string& local_service_cluster, + const Http::RequestHeaderMap& headers, + const StreamInfo::StreamInfo& info) const PURE; }; /** diff --git a/source/common/router/router_ratelimit.cc b/source/common/router/router_ratelimit.cc index cfe882a60f42..6a194da1df9f 100644 --- a/source/common/router/router_ratelimit.cc +++ b/source/common/router/router_ratelimit.cc @@ -17,6 +17,27 @@ namespace Envoy { namespace Router { +namespace { +bool populateDescriptor(const std::vector& actions, + std::vector& descriptor_entries, + const std::string& local_service_cluster, + const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) { + bool result = true; + for (const RateLimit::DescriptorProducerPtr& action : actions) { + RateLimit::DescriptorEntry descriptor_entry; + result = result && + action->populateDescriptor(descriptor_entry, local_service_cluster, headers, info); + if (!result) { + break; + } + if (!descriptor_entry.key_.empty()) { + descriptor_entries.push_back(descriptor_entry); + } + } + return result; +} +} // namespace + const uint64_t RateLimitPolicyImpl::MAX_STAGE_NUMBER = 10UL; bool DynamicMetadataRateLimitOverride::populateOverride( @@ -44,22 +65,23 @@ bool DynamicMetadataRateLimitOverride::populateOverride( return false; } -bool SourceClusterAction::populateDescriptor(RateLimit::Descriptor& descriptor, +bool SourceClusterAction::populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string& local_service_cluster, const Http::RequestHeaderMap&, const StreamInfo::StreamInfo&) const { - descriptor.entries_.push_back({"source_cluster", local_service_cluster}); + descriptor_entry = {"source_cluster", local_service_cluster}; return true; } -bool DestinationClusterAction::populateDescriptor(RateLimit::Descriptor& descriptor, +bool DestinationClusterAction::populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string&, const Http::RequestHeaderMap&, const StreamInfo::StreamInfo& info) const { - descriptor.entries_.push_back({"destination_cluster", info.routeEntry()->clusterName()}); + descriptor_entry = {"destination_cluster", info.routeEntry()->clusterName()}; return true; } -bool RequestHeadersAction::populateDescriptor(RateLimit::Descriptor& descriptor, const std::string&, +bool RequestHeadersAction::populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, + const std::string&, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo&) const { const auto header_value = headers.get(header_name_); @@ -71,13 +93,12 @@ bool RequestHeadersAction::populateDescriptor(RateLimit::Descriptor& descriptor, return skip_if_absent_; } // TODO(https://github.com/envoyproxy/envoy/issues/13454): Potentially populate all header values. - descriptor.entries_.push_back( - {descriptor_key_, std::string(header_value[0]->value().getStringView())}); + descriptor_entry = {descriptor_key_, std::string(header_value[0]->value().getStringView())}; return true; } -bool RemoteAddressAction::populateDescriptor(RateLimit::Descriptor& descriptor, const std::string&, - const Http::RequestHeaderMap&, +bool RemoteAddressAction::populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, + const std::string&, const Http::RequestHeaderMap&, const StreamInfo::StreamInfo& info) const { const Network::Address::InstanceConstSharedPtr& remote_address = info.downstreamAddressProvider().remoteAddress(); @@ -85,14 +106,14 @@ bool RemoteAddressAction::populateDescriptor(RateLimit::Descriptor& descriptor, return false; } - descriptor.entries_.push_back({"remote_address", remote_address->ip()->addressAsString()}); + descriptor_entry = {"remote_address", remote_address->ip()->addressAsString()}; return true; } -bool GenericKeyAction::populateDescriptor(RateLimit::Descriptor& descriptor, const std::string&, - const Http::RequestHeaderMap&, +bool GenericKeyAction::populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, + const std::string&, const Http::RequestHeaderMap&, const StreamInfo::StreamInfo&) const { - descriptor.entries_.push_back({descriptor_key_, descriptor_value_}); + descriptor_entry = {descriptor_key_, descriptor_value_}; return true; } @@ -106,8 +127,8 @@ MetaDataAction::MetaDataAction( default_value_(action.default_value()), source_(envoy::config::route::v3::RateLimit::Action::MetaData::DYNAMIC) {} -bool MetaDataAction::populateDescriptor(RateLimit::Descriptor& descriptor, const std::string&, - const Http::RequestHeaderMap&, +bool MetaDataAction::populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, + const std::string&, const Http::RequestHeaderMap&, const StreamInfo::StreamInfo& info) const { const envoy::config::core::v3::Metadata* metadata_source; @@ -126,10 +147,10 @@ bool MetaDataAction::populateDescriptor(RateLimit::Descriptor& descriptor, const Envoy::Config::Metadata::metadataValue(metadata_source, metadata_key_).string_value(); if (!metadata_string_value.empty()) { - descriptor.entries_.push_back({descriptor_key_, metadata_string_value}); + descriptor_entry = {descriptor_key_, metadata_string_value}; return true; } else if (metadata_string_value.empty() && !default_value_.empty()) { - descriptor.entries_.push_back({descriptor_key_, default_value_}); + descriptor_entry = {descriptor_key_, default_value_}; return true; } @@ -142,12 +163,12 @@ HeaderValueMatchAction::HeaderValueMatchAction( expect_match_(PROTOBUF_GET_WRAPPED_OR_DEFAULT(action, expect_match, true)), action_headers_(Http::HeaderUtility::buildHeaderDataVector(action.headers())) {} -bool HeaderValueMatchAction::populateDescriptor(RateLimit::Descriptor& descriptor, +bool HeaderValueMatchAction::populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string&, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo&) const { if (expect_match_ == Http::HeaderUtility::matchHeaders(headers, action_headers_)) { - descriptor.entries_.push_back({"header_match", descriptor_value_}); + descriptor_entry = {"header_match", descriptor_value_}; return true; } else { return false; @@ -225,13 +246,8 @@ void RateLimitPolicyEntryImpl::populateDescriptors(std::vectorpopulateDescriptor(descriptor, local_service_cluster, headers, info); - if (!result) { - break; - } - } + bool result = + populateDescriptor(actions_, descriptor.entries_, local_service_cluster, headers, info); if (limit_override_) { limit_override_.value()->populateOverride(descriptor, &info.dynamicMetadata()); @@ -242,6 +258,18 @@ void RateLimitPolicyEntryImpl::populateDescriptors(std::vector& descriptors, + const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, + const StreamInfo::StreamInfo& info) const { + RateLimit::LocalDescriptor descriptor({}); + bool result = + populateDescriptor(actions_, descriptor.entries_, local_service_cluster, headers, info); + if (result) { + descriptors.emplace_back(descriptor); + } +} + RateLimitPolicyImpl::RateLimitPolicyImpl( const Protobuf::RepeatedPtrField& rate_limits, ProtobufMessage::ValidationVisitor& validator) diff --git a/source/common/router/router_ratelimit.h b/source/common/router/router_ratelimit.h index 7aa0d6ca0401..b7d592f32974 100644 --- a/source/common/router/router_ratelimit.h +++ b/source/common/router/router_ratelimit.h @@ -42,7 +42,7 @@ class DynamicMetadataRateLimitOverride : public RateLimitOverrideAction { class SourceClusterAction : public RateLimit::DescriptorProducer { public: // Ratelimit::DescriptorProducer - bool populateDescriptor(RateLimit::Descriptor& descriptor, + bool populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) const override; @@ -54,7 +54,7 @@ class SourceClusterAction : public RateLimit::DescriptorProducer { class DestinationClusterAction : public RateLimit::DescriptorProducer { public: // Ratelimit::DescriptorProducer - bool populateDescriptor(RateLimit::Descriptor& descriptor, + bool populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) const override; @@ -70,7 +70,7 @@ class RequestHeadersAction : public RateLimit::DescriptorProducer { skip_if_absent_(action.skip_if_absent()) {} // Ratelimit::DescriptorProducer - bool populateDescriptor(RateLimit::Descriptor& descriptor, + bool populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) const override; @@ -87,7 +87,7 @@ class RequestHeadersAction : public RateLimit::DescriptorProducer { class RemoteAddressAction : public RateLimit::DescriptorProducer { public: // Ratelimit::DescriptorProducer - bool populateDescriptor(RateLimit::Descriptor& descriptor, + bool populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) const override; @@ -104,7 +104,7 @@ class GenericKeyAction : public RateLimit::DescriptorProducer { : "generic_key") {} // Ratelimit::DescriptorProducer - bool populateDescriptor(RateLimit::Descriptor& descriptor, + bool populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) const override; @@ -123,7 +123,7 @@ class MetaDataAction : public RateLimit::DescriptorProducer { // for maintaining backward compatibility with the deprecated DynamicMetaData action MetaDataAction(const envoy::config::route::v3::RateLimit::Action::DynamicMetaData& action); // Ratelimit::DescriptorProducer - bool populateDescriptor(RateLimit::Descriptor& descriptor, + bool populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) const override; @@ -144,7 +144,7 @@ class HeaderValueMatchAction : public RateLimit::DescriptorProducer { const envoy::config::route::v3::RateLimit::Action::HeaderValueMatch& action); // Ratelimit::DescriptorProducer - bool populateDescriptor(RateLimit::Descriptor& descriptor, + bool populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) const override; @@ -169,6 +169,10 @@ class RateLimitPolicyEntryImpl : public RateLimitPolicyEntry { void populateDescriptors(std::vector& descriptors, const std::string& local_service_cluster, const Http::RequestHeaderMap&, const StreamInfo::StreamInfo& info) const override; + void populateLocalDescriptors(std::vector& descriptors, + const std::string& local_service_cluster, + const Http::RequestHeaderMap&, + const StreamInfo::StreamInfo& info) const override; private: const std::string disable_key_; diff --git a/source/extensions/filters/common/local_ratelimit/BUILD b/source/extensions/filters/common/local_ratelimit/BUILD index 1a201025ca3f..0234c335c3e6 100644 --- a/source/extensions/filters/common/local_ratelimit/BUILD +++ b/source/extensions/filters/common/local_ratelimit/BUILD @@ -15,6 +15,9 @@ envoy_cc_library( deps = [ "//include/envoy/event:dispatcher_interface", "//include/envoy/event:timer_interface", + "//include/envoy/ratelimit:ratelimit_interface", "//source/common/common:thread_synchronizer_lib", + "//source/common/protobuf:utility_lib", + "@envoy_api//envoy/extensions/common/ratelimit/v3:pkg_cc_proto", ], ) diff --git a/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc b/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc index 2adee384673e..ab3100e50391 100644 --- a/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc +++ b/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc @@ -1,27 +1,62 @@ #include "extensions/filters/common/local_ratelimit/local_ratelimit_impl.h" +#include "common/protobuf/utility.h" + namespace Envoy { namespace Extensions { namespace Filters { namespace Common { namespace LocalRateLimit { -LocalRateLimiterImpl::LocalRateLimiterImpl(const std::chrono::milliseconds fill_interval, - const uint32_t max_tokens, - const uint32_t tokens_per_fill, - Event::Dispatcher& dispatcher) - : fill_interval_(fill_interval), max_tokens_(max_tokens), tokens_per_fill_(tokens_per_fill), - fill_timer_(fill_interval_ > std::chrono::milliseconds(0) +LocalRateLimiterImpl::LocalRateLimiterImpl( + const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, + const uint32_t tokens_per_fill, Event::Dispatcher& dispatcher, + const Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor>& descriptors) + : fill_timer_(fill_interval > std::chrono::milliseconds(0) ? dispatcher.createTimer([this] { onFillTimer(); }) - : nullptr) { - if (fill_timer_ && fill_interval_ < std::chrono::milliseconds(50)) { + : nullptr), + time_source_(dispatcher.timeSource()) { + if (fill_timer_ && fill_interval < std::chrono::milliseconds(50)) { throw EnvoyException("local rate limit token bucket fill timer must be >= 50ms"); } - tokens_ = max_tokens; + token_bucket_.max_tokens_ = max_tokens; + token_bucket_.tokens_per_fill_ = tokens_per_fill; + token_bucket_.fill_interval_ = absl::FromChrono(fill_interval); + tokens_.tokens_ = max_tokens; if (fill_timer_) { - fill_timer_->enableTimer(fill_interval_); + fill_timer_->enableTimer(fill_interval); + } + + for (const auto& descriptor : descriptors) { + LocalDescriptorImpl new_descriptor; + for (const auto& entry : descriptor.entries()) { + new_descriptor.entries_.push_back({entry.key(), entry.value()}); + } + RateLimit::TokenBucket token_bucket; + token_bucket.fill_interval_ = + absl::Milliseconds(PROTOBUF_GET_MS_OR_DEFAULT(descriptor.token_bucket(), fill_interval, 0)); + if (token_bucket.fill_interval_ % token_bucket_.fill_interval_ != absl::ZeroDuration()) { + throw EnvoyException( + "local rate descriptor limit is not a multiple of token bucket fill timer"); + } + token_bucket.max_tokens_ = descriptor.token_bucket().max_tokens(); + token_bucket.tokens_per_fill_ = + PROTOBUF_GET_WRAPPED_OR_DEFAULT(descriptor.token_bucket(), tokens_per_fill, 1); + new_descriptor.token_bucket_ = token_bucket; + + auto token_state = std::make_unique(); + token_state->tokens_ = token_bucket.max_tokens_; + token_state->fill_time_ = time_source_.monotonicTime(); + new_descriptor.token_state_ = std::move(token_state); + + auto result = descriptors_.emplace(std::move(new_descriptor)); + if (!result.second) { + throw EnvoyException(absl::StrCat("duplicate descriptor in the local rate descriptor: ", + result.first->toString())); + } } } @@ -32,28 +67,45 @@ LocalRateLimiterImpl::~LocalRateLimiterImpl() { } void LocalRateLimiterImpl::onFillTimer() { + onFillTimerHelper(tokens_, token_bucket_); + onFillTimerDescriptorHelper(); + fill_timer_->enableTimer(absl::ToChronoMilliseconds(token_bucket_.fill_interval_)); +} + +void LocalRateLimiterImpl::onFillTimerHelper(const TokenState& tokens, + const RateLimit::TokenBucket& bucket) { // Relaxed consistency is used for all operations because we don't care about ordering, just the // final atomic correctness. - uint32_t expected_tokens = tokens_.load(std::memory_order_relaxed); + uint32_t expected_tokens = tokens.tokens_.load(std::memory_order_relaxed); uint32_t new_tokens_value; do { // expected_tokens is either initialized above or reloaded during the CAS failure below. - new_tokens_value = std::min(max_tokens_, expected_tokens + tokens_per_fill_); + new_tokens_value = std::min(bucket.max_tokens_, expected_tokens + bucket.tokens_per_fill_); // Testing hook. synchronizer_.syncPoint("on_fill_timer_pre_cas"); // Loop while the weak CAS fails trying to update the tokens value. - } while ( - !tokens_.compare_exchange_weak(expected_tokens, new_tokens_value, std::memory_order_relaxed)); + } while (!tokens.tokens_.compare_exchange_weak(expected_tokens, new_tokens_value, + std::memory_order_relaxed)); +} - fill_timer_->enableTimer(fill_interval_); +void LocalRateLimiterImpl::onFillTimerDescriptorHelper() { + auto current_time = time_source_.monotonicTime(); + for (const auto& descriptor : descriptors_) { + if (std::chrono::duration_cast( + current_time - descriptor.token_state_->fill_time_) >= + absl::ToChronoMilliseconds(descriptor.token_bucket_.fill_interval_)) { + onFillTimerHelper(*descriptor.token_state_, descriptor.token_bucket_); + descriptor.token_state_->fill_time_ = current_time; + } + } } -bool LocalRateLimiterImpl::requestAllowed() const { +bool LocalRateLimiterImpl::requestAllowedHelper(const TokenState& tokens) const { // Relaxed consistency is used for all operations because we don't care about ordering, just the // final atomic correctness. - uint32_t expected_tokens = tokens_.load(std::memory_order_relaxed); + uint32_t expected_tokens = tokens.tokens_.load(std::memory_order_relaxed); do { // expected_tokens is either initialized above or reloaded during the CAS failure below. if (expected_tokens == 0) { @@ -64,13 +116,26 @@ bool LocalRateLimiterImpl::requestAllowed() const { synchronizer_.syncPoint("allowed_pre_cas"); // Loop while the weak CAS fails trying to subtract 1 from expected. - } while (!tokens_.compare_exchange_weak(expected_tokens, expected_tokens - 1, - std::memory_order_relaxed)); + } while (!tokens.tokens_.compare_exchange_weak(expected_tokens, expected_tokens - 1, + std::memory_order_relaxed)); // We successfully decremented the counter by 1. return true; } +bool LocalRateLimiterImpl::requestAllowed( + absl::Span request_descriptors) const { + if (!descriptors_.empty() && !request_descriptors.empty()) { + for (const auto& request_descriptor : request_descriptors) { + auto it = descriptors_.find(request_descriptor); + if (it != descriptors_.end()) { + return requestAllowedHelper(*it->token_state_); + } + } + } + return requestAllowedHelper(tokens_); +} + } // namespace LocalRateLimit } // namespace Common } // namespace Filters diff --git a/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h b/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h index 2e35dc5b0ef4..953fb612daf8 100644 --- a/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h +++ b/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h @@ -4,8 +4,11 @@ #include "envoy/event/dispatcher.h" #include "envoy/event/timer.h" +#include "envoy/extensions/common/ratelimit/v3/ratelimit.pb.h" +#include "envoy/ratelimit/ratelimit.h" #include "common/common/thread_synchronizer.h" +#include "common/protobuf/protobuf.h" namespace Envoy { namespace Extensions { @@ -15,20 +18,56 @@ namespace LocalRateLimit { class LocalRateLimiterImpl { public: - LocalRateLimiterImpl(const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, - const uint32_t tokens_per_fill, Event::Dispatcher& dispatcher); + LocalRateLimiterImpl( + const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, + const uint32_t tokens_per_fill, Event::Dispatcher& dispatcher, + const Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor>& descriptors); ~LocalRateLimiterImpl(); - bool requestAllowed() const; + bool requestAllowed(absl::Span request_descriptors) const; private: + struct TokenState { + mutable std::atomic tokens_; + MonotonicTime fill_time_; + }; + struct LocalDescriptorImpl : public RateLimit::LocalDescriptor { + std::unique_ptr token_state_; + RateLimit::TokenBucket token_bucket_; + std::string toString() const { + std::vector entries; + entries.reserve(entries_.size()); + for (const auto& entry : entries_) { + entries.push_back(absl::StrCat(entry.key_, "=", entry.value_)); + } + return absl::StrJoin(entries, ", "); + } + }; + struct LocalDescriptorHash { + using is_transparent = void; // NOLINT(readability-identifier-naming)t + size_t operator()(const RateLimit::LocalDescriptor& d) const { + return absl::Hash>()(d.entries_); + } + }; + struct LocalDescriptorEqual { + using is_transparent = void; // NOLINT(readability-identifier-naming) + size_t operator()(const RateLimit::LocalDescriptor& a, + const RateLimit::LocalDescriptor& b) const { + return a.entries_ == b.entries_; + } + }; + void onFillTimer(); + void onFillTimerHelper(const TokenState& state, const RateLimit::TokenBucket& bucket); + void onFillTimerDescriptorHelper(); + bool requestAllowedHelper(const TokenState& tokens) const; - const std::chrono::milliseconds fill_interval_; - const uint32_t max_tokens_; - const uint32_t tokens_per_fill_; + RateLimit::TokenBucket token_bucket_; const Event::TimerPtr fill_timer_; - mutable std::atomic tokens_; + TimeSource& time_source_; + TokenState tokens_; + absl::flat_hash_set descriptors_; mutable Thread::ThreadSynchronizer synchronizer_; // Used for testing only. friend class LocalRateLimiterImplTest; diff --git a/source/extensions/filters/http/local_ratelimit/BUILD b/source/extensions/filters/http/local_ratelimit/BUILD index 048d7d4ed4e0..91493ff13f66 100644 --- a/source/extensions/filters/http/local_ratelimit/BUILD +++ b/source/extensions/filters/http/local_ratelimit/BUILD @@ -26,6 +26,7 @@ envoy_cc_library( "//source/common/router:header_parser_lib", "//source/common/runtime:runtime_lib", "//source/extensions/filters/common/local_ratelimit:local_ratelimit_lib", + "//source/extensions/filters/common/ratelimit:ratelimit_lib", "//source/extensions/filters/http/common:pass_through_filter_lib", "@envoy_api//envoy/extensions/filters/http/local_ratelimit/v3:pkg_cc_proto", ], diff --git a/source/extensions/filters/http/local_ratelimit/config.cc b/source/extensions/filters/http/local_ratelimit/config.cc index 529fd0dd2977..a5e629d88055 100644 --- a/source/extensions/filters/http/local_ratelimit/config.cc +++ b/source/extensions/filters/http/local_ratelimit/config.cc @@ -17,7 +17,7 @@ Http::FilterFactoryCb LocalRateLimitFilterConfig::createFilterFactoryFromProtoTy const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& proto_config, const std::string&, Server::Configuration::FactoryContext& context) { FilterConfigSharedPtr filter_config = std::make_shared( - proto_config, context.dispatcher(), context.scope(), context.runtime()); + proto_config, context.localInfo(), context.dispatcher(), context.scope(), context.runtime()); return [filter_config](Http::FilterChainFactoryCallbacks& callbacks) -> void { callbacks.addStreamFilter(std::make_shared(filter_config)); }; @@ -27,7 +27,8 @@ Router::RouteSpecificFilterConfigConstSharedPtr LocalRateLimitFilterConfig::createRouteSpecificFilterConfigTyped( const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& proto_config, Server::Configuration::ServerFactoryContext& context, ProtobufMessage::ValidationVisitor&) { - return std::make_shared(proto_config, context.dispatcher(), context.scope(), + return std::make_shared(proto_config, context.localInfo(), + context.dispatcher(), context.scope(), context.runtime(), true); } diff --git a/source/extensions/filters/http/local_ratelimit/local_ratelimit.cc b/source/extensions/filters/http/local_ratelimit/local_ratelimit.cc index 3b13bfa374ac..263d77849dc2 100644 --- a/source/extensions/filters/http/local_ratelimit/local_ratelimit.cc +++ b/source/extensions/filters/http/local_ratelimit/local_ratelimit.cc @@ -6,6 +6,7 @@ #include "envoy/http/codes.h" #include "common/http/utility.h" +#include "common/router/config_impl.h" namespace Envoy { namespace Extensions { @@ -14,16 +15,17 @@ namespace LocalRateLimitFilter { FilterConfig::FilterConfig( const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& config, - Event::Dispatcher& dispatcher, Stats::Scope& scope, Runtime::Loader& runtime, - const bool per_route) + const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, Stats::Scope& scope, + Runtime::Loader& runtime, const bool per_route) : status_(toErrorCode(config.status().code())), stats_(generateStats(config.stat_prefix(), scope)), rate_limiter_(Filters::Common::LocalRateLimit::LocalRateLimiterImpl( std::chrono::milliseconds( PROTOBUF_GET_MS_OR_DEFAULT(config.token_bucket(), fill_interval, 0)), config.token_bucket().max_tokens(), - PROTOBUF_GET_WRAPPED_OR_DEFAULT(config.token_bucket(), tokens_per_fill, 1), dispatcher)), - runtime_(runtime), + PROTOBUF_GET_WRAPPED_OR_DEFAULT(config.token_bucket(), tokens_per_fill, 1), dispatcher, + config.descriptors())), + local_info_(local_info), runtime_(runtime), filter_enabled_( config.has_filter_enabled() ? absl::optional( @@ -35,7 +37,9 @@ FilterConfig::FilterConfig( Envoy::Runtime::FractionalPercent(config.filter_enforced(), runtime_)) : absl::nullopt), response_headers_parser_( - Envoy::Router::HeaderParser::configure(config.response_headers_to_add())) { + Envoy::Router::HeaderParser::configure(config.response_headers_to_add())), + stage_(static_cast(config.stage())), + has_descriptors_(!config.descriptors().empty()) { // Note: no token bucket is fine for the global config, which would be the case for enabling // the filter globally but disabled and then applying limits at the virtual host or // route level. At the virtual or route level, it makes no sense to have an no token @@ -46,7 +50,10 @@ FilterConfig::FilterConfig( } } -bool FilterConfig::requestAllowed() const { return rate_limiter_.requestAllowed(); } +bool FilterConfig::requestAllowed( + absl::Span request_descriptors) const { + return rate_limiter_.requestAllowed(request_descriptors); +} LocalRateLimitStats FilterConfig::generateStats(const std::string& prefix, Stats::Scope& scope) { const std::string final_prefix = prefix + ".http_local_rate_limit"; @@ -61,7 +68,7 @@ bool FilterConfig::enforced() const { return filter_enforced_.has_value() ? filter_enforced_->enabled() : false; } -Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap&, bool) { +Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool) { const auto* config = getConfig(); if (!config->enabled()) { @@ -70,7 +77,12 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap&, bool) { config->stats().enabled_.inc(); - if (config->requestAllowed()) { + std::vector descriptors; + if (config->hasDescriptors()) { + populateDescriptors(descriptors, headers); + } + + if (config->requestAllowed(descriptors)) { config->stats().ok_.inc(); return Http::FilterHeadersStatus::Continue; } @@ -94,6 +106,28 @@ Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap&, bool) { return Http::FilterHeadersStatus::StopIteration; } +void Filter::populateDescriptors(std::vector& descriptors, + Http::RequestHeaderMap& headers) { + Router::RouteConstSharedPtr route = decoder_callbacks_->route(); + if (!route || !route->routeEntry()) { + return; + } + + const Router::RouteEntry* route_entry = route->routeEntry(); + // Get all applicable rate limit policy entries for the route. + const auto* config = getConfig(); + for (const Router::RateLimitPolicyEntry& rate_limit : + route_entry->rateLimitPolicy().getApplicableRateLimit(config->stage())) { + const std::string& disable_key = rate_limit.disableKey(); + + if (!disable_key.empty()) { + continue; + } + rate_limit.populateLocalDescriptors(descriptors, config->localInfo().clusterName(), headers, + decoder_callbacks_->streamInfo()); + } +} + const FilterConfig* Filter::getConfig() const { const auto* config = Http::Utility::resolveMostSpecificPerFilterConfig( "envoy.filters.http.local_ratelimit", decoder_callbacks_->route()); diff --git a/source/extensions/filters/http/local_ratelimit/local_ratelimit.h b/source/extensions/filters/http/local_ratelimit/local_ratelimit.h index 6549094d07c3..cffbc399c05d 100644 --- a/source/extensions/filters/http/local_ratelimit/local_ratelimit.h +++ b/source/extensions/filters/http/local_ratelimit/local_ratelimit.h @@ -7,6 +7,7 @@ #include "envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.pb.h" #include "envoy/http/filter.h" +#include "envoy/local_info/local_info.h" #include "envoy/runtime/runtime.h" #include "envoy/stats/scope.h" #include "envoy/stats/stats_macros.h" @@ -17,6 +18,7 @@ #include "common/runtime/runtime_protos.h" #include "extensions/filters/common/local_ratelimit/local_ratelimit_impl.h" +#include "extensions/filters/common/ratelimit/ratelimit.h" #include "extensions/filters/http/common/pass_through_filter.h" namespace Envoy { @@ -43,19 +45,22 @@ struct LocalRateLimitStats { /** * Global configuration for the HTTP local rate limit filter. */ -class FilterConfig : public ::Envoy::Router::RouteSpecificFilterConfig { +class FilterConfig : public Router::RouteSpecificFilterConfig { public: FilterConfig(const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& config, - Event::Dispatcher& dispatcher, Stats::Scope& scope, Runtime::Loader& runtime, - bool per_route = false); + const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, + Stats::Scope& scope, Runtime::Loader& runtime, bool per_route = false); ~FilterConfig() override = default; + const LocalInfo::LocalInfo& localInfo() const { return local_info_; } Runtime::Loader& runtime() { return runtime_; } - bool requestAllowed() const; + bool requestAllowed(absl::Span request_descriptors) const; bool enabled() const; bool enforced() const; LocalRateLimitStats& stats() const { return stats_; } const Router::HeaderParser& responseHeadersParser() const { return *response_headers_parser_; } Http::Code status() const { return status_; } + uint64_t stage() const { return stage_; } + bool hasDescriptors() const { return has_descriptors_; } private: friend class FilterTest; @@ -73,10 +78,13 @@ class FilterConfig : public ::Envoy::Router::RouteSpecificFilterConfig { const Http::Code status_; mutable LocalRateLimitStats stats_; Filters::Common::LocalRateLimit::LocalRateLimiterImpl rate_limiter_; + const LocalInfo::LocalInfo& local_info_; Runtime::Loader& runtime_; const absl::optional filter_enabled_; const absl::optional filter_enforced_; Router::HeaderParserPtr response_headers_parser_; + const uint64_t stage_; + const bool has_descriptors_; }; using FilterConfigSharedPtr = std::shared_ptr; @@ -96,8 +104,10 @@ class Filter : public Http::PassThroughFilter { private: friend class FilterTest; - const FilterConfig* getConfig() const; + void populateDescriptors(std::vector& descriptors, + Http::RequestHeaderMap& headers); + const FilterConfig* getConfig() const; FilterConfigSharedPtr config_; }; diff --git a/source/extensions/filters/network/local_ratelimit/local_ratelimit.cc b/source/extensions/filters/network/local_ratelimit/local_ratelimit.cc index 773daf175139..287679763927 100644 --- a/source/extensions/filters/network/local_ratelimit/local_ratelimit.cc +++ b/source/extensions/filters/network/local_ratelimit/local_ratelimit.cc @@ -18,7 +18,9 @@ Config::Config( PROTOBUF_GET_MS_REQUIRED(proto_config.token_bucket(), fill_interval)), proto_config.token_bucket().max_tokens(), PROTOBUF_GET_WRAPPED_OR_DEFAULT(proto_config.token_bucket(), tokens_per_fill, 1), - dispatcher)), + dispatcher, + Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor>())), enabled_(proto_config.runtime_enabled(), runtime), stats_(generateStats(proto_config.stat_prefix(), scope)) {} @@ -27,7 +29,7 @@ LocalRateLimitStats Config::generateStats(const std::string& prefix, Stats::Scop return {ALL_LOCAL_RATE_LIMIT_STATS(POOL_COUNTER_PREFIX(scope, final_prefix))}; } -bool Config::canCreateConnection() { return rate_limiter_.requestAllowed(); } +bool Config::canCreateConnection() { return rate_limiter_.requestAllowed(descriptors_); } Network::FilterStatus Filter::onNewConnection() { if (!config_->enabled()) { diff --git a/source/extensions/filters/network/local_ratelimit/local_ratelimit.h b/source/extensions/filters/network/local_ratelimit/local_ratelimit.h index e1cd52ac1bee..f8ac07272459 100644 --- a/source/extensions/filters/network/local_ratelimit/local_ratelimit.h +++ b/source/extensions/filters/network/local_ratelimit/local_ratelimit.h @@ -49,6 +49,7 @@ class Config : Logger::Loggable { Runtime::FeatureFlag enabled_; LocalRateLimitStats stats_; + std::vector descriptors_; friend class LocalRateLimitTestBase; }; diff --git a/source/extensions/rate_limit_descriptors/expr/config.cc b/source/extensions/rate_limit_descriptors/expr/config.cc index aeebb5f96e23..57aeda76b92e 100644 --- a/source/extensions/rate_limit_descriptors/expr/config.cc +++ b/source/extensions/rate_limit_descriptors/expr/config.cc @@ -30,7 +30,7 @@ class ExpressionDescriptor : public RateLimit::DescriptorProducer { } // Ratelimit::DescriptorProducer - bool populateDescriptor(RateLimit::Descriptor& descriptor, const std::string&, + bool populateDescriptor(RateLimit::DescriptorEntry& descriptor_entry, const std::string&, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info) const override { ProtobufWkt::Arena arena; @@ -42,7 +42,7 @@ class ExpressionDescriptor : public RateLimit::DescriptorProducer { // service. return skip_if_error_; } - descriptor.entries_.push_back({descriptor_key_, Filters::Common::Expr::print(result.value())}); + descriptor_entry = {descriptor_key_, Filters::Common::Expr::print(result.value())}; return true; } diff --git a/test/common/router/router_ratelimit_test.cc b/test/common/router/router_ratelimit_test.cc index 4fab81660417..43a61fdab55f 100644 --- a/test/common/router/router_ratelimit_test.cc +++ b/test/common/router/router_ratelimit_test.cc @@ -209,11 +209,16 @@ TEST_F(RateLimitConfiguration, TestVirtualHost) { EXPECT_EQ(1U, rate_limits.size()); std::vector descriptors; + std::vector local_descriptors; for (const RateLimitPolicyEntry& rate_limit : rate_limits) { rate_limit.populateDescriptors(descriptors, "service_cluster", header_, stream_info_); + rate_limit.populateLocalDescriptors(local_descriptors, "service_cluster", header_, + stream_info_); } EXPECT_THAT(std::vector({{{{"destination_cluster", "www2test"}}}}), testing::ContainerEq(descriptors)); + EXPECT_THAT(std::vector({{"destination_cluster", "www2test"}}), + testing::ContainerEq(local_descriptors.at(0).entries_)); } TEST_F(RateLimitConfiguration, Stages) { @@ -249,24 +254,35 @@ TEST_F(RateLimitConfiguration, Stages) { EXPECT_EQ(2U, rate_limits.size()); std::vector descriptors; + std::vector local_descriptors; for (const RateLimitPolicyEntry& rate_limit : rate_limits) { rate_limit.populateDescriptors(descriptors, "service_cluster", header_, stream_info_); + rate_limit.populateLocalDescriptors(local_descriptors, "service_cluster", header_, + stream_info_); } EXPECT_THAT(std::vector( {{{{"destination_cluster", "www2test"}}}, {{{"destination_cluster", "www2test"}, {"source_cluster", "service_cluster"}}}}), testing::ContainerEq(descriptors)); + EXPECT_THAT(std::vector( + {{{{"destination_cluster", "www2test"}}}, + {{{"destination_cluster", "www2test"}, {"source_cluster", "service_cluster"}}}}), + testing::ContainerEq(local_descriptors)); descriptors.clear(); + local_descriptors.clear(); rate_limits = route->rateLimitPolicy().getApplicableRateLimit(1UL); EXPECT_EQ(1U, rate_limits.size()); for (const RateLimitPolicyEntry& rate_limit : rate_limits) { rate_limit.populateDescriptors(descriptors, "service_cluster", header_, stream_info_); + rate_limit.populateLocalDescriptors(local_descriptors, "service_cluster", header_, + stream_info_); } EXPECT_THAT(std::vector({{{{"remote_address", "10.0.0.1"}}}}), testing::ContainerEq(descriptors)); - + EXPECT_THAT(std::vector({{{{"remote_address", "10.0.0.1"}}}}), + testing::ContainerEq(local_descriptors)); rate_limits = route->rateLimitPolicy().getApplicableRateLimit(10UL); EXPECT_TRUE(rate_limits.empty()); } @@ -277,6 +293,7 @@ class RateLimitPolicyEntryTest : public testing::Test { rate_limit_entry_ = std::make_unique( parseRateLimitFromV3Yaml(yaml), ProtobufMessage::getStrictValidationVisitor()); descriptors_.clear(); + local_descriptors_.clear(); stream_info_.downstream_address_provider_->setRemoteAddress(default_remote_address_); ON_CALL(Const(stream_info_), routeEntry()).WillByDefault(testing::Return(&route_)); } @@ -285,6 +302,7 @@ class RateLimitPolicyEntryTest : public testing::Test { Http::TestRequestHeaderMapImpl header_; NiceMock route_; std::vector descriptors_; + std::vector local_descriptors_; Network::Address::InstanceConstSharedPtr default_remote_address_{ new Network::Address::Ipv4Instance("10.0.0.1")}; NiceMock stream_info_; @@ -313,8 +331,11 @@ TEST_F(RateLimitPolicyEntryTest, RemoteAddress) { setupTest(yaml); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_THAT(std::vector({{{{"remote_address", "10.0.0.1"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"remote_address", "10.0.0.1"}}}}), + testing::ContainerEq(local_descriptors_)); } // Verify no descriptor is emitted if remote is a pipe. @@ -329,7 +350,9 @@ TEST_F(RateLimitPolicyEntryTest, PipeAddress) { stream_info_.downstream_address_provider_->setRemoteAddress( std::make_shared("/hello")); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, SourceService) { @@ -341,9 +364,14 @@ TEST_F(RateLimitPolicyEntryTest, SourceService) { setupTest(yaml); rate_limit_entry_->populateDescriptors(descriptors_, "service_cluster", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "service_cluster", header_, + stream_info_); EXPECT_THAT( std::vector({{{{"source_cluster", "service_cluster"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT( + std::vector({{{{"source_cluster", "service_cluster"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, DestinationService) { @@ -355,9 +383,14 @@ TEST_F(RateLimitPolicyEntryTest, DestinationService) { setupTest(yaml); rate_limit_entry_->populateDescriptors(descriptors_, "service_cluster", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "service_cluster", header_, + stream_info_); EXPECT_THAT( std::vector({{{{"destination_cluster", "fake_cluster"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT( + std::vector({{{{"destination_cluster", "fake_cluster"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, RequestHeaders) { @@ -372,8 +405,13 @@ TEST_F(RateLimitPolicyEntryTest, RequestHeaders) { Http::TestRequestHeaderMapImpl header{{"x-header-name", "test_value"}}; rate_limit_entry_->populateDescriptors(descriptors_, "service_cluster", header, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "service_cluster", header, + stream_info_); EXPECT_THAT(std::vector({{{{"my_header_name", "test_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT( + std::vector({{{{"my_header_name", "test_value"}}}}), + testing::ContainerEq(local_descriptors_)); } // Validate that a descriptor is added if the missing request header @@ -395,8 +433,13 @@ TEST_F(RateLimitPolicyEntryTest, RequestHeadersWithSkipIfAbsent) { Http::TestRequestHeaderMapImpl header{{"x-header-name", "test_value"}}; rate_limit_entry_->populateDescriptors(descriptors_, "service_cluster", header, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "service_cluster", header, + stream_info_); EXPECT_THAT(std::vector({{{{"my_header_name", "test_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT( + std::vector({{{{"my_header_name", "test_value"}}}}), + testing::ContainerEq(local_descriptors_)); } // Tests if the descriptors are added if one of the headers is missing @@ -418,7 +461,10 @@ TEST_F(RateLimitPolicyEntryTest, RequestHeadersWithDefaultSkipIfAbsent) { Http::TestRequestHeaderMapImpl header{{"x-header-test", "test_value"}}; rate_limit_entry_->populateDescriptors(descriptors_, "service_cluster", header, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "service_cluster", header, + stream_info_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, RequestHeadersNoMatch) { @@ -433,7 +479,10 @@ TEST_F(RateLimitPolicyEntryTest, RequestHeadersNoMatch) { Http::TestRequestHeaderMapImpl header{{"x-header-name", "test_value"}}; rate_limit_entry_->populateDescriptors(descriptors_, "service_cluster", header, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "service_cluster", header, + stream_info_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, RateLimitKey) { @@ -446,8 +495,11 @@ TEST_F(RateLimitPolicyEntryTest, RateLimitKey) { setupTest(yaml); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_THAT(std::vector({{{{"generic_key", "fake_key"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"generic_key", "fake_key"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, GenericKeyWithSetDescriptorKey) { @@ -461,8 +513,11 @@ TEST_F(RateLimitPolicyEntryTest, GenericKeyWithSetDescriptorKey) { setupTest(yaml); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_THAT(std::vector({{{{"fake_key", "fake_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"fake_key", "fake_value"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, GenericKeyWithEmptyDescriptorKey) { @@ -476,8 +531,11 @@ TEST_F(RateLimitPolicyEntryTest, GenericKeyWithEmptyDescriptorKey) { setupTest(yaml); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_THAT(std::vector({{{{"generic_key", "fake_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"generic_key", "fake_value"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, DEPRECATED_FEATURE_TEST(DynamicMetaDataMatch)) { @@ -504,9 +562,12 @@ TEST_F(RateLimitPolicyEntryTest, DEPRECATED_FEATURE_TEST(DynamicMetaDataMatch)) TestUtility::loadFromYaml(metadata_yaml, stream_info_.dynamicMetadata()); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, MetaDataMatchDynamicSourceByDefault) { @@ -533,9 +594,12 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataMatchDynamicSourceByDefault) { TestUtility::loadFromYaml(metadata_yaml, stream_info_.dynamicMetadata()); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, MetaDataMatchDynamicSource) { @@ -563,9 +627,12 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataMatchDynamicSource) { TestUtility::loadFromYaml(metadata_yaml, stream_info_.dynamicMetadata()); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, MetaDataMatchRouteEntrySource) { @@ -594,9 +661,12 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataMatchRouteEntrySource) { TestUtility::loadFromYaml(metadata_yaml, route_.metadata_); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"fake_key", "foo"}}}}), + testing::ContainerEq(local_descriptors_)); } // Tests that the default_value is used in the descriptor when the metadata_key is empty. @@ -624,9 +694,12 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataNoMatchWithDefaultValue) { TestUtility::loadFromYaml(metadata_yaml, stream_info_.dynamicMetadata()); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_THAT(std::vector({{{{"fake_key", "fake_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"fake_key", "fake_value"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, MetaDataNoMatch) { @@ -652,8 +725,10 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataNoMatch) { TestUtility::loadFromYaml(metadata_yaml, stream_info_.dynamicMetadata()); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, MetaDataEmptyValue) { @@ -679,8 +754,10 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataEmptyValue) { TestUtility::loadFromYaml(metadata_yaml, stream_info_.dynamicMetadata()); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } // Tests that no descriptor is generated when both the metadata_key and default_value are empty. TEST_F(RateLimitPolicyEntryTest, MetaDataAndDefaultValueEmpty) { @@ -707,8 +784,10 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataAndDefaultValueEmpty) { TestUtility::loadFromYaml(metadata_yaml, stream_info_.dynamicMetadata()); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, MetaDataNonStringNoMatch) { @@ -735,8 +814,10 @@ TEST_F(RateLimitPolicyEntryTest, MetaDataNonStringNoMatch) { TestUtility::loadFromYaml(metadata_yaml, stream_info_.dynamicMetadata()); rate_limit_entry_->populateDescriptors(descriptors_, "", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header_, stream_info_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, HeaderValueMatch) { @@ -753,8 +834,11 @@ TEST_F(RateLimitPolicyEntryTest, HeaderValueMatch) { Http::TestRequestHeaderMapImpl header{{"x-header-name", "test_value"}}; rate_limit_entry_->populateDescriptors(descriptors_, "", header, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header, stream_info_); EXPECT_THAT(std::vector({{{{"header_match", "fake_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"header_match", "fake_value"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchNoMatch) { @@ -771,7 +855,9 @@ TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchNoMatch) { Http::TestRequestHeaderMapImpl header{{"x-header-name", "not_same_value"}}; rate_limit_entry_->populateDescriptors(descriptors_, "", header, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header, stream_info_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchHeadersNotPresent) { @@ -789,8 +875,11 @@ TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchHeadersNotPresent) { Http::TestRequestHeaderMapImpl header{{"x-header-name", "not_same_value"}}; rate_limit_entry_->populateDescriptors(descriptors_, "", header, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header, stream_info_); EXPECT_THAT(std::vector({{{{"header_match", "fake_value"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT(std::vector({{{{"header_match", "fake_value"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchHeadersPresent) { @@ -808,7 +897,9 @@ TEST_F(RateLimitPolicyEntryTest, HeaderValueMatchHeadersPresent) { Http::TestRequestHeaderMapImpl header{{"x-header-name", "test_value"}}; rate_limit_entry_->populateDescriptors(descriptors_, "", header, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "", header, stream_info_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, CompoundActions) { @@ -821,10 +912,16 @@ TEST_F(RateLimitPolicyEntryTest, CompoundActions) { setupTest(yaml); rate_limit_entry_->populateDescriptors(descriptors_, "service_cluster", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "service_cluster", header_, + stream_info_); EXPECT_THAT( std::vector( {{{{"destination_cluster", "fake_cluster"}, {"source_cluster", "service_cluster"}}}}), testing::ContainerEq(descriptors_)); + EXPECT_THAT( + std::vector( + {{{{"destination_cluster", "fake_cluster"}, {"source_cluster", "service_cluster"}}}}), + testing::ContainerEq(local_descriptors_)); } TEST_F(RateLimitPolicyEntryTest, CompoundActionsNoDescriptor) { @@ -841,7 +938,10 @@ TEST_F(RateLimitPolicyEntryTest, CompoundActionsNoDescriptor) { setupTest(yaml); rate_limit_entry_->populateDescriptors(descriptors_, "service_cluster", header_, stream_info_); + rate_limit_entry_->populateLocalDescriptors(local_descriptors_, "service_cluster", header_, + stream_info_); EXPECT_TRUE(descriptors_.empty()); + EXPECT_TRUE(local_descriptors_.empty()); } TEST_F(RateLimitPolicyEntryTest, DynamicMetadataRateLimitOverride) { diff --git a/test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc b/test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc index a6142dfb16aa..f35cf5b325b5 100644 --- a/test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc +++ b/test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc @@ -1,6 +1,7 @@ #include "extensions/filters/common/local_ratelimit/local_ratelimit_impl.h" #include "test/mocks/event/mocks.h" +#include "test/test_common/utility.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -16,19 +17,27 @@ namespace LocalRateLimit { class LocalRateLimiterImplTest : public testing::Test { public: - void initialize(const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, - const uint32_t tokens_per_fill) { - + void initializeTimer() { fill_timer_ = new Event::MockTimer(&dispatcher_); EXPECT_CALL(*fill_timer_, enableTimer(_, nullptr)); EXPECT_CALL(*fill_timer_, disableTimer()); + } + + void initialize(const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, + const uint32_t tokens_per_fill) { - rate_limiter_ = std::make_shared(fill_interval, max_tokens, - tokens_per_fill, dispatcher_); + initializeTimer(); + + rate_limiter_ = std::make_shared( + fill_interval, max_tokens, tokens_per_fill, dispatcher_, descriptors_); } Thread::ThreadSynchronizer& synchronizer() { return rate_limiter_->synchronizer_; } + Envoy::Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor> + descriptors_; + std::vector route_descriptors_; NiceMock dispatcher_; Event::MockTimer* fill_timer_{}; std::shared_ptr rate_limiter_; @@ -37,8 +46,8 @@ class LocalRateLimiterImplTest : public testing::Test { // Make sure we fail with a fill rate this is too fast. TEST_F(LocalRateLimiterImplTest, TooFastFillRate) { EXPECT_THROW_WITH_MESSAGE( - LocalRateLimiterImpl(std::chrono::milliseconds(49), 100, 1, dispatcher_), EnvoyException, - "local rate limit token bucket fill timer must be >= 50ms"); + LocalRateLimiterImpl(std::chrono::milliseconds(49), 100, 1, dispatcher_, descriptors_), + EnvoyException, "local rate limit token bucket fill timer must be >= 50ms"); } // Verify various token bucket CAS edge cases. @@ -59,15 +68,15 @@ TEST_F(LocalRateLimiterImplTest, CasEdgeCases) { synchronizer().barrierOn("on_fill_timer_pre_cas"); // This should succeed. - EXPECT_TRUE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); // Now signal the thread to continue which should cause a CAS failure and the loop to repeat. synchronizer().signal("on_fill_timer_pre_cas"); t1.join(); // 1 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); } // This tests the case in which two allowed checks race. @@ -78,12 +87,12 @@ TEST_F(LocalRateLimiterImplTest, CasEdgeCases) { // Start a thread and see if we are under limit. This will wait pre-CAS. synchronizer().waitOn("allowed_pre_cas"); - std::thread t1([&] { EXPECT_FALSE(rate_limiter_->requestAllowed()); }); + std::thread t1([&] { EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); }); // Wait until the thread is actually waiting. synchronizer().barrierOn("allowed_pre_cas"); // Consume a token on this thread, which should cause the CAS to fail on the other thread. - EXPECT_TRUE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); synchronizer().signal("allowed_pre_cas"); t1.join(); } @@ -94,17 +103,17 @@ TEST_F(LocalRateLimiterImplTest, TokenBucket) { initialize(std::chrono::milliseconds(200), 1, 1); // 1 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); // 0 -> 1 tokens EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); fill_timer_->invokeCallback(); // 1 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); // 0 -> 1 tokens EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); @@ -115,8 +124,8 @@ TEST_F(LocalRateLimiterImplTest, TokenBucket) { fill_timer_->invokeCallback(); // 1 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); } // Verify token bucket functionality with max tokens and tokens per fill > 1. @@ -124,25 +133,25 @@ TEST_F(LocalRateLimiterImplTest, TokenBucketMultipleTokensPerFill) { initialize(std::chrono::milliseconds(200), 2, 2); // 2 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); // 0 -> 2 tokens EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); fill_timer_->invokeCallback(); // 2 -> 1 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); // 1 -> 2 tokens EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); fill_timer_->invokeCallback(); // 2 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); } // Verify token bucket functionality with max tokens > tokens per fill. @@ -150,17 +159,239 @@ TEST_F(LocalRateLimiterImplTest, TokenBucketMaxTokensGreaterThanTokensPerFill) { initialize(std::chrono::milliseconds(200), 2, 1); // 2 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); // 0 -> 1 tokens EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); fill_timer_->invokeCallback(); // 1 -> 0 tokens - EXPECT_TRUE(rate_limiter_->requestAllowed()); - EXPECT_FALSE(rate_limiter_->requestAllowed()); + EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); +} + +class LocalRateLimiterDescriptorImplTest : public LocalRateLimiterImplTest { +public: + void initializeWithDescriptor(const std::chrono::milliseconds fill_interval, + const uint32_t max_tokens, const uint32_t tokens_per_fill) { + + initializeTimer(); + + rate_limiter_ = std::make_shared( + fill_interval, max_tokens, tokens_per_fill, dispatcher_, descriptors_); + } + const std::string single_descriptor_config_yaml = R"( + entries: + - key: foo2 + value: bar2 + token_bucket: + max_tokens: {} + tokens_per_fill: {} + fill_interval: {} + )"; + + const std::string multiple_descriptor_config_yaml = R"( + entries: + - key: hello + value: world + - key: foo + value: bar + token_bucket: + max_tokens: 1 + tokens_per_fill: 1 + fill_interval: 0.05s + )"; + + // Default token bucket + std::vector descriptor_{{{{"foo2", "bar2"}}}}; + std::vector descriptor2_{{{{"hello", "world"}, {"foo", "bar"}}}}; +}; + +// Verify descriptor rate limit time interval is multiple of token bucket fill interval. +TEST_F(LocalRateLimiterDescriptorImplTest, DescriptorRateLimitDivisibleByTokenFillInterval) { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 10, 10, "60s"), + *descriptors_.Add()); + + EXPECT_THROW_WITH_MESSAGE( + LocalRateLimiterImpl(std::chrono::milliseconds(59000), 2, 1, dispatcher_, descriptors_), + EnvoyException, "local rate descriptor limit is not a multiple of token bucket fill timer"); +} + +TEST_F(LocalRateLimiterDescriptorImplTest, DuplicateDescriptor) { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), + *descriptors_.Add()); + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), + *descriptors_.Add()); + + EXPECT_THROW_WITH_MESSAGE( + LocalRateLimiterImpl(std::chrono::milliseconds(50), 1, 1, dispatcher_, descriptors_), + EnvoyException, "duplicate descriptor in the local rate descriptor: foo2=bar2"); +} + +// Verify no exception for per route config without descriptors. +TEST_F(LocalRateLimiterDescriptorImplTest, DescriptorRateLimitNoExceptionWithoutDescriptor) { + VERBOSE_EXPECT_NO_THROW( + LocalRateLimiterImpl(std::chrono::milliseconds(59000), 2, 1, dispatcher_, descriptors_)); +} + +// Verify various token bucket CAS edge cases for descriptors. +TEST_F(LocalRateLimiterDescriptorImplTest, CasEdgeCasesDescriptor) { + // This tests the case in which an allowed check races with the fill timer. + { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 1, 1); + + synchronizer().enable(); + + // Start a thread and start the fill callback. This will wait pre-CAS. + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + synchronizer().waitOn("on_fill_timer_pre_cas"); + std::thread t1([&] { + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + }); + // Wait until the thread is actually waiting. + synchronizer().barrierOn("on_fill_timer_pre_cas"); + + // This should succeed. + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + + // Now signal the thread to continue which should cause a CAS failure and the loop to repeat. + synchronizer().signal("on_fill_timer_pre_cas"); + t1.join(); + + // 1 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + } + + // This tests the case in which two allowed checks race. + { + initializeWithDescriptor(std::chrono::milliseconds(50), 1, 1); + + synchronizer().enable(); + + // Start a thread and see if we are under limit. This will wait pre-CAS. + synchronizer().waitOn("allowed_pre_cas"); + std::thread t1([&] { EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); }); + // Wait until the thread is actually waiting. + synchronizer().barrierOn("allowed_pre_cas"); + + // Consume a token on this thread, which should cause the CAS to fail on the other thread. + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + synchronizer().signal("allowed_pre_cas"); + t1.join(); + } +} + +TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptor2) { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 1, 1); + + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); +} + +// Verify token bucket functionality with a single token. +TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptor) { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 1, 1); + + // 1 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + + // 0 -> 1 tokens + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + + // 0 -> 1 tokens + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 1 tokens + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); +} + +// Verify token bucket functionality with request per unit > 1. +TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketMultipleTokensPerFillDescriptor) { + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 2, 2, "0.1s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 2, 2); + + // 2 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + + // 0 -> 2 tokens + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 2 -> 1 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + + // 1 -> 2 tokens + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 2 -> 0 tokens + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); +} + +// Verify token bucket functionality with multiple descriptors. +TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDifferentDescriptorDifferentRateLimits) { + TestUtility::loadFromYaml(multiple_descriptor_config_yaml, *descriptors_.Add()); + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "1000s"), + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 2, 1); + + // 1 -> 0 tokens for descriptor_ and descriptor2_ + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_)); + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); + + // 0 -> 1 tokens for descriptor2_ + dispatcher_.time_system_.advanceTimeAndRun(std::chrono::milliseconds(50), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens for descriptor2_ and 0 only for descriptor_ + EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_)); + EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); } } // Namespace LocalRateLimit diff --git a/test/extensions/filters/http/local_ratelimit/BUILD b/test/extensions/filters/http/local_ratelimit/BUILD index 38cd85098ee7..0aeca92a857a 100644 --- a/test/extensions/filters/http/local_ratelimit/BUILD +++ b/test/extensions/filters/http/local_ratelimit/BUILD @@ -19,6 +19,7 @@ envoy_extension_cc_test( "//source/extensions/filters/http/local_ratelimit:local_ratelimit_lib", "//test/common/stream_info:test_util", "//test/mocks/http:http_mocks", + "//test/mocks/local_info:local_info_mocks", "@envoy_api//envoy/extensions/filters/http/local_ratelimit/v3:pkg_cc_proto", ], ) diff --git a/test/extensions/filters/http/local_ratelimit/config_test.cc b/test/extensions/filters/http/local_ratelimit/config_test.cc index 3f48a5830e21..55cc18229969 100644 --- a/test/extensions/filters/http/local_ratelimit/config_test.cc +++ b/test/extensions/filters/http/local_ratelimit/config_test.cc @@ -63,7 +63,7 @@ stat_prefix: test const auto route_config = factory.createRouteSpecificFilterConfig( *proto_config, context, ProtobufMessage::getNullValidationVisitor()); const auto* config = dynamic_cast(route_config.get()); - EXPECT_TRUE(config->requestAllowed()); + EXPECT_TRUE(config->requestAllowed({})); } TEST(Factory, EnabledEnforcedDisabledByDefault) { @@ -125,6 +125,158 @@ stat_prefix: test EnvoyException); } +TEST(Factory, RouteSpecificFilterConfigWithDescriptorsWithNoTokenBucket) { + const std::string config_yaml = R"( +stat_prefix: test +token_bucket: + max_tokens: 1 + tokens_per_fill: 1 + fill_interval: 1000s +filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED +filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED +response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' +descriptors: +- entries: + - key: hello + value: world + - key: foo + value: bar +- entries: + - key: foo2 + value: bar2 + )"; + + LocalRateLimitFilterConfig factory; + ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); + TestUtility::loadFromYaml(config_yaml, *proto_config); + + NiceMock context; + + EXPECT_CALL(context.dispatcher_, createTimer_(_)).Times(0); + EXPECT_THROW(factory.createRouteSpecificFilterConfig(*proto_config, context, + ProtobufMessage::getNullValidationVisitor()), + EnvoyException); +} + +TEST(Factory, RouteSpecificFilterConfigWithDescriptors) { + const std::string config_yaml = R"( +stat_prefix: test +token_bucket: + max_tokens: 1 + tokens_per_fill: 1 + fill_interval: 60s +filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED +filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED +response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' +descriptors: +- entries: + - key: hello + value: world + - key: foo + value: bar + token_bucket: + max_tokens: 10 + tokens_per_fill: 10 + fill_interval: 60s +- entries: + - key: foo2 + value: bar2 + token_bucket: + max_tokens: 100 + tokens_per_fill: 100 + fill_interval: 3600s + )"; + + LocalRateLimitFilterConfig factory; + ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); + TestUtility::loadFromYaml(config_yaml, *proto_config); + + NiceMock context; + + EXPECT_CALL(context.dispatcher_, createTimer_(_)); + const auto route_config = factory.createRouteSpecificFilterConfig( + *proto_config, context, ProtobufMessage::getNullValidationVisitor()); + const auto* config = dynamic_cast(route_config.get()); + EXPECT_TRUE(config->requestAllowed({})); +} + +TEST(Factory, RouteSpecificFilterConfigWithDescriptorsTimerNotDivisible) { + const std::string config_yaml = R"( +stat_prefix: test +token_bucket: + max_tokens: 1 + tokens_per_fill: 1 + fill_interval: 100s +filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED +filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED +response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' +descriptors: +- entries: + - key: hello + value: world + - key: foo + value: bar + token_bucket: + max_tokens: 10 + tokens_per_fill: 10 + fill_interval: 1s +- entries: + - key: foo2 + value: bar2 + token_bucket: + max_tokens: 100 + tokens_per_fill: 100 + fill_interval: 86400s + )"; + + LocalRateLimitFilterConfig factory; + ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); + TestUtility::loadFromYaml(config_yaml, *proto_config); + + NiceMock context; + + EXPECT_CALL(context.dispatcher_, createTimer_(_)); + EXPECT_THROW(factory.createRouteSpecificFilterConfig(*proto_config, context, + ProtobufMessage::getNullValidationVisitor()), + EnvoyException); +} + } // namespace LocalRateLimitFilter } // namespace HttpFilters } // namespace Extensions diff --git a/test/extensions/filters/http/local_ratelimit/filter_test.cc b/test/extensions/filters/http/local_ratelimit/filter_test.cc index 9662f9a783e1..0c9f2507e016 100644 --- a/test/extensions/filters/http/local_ratelimit/filter_test.cc +++ b/test/extensions/filters/http/local_ratelimit/filter_test.cc @@ -3,6 +3,7 @@ #include "extensions/filters/http/local_ratelimit/local_ratelimit.h" #include "test/mocks/http/mocks.h" +#include "test/mocks/local_info/mocks.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -39,7 +40,8 @@ class FilterTest : public testing::Test { public: FilterTest() = default; - void setup(const std::string& yaml, const bool enabled = true, const bool enforced = true) { + void setupPerRoute(const std::string& yaml, const bool enabled = true, const bool enforced = true, + const bool per_route = false) { EXPECT_CALL( runtime_.snapshot_, featureEnabled(absl::string_view("test_enabled"), @@ -53,10 +55,14 @@ class FilterTest : public testing::Test { envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit config; TestUtility::loadFromYaml(yaml, config); - config_ = std::make_shared(config, dispatcher_, stats_, runtime_); + config_ = std::make_shared(config, local_info_, dispatcher_, stats_, runtime_, + per_route); filter_ = std::make_shared(config_); filter_->setDecoderFilterCallbacks(decoder_callbacks_); } + void setup(const std::string& yaml, const bool enabled = true, const bool enforced = true) { + setupPerRoute(yaml, enabled, enforced); + } uint64_t findCounter(const std::string& name) { const auto counter = TestUtility::findCounter(stats_, name); @@ -69,6 +75,7 @@ class FilterTest : public testing::Test { testing::NiceMock decoder_callbacks_; NiceMock dispatcher_; NiceMock runtime_; + NiceMock local_info_; std::shared_ptr config_; std::shared_ptr filter_; }; @@ -140,6 +147,188 @@ TEST_F(FilterTest, RequestRateLimitedButNotEnforced) { EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.rate_limited")); } +static const std::string descriptor_config_yaml = R"( +stat_prefix: test +token_bucket: + max_tokens: {} + tokens_per_fill: 1 + fill_interval: 60s +filter_enabled: + runtime_key: test_enabled + default_value: + numerator: 100 + denominator: HUNDRED +filter_enforced: + runtime_key: test_enforced + default_value: + numerator: 100 + denominator: HUNDRED +response_headers_to_add: + - append: false + header: + key: x-test-rate-limit + value: 'true' +descriptors: +- entries: + - key: hello + value: world + - key: foo + value: bar + token_bucket: + max_tokens: 10 + tokens_per_fill: 10 + fill_interval: 60s +- entries: + - key: foo2 + value: bar2 + token_bucket: + max_tokens: {} + tokens_per_fill: 1 + fill_interval: 60s +stage: {} + )"; + +class DescriptorFilterTest : public FilterTest { +public: + DescriptorFilterTest() = default; + + void setUpTest(const std::string& yaml) { + setupPerRoute(yaml, true, true, true); + decoder_callbacks_.route_->route_entry_.rate_limit_policy_.rate_limit_policy_entry_.clear(); + decoder_callbacks_.route_->route_entry_.rate_limit_policy_.rate_limit_policy_entry_ + .emplace_back(route_rate_limit_); + } + + std::vector descriptor_{{{{"foo2", "bar2"}}}}; + std::vector descriptor_first_match_{{ + {{ + {"hello", "world"}, + {"foo", "bar"}, + }}, + {{{"foo2", "bar2"}}}, + }}; + std::vector descriptor_not_found_{{{{"foo", "bar"}}}}; + NiceMock route_rate_limit_; +}; + +TEST_F(DescriptorFilterTest, NoRouteEntry) { + setupPerRoute(fmt::format(descriptor_config_yaml, "1", "1", "0"), true, true, true); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); +} + +TEST_F(DescriptorFilterTest, NoCluster) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "1", "0")); + + EXPECT_CALL(decoder_callbacks_, clusterInfo()).WillRepeatedly(testing::Return(nullptr)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); +} + +TEST_F(DescriptorFilterTest, DisabledInRoute) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "1", "0")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + + route_rate_limit_.disable_key_ = "disabled"; + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); +} + +TEST_F(DescriptorFilterTest, RouteDescriptorRequestOk) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "1", "0")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + + EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _)) + .WillOnce(testing::SetArgReferee<0>(descriptor_)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); +} + +TEST_F(DescriptorFilterTest, RouteDescriptorRequestRatelimited) { + setUpTest(fmt::format(descriptor_config_yaml, "0", "0", "0")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + + EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _)) + .WillOnce(testing::SetArgReferee<0>(descriptor_)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.rate_limited")); +} + +TEST_F(DescriptorFilterTest, RouteDescriptorNotFound) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "1", "0")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + + EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _)) + .WillOnce(testing::SetArgReferee<0>(descriptor_not_found_)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.rate_limited")); +} + +TEST_F(DescriptorFilterTest, RouteDescriptorFirstMatch) { + // Request should not be rate limited as it should match first descriptor with 10 req/min + setUpTest(fmt::format(descriptor_config_yaml, "0", "0", "0")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + + EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _)) + .WillOnce(testing::SetArgReferee<0>(descriptor_first_match_)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.rate_limited")); +} + +TEST_F(DescriptorFilterTest, RouteDescriptorWithStageConfig) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "1", "1")); + + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(1)); + + EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _)) + .WillOnce(testing::SetArgReferee<0>(descriptor_)); + + auto headers = Http::TestRequestHeaderMapImpl(); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(headers, false)); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); + EXPECT_EQ(0U, findCounter("test.http_local_rate_limit.enforced")); + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); +} + } // namespace LocalRateLimitFilter } // namespace HttpFilters } // namespace Extensions diff --git a/test/mocks/ratelimit/mocks.h b/test/mocks/ratelimit/mocks.h index 7f983beabbca..2cb8fbef7471 100644 --- a/test/mocks/ratelimit/mocks.h +++ b/test/mocks/ratelimit/mocks.h @@ -14,10 +14,6 @@ inline bool operator==(const RateLimitOverride& lhs, const RateLimitOverride& rh return lhs.requests_per_unit_ == rhs.requests_per_unit_ && lhs.unit_ == rhs.unit_; } -inline bool operator==(const DescriptorEntry& lhs, const DescriptorEntry& rhs) { - return lhs.key_ == rhs.key_ && lhs.value_ == rhs.value_; -} - inline bool operator==(const Descriptor& lhs, const Descriptor& rhs) { return lhs.entries_ == rhs.entries_ && lhs.limit_ == rhs.limit_; } diff --git a/test/mocks/router/mocks.h b/test/mocks/router/mocks.h index 6d665cd26252..ea320f794828 100644 --- a/test/mocks/router/mocks.h +++ b/test/mocks/router/mocks.h @@ -194,6 +194,11 @@ class MockRateLimitPolicyEntry : public RateLimitPolicyEntry { const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info), (const)); + MOCK_METHOD(void, populateLocalDescriptors, + (std::vector & descriptors, + const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, + const StreamInfo::StreamInfo& info), + (const)); uint64_t stage_{}; std::string disable_key_;