Skip to content

Commit

Permalink
Remote Address rate limiting (#107)
Browse files Browse the repository at this point in the history
* Remote Address rate limiting

* Fix broken documentation link and add descriptor for route key rate limiting.

* Rename method and variables from address to downstreamAddress/downstream_address

* Preserve include alphabetization

* Refactor getLastAddressFromXff into Http::Utility
  • Loading branch information
ccaraman authored Sep 30, 2016
1 parent f095ab5 commit 89de442
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 8 deletions.
21 changes: 20 additions & 1 deletion docs/configuration/http_filters/rate_limit_filter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Actions
type
*(required, string)* The type of rate limit action to perform. The currently supported action
types are *service_to_service* and *request_headers*.
types are *service_to_service* , *request_headers* and *remote_address*.

Service to service
^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -87,6 +87,25 @@ descriptor is sent as well:

* ("route_key", "<route_key>"), ("<descriptor_key>", "<header_value_queried_from_header>")

Remote Address
^^^^^^^^^^^^^^

.. code-block:: json
{
"type": "remote_address"
}
The following descriptor is sent using the trusted address from :ref:`x-forwarded-for <config_http_conn_man_headers_x-forwarded-for>`:

* ("remote_address", "<:ref:`trusted address from x-forwarded-for <config_http_conn_man_headers_x-forwarded-for>`>")

If *route_key* is set in the :ref:`route <config_http_conn_man_route_table_route_rate_limit>`, the following
descriptor is sent as well:

* ("route_key", "<route_key>"),
("remote_address", "<:ref:`trusted address from x-forwarded-for <config_http_conn_man_headers_x-forwarded-for>`>")

Statistics
----------

Expand Down
5 changes: 5 additions & 0 deletions include/envoy/http/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class StreamFilterCallbacks {
* put into the access log.
*/
virtual AccessLog::RequestInfo& requestInfo() PURE;

/**
* @return the trusted downstream address for the connection.
*/
virtual const std::string& downstreamAddress() PURE;
};

/**
Expand Down
1 change: 1 addition & 0 deletions source/common/http/async_client_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class AsyncRequestImpl final : public AsyncClient::Request,
const Router::StableRouteTable& routeTable() { return *this; }
uint64_t streamId() override { return stream_id_; }
AccessLog::RequestInfo& requestInfo() override { return request_info_; }
const std::string& downstreamAddress() override { return EMPTY_STRING; }
void continueDecoding() override { NOT_IMPLEMENTED; }
const Buffer::Instance* decodingBuffer() override { return request_->body(); }
void encodeHeaders(HeaderMapPtr&& headers, bool end_stream) override;
Expand Down
6 changes: 6 additions & 0 deletions source/common/http/conn_manager_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ void ConnectionManagerImpl::ActiveStream::decodeHeaders(HeaderMapPtr&& headers,
connection_manager_.config_, connection_manager_.random_generator_,
connection_manager_.runtime_);

// Set the trusted address for the connection by taking the last address in XFF.
downstream_address_ = Utility::getLastAddressFromXFF(*request_headers_);
decodeHeaders(nullptr, *request_headers_, end_stream);
}

Expand Down Expand Up @@ -774,4 +776,8 @@ void ConnectionManagerImpl::ActiveStreamFilterBase::resetStream() {

uint64_t ConnectionManagerImpl::ActiveStreamFilterBase::streamId() { return parent_.stream_id_; }

const std::string& ConnectionManagerImpl::ActiveStreamFilterBase::downstreamAddress() {
return parent_.downstream_address_;
}

} // Http
2 changes: 2 additions & 0 deletions source/common/http/conn_manager_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class ConnectionManagerImpl : Logger::Loggable<Logger::Id::http>,
const Router::StableRouteTable& routeTable() override { return *this; }
uint64_t streamId() override;
AccessLog::RequestInfo& requestInfo() override;
const std::string& downstreamAddress() override;

// Router::StableRouteTable
const Router::RedirectEntry* redirectRequest(const HeaderMap& headers) const {
Expand Down Expand Up @@ -366,6 +367,7 @@ class ConnectionManagerImpl : Logger::Loggable<Logger::Id::http>,
std::list<std::function<void()>> reset_callbacks_;
State state_;
AccessLog::RequestInfoImpl request_info_;
std::string downstream_address_;
};

typedef std::unique_ptr<ActiveStream> ActiveStreamPtr;
Expand Down
31 changes: 27 additions & 4 deletions source/common/http/filter/ratelimit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ const Http::HeaderMapImpl Filter::TOO_MANY_REQUESTS_HEADER{

void ServiceToServiceAction::populateDescriptors(const Router::RouteEntry& route,
std::vector<::RateLimit::Descriptor>& descriptors,
FilterConfig& config, const HeaderMap&) {
FilterConfig& config, const HeaderMap&,
StreamDecoderFilterCallbacks&) {
// We limit on 2 dimensions.
// 1) All calls to the given cluster.
// 2) Calls to the given cluster and from this cluster.
Expand All @@ -28,8 +29,9 @@ void ServiceToServiceAction::populateDescriptors(const Router::RouteEntry& route

void RequestHeadersAction::populateDescriptors(const Router::RouteEntry& route,
std::vector<::RateLimit::Descriptor>& descriptors,
FilterConfig&, const HeaderMap& headers) {
std::string header_value = headers.get(header_name_);
FilterConfig&, const HeaderMap& headers,
StreamDecoderFilterCallbacks&) {
const std::string& header_value = headers.get(header_name_);
if (header_value.empty()) {
return;
}
Expand All @@ -44,6 +46,25 @@ void RequestHeadersAction::populateDescriptors(const Router::RouteEntry& route,
descriptors.push_back({{{"route_key", route_key}, {descriptor_key_, header_value}}});
}

void RemoteAddressAction::populateDescriptors(const Router::RouteEntry& route,
std::vector<::RateLimit::Descriptor>& descriptors,
FilterConfig&, const HeaderMap&,
StreamDecoderFilterCallbacks& callbacks) {
const std::string& remote_address = callbacks.downstreamAddress();
if (remote_address.empty()) {
return;
}

descriptors.push_back({{{"remote_address", remote_address}}});

const std::string& route_key = route.rateLimitPolicy().routeKey();
if (route_key.empty()) {
return;
}

descriptors.push_back({{{"route_key", route_key}, {"remote_address", remote_address}}});
}

FilterConfig::FilterConfig(const Json::Object& config, const std::string& local_service_cluster,
Stats::Store& stats_store, Runtime::Loader& runtime)
: domain_(config.getString("domain")), local_service_cluster_(local_service_cluster),
Expand All @@ -54,6 +75,8 @@ FilterConfig::FilterConfig(const Json::Object& config, const std::string& local_
actions_.emplace_back(new ServiceToServiceAction());
} else if (type == "request_headers") {
actions_.emplace_back(new RequestHeadersAction(action));
} else if (type == "remote_address") {
actions_.emplace_back(new RemoteAddressAction());
} else {
throw EnvoyException(fmt::format("unknown http rate limit filter action '{}'", type));
}
Expand All @@ -69,7 +92,7 @@ FilterHeadersStatus Filter::decodeHeaders(HeaderMap& headers, bool) {
if (route && route->rateLimitPolicy().doGlobalLimiting()) {
std::vector<::RateLimit::Descriptor> descriptors;
for (const ActionPtr& action : config_->actions()) {
action->populateDescriptors(*route, descriptors, *config_, headers);
action->populateDescriptors(*route, descriptors, *config_, headers, *callbacks_);
}

if (!descriptors.empty()) {
Expand Down
19 changes: 16 additions & 3 deletions source/common/http/filter/ratelimit.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class Action {
*/
virtual void populateDescriptors(const Router::RouteEntry& route,
std::vector<::RateLimit::Descriptor>& descriptors,
FilterConfig& config, const HeaderMap& headers) PURE;
FilterConfig& config, const HeaderMap& headers,
StreamDecoderFilterCallbacks& callbacks) PURE;
};

typedef std::unique_ptr<Action> ActionPtr;
Expand All @@ -40,7 +41,7 @@ class ServiceToServiceAction : public Action {
// Action
void populateDescriptors(const Router::RouteEntry& route,
std::vector<::RateLimit::Descriptor>& descriptors, FilterConfig& config,
const HeaderMap&) override;
const HeaderMap&, StreamDecoderFilterCallbacks&) override;
};

/**
Expand All @@ -54,12 +55,24 @@ class RequestHeadersAction : public Action {
// Action
void populateDescriptors(const Router::RouteEntry& route,
std::vector<::RateLimit::Descriptor>& descriptors, FilterConfig& config,
const HeaderMap& headers) override;
const HeaderMap& headers, StreamDecoderFilterCallbacks&) override;

private:
const LowerCaseString header_name_;
const std::string descriptor_key_;
};

/**
* Action for remote address rate limiting.
*/
class RemoteAddressAction : public Action {
public:
// Action
void populateDescriptors(const Router::RouteEntry& route,
std::vector<::RateLimit::Descriptor>& descriptors, FilterConfig&,
const HeaderMap&, StreamDecoderFilterCallbacks& callbacks) override;
};

/**
* Global configuration for the HTTP rate limit filter.
*/
Expand Down
11 changes: 11 additions & 0 deletions source/common/http/utility.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "common/buffer/buffer_impl.h"
#include "common/common/assert.h"
#include "common/common/empty_string.h"
#include "common/common/enum_to_int.h"
#include "common/common/utility.h"
#include "common/http/exception.h"
Expand Down Expand Up @@ -128,4 +129,14 @@ void Utility::sendRedirect(StreamDecoderFilterCallbacks& callbacks, const std::s
callbacks.encodeHeaders(std::move(response_headers), true);
}

std::string Utility::getLastAddressFromXFF(const Http::HeaderMap& request_headers) {
std::vector<std::string> xff_address_list =
StringUtil::split(request_headers.get(Headers::get().ForwardedFor), ',');

if (xff_address_list.empty()) {
return EMPTY_STRING;
}
return xff_address_list.back();
}

} // Http
7 changes: 7 additions & 0 deletions source/common/http/utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ class Utility {
* @param new_path supplies the redirect target.
*/
static void sendRedirect(StreamDecoderFilterCallbacks& callbacks, const std::string& new_path);

/**
* Retrieves the last address in x-forwarded-for header. If it isn't set, returns empty string.
* @param request_headers
* @return last_address_in_xff
*/
static std::string getLastAddressFromXFF(const Http::HeaderMap& request_headers);
};

} // Http
69 changes: 69 additions & 0 deletions test/common/http/filter/ratelimit_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "common/buffer/buffer_impl.h"
#include "common/common/empty_string.h"
#include "common/http/filter/ratelimit.h"
#include "common/stats/stats_impl.h"

Expand All @@ -11,6 +12,7 @@ using testing::InSequence;
using testing::Invoke;
using testing::NiceMock;
using testing::Return;
using testing::ReturnRef;
using testing::WithArgs;

namespace Http {
Expand Down Expand Up @@ -91,6 +93,15 @@ class HttpRateLimitFilterTest : public testing::Test {
}
)EOF";

const std::string address_json = R"EOF(
{
"domain": "foo",
"actions": [
{"type": "remote_address"}
]
}
)EOF";

FilterConfigPtr config_;
::RateLimit::MockClient* client_;
std::unique_ptr<Filter> filter_;
Expand Down Expand Up @@ -319,5 +330,63 @@ TEST_F(HttpRateLimitFilterTest, NoRateLimitHeaderMatch) {
EXPECT_EQ(FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_));
}

TEST_F(HttpRateLimitFilterTest, AddressRateLimiting) {
SetUpTest(address_json);
filter_callbacks_.route_table_.route_entry_.rate_limit_policy_.do_global_limiting_ = true;

std::string address = "10.0.0.1";
EXPECT_CALL(filter_callbacks_, downstreamAddress()).WillOnce(ReturnRef(address));
EXPECT_CALL(*client_, limit(_, "foo", testing::ContainerEq(std::vector<::RateLimit::Descriptor>{
{{{"remote_address", address}}}})))
.WillOnce(WithArgs<0>(Invoke([&](::RateLimit::RequestCallbacks& callbacks)
-> void { request_callbacks_ = &callbacks; })));

EXPECT_EQ(FilterHeadersStatus::StopIteration, filter_->decodeHeaders(request_headers_, false));
EXPECT_EQ(FilterDataStatus::StopIterationAndBuffer, filter_->decodeData(data_, false));
EXPECT_EQ(FilterTrailersStatus::StopIteration, filter_->decodeTrailers(request_headers_));

EXPECT_CALL(filter_callbacks_, continueDecoding());
request_callbacks_->complete(::RateLimit::LimitStatus::OK);

EXPECT_EQ(1U, stats_store_.counter("cluster.fake_cluster.ratelimit.ok").value());
}

TEST_F(HttpRateLimitFilterTest, RouteAddressRateLimiting) {
SetUpTest(address_json);
filter_callbacks_.route_table_.route_entry_.rate_limit_policy_.do_global_limiting_ = true;
filter_callbacks_.route_table_.route_entry_.rate_limit_policy_.route_key_ = "test_key";

std::string address = "10.0.0.1";
EXPECT_CALL(filter_callbacks_, downstreamAddress()).WillOnce(ReturnRef(address));
EXPECT_CALL(*client_,
limit(_, "foo", testing::ContainerEq(std::vector<::RateLimit::Descriptor>{
{{{"remote_address", address}}},
{{{"route_key", "test_key"}, {"remote_address", address}}}})))
.WillOnce(WithArgs<0>(Invoke([&](::RateLimit::RequestCallbacks& callbacks)
-> void { request_callbacks_ = &callbacks; })));

EXPECT_EQ(FilterHeadersStatus::StopIteration, filter_->decodeHeaders(request_headers_, false));
EXPECT_EQ(FilterDataStatus::StopIterationAndBuffer, filter_->decodeData(data_, false));
EXPECT_EQ(FilterTrailersStatus::StopIteration, filter_->decodeTrailers(request_headers_));

EXPECT_CALL(filter_callbacks_, continueDecoding());
request_callbacks_->complete(::RateLimit::LimitStatus::OK);

EXPECT_EQ(1U, stats_store_.counter("cluster.fake_cluster.ratelimit.ok").value());
}

TEST_F(HttpRateLimitFilterTest, NoAddressRateLimiting) {
SetUpTest(address_json);
filter_callbacks_.route_table_.route_entry_.rate_limit_policy_.do_global_limiting_ = true;

EXPECT_CALL(filter_callbacks_, downstreamAddress()).WillOnce(ReturnRef(EMPTY_STRING));

EXPECT_CALL(*client_, limit(_, _, _)).Times(0);

EXPECT_EQ(FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false));
EXPECT_EQ(FilterDataStatus::Continue, filter_->decodeData(data_, false));
EXPECT_EQ(FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_));
}

} // RateLimit
} // Http
19 changes: 19 additions & 0 deletions test/common/http/utility_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,23 @@ TEST(HttpUtility, parseCodecOptions) {
}
}

TEST(HttpUtility, TwoAddressesInXFF) {
const std::string first_address = "34.0.0.1";
const std::string second_address = "10.0.0.1";
HeaderMapImpl request_headers{
{"x-forwarded-for", fmt::format("{0},{1}", first_address, second_address)}};
EXPECT_EQ(second_address, Utility::getLastAddressFromXFF(request_headers));
}

TEST(HttpUtility, EmptyXFF) {
HeaderMapImpl request_headers;
EXPECT_EQ("", Utility::getLastAddressFromXFF(request_headers));
}

TEST(HttpUtility, OneAddressInXFF) {
const std::string first_address = "34.0.0.1";
HeaderMapImpl request_headers{{"x-forwarded-for", first_address}};
EXPECT_EQ(first_address, Utility::getLastAddressFromXFF(request_headers));
}

} // Http
2 changes: 2 additions & 0 deletions test/mocks/http/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class MockStreamDecoderFilterCallbacks : public StreamDecoderFilterCallbacks,
MOCK_METHOD0(routeTable, Router::StableRouteTable&());
MOCK_METHOD0(streamId, uint64_t());
MOCK_METHOD0(requestInfo, Http::AccessLog::RequestInfo&());
MOCK_METHOD0(downstreamAddress, const std::string&());

// Http::StreamDecoderFilterCallbacks
void encodeHeaders(HeaderMapPtr&& headers, bool end_stream) override {
Expand Down Expand Up @@ -240,6 +241,7 @@ class MockStreamEncoderFilterCallbacks : public StreamEncoderFilterCallbacks,
MOCK_METHOD0(routeTable, Router::StableRouteTable&());
MOCK_METHOD0(streamId, uint64_t());
MOCK_METHOD0(requestInfo, Http::AccessLog::RequestInfo&());
MOCK_METHOD0(downstreamAddress, const std::string&());

// Http::StreamEncoderFilterCallbacks
MOCK_METHOD0(continueEncoding, void());
Expand Down

0 comments on commit 89de442

Please sign in to comment.