Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition when the redis filter is destroyed. #11466

Merged
merged 14 commits into from
Jun 26, 2020
6 changes: 6 additions & 0 deletions source/extensions/clusters/redis/redis_cluster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ RedisCluster::DnsDiscoveryResolveTarget::~DnsDiscoveryResolveTarget() {
if (active_query_) {
active_query_->cancel();
}
if (resolve_timer_) {
resolve_timer_->disableTimer();
}
mattklein123 marked this conversation as resolved.
Show resolved Hide resolved
}

void RedisCluster::DnsDiscoveryResolveTarget::startResolveDns() {
Expand Down Expand Up @@ -226,6 +229,9 @@ RedisCluster::RedisDiscoverySession::~RedisDiscoverySession() {
current_request_->cancel();
current_request_ = nullptr;
}
if (resolve_timer_) {
resolve_timer_->disableTimer();
}
Comment on lines +236 to +238
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? If so comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment. Sorry I missed this one.


while (!client_map_.empty()) {
client_map_.begin()->second->client_->close();
Expand Down
12 changes: 6 additions & 6 deletions source/extensions/filters/network/redis_proxy/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ Network::FilterFactoryCb RedisProxyFilterConfigFactory::createFilterFactoryFromP
for (auto& cluster : unique_clusters) {
Stats::ScopePtr stats_scope =
context.scope().createScope(fmt::format("cluster.{}.redis_cluster", cluster));

upstreams.emplace(cluster, std::make_shared<ConnPool::InstanceImpl>(
cluster, context.clusterManager(),
Common::Redis::Client::ClientFactoryImpl::instance_,
context.threadLocal(), proto_config.settings(), context.api(),
std::move(stats_scope), redis_command_stats, refresh_manager));
auto conn_pool_ptr = std::make_shared<ConnPool::InstanceImpl>(
cluster, context.clusterManager(), Common::Redis::Client::ClientFactoryImpl::instance_,
context.threadLocal(), proto_config.settings(), context.api(), std::move(stats_scope),
redis_command_stats, refresh_manager);
conn_pool_ptr->init();
upstreams.emplace(cluster, conn_pool_ptr);
}

auto router =
Expand Down
86 changes: 63 additions & 23 deletions source/extensions/filters/network/redis_proxy/conn_pool_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,18 @@ InstanceImpl::InstanceImpl(
tls_(tls.allocateSlot()), config_(config), api_(api), stats_scope_(std::move(stats_scope)),
redis_command_stats_(redis_command_stats), redis_cluster_stats_{REDIS_CLUSTER_STATS(
POOL_COUNTER(*stats_scope_))},
refresh_manager_(std::move(refresh_manager)) {
tls_->set([this, cluster_name](
refresh_manager_(std::move(refresh_manager)) {}

void InstanceImpl::init() {
// Note: `this` and `cluster_name` have a a lifetime of the filter.
// That may be shorter of the tls callback if the listener is torn shortly after it is created.
HenryYYang marked this conversation as resolved.
Show resolved Hide resolved
// We use a weak pointer to make sure this object outlives the tls callbacks.
auto this_shared_ptr = this->shared_from_this();
std::weak_ptr<InstanceImpl> this_weak_ptr = this_shared_ptr;
auto cluster_name = this_shared_ptr->cluster_name_;
tls_->set([this_weak_ptr, cluster_name](
Event::Dispatcher& dispatcher) -> ThreadLocal::ThreadLocalObjectSharedPtr {
return std::make_shared<ThreadLocalPool>(*this, dispatcher, cluster_name);
return std::make_shared<ThreadLocalPool>(this_weak_ptr, dispatcher, cluster_name);
});
}

Expand All @@ -66,23 +74,30 @@ InstanceImpl::makeRequestToHost(const std::string& host_address,
return tls_->getTyped<ThreadLocalPool>().makeRequestToHost(host_address, request, callbacks);
}

InstanceImpl::ThreadLocalPool::ThreadLocalPool(InstanceImpl& parent, Event::Dispatcher& dispatcher,
InstanceImpl::ThreadLocalPool::ThreadLocalPool(std::weak_ptr<InstanceImpl> parent,
Event::Dispatcher& dispatcher,
std::string cluster_name)
: parent_(parent), dispatcher_(dispatcher), cluster_name_(std::move(cluster_name)),
drain_timer_(dispatcher.createTimer([this]() -> void { drainClients(); })),
is_redis_cluster_(false) {
cluster_update_handle_ = parent_.cm_.addThreadLocalClusterUpdateCallbacks(*this);
Upstream::ThreadLocalCluster* cluster = parent_.cm_.get(cluster_name_);
if (cluster != nullptr) {
auth_password_ = ProtocolOptionsConfigImpl::authPassword(cluster->info(), parent_.api_);
onClusterAddOrUpdateNonVirtual(*cluster);
if (auto shared_parent = parent_.lock()) {
cluster_update_handle_ = shared_parent->cm_.addThreadLocalClusterUpdateCallbacks(*this);
Upstream::ThreadLocalCluster* cluster = shared_parent->cm_.get(cluster_name_);
if (cluster != nullptr) {
auth_password_ =
ProtocolOptionsConfigImpl::authPassword(cluster->info(), shared_parent->api_);
onClusterAddOrUpdateNonVirtual(*cluster);
}
}
}

InstanceImpl::ThreadLocalPool::~ThreadLocalPool() {
if (host_set_member_update_cb_handle_ != nullptr) {
host_set_member_update_cb_handle_->remove();
}
if (drain_timer_) {
drain_timer_->disableTimer();
}
while (!pending_requests_.empty()) {
pending_requests_.pop_front();
}
Expand All @@ -99,6 +114,11 @@ void InstanceImpl::ThreadLocalPool::onClusterAddOrUpdateNonVirtual(
if (cluster.info()->name() != cluster_name_) {
return;
}
// ensure the filter is still exists
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: start with capital, end with period, etc. Same elsewhere.

auto shared_parent = parent_.lock();
if (!shared_parent) {
return;
}

if (cluster_ != nullptr) {
// Treat an update as a removal followed by an add.
Expand Down Expand Up @@ -212,12 +232,14 @@ InstanceImpl::ThreadLocalActiveClientPtr&
InstanceImpl::ThreadLocalPool::threadLocalActiveClient(Upstream::HostConstSharedPtr host) {
ThreadLocalActiveClientPtr& client = client_map_[host];
if (!client) {
client = std::make_unique<ThreadLocalActiveClient>(*this);
client->host_ = host;
client->redis_client_ = parent_.client_factory_.create(host, dispatcher_, parent_.config_,
parent_.redis_command_stats_,
*parent_.stats_scope_, auth_password_);
client->redis_client_->addConnectionCallbacks(*client);
if (auto shared_parent = parent_.lock()) {
client = std::make_unique<ThreadLocalActiveClient>(*this);
client->host_ = host;
client->redis_client_ = shared_parent->client_factory_.create(
host, dispatcher_, shared_parent->config_, shared_parent->redis_command_stats_,
*(shared_parent->stats_scope_), auth_password_);
client->redis_client_->addConnectionCallbacks(*client);
}
}
return client;
}
Expand All @@ -230,10 +252,15 @@ InstanceImpl::ThreadLocalPool::makeRequest(const std::string& key, RespVariant&&
ASSERT(host_set_member_update_cb_handle_ == nullptr);
return nullptr;
}
// ensure the filter is not removed
auto shared_parent = parent_.lock();
if (!shared_parent) {
return nullptr;
}

Clusters::Redis::RedisLoadBalancerContextImpl lb_context(key, parent_.config_.enableHashtagging(),
is_redis_cluster_, getRequest(request),
parent_.config_.readPolicy());
Clusters::Redis::RedisLoadBalancerContextImpl lb_context(
key, shared_parent->config_.enableHashtagging(), is_redis_cluster_, getRequest(request),
shared_parent->config_.readPolicy());
Upstream::HostConstSharedPtr host = cluster_->loadBalancer().chooseHost(&lb_context);
if (!host) {
return nullptr;
Expand Down Expand Up @@ -265,6 +292,12 @@ Common::Redis::Client::PoolRequest* InstanceImpl::ThreadLocalPool::makeRequestTo
return nullptr;
}

// ensure filter is not removed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced you need all of these locks everywhere. How can we be making a request with a removed filter? I'm pretty sure you only need it in the construction path for the original set call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that when the connection pool instance is destroyed the TLS slot is not immediately destroyed. This leads to the code that cancel the pending requests and closing connections to be delayed as well. During this delay, a redirection from a pending request would cause a new request to be sent to the new host.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Please add more comments around this and also fix format.

auto shared_parent = parent_.lock();
if (!shared_parent) {
return nullptr;
}

const std::string ip_address = host_address.substr(0, colon_pos);
const bool ipv6 = (ip_address.find(':') != std::string::npos);
std::string host_address_map_key;
Expand All @@ -289,9 +322,10 @@ Common::Redis::Client::PoolRequest* InstanceImpl::ThreadLocalPool::makeRequestTo
auto it = host_address_map_.find(host_address_map_key);
if (it == host_address_map_.end()) {
// This host is not known to the cluster manager. Create a new host and insert it into the map.
if (created_via_redirect_hosts_.size() == parent_.config_.maxUpstreamUnknownConnections()) {
if (created_via_redirect_hosts_.size() ==
shared_parent->config_.maxUpstreamUnknownConnections()) {
// Too many upstream connections to unknown hosts have been created.
parent_.redis_cluster_stats_.max_upstream_unknown_connections_reached_.inc();
shared_parent->redis_cluster_stats_.max_upstream_unknown_connections_reached_.inc();
return nullptr;
}
if (!ipv6) {
Expand Down Expand Up @@ -343,7 +377,9 @@ void InstanceImpl::ThreadLocalActiveClient::onEvent(Network::ConnectionEvent eve
it++) {
if ((*it).get() == this) {
if (!redis_client_->active()) {
parent_.parent_.redis_cluster_stats_.upstream_cx_drained_.inc();
if (auto shared_conn_pool = parent_.parent_.lock()) {
shared_conn_pool->redis_cluster_stats_.upstream_cx_drained_.inc();
}
}
parent_.dispatcher_.deferredDelete(std::move(redis_client_));
parent_.clients_to_drain_.erase(it);
Expand Down Expand Up @@ -379,7 +415,9 @@ void InstanceImpl::PendingRequest::onResponse(Common::Redis::RespValuePtr&& resp
void InstanceImpl::PendingRequest::onFailure() {
request_handler_ = nullptr;
pool_callbacks_.onFailure();
parent_.parent_.onFailure();
if (auto shared_conn_pool = parent_.parent_.lock()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here, if the filter goes away we should be cancelling things. I don't think these are all needed.

shared_conn_pool->onFailure();
}
parent_.onRequestCompleted();
}

Expand All @@ -402,7 +440,9 @@ bool InstanceImpl::PendingRequest::onRedirection(Common::Redis::RespValuePtr&& v
onResponse(std::move(value));
return false;
} else {
parent_.parent_.onRedirection();
if (auto shared_conn_pool = parent_.parent_.lock()) {
shared_conn_pool->onRedirection();
}
return true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class DoNothingPoolCallbacks : public PoolCallbacks {
void onFailure() override{};
};

class InstanceImpl : public Instance {
class InstanceImpl : public Instance, public std::enable_shared_from_this<InstanceImpl> {
public:
InstanceImpl(
const std::string& cluster_name, Upstream::ClusterManager& cm,
Expand Down Expand Up @@ -83,6 +83,8 @@ class InstanceImpl : public Instance {
bool onFailure() { return refresh_manager_->onFailure(cluster_name_); }
bool onHostDegraded() { return refresh_manager_->onHostDegraded(cluster_name_); }

void init();

// Allow the unit test to have access to private members.
friend class RedisConnPoolImplTest;

Expand Down Expand Up @@ -127,7 +129,8 @@ class InstanceImpl : public Instance {

struct ThreadLocalPool : public ThreadLocal::ThreadLocalObject,
public Upstream::ClusterUpdateCallbacks {
ThreadLocalPool(InstanceImpl& parent, Event::Dispatcher& dispatcher, std::string cluster_name);
ThreadLocalPool(std::weak_ptr<InstanceImpl> parent, Event::Dispatcher& dispatcher,
std::string cluster_name);
~ThreadLocalPool() override;
ThreadLocalActiveClientPtr& threadLocalActiveClient(Upstream::HostConstSharedPtr host);
Common::Redis::Client::PoolRequest* makeRequest(const std::string& key, RespVariant&& request,
Expand All @@ -149,7 +152,7 @@ class InstanceImpl : public Instance {

void onRequestCompleted();

InstanceImpl& parent_;
std::weak_ptr<InstanceImpl> parent_;
Event::Dispatcher& dispatcher_;
const std::string cluster_name_;
Upstream::ClusterUpdateCallbacksHandlePtr cluster_update_handle_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ class RedisConnPoolImplTest : public testing::Test, public Common::Redis::Client
std::make_shared<NiceMock<Extensions::Common::Redis::MockClusterRefreshManager>>();
auto redis_command_stats =
Common::Redis::RedisCommandStats::createRedisCommandStats(store->symbolTable());
std::unique_ptr<InstanceImpl> conn_pool_impl = std::make_unique<InstanceImpl>(
std::shared_ptr<InstanceImpl> conn_pool_impl = std::make_shared<InstanceImpl>(
cluster_name_, cm_, *this, tls_,
Common::Redis::Client::createConnPoolSettings(20, hashtagging, true, max_unknown_conns,
read_policy_),
api_, std::move(store), redis_command_stats, cluster_refresh_manager_);
conn_pool_impl->init();
// Set the authentication password for this connection pool.
conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>().auth_password_ = auth_password_;
conn_pool_ = std::move(conn_pool_impl);
Expand Down