Skip to content

Commit

Permalink
Fix race condition when the redis filter is destroyed. (#11466)
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Yang <[email protected]>
  • Loading branch information
HenryYYang authored Jun 26, 2020
1 parent fc8c79a commit 2705235
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 46 deletions.
8 changes: 8 additions & 0 deletions source/extensions/clusters/redis/redis_cluster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ RedisCluster::DnsDiscoveryResolveTarget::~DnsDiscoveryResolveTarget() {
if (active_query_) {
active_query_->cancel();
}
// Disable timer for mock tests.
if (resolve_timer_) {
resolve_timer_->disableTimer();
}
}

void RedisCluster::DnsDiscoveryResolveTarget::startResolveDns() {
Expand Down Expand Up @@ -228,6 +232,10 @@ RedisCluster::RedisDiscoverySession::~RedisDiscoverySession() {
current_request_->cancel();
current_request_ = nullptr;
}
// Disable timer for mock tests.
if (resolve_timer_) {
resolve_timer_->disableTimer();
}

while (!client_map_.empty()) {
client_map_.begin()->second->client_->close();
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/filters/network/common/redis/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ class Config {
virtual ReadPolicy readPolicy() const PURE;
};

using ConfigSharedPtr = std::shared_ptr<Config>;

/**
* A factory for individual redis client connections.
*/
Expand Down
12 changes: 6 additions & 6 deletions source/extensions/filters/network/redis_proxy/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ Network::FilterFactoryCb RedisProxyFilterConfigFactory::createFilterFactoryFromP
for (auto& cluster : unique_clusters) {
Stats::ScopePtr stats_scope =
context.scope().createScope(fmt::format("cluster.{}.redis_cluster", cluster));

upstreams.emplace(cluster, std::make_shared<ConnPool::InstanceImpl>(
cluster, context.clusterManager(),
Common::Redis::Client::ClientFactoryImpl::instance_,
context.threadLocal(), proto_config.settings(), context.api(),
std::move(stats_scope), redis_command_stats, refresh_manager));
auto conn_pool_ptr = std::make_shared<ConnPool::InstanceImpl>(
cluster, context.clusterManager(), Common::Redis::Client::ClientFactoryImpl::instance_,
context.threadLocal(), proto_config.settings(), context.api(), std::move(stats_scope),
redis_command_stats, refresh_manager);
conn_pool_ptr->init();
upstreams.emplace(cluster, conn_pool_ptr);
}

auto router =
Expand Down
8 changes: 0 additions & 8 deletions source/extensions/filters/network/redis_proxy/conn_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@ class Instance {
*/
virtual Common::Redis::Client::PoolRequest*
makeRequest(const std::string& hash_key, RespVariant&& request, PoolCallbacks& callbacks) PURE;

/**
* Notify the redirection manager singleton that a redirection error has been received from an
* upstream server associated with the pool's associated cluster.
* @return bool true if a cluster's registered callback with the redirection manager is scheduled
* to be called from the main thread dispatcher, false otherwise.
*/
virtual bool onRedirection() PURE;
};

using InstanceSharedPtr = std::shared_ptr<Instance>;
Expand Down
68 changes: 46 additions & 22 deletions source/extensions/filters/network/redis_proxy/conn_pool_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,57 @@ InstanceImpl::InstanceImpl(
const Common::Redis::RedisCommandStatsSharedPtr& redis_command_stats,
Extensions::Common::Redis::ClusterRefreshManagerSharedPtr refresh_manager)
: cluster_name_(cluster_name), cm_(cm), client_factory_(client_factory),
tls_(tls.allocateSlot()), config_(config), api_(api), stats_scope_(std::move(stats_scope)),
tls_(tls.allocateSlot()), config_(new Common::Redis::Client::ConfigImpl(config)), api_(api),
stats_scope_(std::move(stats_scope)),
redis_command_stats_(redis_command_stats), redis_cluster_stats_{REDIS_CLUSTER_STATS(
POOL_COUNTER(*stats_scope_))},
refresh_manager_(std::move(refresh_manager)) {
tls_->set([this, cluster_name](
Event::Dispatcher& dispatcher) -> ThreadLocal::ThreadLocalObjectSharedPtr {
return std::make_shared<ThreadLocalPool>(*this, dispatcher, cluster_name);
});
refresh_manager_(std::move(refresh_manager)) {}

void InstanceImpl::init() {
// Note: `this` and `cluster_name` have a a lifetime of the filter.
// That may be shorter than the tls callback if the listener is torn down shortly after it is
// created. We use a weak pointer to make sure this object outlives the tls callbacks.
std::weak_ptr<InstanceImpl> this_weak_ptr = this->shared_from_this();
tls_->set(
[this_weak_ptr](Event::Dispatcher& dispatcher) -> ThreadLocal::ThreadLocalObjectSharedPtr {
if (auto this_shared_ptr = this_weak_ptr.lock()) {
return std::make_shared<ThreadLocalPool>(this_shared_ptr, dispatcher,
this_shared_ptr->cluster_name_);
}
return nullptr;
});
}

// This method is always called from a InstanceSharedPtr we don't have to worry about tls_->getTyped
// failing due to InstanceImpl going away.
Common::Redis::Client::PoolRequest*
InstanceImpl::makeRequest(const std::string& key, RespVariant&& request, PoolCallbacks& callbacks) {
return tls_->getTyped<ThreadLocalPool>().makeRequest(key, std::move(request), callbacks);
}

// This method is always called from a InstanceSharedPtr we don't have to worry about tls_->getTyped
// failing due to InstanceImpl going away.
Common::Redis::Client::PoolRequest*
InstanceImpl::makeRequestToHost(const std::string& host_address,
const Common::Redis::RespValue& request,
Common::Redis::Client::ClientCallbacks& callbacks) {
return tls_->getTyped<ThreadLocalPool>().makeRequestToHost(host_address, request, callbacks);
}

InstanceImpl::ThreadLocalPool::ThreadLocalPool(InstanceImpl& parent, Event::Dispatcher& dispatcher,
InstanceImpl::ThreadLocalPool::ThreadLocalPool(std::shared_ptr<InstanceImpl> parent,
Event::Dispatcher& dispatcher,
std::string cluster_name)
: parent_(parent), dispatcher_(dispatcher), cluster_name_(std::move(cluster_name)),
drain_timer_(dispatcher.createTimer([this]() -> void { drainClients(); })),
is_redis_cluster_(false) {
cluster_update_handle_ = parent_.cm_.addThreadLocalClusterUpdateCallbacks(*this);
Upstream::ThreadLocalCluster* cluster = parent_.cm_.get(cluster_name_);
is_redis_cluster_(false), client_factory_(parent->client_factory_), config_(parent->config_),
stats_scope_(parent->stats_scope_), redis_command_stats_(parent->redis_command_stats_),
redis_cluster_stats_(parent->redis_cluster_stats_),
refresh_manager_(parent->refresh_manager_) {
cluster_update_handle_ = parent->cm_.addThreadLocalClusterUpdateCallbacks(*this);
Upstream::ThreadLocalCluster* cluster = parent->cm_.get(cluster_name_);
if (cluster != nullptr) {
auth_username_ = ProtocolOptionsConfigImpl::authUsername(cluster->info(), parent_.api_);
auth_password_ = ProtocolOptionsConfigImpl::authPassword(cluster->info(), parent_.api_);
auth_username_ = ProtocolOptionsConfigImpl::authUsername(cluster->info(), parent->api_);
auth_password_ = ProtocolOptionsConfigImpl::authPassword(cluster->info(), parent->api_);
onClusterAddOrUpdateNonVirtual(*cluster);
}
}
Expand All @@ -100,6 +119,11 @@ void InstanceImpl::ThreadLocalPool::onClusterAddOrUpdateNonVirtual(
if (cluster.info()->name() != cluster_name_) {
return;
}
// Ensure the filter is not deleted in the main thread during this method.
auto shared_parent = parent_.lock();
if (!shared_parent) {
return;
}

if (cluster_ != nullptr) {
// Treat an update as a removal followed by an add.
Expand Down Expand Up @@ -215,9 +239,9 @@ InstanceImpl::ThreadLocalPool::threadLocalActiveClient(Upstream::HostConstShared
if (!client) {
client = std::make_unique<ThreadLocalActiveClient>(*this);
client->host_ = host;
client->redis_client_ = parent_.client_factory_.create(
host, dispatcher_, parent_.config_, parent_.redis_command_stats_, *parent_.stats_scope_,
auth_username_, auth_password_);
client->redis_client_ =
client_factory_.create(host, dispatcher_, *config_, redis_command_stats_, *(stats_scope_),
auth_username_, auth_password_);
client->redis_client_->addConnectionCallbacks(*client);
}
return client;
Expand All @@ -232,9 +256,9 @@ InstanceImpl::ThreadLocalPool::makeRequest(const std::string& key, RespVariant&&
return nullptr;
}

Clusters::Redis::RedisLoadBalancerContextImpl lb_context(key, parent_.config_.enableHashtagging(),
Clusters::Redis::RedisLoadBalancerContextImpl lb_context(key, config_->enableHashtagging(),
is_redis_cluster_, getRequest(request),
parent_.config_.readPolicy());
config_->readPolicy());
Upstream::HostConstSharedPtr host = cluster_->loadBalancer().chooseHost(&lb_context);
if (!host) {
return nullptr;
Expand Down Expand Up @@ -290,9 +314,9 @@ Common::Redis::Client::PoolRequest* InstanceImpl::ThreadLocalPool::makeRequestTo
auto it = host_address_map_.find(host_address_map_key);
if (it == host_address_map_.end()) {
// This host is not known to the cluster manager. Create a new host and insert it into the map.
if (created_via_redirect_hosts_.size() == parent_.config_.maxUpstreamUnknownConnections()) {
if (created_via_redirect_hosts_.size() == config_->maxUpstreamUnknownConnections()) {
// Too many upstream connections to unknown hosts have been created.
parent_.redis_cluster_stats_.max_upstream_unknown_connections_reached_.inc();
redis_cluster_stats_.max_upstream_unknown_connections_reached_.inc();
return nullptr;
}
if (!ipv6) {
Expand Down Expand Up @@ -344,7 +368,7 @@ void InstanceImpl::ThreadLocalActiveClient::onEvent(Network::ConnectionEvent eve
it++) {
if ((*it).get() == this) {
if (!redis_client_->active()) {
parent_.parent_.redis_cluster_stats_.upstream_cx_drained_.inc();
parent_.redis_cluster_stats_.upstream_cx_drained_.inc();
}
parent_.dispatcher_.deferredDelete(std::move(redis_client_));
parent_.clients_to_drain_.erase(it);
Expand Down Expand Up @@ -380,7 +404,7 @@ void InstanceImpl::PendingRequest::onResponse(Common::Redis::RespValuePtr&& resp
void InstanceImpl::PendingRequest::onFailure() {
request_handler_ = nullptr;
pool_callbacks_.onFailure();
parent_.parent_.onFailure();
parent_.refresh_manager_->onFailure(parent_.cluster_name_);
parent_.onRequestCompleted();
}

Expand All @@ -403,7 +427,7 @@ bool InstanceImpl::PendingRequest::onRedirection(Common::Redis::RespValuePtr&& v
onResponse(std::move(value));
return false;
} else {
parent_.parent_.onRedirection();
parent_.refresh_manager_->onRedirection(parent_.cluster_name_);
return true;
}
}
Expand Down
22 changes: 14 additions & 8 deletions source/extensions/filters/network/redis_proxy/conn_pool_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "source/extensions/clusters/redis/redis_cluster_lb.h"

#include "extensions/common/redis/cluster_refresh_manager.h"
#include "extensions/filters/network/common/redis/client.h"
#include "extensions/filters/network/common/redis/client_impl.h"
#include "extensions/filters/network/common/redis/codec_impl.h"
#include "extensions/filters/network/common/redis/utility.h"
Expand Down Expand Up @@ -52,7 +53,7 @@ class DoNothingPoolCallbacks : public PoolCallbacks {
void onFailure() override{};
};

class InstanceImpl : public Instance {
class InstanceImpl : public Instance, public std::enable_shared_from_this<InstanceImpl> {
public:
InstanceImpl(
const std::string& cluster_name, Upstream::ClusterManager& cm,
Expand All @@ -79,9 +80,7 @@ class InstanceImpl : public Instance {
makeRequestToHost(const std::string& host_address, const Common::Redis::RespValue& request,
Common::Redis::Client::ClientCallbacks& callbacks);

bool onRedirection() override { return refresh_manager_->onRedirection(cluster_name_); }
bool onFailure() { return refresh_manager_->onFailure(cluster_name_); }
bool onHostDegraded() { return refresh_manager_->onHostDegraded(cluster_name_); }
void init();

// Allow the unit test to have access to private members.
friend class RedisConnPoolImplTest;
Expand Down Expand Up @@ -127,7 +126,8 @@ class InstanceImpl : public Instance {

struct ThreadLocalPool : public ThreadLocal::ThreadLocalObject,
public Upstream::ClusterUpdateCallbacks {
ThreadLocalPool(InstanceImpl& parent, Event::Dispatcher& dispatcher, std::string cluster_name);
ThreadLocalPool(std::shared_ptr<InstanceImpl> parent, Event::Dispatcher& dispatcher,
std::string cluster_name);
~ThreadLocalPool() override;
ThreadLocalActiveClientPtr& threadLocalActiveClient(Upstream::HostConstSharedPtr host);
Common::Redis::Client::PoolRequest* makeRequest(const std::string& key, RespVariant&& request,
Expand All @@ -149,7 +149,7 @@ class InstanceImpl : public Instance {

void onRequestCompleted();

InstanceImpl& parent_;
std::weak_ptr<InstanceImpl> parent_;
Event::Dispatcher& dispatcher_;
const std::string cluster_name_;
Upstream::ClusterUpdateCallbacksHandlePtr cluster_update_handle_;
Expand All @@ -171,15 +171,21 @@ class InstanceImpl : public Instance {
*/
Event::TimerPtr drain_timer_;
bool is_redis_cluster_;
Common::Redis::Client::ClientFactory& client_factory_;
Common::Redis::Client::ConfigSharedPtr config_;
Stats::ScopeSharedPtr stats_scope_;
Common::Redis::RedisCommandStatsSharedPtr redis_command_stats_;
RedisClusterStats redis_cluster_stats_;
const Extensions::Common::Redis::ClusterRefreshManagerSharedPtr refresh_manager_;
};

const std::string cluster_name_;
Upstream::ClusterManager& cm_;
Common::Redis::Client::ClientFactory& client_factory_;
ThreadLocal::SlotPtr tls_;
Common::Redis::Client::ConfigImpl config_;
Common::Redis::Client::ConfigSharedPtr config_;
Api::Api& api_;
Stats::ScopePtr stats_scope_;
Stats::ScopeSharedPtr stats_scope_;
Common::Redis::RedisCommandStatsSharedPtr redis_command_stats_;
RedisClusterStats redis_cluster_stats_;
const Extensions::Common::Redis::ClusterRefreshManagerSharedPtr refresh_manager_;
Expand Down
1 change: 1 addition & 0 deletions test/extensions/clusters/redis/redis_cluster_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class RedisClusterTest : public testing::Test,

void expectRedisSessionCreated() {
resolve_timer_ = new Event::MockTimer(&dispatcher_);
EXPECT_CALL(*resolve_timer_, disableTimer());
ON_CALL(random_, random()).WillByDefault(Return(0));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ class RedisConnPoolImplTest : public testing::Test, public Common::Redis::Client
std::make_shared<NiceMock<Extensions::Common::Redis::MockClusterRefreshManager>>();
auto redis_command_stats =
Common::Redis::RedisCommandStats::createRedisCommandStats(store->symbolTable());
std::unique_ptr<InstanceImpl> conn_pool_impl = std::make_unique<InstanceImpl>(
std::shared_ptr<InstanceImpl> conn_pool_impl = std::make_shared<InstanceImpl>(
cluster_name_, cm_, *this, tls_,
Common::Redis::Client::createConnPoolSettings(20, hashtagging, true, max_unknown_conns,
read_policy_),
api_, std::move(store), redis_command_stats, cluster_refresh_manager_);
conn_pool_impl->init();
// Set the authentication password for this connection pool.
conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>().auth_username_ = auth_username_;
conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>().auth_password_ = auth_password_;
Expand Down Expand Up @@ -176,6 +177,11 @@ class RedisConnPoolImplTest : public testing::Test, public Common::Redis::Client
return conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>().clients_to_drain_;
}

InstanceImpl::ThreadLocalPool& threadLocalPool() {
InstanceImpl* conn_pool_impl = dynamic_cast<InstanceImpl*>(conn_pool_.get());
return conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>();
}

Event::TimerPtr& drainTimer() {
InstanceImpl* conn_pool_impl = dynamic_cast<InstanceImpl*>(conn_pool_.get());
return conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>().drain_timer_;
Expand Down Expand Up @@ -1156,6 +1162,61 @@ TEST_F(RedisConnPoolImplTest, AskRedirectionFailure) {
tls_.shutdownThread();
}

TEST_F(RedisConnPoolImplTest, MakeRequestAndRedirectFollowedByDelete) {
tls_.defer_delete = true;
std::unique_ptr<NiceMock<Stats::MockStore>> store =
std::make_unique<NiceMock<Stats::MockStore>>();
cluster_refresh_manager_ =
std::make_shared<NiceMock<Extensions::Common::Redis::MockClusterRefreshManager>>();
auto redis_command_stats =
Common::Redis::RedisCommandStats::createRedisCommandStats(store->symbolTable());
conn_pool_ = std::make_shared<InstanceImpl>(
cluster_name_, cm_, *this, tls_,
Common::Redis::Client::createConnPoolSettings(20, true, true, 100, read_policy_), api_,
std::move(store), redis_command_stats, cluster_refresh_manager_);
conn_pool_->init();

auto& local_pool = threadLocalPool();
conn_pool_.reset();

// Request
Common::Redis::Client::MockClient* client = new NiceMock<Common::Redis::Client::MockClient>();
Common::Redis::RespValueSharedPtr value = std::make_shared<Common::Redis::RespValue>();
Common::Redis::Client::MockPoolRequest active_request;
MockPoolCallbacks callbacks;
EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_))
.WillOnce(Invoke([&](Upstream::LoadBalancerContext* context) -> Upstream::HostConstSharedPtr {
EXPECT_EQ(context->computeHashKey().value(), MurmurHash::murmurHash2_64("hash_key"));
EXPECT_EQ(context->metadataMatchCriteria(), nullptr);
EXPECT_EQ(context->downstreamConnection(), nullptr);
return this->cm_.thread_local_cluster_.lb_.host_;
}));
EXPECT_CALL(*this, create_(_)).WillOnce(Return(client));
EXPECT_CALL(*cm_.thread_local_cluster_.lb_.host_, address())
.WillRepeatedly(Return(this->test_address_));
EXPECT_CALL(*client, makeRequest_(Ref(*value), _)).WillOnce(Return(&active_request));
EXPECT_NE(nullptr, local_pool.makeRequest("hash_key", value, callbacks));

// Move redirection.
Common::Redis::Client::MockPoolRequest active_request2;
Common::Redis::Client::MockClient* client2 = new NiceMock<Common::Redis::Client::MockClient>();
Upstream::HostConstSharedPtr host1;
Common::Redis::RespValuePtr moved_response{new Common::Redis::RespValue()};
moved_response->type(Common::Redis::RespType::Error);
moved_response->asString() = "MOVED 1111 10.1.2.3:4000";

EXPECT_CALL(*this, create_(_)).WillOnce(DoAll(SaveArg<0>(&host1), Return(client2)));
EXPECT_CALL(*client2, makeRequest_(Ref(*value), _)).WillOnce(Return(&active_request2));
EXPECT_TRUE(client->client_callbacks_.back()->onRedirection(std::move(moved_response),
"10.1.2.3:4000", false));
EXPECT_EQ(host1->address()->asString(), "10.1.2.3:4000");
EXPECT_CALL(callbacks, onResponse_(_));
client2->client_callbacks_.back()->onResponse(std::make_unique<Common::Redis::RespValue>());

EXPECT_CALL(*client, close());
tls_.shutdownThread();
}

} // namespace ConnPool
} // namespace RedisProxy
} // namespace NetworkFilters
Expand Down
4 changes: 3 additions & 1 deletion test/mocks/thread_local/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class MockInstance : public Instance {

~SlotImpl() override {
// Do not actually clear slot data during shutdown. This mimics the production code.
if (!parent_.shutdown_) {
// The defer_delete mimics the recycle() code with Bookkeeper.
if (!parent_.shutdown_ && !parent_.defer_delete) {
EXPECT_LT(index_, parent_.data_.size());
parent_.data_[index_].reset();
}
Expand Down Expand Up @@ -98,6 +99,7 @@ class MockInstance : public Instance {
bool defer_data{};
bool shutdown_{};
bool registered_{true};
bool defer_delete{};
};

} // namespace ThreadLocal
Expand Down

0 comments on commit 2705235

Please sign in to comment.