Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition when the redis filter is destroyed. #11466

Merged
merged 14 commits into from
Jun 26, 2020
7 changes: 7 additions & 0 deletions source/extensions/clusters/redis/redis_cluster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ RedisCluster::DnsDiscoveryResolveTarget::~DnsDiscoveryResolveTarget() {
if (active_query_) {
active_query_->cancel();
}
// Disable timer for mock tests.
if (resolve_timer_) {
resolve_timer_->disableTimer();
}
mattklein123 marked this conversation as resolved.
Show resolved Hide resolved
}

void RedisCluster::DnsDiscoveryResolveTarget::startResolveDns() {
Expand Down Expand Up @@ -228,6 +232,9 @@ RedisCluster::RedisDiscoverySession::~RedisDiscoverySession() {
current_request_->cancel();
current_request_ = nullptr;
}
if (resolve_timer_) {
resolve_timer_->disableTimer();
}
Comment on lines +236 to +238
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? If so comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment. Sorry I missed this one.


while (!client_map_.empty()) {
client_map_.begin()->second->client_->close();
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/filters/network/common/redis/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ class Config {
virtual ReadPolicy readPolicy() const PURE;
};

using ConfigSharedPtr = std::shared_ptr<Config>;

/**
* A factory for individual redis client connections.
*/
Expand Down
12 changes: 6 additions & 6 deletions source/extensions/filters/network/redis_proxy/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ Network::FilterFactoryCb RedisProxyFilterConfigFactory::createFilterFactoryFromP
for (auto& cluster : unique_clusters) {
Stats::ScopePtr stats_scope =
context.scope().createScope(fmt::format("cluster.{}.redis_cluster", cluster));

upstreams.emplace(cluster, std::make_shared<ConnPool::InstanceImpl>(
cluster, context.clusterManager(),
Common::Redis::Client::ClientFactoryImpl::instance_,
context.threadLocal(), proto_config.settings(), context.api(),
std::move(stats_scope), redis_command_stats, refresh_manager));
auto conn_pool_ptr = std::make_shared<ConnPool::InstanceImpl>(
cluster, context.clusterManager(), Common::Redis::Client::ClientFactoryImpl::instance_,
context.threadLocal(), proto_config.settings(), context.api(), std::move(stats_scope),
redis_command_stats, refresh_manager);
conn_pool_ptr->init();
upstreams.emplace(cluster, conn_pool_ptr);
}

auto router =
Expand Down
8 changes: 0 additions & 8 deletions source/extensions/filters/network/redis_proxy/conn_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@ class Instance {
*/
virtual Common::Redis::Client::PoolRequest*
makeRequest(const std::string& hash_key, RespVariant&& request, PoolCallbacks& callbacks) PURE;

/**
* Notify the redirection manager singleton that a redirection error has been received from an
* upstream server associated with the pool's associated cluster.
* @return bool true if a cluster's registered callback with the redirection manager is scheduled
* to be called from the main thread dispatcher, false otherwise.
*/
virtual bool onRedirection() PURE;
};

using InstanceSharedPtr = std::shared_ptr<Instance>;
Expand Down
68 changes: 46 additions & 22 deletions source/extensions/filters/network/redis_proxy/conn_pool_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,57 @@ InstanceImpl::InstanceImpl(
const Common::Redis::RedisCommandStatsSharedPtr& redis_command_stats,
Extensions::Common::Redis::ClusterRefreshManagerSharedPtr refresh_manager)
: cluster_name_(cluster_name), cm_(cm), client_factory_(client_factory),
tls_(tls.allocateSlot()), config_(config), api_(api), stats_scope_(std::move(stats_scope)),
tls_(tls.allocateSlot()), config_(new Common::Redis::Client::ConfigImpl(config)), api_(api),
stats_scope_(std::move(stats_scope)),
redis_command_stats_(redis_command_stats), redis_cluster_stats_{REDIS_CLUSTER_STATS(
POOL_COUNTER(*stats_scope_))},
refresh_manager_(std::move(refresh_manager)) {
tls_->set([this, cluster_name](
Event::Dispatcher& dispatcher) -> ThreadLocal::ThreadLocalObjectSharedPtr {
return std::make_shared<ThreadLocalPool>(*this, dispatcher, cluster_name);
});
refresh_manager_(std::move(refresh_manager)) {}

void InstanceImpl::init() {
// Note: `this` and `cluster_name` have a a lifetime of the filter.
// That may be shorter than the tls callback if the listener is torn down shortly after it is
// created. We use a weak pointer to make sure this object outlives the tls callbacks.
std::weak_ptr<InstanceImpl> this_weak_ptr = this->shared_from_this();
tls_->set(
[this_weak_ptr](Event::Dispatcher& dispatcher) -> ThreadLocal::ThreadLocalObjectSharedPtr {
if (auto this_shared_ptr = this_weak_ptr.lock()) {
return std::make_shared<ThreadLocalPool>(this_shared_ptr, dispatcher,
this_shared_ptr->cluster_name_);
}
return nullptr;
});
}

// This method is always called from a InstanceSharedPtr we don't have to worry about tls_->getTyped
// failing due to InstanceImpl going away.
Common::Redis::Client::PoolRequest*
InstanceImpl::makeRequest(const std::string& key, RespVariant&& request, PoolCallbacks& callbacks) {
return tls_->getTyped<ThreadLocalPool>().makeRequest(key, std::move(request), callbacks);
}

// This method is always called from a InstanceSharedPtr we don't have to worry about tls_->getTyped
// failing due to InstanceImpl going away.
Common::Redis::Client::PoolRequest*
InstanceImpl::makeRequestToHost(const std::string& host_address,
const Common::Redis::RespValue& request,
Common::Redis::Client::ClientCallbacks& callbacks) {
return tls_->getTyped<ThreadLocalPool>().makeRequestToHost(host_address, request, callbacks);
}

InstanceImpl::ThreadLocalPool::ThreadLocalPool(InstanceImpl& parent, Event::Dispatcher& dispatcher,
InstanceImpl::ThreadLocalPool::ThreadLocalPool(std::shared_ptr<InstanceImpl> parent,
Event::Dispatcher& dispatcher,
std::string cluster_name)
: parent_(parent), dispatcher_(dispatcher), cluster_name_(std::move(cluster_name)),
drain_timer_(dispatcher.createTimer([this]() -> void { drainClients(); })),
is_redis_cluster_(false) {
cluster_update_handle_ = parent_.cm_.addThreadLocalClusterUpdateCallbacks(*this);
Upstream::ThreadLocalCluster* cluster = parent_.cm_.get(cluster_name_);
is_redis_cluster_(false), client_factory_(parent->client_factory_), config_(parent->config_),
stats_scope_(parent->stats_scope_), redis_command_stats_(parent->redis_command_stats_),
redis_cluster_stats_(parent->redis_cluster_stats_),
refresh_manager_(parent->refresh_manager_) {
cluster_update_handle_ = parent->cm_.addThreadLocalClusterUpdateCallbacks(*this);
Upstream::ThreadLocalCluster* cluster = parent->cm_.get(cluster_name_);
if (cluster != nullptr) {
auth_username_ = ProtocolOptionsConfigImpl::authUsername(cluster->info(), parent_.api_);
auth_password_ = ProtocolOptionsConfigImpl::authPassword(cluster->info(), parent_.api_);
auth_username_ = ProtocolOptionsConfigImpl::authUsername(cluster->info(), parent->api_);
auth_password_ = ProtocolOptionsConfigImpl::authPassword(cluster->info(), parent->api_);
onClusterAddOrUpdateNonVirtual(*cluster);
}
}
Expand All @@ -100,6 +119,11 @@ void InstanceImpl::ThreadLocalPool::onClusterAddOrUpdateNonVirtual(
if (cluster.info()->name() != cluster_name_) {
return;
}
// Ensure the filter is not deleted in the main thread during this method.
auto shared_parent = parent_.lock();
if (!shared_parent) {
return;
}

if (cluster_ != nullptr) {
// Treat an update as a removal followed by an add.
Expand Down Expand Up @@ -215,9 +239,9 @@ InstanceImpl::ThreadLocalPool::threadLocalActiveClient(Upstream::HostConstShared
if (!client) {
client = std::make_unique<ThreadLocalActiveClient>(*this);
client->host_ = host;
client->redis_client_ = parent_.client_factory_.create(
host, dispatcher_, parent_.config_, parent_.redis_command_stats_, *parent_.stats_scope_,
auth_username_, auth_password_);
client->redis_client_ =
client_factory_.create(host, dispatcher_, *config_, redis_command_stats_, *(stats_scope_),
auth_username_, auth_password_);
client->redis_client_->addConnectionCallbacks(*client);
}
return client;
Expand All @@ -232,9 +256,9 @@ InstanceImpl::ThreadLocalPool::makeRequest(const std::string& key, RespVariant&&
return nullptr;
}

Clusters::Redis::RedisLoadBalancerContextImpl lb_context(key, parent_.config_.enableHashtagging(),
Clusters::Redis::RedisLoadBalancerContextImpl lb_context(key, config_->enableHashtagging(),
is_redis_cluster_, getRequest(request),
parent_.config_.readPolicy());
config_->readPolicy());
Upstream::HostConstSharedPtr host = cluster_->loadBalancer().chooseHost(&lb_context);
if (!host) {
return nullptr;
Expand Down Expand Up @@ -290,9 +314,9 @@ Common::Redis::Client::PoolRequest* InstanceImpl::ThreadLocalPool::makeRequestTo
auto it = host_address_map_.find(host_address_map_key);
if (it == host_address_map_.end()) {
// This host is not known to the cluster manager. Create a new host and insert it into the map.
if (created_via_redirect_hosts_.size() == parent_.config_.maxUpstreamUnknownConnections()) {
if (created_via_redirect_hosts_.size() == config_->maxUpstreamUnknownConnections()) {
// Too many upstream connections to unknown hosts have been created.
parent_.redis_cluster_stats_.max_upstream_unknown_connections_reached_.inc();
redis_cluster_stats_.max_upstream_unknown_connections_reached_.inc();
return nullptr;
}
if (!ipv6) {
Expand Down Expand Up @@ -344,7 +368,7 @@ void InstanceImpl::ThreadLocalActiveClient::onEvent(Network::ConnectionEvent eve
it++) {
if ((*it).get() == this) {
if (!redis_client_->active()) {
parent_.parent_.redis_cluster_stats_.upstream_cx_drained_.inc();
parent_.redis_cluster_stats_.upstream_cx_drained_.inc();
}
parent_.dispatcher_.deferredDelete(std::move(redis_client_));
parent_.clients_to_drain_.erase(it);
Expand Down Expand Up @@ -380,7 +404,7 @@ void InstanceImpl::PendingRequest::onResponse(Common::Redis::RespValuePtr&& resp
void InstanceImpl::PendingRequest::onFailure() {
request_handler_ = nullptr;
pool_callbacks_.onFailure();
parent_.parent_.onFailure();
parent_.refresh_manager_->onFailure(parent_.cluster_name_);
parent_.onRequestCompleted();
}

Expand All @@ -403,7 +427,7 @@ bool InstanceImpl::PendingRequest::onRedirection(Common::Redis::RespValuePtr&& v
onResponse(std::move(value));
return false;
} else {
parent_.parent_.onRedirection();
parent_.refresh_manager_->onRedirection(parent_.cluster_name_);
return true;
}
}
Expand Down
22 changes: 14 additions & 8 deletions source/extensions/filters/network/redis_proxy/conn_pool_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "source/extensions/clusters/redis/redis_cluster_lb.h"

#include "extensions/common/redis/cluster_refresh_manager.h"
#include "extensions/filters/network/common/redis/client.h"
#include "extensions/filters/network/common/redis/client_impl.h"
#include "extensions/filters/network/common/redis/codec_impl.h"
#include "extensions/filters/network/common/redis/utility.h"
Expand Down Expand Up @@ -52,7 +53,7 @@ class DoNothingPoolCallbacks : public PoolCallbacks {
void onFailure() override{};
};

class InstanceImpl : public Instance {
class InstanceImpl : public Instance, public std::enable_shared_from_this<InstanceImpl> {
public:
InstanceImpl(
const std::string& cluster_name, Upstream::ClusterManager& cm,
Expand All @@ -79,9 +80,7 @@ class InstanceImpl : public Instance {
makeRequestToHost(const std::string& host_address, const Common::Redis::RespValue& request,
Common::Redis::Client::ClientCallbacks& callbacks);

bool onRedirection() override { return refresh_manager_->onRedirection(cluster_name_); }
bool onFailure() { return refresh_manager_->onFailure(cluster_name_); }
bool onHostDegraded() { return refresh_manager_->onHostDegraded(cluster_name_); }
void init();

// Allow the unit test to have access to private members.
friend class RedisConnPoolImplTest;
Expand Down Expand Up @@ -127,7 +126,8 @@ class InstanceImpl : public Instance {

struct ThreadLocalPool : public ThreadLocal::ThreadLocalObject,
public Upstream::ClusterUpdateCallbacks {
ThreadLocalPool(InstanceImpl& parent, Event::Dispatcher& dispatcher, std::string cluster_name);
ThreadLocalPool(std::shared_ptr<InstanceImpl> parent, Event::Dispatcher& dispatcher,
std::string cluster_name);
~ThreadLocalPool() override;
ThreadLocalActiveClientPtr& threadLocalActiveClient(Upstream::HostConstSharedPtr host);
Common::Redis::Client::PoolRequest* makeRequest(const std::string& key, RespVariant&& request,
Expand All @@ -149,7 +149,7 @@ class InstanceImpl : public Instance {

void onRequestCompleted();

InstanceImpl& parent_;
std::weak_ptr<InstanceImpl> parent_;
Event::Dispatcher& dispatcher_;
const std::string cluster_name_;
Upstream::ClusterUpdateCallbacksHandlePtr cluster_update_handle_;
Expand All @@ -171,15 +171,21 @@ class InstanceImpl : public Instance {
*/
Event::TimerPtr drain_timer_;
bool is_redis_cluster_;
Common::Redis::Client::ClientFactory& client_factory_;
Common::Redis::Client::ConfigSharedPtr config_;
Stats::ScopeSharedPtr stats_scope_;
Common::Redis::RedisCommandStatsSharedPtr redis_command_stats_;
RedisClusterStats redis_cluster_stats_;
const Extensions::Common::Redis::ClusterRefreshManagerSharedPtr refresh_manager_;
};

const std::string cluster_name_;
Upstream::ClusterManager& cm_;
Common::Redis::Client::ClientFactory& client_factory_;
ThreadLocal::SlotPtr tls_;
Common::Redis::Client::ConfigImpl config_;
Common::Redis::Client::ConfigSharedPtr config_;
Api::Api& api_;
Stats::ScopePtr stats_scope_;
Stats::ScopeSharedPtr stats_scope_;
Common::Redis::RedisCommandStatsSharedPtr redis_command_stats_;
RedisClusterStats redis_cluster_stats_;
const Extensions::Common::Redis::ClusterRefreshManagerSharedPtr refresh_manager_;
Expand Down
1 change: 1 addition & 0 deletions test/extensions/clusters/redis/redis_cluster_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class RedisClusterTest : public testing::Test,

void expectRedisSessionCreated() {
resolve_timer_ = new Event::MockTimer(&dispatcher_);
EXPECT_CALL(*resolve_timer_, disableTimer());
ON_CALL(random_, random()).WillByDefault(Return(0));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ class RedisConnPoolImplTest : public testing::Test, public Common::Redis::Client
std::make_shared<NiceMock<Extensions::Common::Redis::MockClusterRefreshManager>>();
auto redis_command_stats =
Common::Redis::RedisCommandStats::createRedisCommandStats(store->symbolTable());
std::unique_ptr<InstanceImpl> conn_pool_impl = std::make_unique<InstanceImpl>(
std::shared_ptr<InstanceImpl> conn_pool_impl = std::make_shared<InstanceImpl>(
cluster_name_, cm_, *this, tls_,
Common::Redis::Client::createConnPoolSettings(20, hashtagging, true, max_unknown_conns,
read_policy_),
api_, std::move(store), redis_command_stats, cluster_refresh_manager_);
conn_pool_impl->init();
// Set the authentication password for this connection pool.
conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>().auth_username_ = auth_username_;
conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>().auth_password_ = auth_password_;
Expand Down Expand Up @@ -176,6 +177,11 @@ class RedisConnPoolImplTest : public testing::Test, public Common::Redis::Client
return conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>().clients_to_drain_;
}

InstanceImpl::ThreadLocalPool& threadLocalPool() {
InstanceImpl* conn_pool_impl = dynamic_cast<InstanceImpl*>(conn_pool_.get());
return conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>();
}

Event::TimerPtr& drainTimer() {
InstanceImpl* conn_pool_impl = dynamic_cast<InstanceImpl*>(conn_pool_.get());
return conn_pool_impl->tls_->getTyped<InstanceImpl::ThreadLocalPool>().drain_timer_;
Expand Down Expand Up @@ -1156,6 +1162,61 @@ TEST_F(RedisConnPoolImplTest, AskRedirectionFailure) {
tls_.shutdownThread();
}

TEST_F(RedisConnPoolImplTest, MakeRequestAndRedirectFollowedByDelete) {
tls_.defer_delete = true;
std::unique_ptr<NiceMock<Stats::MockStore>> store =
std::make_unique<NiceMock<Stats::MockStore>>();
cluster_refresh_manager_ =
std::make_shared<NiceMock<Extensions::Common::Redis::MockClusterRefreshManager>>();
auto redis_command_stats =
Common::Redis::RedisCommandStats::createRedisCommandStats(store->symbolTable());
conn_pool_ = std::make_shared<InstanceImpl>(
cluster_name_, cm_, *this, tls_,
Common::Redis::Client::createConnPoolSettings(20, true, true, 100, read_policy_), api_,
std::move(store), redis_command_stats, cluster_refresh_manager_);
conn_pool_->init();

auto& local_pool = threadLocalPool();
conn_pool_.reset();

// Request
Common::Redis::Client::MockClient* client = new NiceMock<Common::Redis::Client::MockClient>();
Common::Redis::RespValueSharedPtr value = std::make_shared<Common::Redis::RespValue>();
Common::Redis::Client::MockPoolRequest active_request;
MockPoolCallbacks callbacks;
EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_))
.WillOnce(Invoke([&](Upstream::LoadBalancerContext* context) -> Upstream::HostConstSharedPtr {
EXPECT_EQ(context->computeHashKey().value(), MurmurHash::murmurHash2_64("hash_key"));
EXPECT_EQ(context->metadataMatchCriteria(), nullptr);
EXPECT_EQ(context->downstreamConnection(), nullptr);
return this->cm_.thread_local_cluster_.lb_.host_;
}));
EXPECT_CALL(*this, create_(_)).WillOnce(Return(client));
EXPECT_CALL(*cm_.thread_local_cluster_.lb_.host_, address())
.WillRepeatedly(Return(this->test_address_));
EXPECT_CALL(*client, makeRequest_(Ref(*value), _)).WillOnce(Return(&active_request));
EXPECT_NE(nullptr, local_pool.makeRequest("hash_key", value, callbacks));

// Move redirection.
Common::Redis::Client::MockPoolRequest active_request2;
Common::Redis::Client::MockClient* client2 = new NiceMock<Common::Redis::Client::MockClient>();
Upstream::HostConstSharedPtr host1;
Common::Redis::RespValuePtr moved_response{new Common::Redis::RespValue()};
moved_response->type(Common::Redis::RespType::Error);
moved_response->asString() = "MOVED 1111 10.1.2.3:4000";

EXPECT_CALL(*this, create_(_)).WillOnce(DoAll(SaveArg<0>(&host1), Return(client2)));
EXPECT_CALL(*client2, makeRequest_(Ref(*value), _)).WillOnce(Return(&active_request2));
EXPECT_TRUE(client->client_callbacks_.back()->onRedirection(std::move(moved_response),
"10.1.2.3:4000", false));
EXPECT_EQ(host1->address()->asString(), "10.1.2.3:4000");
EXPECT_CALL(callbacks, onResponse_(_));
client2->client_callbacks_.back()->onResponse(std::make_unique<Common::Redis::RespValue>());

EXPECT_CALL(*client, close());
tls_.shutdownThread();
}

} // namespace ConnPool
} // namespace RedisProxy
} // namespace NetworkFilters
Expand Down
4 changes: 3 additions & 1 deletion test/mocks/thread_local/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class MockInstance : public Instance {

~SlotImpl() override {
// Do not actually clear slot data during shutdown. This mimics the production code.
if (!parent_.shutdown_) {
// The defer_delete mimics the recycle() code with Bookkeeper.
if (!parent_.shutdown_ && !parent_.defer_delete) {
EXPECT_LT(index_, parent_.data_.size());
parent_.data_[index_].reset();
}
Expand Down Expand Up @@ -98,6 +99,7 @@ class MockInstance : public Instance {
bool defer_data{};
bool shutdown_{};
bool registered_{true};
bool defer_delete{};
};

} // namespace ThreadLocal
Expand Down