Skip to content

Commit

Permalink
tls: allow nullptr as slot value (#13883)
Browse files Browse the repository at this point in the history
Commit Message: add support for null data to tls slot typed API, using optional references. To make these less cumbersome at call-sites, add a struct OptRef wrapper for absl::optional<std::reference_wrapper> that allows directly accessing via -> syntax if the caller can guarantee the optional reference is populated.
Additional Description:
Risk Level: low
Testing: //test/...
Docs Changes: n/a
Release Notes: n/a

Signed-off-by: Joshua Marantz <[email protected]>
  • Loading branch information
jmarantz authored Nov 4, 2020
1 parent bf55b57 commit 7620e7d
Show file tree
Hide file tree
Showing 23 changed files with 229 additions and 125 deletions.
1 change: 1 addition & 0 deletions include/envoy/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ envoy_basic_cc_library(
name = "base_includes",
hdrs = [
"exception.h",
"optref.h",
"platform.h",
"pure.h",
],
Expand Down
39 changes: 39 additions & 0 deletions include/envoy/common/optref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include "absl/types/optional.h"

namespace Envoy {

// Helper class to make it easier to work with optional references, allowing:
// foo(OptRef<T> t) {
// if (t.has_value()) {
// t->method();
// }
// }
//
// Using absl::optional directly you must write optref.value().method() which is
// a bit more awkward.
template <class T> struct OptRef : public absl::optional<std::reference_wrapper<T>> {
OptRef(T& t) : absl::optional<std::reference_wrapper<T>>(t) {}
OptRef() = default;

/**
* Helper to call a method on T. The caller is responsible for ensuring
* has_value() is true.
*/
T* operator->() {
T& ref = **this;
return &ref;
}

/**
* Helper to call a const method on T. The caller is responsible for ensuring
* has_value() is true.
*/
const T* operator->() const {
const T& ref = **this;
return &ref;
}
};

} // namespace Envoy
52 changes: 33 additions & 19 deletions include/envoy/thread_local/thread_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <functional>
#include <memory>

#include "envoy/common/optref.h"
#include "envoy/common/pure.h"
#include "envoy/event/dispatcher.h"

Expand Down Expand Up @@ -93,8 +94,6 @@ class Slot {
// Callers must use the TypedSlot API, below.
virtual void runOnAllThreads(const UpdateCb& update_cb) PURE;
virtual void runOnAllThreads(const UpdateCb& update_cb, const Event::PostCb& complete_cb) PURE;
virtual void runOnAllThreads(const Event::PostCb& cb) PURE;
virtual void runOnAllThreads(const Event::PostCb& cb, const Event::PostCb& complete_cb) PURE;
};

using SlotPtr = std::unique_ptr<Slot>;
Expand Down Expand Up @@ -157,39 +156,54 @@ template <class T> class TypedSlot {
void set(InitializeCb cb) { slot_->set(cb); }

/**
* @return a reference to the thread local object.
* @return an optional reference to the thread local object.
*/
T& get() { return slot_->getTyped<T>(); }
const T& get() const { return slot_->getTyped<T>(); }
OptRef<T> get() { return getOpt(slot_->get()); }
const OptRef<T> get() const { return getOpt(slot_->get()); }

/**
* Helper function to call methods on T. The caller is responsible
* for ensuring that get().has_value() is true.
*
* @return a pointer to the thread local object.
*/
T* operator->() { return &get(); }
const T* operator->() const { return &get(); }
T* operator->() { return &(slot_->getTyped<T>()); }
const T* operator->() const { return &(slot_->getTyped<T>()); }

/**
* UpdateCb is passed a mutable reference to the current stored data.
* Helper function to get access to a T&. The caller is responsible for
* ensuring that get().has_value() is true.
*
* NOTE: The update callback is not supposed to capture the TypedSlot, or its owner, as the owner
* may be destructed in main thread before the update_cb gets called in a worker thread.
* @return a reference to the thread local object.
*/
using UpdateCb = std::function<void(T& obj)>;
T& operator*() { return slot_->getTyped<T>(); }
const T& operator*() const { return slot_->getTyped<T>(); }

/**
* UpdateCb is passed a mutable pointer to the current stored data. Callers
* can assume that the passed-in OptRef has a value if they have called set(),
* yielding a non-null shared_ptr, prior to runOnAllThreads().
*
* NOTE: The update callback is not supposed to capture the TypedSlot, or its
* owner, as the owner may be destructed in main thread before the update_cb
* gets called in a worker thread.
*/
using UpdateCb = std::function<void(OptRef<T> obj)>;
void runOnAllThreads(const UpdateCb& cb) { slot_->runOnAllThreads(makeSlotUpdateCb(cb)); }
void runOnAllThreads(const UpdateCb& cb, const Event::PostCb& complete_cb) {
slot_->runOnAllThreads(makeSlotUpdateCb(cb), complete_cb);
}
void runOnAllThreads(const Event::PostCb& cb) { slot_->runOnAllThreads(cb); }
void runOnAllThreads(const Event::PostCb& cb, const Event::PostCb& complete_cb) {
slot_->runOnAllThreads(cb, complete_cb);
}

private:
static OptRef<T> getOpt(ThreadLocalObjectSharedPtr obj) {
if (obj) {
return OptRef<T>(obj->asType<T>());
}
return OptRef<T>();
}

Slot::UpdateCb makeSlotUpdateCb(UpdateCb cb) {
return [cb](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr {
cb(obj->asType<T>());
return obj;
};
return [cb](ThreadLocalObjectSharedPtr obj) { cb(getOpt(obj)); };
}

const SlotPtr slot_;
Expand Down
4 changes: 2 additions & 2 deletions source/common/config/config_provider_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ ConfigSubscriptionCommonBase::~ConfigSubscriptionCommonBase() {
}

void ConfigSubscriptionCommonBase::applyConfigUpdate(const ConfigUpdateCb& update_fn) {
tls_.runOnAllThreads([update_fn](ThreadLocalConfig& thread_local_config) {
thread_local_config.config_ = update_fn(thread_local_config.config_);
tls_.runOnAllThreads([update_fn](OptRef<ThreadLocalConfig> thread_local_config) {
thread_local_config->config_ = update_fn(thread_local_config->config_);
});
}

Expand Down
6 changes: 3 additions & 3 deletions source/common/filter/http/filter_config_discovery_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ DynamicFilterConfigProviderImpl::~DynamicFilterConfigProviderImpl() {
const std::string& DynamicFilterConfigProviderImpl::name() { return subscription_->name(); }

absl::optional<Envoy::Http::FilterFactoryCb> DynamicFilterConfigProviderImpl::config() {
return tls_.get().config_;
return tls_->config_;
}

void DynamicFilterConfigProviderImpl::validateConfig(
Expand All @@ -53,8 +53,8 @@ void DynamicFilterConfigProviderImpl::onConfigUpdate(Envoy::Http::FilterFactoryC
const std::string&,
Config::ConfigAppliedCb cb) {
tls_.runOnAllThreads(
[config, cb](ThreadLocalConfig& tls) {
tls.config_ = config;
[config, cb](OptRef<ThreadLocalConfig> tls) {
tls->config_ = config;
if (cb) {
cb();
}
Expand Down
2 changes: 1 addition & 1 deletion source/common/router/rds_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ Router::ConfigConstSharedPtr RdsRouteConfigProviderImpl::config() { return tls_-
void RdsRouteConfigProviderImpl::onConfigUpdate() {
ConfigConstSharedPtr new_config(new ConfigImpl(config_update_info_->routeConfiguration(),
factory_context_, validator_, false));
tls_.runOnAllThreads([new_config](ThreadLocalConfig& tls) { tls.config_ = new_config; });
tls_.runOnAllThreads([new_config](OptRef<ThreadLocalConfig> tls) { tls->config_ = new_config; });

const auto aliases = config_update_info_->resourceIdsInLastVhdsUpdate();
// Regular (non-VHDS) RDS updates don't populate aliases fields in resources.
Expand Down
18 changes: 9 additions & 9 deletions source/common/stats/thread_local_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ void ThreadLocalStoreImpl::mergeHistograms(PostMergeCb merge_complete_cb) {
ASSERT(!merge_in_progress_);
merge_in_progress_ = true;
tls_cache_->runOnAllThreads(
[](TlsCache& tls_cache) {
for (const auto& id_hist : tls_cache.tls_histogram_cache_) {
[](OptRef<TlsCache> tls_cache) {
for (const auto& id_hist : tls_cache->tls_histogram_cache_) {
const TlsHistogramSharedPtr& tls_hist = id_hist.second;
tls_hist->beginMerge();
}
Expand Down Expand Up @@ -303,7 +303,7 @@ void ThreadLocalStoreImpl::clearScopeFromCaches(uint64_t scope_id,
if (!shutting_down_) {
// Perform a cache flush on all threads.
tls_cache_->runOnAllThreads(
[scope_id](TlsCache& tls_cache) { tls_cache.eraseScope(scope_id); },
[scope_id](OptRef<TlsCache> tls_cache) { tls_cache->eraseScope(scope_id); },
[central_cache]() { /* Holds onto central_cache until all tls caches are clear */ });
}
}
Expand All @@ -320,7 +320,7 @@ void ThreadLocalStoreImpl::clearHistogramFromCaches(uint64_t histogram_id) {
// contains a patch that will implement batching together to clear multiple
// histograms.
tls_cache_->runOnAllThreads(
[histogram_id](TlsCache& tls_cache) { tls_cache.eraseHistogram(histogram_id); });
[histogram_id](OptRef<TlsCache> tls_cache) { tls_cache->eraseHistogram(histogram_id); });
}
}

Expand Down Expand Up @@ -489,7 +489,7 @@ Counter& ThreadLocalStoreImpl::ScopeImpl::counterFromStatNameWithTags(
StatRefMap<Counter>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_cache_) {
TlsCacheEntry& entry = parent_.tls_cache_->get().insertScope(this->scope_id_);
TlsCacheEntry& entry = parent_.tlsCache().insertScope(this->scope_id_);
tls_cache = &entry.counters_;
tls_rejected_stats = &entry.rejected_stats_;
}
Expand Down Expand Up @@ -541,7 +541,7 @@ Gauge& ThreadLocalStoreImpl::ScopeImpl::gaugeFromStatNameWithTags(
StatRefMap<Gauge>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_cache_) {
TlsCacheEntry& entry = parent_.tls_cache_->get().scope_cache_[this->scope_id_];
TlsCacheEntry& entry = parent_.tlsCache().scope_cache_[this->scope_id_];
tls_cache = &entry.gauges_;
tls_rejected_stats = &entry.rejected_stats_;
}
Expand Down Expand Up @@ -579,7 +579,7 @@ Histogram& ThreadLocalStoreImpl::ScopeImpl::histogramFromStatNameWithTags(
StatNameHashMap<ParentHistogramSharedPtr>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_cache_) {
TlsCacheEntry& entry = parent_.tls_cache_->get().scope_cache_[this->scope_id_];
TlsCacheEntry& entry = parent_.tlsCache().scope_cache_[this->scope_id_];
tls_cache = &entry.parent_histograms_;
auto iter = tls_cache->find(final_stat_name);
if (iter != tls_cache->end()) {
Expand Down Expand Up @@ -657,7 +657,7 @@ TextReadout& ThreadLocalStoreImpl::ScopeImpl::textReadoutFromStatNameWithTags(
StatRefMap<TextReadout>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_cache_) {
TlsCacheEntry& entry = parent_.tls_cache_->get().insertScope(this->scope_id_);
TlsCacheEntry& entry = parent_.tlsCache().insertScope(this->scope_id_);
tls_cache = &entry.text_readouts_;
tls_rejected_stats = &entry.rejected_stats_;
}
Expand Down Expand Up @@ -703,7 +703,7 @@ Histogram& ThreadLocalStoreImpl::tlsHistogram(ParentHistogramImpl& parent, uint6

TlsHistogramSharedPtr* tls_histogram = nullptr;
if (!shutting_down_ && tls_cache_) {
tls_histogram = &tls_cache_->get().tls_histogram_cache_[id];
tls_histogram = &(tlsCache().tls_histogram_cache_[id]);
if (*tls_histogram != nullptr) {
return **tls_histogram;
}
Expand Down
1 change: 1 addition & 0 deletions source/common/stats/thread_local_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ class ThreadLocalStoreImpl : Logger::Loggable<Logger::Id::stats>, public StoreRo
void removeRejectedStats(StatMapClass& map, StatListClass& list);
bool checkAndRememberRejection(StatName name, StatNameStorageSet& central_rejected_stats,
StatNameHashSet* tls_rejected_stats);
TlsCache& tlsCache() { return **tls_cache_; }

Allocator& alloc_;
Event::Dispatcher* main_thread_dispatcher_{};
Expand Down
9 changes: 0 additions & 9 deletions source/common/thread_local/thread_local_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,6 @@ void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb) {
parent_.runOnAllThreads(dataCallback(cb));
}

void InstanceImpl::SlotImpl::runOnAllThreads(const Event::PostCb& cb,
const Event::PostCb& complete_cb) {
parent_.runOnAllThreads(wrapCallback(cb), complete_cb);
}

void InstanceImpl::SlotImpl::runOnAllThreads(const Event::PostCb& cb) {
parent_.runOnAllThreads(wrapCallback(cb));
}

void InstanceImpl::SlotImpl::set(InitializeCb cb) {
ASSERT(std::this_thread::get_id() == parent_.main_thread_id_);
ASSERT(!parent_.shutdown_);
Expand Down
2 changes: 0 additions & 2 deletions source/common/thread_local/thread_local_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ class InstanceImpl : Logger::Loggable<Logger::Id::main>, public NonCopyable, pub
ThreadLocalObjectSharedPtr get() override;
void runOnAllThreads(const UpdateCb& cb) override;
void runOnAllThreads(const UpdateCb& cb, const Event::PostCb& complete_cb) override;
void runOnAllThreads(const Event::PostCb& cb) override;
void runOnAllThreads(const Event::PostCb& cb, const Event::PostCb& complete_cb) override;
bool currentThreadRegistered() override;
void set(InitializeCb cb) override;

Expand Down
Loading

0 comments on commit 7620e7d

Please sign in to comment.