diff --git a/api/envoy/config/route/v3/route_components.proto b/api/envoy/config/route/v3/route_components.proto index 909a7305f123e..a3d4d009f0ebb 100644 --- a/api/envoy/config/route/v3/route_components.proto +++ b/api/envoy/config/route/v3/route_components.proto @@ -1868,7 +1868,7 @@ message VirtualCluster { // Global rate limiting :ref:`architecture overview `. // Also applies to Local rate limiting :ref:`using descriptors `. -// [#next-free-field: 6] +// [#next-free-field: 7] message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit"; @@ -2245,6 +2245,23 @@ message RateLimit { // :ref:`VirtualHost.typed_per_filter_config` or // :ref:`Route.typed_per_filter_config`, etc. HitsAddend hits_addend = 5; + + // If true, the rate limit request will be applied when the stream completes. The default value is false. + // This is useful when the rate limit budget needs to reflect the response context that is not available + // on the request path. + // + // For example, let's say the upstream service calculates the usage statistics and returns them in the response body + // and we want to utilize these numbers to apply the rate limit action for the subsequent requests. + // Combined with another filter that can set the desired addend based on the response (e.g. Lua filter), + // this can be used to subtract the usage statistics from the rate limit budget. + // + // A rate limit applied on the stream completion is "fire-and-forget" by nature, and rate limit is not enforced by this config. + // In other words, the current request won't be blocked when this is true, but the budget will be updated for the subsequent + // requests based on the action with this field set to true. Users should ensure that the rate limit is enforced by the actions + // applied on the request path, i.e. the ones with this field set to false. + // + // Currently, this is only supported by the HTTP global rate filter. + bool apply_on_stream_done = 6; } // .. attention:: diff --git a/envoy/router/router_ratelimit.h b/envoy/router/router_ratelimit.h index 199b3248130e6..7b3dcdbedc72f 100644 --- a/envoy/router/router_ratelimit.h +++ b/envoy/router/router_ratelimit.h @@ -48,6 +48,11 @@ class RateLimitPolicyEntry { */ virtual const std::string& disableKey() const PURE; + /** + * @return true if this rate limit policy should be applied on stream done. + */ + virtual bool applyOnStreamDone() const PURE; + /** * Potentially populate the descriptor array with new descriptors to query. * @param descriptors supplies the descriptor array to optionally fill. diff --git a/source/common/router/router_ratelimit.cc b/source/common/router/router_ratelimit.cc index 77c285a85a24d..386374d397254 100644 --- a/source/common/router/router_ratelimit.cc +++ b/source/common/router/router_ratelimit.cc @@ -267,7 +267,8 @@ RateLimitPolicyEntryImpl::RateLimitPolicyEntryImpl( const envoy::config::route::v3::RateLimit& config, Server::Configuration::CommonFactoryContext& context, absl::Status& creation_status) : disable_key_(config.disable_key()), - stage_(static_cast(PROTOBUF_GET_WRAPPED_OR_DEFAULT(config, stage, 0))) { + stage_(static_cast(PROTOBUF_GET_WRAPPED_OR_DEFAULT(config, stage, 0))), + apply_on_stream_done_(config.apply_on_stream_done()) { for (const auto& action : config.actions()) { switch (action.action_specifier_case()) { case envoy::config::route::v3::RateLimit::Action::ActionSpecifierCase::kSourceCluster: diff --git a/source/common/router/router_ratelimit.h b/source/common/router/router_ratelimit.h index 3fb5149a4cc25..686f15a6938d9 100644 --- a/source/common/router/router_ratelimit.h +++ b/source/common/router/router_ratelimit.h @@ -256,12 +256,14 @@ class RateLimitPolicyEntryImpl : public RateLimitPolicyEntry { const std::string& local_service_cluster, const Http::RequestHeaderMap&, const StreamInfo::StreamInfo& info) const override; + bool applyOnStreamDone() const override { return apply_on_stream_done_; } private: const std::string disable_key_; uint64_t stage_; std::vector actions_; absl::optional limit_override_ = absl::nullopt; + const bool apply_on_stream_done_ = false; }; /** diff --git a/source/extensions/filters/common/ratelimit/ratelimit.h b/source/extensions/filters/common/ratelimit/ratelimit.h index 11267cc7db03b..44a29d13da6ac 100644 --- a/source/extensions/filters/common/ratelimit/ratelimit.h +++ b/source/extensions/filters/common/ratelimit/ratelimit.h @@ -90,7 +90,7 @@ class Client { */ virtual void limit(RequestCallbacks& callbacks, const std::string& domain, const std::vector& descriptors, - Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info, + Tracing::Span& parent_span, OptRef stream_info, uint32_t hits_addend) PURE; }; diff --git a/source/extensions/filters/common/ratelimit/ratelimit_impl.cc b/source/extensions/filters/common/ratelimit/ratelimit_impl.cc index 3350e132562a5..7f1b07e8f6f39 100644 --- a/source/extensions/filters/common/ratelimit/ratelimit_impl.cc +++ b/source/extensions/filters/common/ratelimit/ratelimit_impl.cc @@ -59,18 +59,19 @@ void GrpcClientImpl::createRequest(envoy::service::ratelimit::v3::RateLimitReque void GrpcClientImpl::limit(RequestCallbacks& callbacks, const std::string& domain, const std::vector& descriptors, - Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info, - uint32_t hits_addend) { + Tracing::Span& parent_span, + OptRef stream_info, uint32_t hits_addend) { ASSERT(callbacks_ == nullptr); callbacks_ = &callbacks; envoy::service::ratelimit::v3::RateLimitRequest request; createRequest(request, domain, descriptors, hits_addend); - request_ = - async_client_->send(service_method_, request, *this, parent_span, - Http::AsyncClient::RequestOptions().setTimeout(timeout_).setParentContext( - Http::AsyncClient::ParentContext{&stream_info})); + auto options = Http::AsyncClient::RequestOptions().setTimeout(timeout_); + if (stream_info.has_value()) { + options.setParentContext(Http::AsyncClient::ParentContext{stream_info.ptr()}); + } + request_ = async_client_->send(service_method_, request, *this, parent_span, options); } void GrpcClientImpl::onSuccess( @@ -107,10 +108,13 @@ void GrpcClientImpl::onSuccess( response->has_dynamic_metadata() ? std::make_unique(response->dynamic_metadata()) : nullptr; - callbacks_->complete(status, std::move(descriptor_statuses), std::move(response_headers_to_add), + // The rate limit requests applied on stream-done will destroy the client inside the complete + // callback, so we release the callback here to make the destructor happy. + auto call_backs = callbacks_; + callbacks_ = nullptr; + call_backs->complete(status, std::move(descriptor_statuses), std::move(response_headers_to_add), std::move(request_headers_to_add), response->raw_body(), std::move(dynamic_metadata)); - callbacks_ = nullptr; } void GrpcClientImpl::onFailure(Grpc::Status::GrpcStatus status, const std::string& msg, @@ -118,8 +122,11 @@ void GrpcClientImpl::onFailure(Grpc::Status::GrpcStatus status, const std::strin ASSERT(status != Grpc::Status::WellKnownGrpcStatus::Ok); ENVOY_LOG_TO_LOGGER(Logger::Registry::getLog(Logger::Id::filter), debug, "rate limit fail, status={} msg={}", status, msg); - callbacks_->complete(LimitStatus::Error, nullptr, nullptr, nullptr, EMPTY_STRING, nullptr); + // The rate limit requests applied on stream-done will destroy the client inside the complete + // callback, so we release the callback here to make the destructor happy. + auto call_backs = callbacks_; callbacks_ = nullptr; + call_backs->complete(LimitStatus::Error, nullptr, nullptr, nullptr, EMPTY_STRING, nullptr); } ClientPtr rateLimitClient(Server::Configuration::FactoryContext& context, diff --git a/source/extensions/filters/common/ratelimit/ratelimit_impl.h b/source/extensions/filters/common/ratelimit/ratelimit_impl.h index 79502ec2ef787..61a6c1c5ec880 100644 --- a/source/extensions/filters/common/ratelimit/ratelimit_impl.h +++ b/source/extensions/filters/common/ratelimit/ratelimit_impl.h @@ -57,7 +57,7 @@ class GrpcClientImpl : public Client, void cancel() override; void limit(RequestCallbacks& callbacks, const std::string& domain, const std::vector& descriptors, - Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info, + Tracing::Span& parent_span, OptRef stream_info, uint32_t hits_addend = 0) override; // Grpc::AsyncRequestCallbacks diff --git a/source/extensions/filters/http/ratelimit/ratelimit.cc b/source/extensions/filters/http/ratelimit/ratelimit.cc index 7052f8f793edf..858b1548810a6 100644 --- a/source/extensions/filters/http/ratelimit/ratelimit.cc +++ b/source/extensions/filters/http/ratelimit/ratelimit.cc @@ -55,37 +55,58 @@ void Filter::initiateCall(const Http::RequestHeaderMap& headers) { return; } - Router::RouteConstSharedPtr route = callbacks_->route(); - if (!route || !route->routeEntry()) { - return; + std::vector descriptors; + populateRateLimitDescriptors(descriptors, headers, false); + if (!descriptors.empty()) { + state_ = State::Calling; + initiating_call_ = true; + client_->limit(*this, getDomain(), descriptors, callbacks_->activeSpan(), + callbacks_->streamInfo(), getHitAddend()); + initiating_call_ = false; } +} - cluster_ = callbacks_->clusterInfo(); - if (!cluster_) { +void Filter::populateRateLimitDescriptors(std::vector& descriptors, + const Http::RequestHeaderMap& headers, + bool on_stream_done) { + if (!on_stream_done) { + // To use the exact same context for both request and on_stream_done rate limiting descriptors, + // we save the route and per-route configuration here and use them later. + route_ = callbacks_->route(); + cluster_ = callbacks_->clusterInfo(); + } + if (!route_ || !cluster_) { return; } - std::vector descriptors; - - const Router::RouteEntry* route_entry = route->routeEntry(); + const Router::RouteEntry* route_entry = route_->routeEntry(); + if (!route_entry) { + return; + } + if (!on_stream_done) { + initializeVirtualHostRateLimitOption(route_entry); + } // Get all applicable rate limit policy entries for the route. - populateRateLimitDescriptors(route_entry->rateLimitPolicy(), descriptors, headers); + populateRateLimitDescriptorsForPolicy(route_entry->rateLimitPolicy(), descriptors, headers, + on_stream_done); - VhRateLimitOptions vh_rate_limit_option = getVirtualHostRateLimitOption(route); - - switch (vh_rate_limit_option) { + switch (vh_rate_limits_) { case VhRateLimitOptions::Ignore: break; case VhRateLimitOptions::Include: - populateRateLimitDescriptors(route->virtualHost().rateLimitPolicy(), descriptors, headers); + populateRateLimitDescriptorsForPolicy(route_->virtualHost().rateLimitPolicy(), descriptors, + headers, on_stream_done); break; case VhRateLimitOptions::Override: if (route_entry->rateLimitPolicy().empty()) { - populateRateLimitDescriptors(route->virtualHost().rateLimitPolicy(), descriptors, headers); + populateRateLimitDescriptorsForPolicy(route_->virtualHost().rateLimitPolicy(), descriptors, + headers, on_stream_done); } break; } +} +double Filter::getHitAddend() { const StreamInfo::UInt32Accessor* hits_addend_filter_state = callbacks_->streamInfo().filterState()->getDataReadOnly( HitsAddendFilterStateKey); @@ -93,14 +114,7 @@ void Filter::initiateCall(const Http::RequestHeaderMap& headers) { if (hits_addend_filter_state != nullptr) { hits_addend = hits_addend_filter_state->value(); } - - if (!descriptors.empty()) { - state_ = State::Calling; - initiating_call_ = true; - client_->limit(*this, getDomain(), descriptors, callbacks_->activeSpan(), - callbacks_->streamInfo(), hits_addend); - initiating_call_ = false; - } + return hits_addend; } Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool) { @@ -161,6 +175,16 @@ void Filter::onDestroy() { if (state_ == State::Calling) { state_ = State::Complete; client_->cancel(); + } else if (client_ != nullptr) { + std::vector descriptors; + populateRateLimitDescriptors(descriptors, *request_headers_, true); + if (!descriptors.empty()) { + // Since this filter is being destroyed, we need to keep the client alive until the request + // is complete by leaking the client with OnStreamDoneCallBack. + auto callback = new OnStreamDoneCallBack(std::move(client_)); + callback->client().limit(*callback, getDomain(), descriptors, Tracing::NullSpan::instance(), + absl::nullopt, getHitAddend()); + } } } @@ -256,9 +280,10 @@ void Filter::complete(Filters::Common::RateLimit::LimitStatus status, } } -void Filter::populateRateLimitDescriptors(const Router::RateLimitPolicy& rate_limit_policy, - std::vector& descriptors, - const Http::RequestHeaderMap& headers) const { +void Filter::populateRateLimitDescriptorsForPolicy(const Router::RateLimitPolicy& rate_limit_policy, + std::vector& descriptors, + const Http::RequestHeaderMap& headers, + bool on_stream_done) { for (const Router::RateLimitPolicyEntry& rate_limit : rate_limit_policy.getApplicableRateLimit(config_->stage())) { const std::string& disable_key = rate_limit.disableKey(); @@ -267,8 +292,11 @@ void Filter::populateRateLimitDescriptors(const Router::RateLimitPolicy& rate_li fmt::format("ratelimit.{}.http_filter_enabled", disable_key), 100)) { continue; } - rate_limit.populateDescriptors(descriptors, config_->localInfo().clusterName(), headers, - callbacks_->streamInfo()); + const bool apply_on_stream_done = rate_limit.applyOnStreamDone(); + if (on_stream_done == apply_on_stream_done) { + rate_limit.populateDescriptors(descriptors, config_->localInfo().clusterName(), headers, + callbacks_->streamInfo()); + } } } @@ -296,8 +324,8 @@ void Filter::appendRequestHeaders(Http::HeaderMapPtr& request_headers_to_add) { } } -VhRateLimitOptions Filter::getVirtualHostRateLimitOption(const Router::RouteConstSharedPtr& route) { - if (route->routeEntry()->includeVirtualHostRateLimits()) { +void Filter::initializeVirtualHostRateLimitOption(const Router::RouteEntry* route_entry) { + if (route_entry->includeVirtualHostRateLimits()) { vh_rate_limits_ = VhRateLimitOptions::Include; } else { const auto* specific_per_route_config = @@ -318,7 +346,6 @@ VhRateLimitOptions Filter::getVirtualHostRateLimitOption(const Router::RouteCons vh_rate_limits_ = VhRateLimitOptions::Override; } } - return vh_rate_limits_; } std::string Filter::getDomain() { @@ -330,6 +357,14 @@ std::string Filter::getDomain() { return config_->domain(); } +void OnStreamDoneCallBack::complete(Filters::Common::RateLimit::LimitStatus, + Filters::Common::RateLimit::DescriptorStatusListPtr&&, + Http::ResponseHeaderMapPtr&&, Http::RequestHeaderMapPtr&&, + const std::string&, + Filters::Common::RateLimit::DynamicMetadataPtr&&) { + delete this; +} + } // namespace RateLimitFilter } // namespace HttpFilters } // namespace Extensions diff --git a/source/extensions/filters/http/ratelimit/ratelimit.h b/source/extensions/filters/http/ratelimit/ratelimit.h index 889e79a95957a..60b110a505849 100644 --- a/source/extensions/filters/http/ratelimit/ratelimit.h +++ b/source/extensions/filters/http/ratelimit/ratelimit.h @@ -185,12 +185,16 @@ class Filter : public Http::StreamFilter, public Filters::Common::RateLimit::Req private: void initiateCall(const Http::RequestHeaderMap& headers); - void populateRateLimitDescriptors(const Router::RateLimitPolicy& rate_limit_policy, - std::vector& descriptors, - const Http::RequestHeaderMap& headers) const; + void populateRateLimitDescriptors(std::vector& descriptors, + const Http::RequestHeaderMap& headers, bool on_stream_done); + void populateRateLimitDescriptorsForPolicy(const Router::RateLimitPolicy& rate_limit_policy, + std::vector& descriptors, + const Http::RequestHeaderMap& headers, + bool on_stream_done); void populateResponseHeaders(Http::HeaderMap& response_headers, bool from_local_reply); void appendRequestHeaders(Http::HeaderMapPtr& request_headers_to_add); - VhRateLimitOptions getVirtualHostRateLimitOption(const Router::RouteConstSharedPtr& route); + double getHitAddend(); + void initializeVirtualHostRateLimitOption(const Router::RouteEntry* route_entry); std::string getDomain(); Http::Context& httpContext() { return config_->httpContext(); } @@ -203,11 +207,33 @@ class Filter : public Http::StreamFilter, public Filters::Common::RateLimit::Req State state_{State::NotStarted}; VhRateLimitOptions vh_rate_limits_; Upstream::ClusterInfoConstSharedPtr cluster_; + Router::RouteConstSharedPtr route_ = nullptr; bool initiating_call_{}; Http::ResponseHeaderMapPtr response_headers_to_add_; Http::RequestHeaderMap* request_headers_{}; }; +/** + * This implements the rate limit callback that outlives the filter holding the client. + * On completion, it deletes itself. + */ +class OnStreamDoneCallBack : public Filters::Common::RateLimit::RequestCallbacks { +public: + OnStreamDoneCallBack(Filters::Common::RateLimit::ClientPtr client) : client_(std::move(client)) {} + ~OnStreamDoneCallBack() override = default; + + // RateLimit::RequestCallbacks + void complete(Filters::Common::RateLimit::LimitStatus, + Filters::Common::RateLimit::DescriptorStatusListPtr&&, Http::ResponseHeaderMapPtr&&, + Http::RequestHeaderMapPtr&&, const std::string&, + Filters::Common::RateLimit::DynamicMetadataPtr&&) override; + + Filters::Common::RateLimit::Client& client() { return *client_; } + +private: + Filters::Common::RateLimit::ClientPtr client_; +}; + } // namespace RateLimitFilter } // namespace HttpFilters } // namespace Extensions diff --git a/test/extensions/filters/common/ratelimit/mocks.h b/test/extensions/filters/common/ratelimit/mocks.h index 5155d335f057d..259fa3f08881a 100644 --- a/test/extensions/filters/common/ratelimit/mocks.h +++ b/test/extensions/filters/common/ratelimit/mocks.h @@ -26,7 +26,7 @@ class MockClient : public Client { MOCK_METHOD(void, limit, (RequestCallbacks & callbacks, const std::string& domain, const std::vector& descriptors, - Tracing::Span& parent_span, const StreamInfo::StreamInfo& stream_info, + Tracing::Span& parent_span, OptRef stream_info, uint32_t hits_addend)); }; diff --git a/test/extensions/filters/http/ratelimit/ratelimit_test.cc b/test/extensions/filters/http/ratelimit/ratelimit_test.cc index 6caa81e8eb5f1..63f14004ef3d3 100644 --- a/test/extensions/filters/http/ratelimit/ratelimit_test.cc +++ b/test/extensions/filters/http/ratelimit/ratelimit_test.cc @@ -265,15 +265,13 @@ TEST_F(HttpRateLimitFilterTest, OkResponseWithAdditionalHitsAddend) { filter_callbacks_.stream_info_.filter_state_->setData( "envoy.ratelimit.hits_addend", std::make_unique(5), - StreamInfo::FilterState::StateType::ReadOnly); + StreamInfo::FilterState::StateType::Mutable); EXPECT_CALL(filter_callbacks_.route_->route_entry_.rate_limit_policy_, getApplicableRateLimit(0)); EXPECT_CALL(route_rate_limit_, populateDescriptors(_, _, _, _)) .WillOnce(SetArgReferee<0>(descriptor_)); - EXPECT_CALL(filter_callbacks_.route_->virtual_host_.rate_limit_policy_, - getApplicableRateLimit(0)); - + EXPECT_CALL(vh_rate_limit_, applyOnStreamDone()).WillRepeatedly(Return(true)); EXPECT_CALL(*client_, limit(_, "foo", testing::ContainerEq(std::vector{ {{{"descriptor_key", "descriptor_value"}}}}), @@ -304,6 +302,30 @@ TEST_F(HttpRateLimitFilterTest, OkResponseWithAdditionalHitsAddend) { EXPECT_EQ( 1U, filter_callbacks_.clusterInfo()->statsScope().counterFromStatName(ratelimit_ok_).value()); + + // Test the behavior for the apply_on_stream_done flag. + testing::Mock::VerifyAndClearExpectations(client_); + testing::Mock::VerifyAndClearExpectations(&filter_callbacks_); + testing::Mock::VerifyAndClearExpectations( + &filter_callbacks_.route_->route_entry_.rate_limit_policy_); + testing::Mock::VerifyAndClearExpectations(&route_rate_limit_); + testing::Mock::VerifyAndClearExpectations(&vh_rate_limit_); + filter_callbacks_.stream_info_.filter_state_->setData( + // Ensures that addend can be set differently than the request path. + "envoy.ratelimit.hits_addend", std::make_unique(100), + StreamInfo::FilterState::StateType::Mutable); + EXPECT_CALL(filter_callbacks_.route_->route_entry_.rate_limit_policy_, getApplicableRateLimit(0)); + EXPECT_CALL(vh_rate_limit_, applyOnStreamDone()).WillRepeatedly(Return(true)); + EXPECT_CALL(vh_rate_limit_, populateDescriptors(_, _, _, _)) + .WillOnce(SetArgReferee<0>(descriptor_two_)); + EXPECT_CALL(*client_, limit(_, "foo", testing::ContainerEq(descriptor_two_), _, _, 100)) + .WillOnce( + WithArgs<0>(Invoke([&](Filters::Common::RateLimit::RequestCallbacks& callbacks) -> void { + request_callbacks_ = &callbacks; + }))); + filter_->onDestroy(); + request_callbacks_->complete(Filters::Common::RateLimit::LimitStatus::OK, nullptr, nullptr, + nullptr, "", nullptr); } TEST_F(HttpRateLimitFilterTest, OkResponseWithHeaders) { diff --git a/test/mocks/router/mocks.h b/test/mocks/router/mocks.h index 9c21895dcad9a..e4e3ac4093571 100644 --- a/test/mocks/router/mocks.h +++ b/test/mocks/router/mocks.h @@ -250,6 +250,7 @@ class MockRateLimitPolicyEntry : public RateLimitPolicyEntry { const std::string& local_service_cluster, const Http::RequestHeaderMap& headers, const StreamInfo::StreamInfo& info), (const)); + MOCK_METHOD(bool, applyOnStreamDone, (), (const)); uint64_t stage_{}; std::string disable_key_;