diff --git a/include/envoy/common/BUILD b/include/envoy/common/BUILD index de79fdf0e689..d120b7c0cc54 100644 --- a/include/envoy/common/BUILD +++ b/include/envoy/common/BUILD @@ -13,6 +13,7 @@ envoy_basic_cc_library( name = "base_includes", hdrs = [ "exception.h", + "optref.h", "platform.h", "pure.h", ], diff --git a/include/envoy/common/optref.h b/include/envoy/common/optref.h new file mode 100644 index 000000000000..cf51cdaa52ea --- /dev/null +++ b/include/envoy/common/optref.h @@ -0,0 +1,39 @@ +#pragma once + +#include "absl/types/optional.h" + +namespace Envoy { + +// Helper class to make it easier to work with optional references, allowing: +// foo(OptRef t) { +// if (t.has_value()) { +// t->method(); +// } +// } +// +// Using absl::optional directly you must write optref.value().method() which is +// a bit more awkward. +template struct OptRef : public absl::optional> { + OptRef(T& t) : absl::optional>(t) {} + OptRef() = default; + + /** + * Helper to call a method on T. The caller is responsible for ensuring + * has_value() is true. + */ + T* operator->() { + T& ref = **this; + return &ref; + } + + /** + * Helper to call a const method on T. The caller is responsible for ensuring + * has_value() is true. + */ + const T* operator->() const { + const T& ref = **this; + return &ref; + } +}; + +} // namespace Envoy diff --git a/include/envoy/thread_local/thread_local.h b/include/envoy/thread_local/thread_local.h index fcf7e54e2091..4993b118d1eb 100644 --- a/include/envoy/thread_local/thread_local.h +++ b/include/envoy/thread_local/thread_local.h @@ -4,6 +4,7 @@ #include #include +#include "envoy/common/optref.h" #include "envoy/common/pure.h" #include "envoy/event/dispatcher.h" @@ -93,8 +94,6 @@ class Slot { // Callers must use the TypedSlot API, below. virtual void runOnAllThreads(const UpdateCb& update_cb) PURE; virtual void runOnAllThreads(const UpdateCb& update_cb, const Event::PostCb& complete_cb) PURE; - virtual void runOnAllThreads(const Event::PostCb& cb) PURE; - virtual void runOnAllThreads(const Event::PostCb& cb, const Event::PostCb& complete_cb) PURE; }; using SlotPtr = std::unique_ptr; @@ -157,39 +156,54 @@ template class TypedSlot { void set(InitializeCb cb) { slot_->set(cb); } /** - * @return a reference to the thread local object. + * @return an optional reference to the thread local object. */ - T& get() { return slot_->getTyped(); } - const T& get() const { return slot_->getTyped(); } + OptRef get() { return getOpt(slot_->get()); } + const OptRef get() const { return getOpt(slot_->get()); } /** + * Helper function to call methods on T. The caller is responsible + * for ensuring that get().has_value() is true. + * * @return a pointer to the thread local object. */ - T* operator->() { return &get(); } - const T* operator->() const { return &get(); } + T* operator->() { return &(slot_->getTyped()); } + const T* operator->() const { return &(slot_->getTyped()); } /** - * UpdateCb is passed a mutable reference to the current stored data. + * Helper function to get access to a T&. The caller is responsible for + * ensuring that get().has_value() is true. * - * NOTE: The update callback is not supposed to capture the TypedSlot, or its owner, as the owner - * may be destructed in main thread before the update_cb gets called in a worker thread. + * @return a reference to the thread local object. */ - using UpdateCb = std::function; + T& operator*() { return slot_->getTyped(); } + const T& operator*() const { return slot_->getTyped(); } + + /** + * UpdateCb is passed a mutable pointer to the current stored data. Callers + * can assume that the passed-in OptRef has a value if they have called set(), + * yielding a non-null shared_ptr, prior to runOnAllThreads(). + * + * NOTE: The update callback is not supposed to capture the TypedSlot, or its + * owner, as the owner may be destructed in main thread before the update_cb + * gets called in a worker thread. + */ + using UpdateCb = std::function obj)>; void runOnAllThreads(const UpdateCb& cb) { slot_->runOnAllThreads(makeSlotUpdateCb(cb)); } void runOnAllThreads(const UpdateCb& cb, const Event::PostCb& complete_cb) { slot_->runOnAllThreads(makeSlotUpdateCb(cb), complete_cb); } - void runOnAllThreads(const Event::PostCb& cb) { slot_->runOnAllThreads(cb); } - void runOnAllThreads(const Event::PostCb& cb, const Event::PostCb& complete_cb) { - slot_->runOnAllThreads(cb, complete_cb); - } private: + static OptRef getOpt(ThreadLocalObjectSharedPtr obj) { + if (obj) { + return OptRef(obj->asType()); + } + return OptRef(); + } + Slot::UpdateCb makeSlotUpdateCb(UpdateCb cb) { - return [cb](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr { - cb(obj->asType()); - return obj; - }; + return [cb](ThreadLocalObjectSharedPtr obj) { cb(getOpt(obj)); }; } const SlotPtr slot_; diff --git a/source/common/config/config_provider_impl.cc b/source/common/config/config_provider_impl.cc index b0098883d368..78eddb0ffe10 100644 --- a/source/common/config/config_provider_impl.cc +++ b/source/common/config/config_provider_impl.cc @@ -25,8 +25,8 @@ ConfigSubscriptionCommonBase::~ConfigSubscriptionCommonBase() { } void ConfigSubscriptionCommonBase::applyConfigUpdate(const ConfigUpdateCb& update_fn) { - tls_.runOnAllThreads([update_fn](ThreadLocalConfig& thread_local_config) { - thread_local_config.config_ = update_fn(thread_local_config.config_); + tls_.runOnAllThreads([update_fn](OptRef thread_local_config) { + thread_local_config->config_ = update_fn(thread_local_config->config_); }); } diff --git a/source/common/filter/http/filter_config_discovery_impl.cc b/source/common/filter/http/filter_config_discovery_impl.cc index d6c7e1b6bb62..aef0519b4a88 100644 --- a/source/common/filter/http/filter_config_discovery_impl.cc +++ b/source/common/filter/http/filter_config_discovery_impl.cc @@ -37,7 +37,7 @@ DynamicFilterConfigProviderImpl::~DynamicFilterConfigProviderImpl() { const std::string& DynamicFilterConfigProviderImpl::name() { return subscription_->name(); } absl::optional DynamicFilterConfigProviderImpl::config() { - return tls_.get().config_; + return tls_->config_; } void DynamicFilterConfigProviderImpl::validateConfig( @@ -53,8 +53,8 @@ void DynamicFilterConfigProviderImpl::onConfigUpdate(Envoy::Http::FilterFactoryC const std::string&, Config::ConfigAppliedCb cb) { tls_.runOnAllThreads( - [config, cb](ThreadLocalConfig& tls) { - tls.config_ = config; + [config, cb](OptRef tls) { + tls->config_ = config; if (cb) { cb(); } diff --git a/source/common/router/rds_impl.cc b/source/common/router/rds_impl.cc index 62844e63888a..3886e68e01f6 100644 --- a/source/common/router/rds_impl.cc +++ b/source/common/router/rds_impl.cc @@ -255,7 +255,7 @@ Router::ConfigConstSharedPtr RdsRouteConfigProviderImpl::config() { return tls_- void RdsRouteConfigProviderImpl::onConfigUpdate() { ConfigConstSharedPtr new_config(new ConfigImpl(config_update_info_->routeConfiguration(), factory_context_, validator_, false)); - tls_.runOnAllThreads([new_config](ThreadLocalConfig& tls) { tls.config_ = new_config; }); + tls_.runOnAllThreads([new_config](OptRef tls) { tls->config_ = new_config; }); const auto aliases = config_update_info_->resourceIdsInLastVhdsUpdate(); // Regular (non-VHDS) RDS updates don't populate aliases fields in resources. diff --git a/source/common/stats/thread_local_store.cc b/source/common/stats/thread_local_store.cc index b0704ff97c19..11b977440837 100644 --- a/source/common/stats/thread_local_store.cc +++ b/source/common/stats/thread_local_store.cc @@ -205,8 +205,8 @@ void ThreadLocalStoreImpl::mergeHistograms(PostMergeCb merge_complete_cb) { ASSERT(!merge_in_progress_); merge_in_progress_ = true; tls_cache_->runOnAllThreads( - [](TlsCache& tls_cache) { - for (const auto& id_hist : tls_cache.tls_histogram_cache_) { + [](OptRef tls_cache) { + for (const auto& id_hist : tls_cache->tls_histogram_cache_) { const TlsHistogramSharedPtr& tls_hist = id_hist.second; tls_hist->beginMerge(); } @@ -303,7 +303,7 @@ void ThreadLocalStoreImpl::clearScopeFromCaches(uint64_t scope_id, if (!shutting_down_) { // Perform a cache flush on all threads. tls_cache_->runOnAllThreads( - [scope_id](TlsCache& tls_cache) { tls_cache.eraseScope(scope_id); }, + [scope_id](OptRef tls_cache) { tls_cache->eraseScope(scope_id); }, [central_cache]() { /* Holds onto central_cache until all tls caches are clear */ }); } } @@ -320,7 +320,7 @@ void ThreadLocalStoreImpl::clearHistogramFromCaches(uint64_t histogram_id) { // contains a patch that will implement batching together to clear multiple // histograms. tls_cache_->runOnAllThreads( - [histogram_id](TlsCache& tls_cache) { tls_cache.eraseHistogram(histogram_id); }); + [histogram_id](OptRef tls_cache) { tls_cache->eraseHistogram(histogram_id); }); } } @@ -489,7 +489,7 @@ Counter& ThreadLocalStoreImpl::ScopeImpl::counterFromStatNameWithTags( StatRefMap* tls_cache = nullptr; StatNameHashSet* tls_rejected_stats = nullptr; if (!parent_.shutting_down_ && parent_.tls_cache_) { - TlsCacheEntry& entry = parent_.tls_cache_->get().insertScope(this->scope_id_); + TlsCacheEntry& entry = parent_.tlsCache().insertScope(this->scope_id_); tls_cache = &entry.counters_; tls_rejected_stats = &entry.rejected_stats_; } @@ -541,7 +541,7 @@ Gauge& ThreadLocalStoreImpl::ScopeImpl::gaugeFromStatNameWithTags( StatRefMap* tls_cache = nullptr; StatNameHashSet* tls_rejected_stats = nullptr; if (!parent_.shutting_down_ && parent_.tls_cache_) { - TlsCacheEntry& entry = parent_.tls_cache_->get().scope_cache_[this->scope_id_]; + TlsCacheEntry& entry = parent_.tlsCache().scope_cache_[this->scope_id_]; tls_cache = &entry.gauges_; tls_rejected_stats = &entry.rejected_stats_; } @@ -579,7 +579,7 @@ Histogram& ThreadLocalStoreImpl::ScopeImpl::histogramFromStatNameWithTags( StatNameHashMap* tls_cache = nullptr; StatNameHashSet* tls_rejected_stats = nullptr; if (!parent_.shutting_down_ && parent_.tls_cache_) { - TlsCacheEntry& entry = parent_.tls_cache_->get().scope_cache_[this->scope_id_]; + TlsCacheEntry& entry = parent_.tlsCache().scope_cache_[this->scope_id_]; tls_cache = &entry.parent_histograms_; auto iter = tls_cache->find(final_stat_name); if (iter != tls_cache->end()) { @@ -657,7 +657,7 @@ TextReadout& ThreadLocalStoreImpl::ScopeImpl::textReadoutFromStatNameWithTags( StatRefMap* tls_cache = nullptr; StatNameHashSet* tls_rejected_stats = nullptr; if (!parent_.shutting_down_ && parent_.tls_cache_) { - TlsCacheEntry& entry = parent_.tls_cache_->get().insertScope(this->scope_id_); + TlsCacheEntry& entry = parent_.tlsCache().insertScope(this->scope_id_); tls_cache = &entry.text_readouts_; tls_rejected_stats = &entry.rejected_stats_; } @@ -703,7 +703,7 @@ Histogram& ThreadLocalStoreImpl::tlsHistogram(ParentHistogramImpl& parent, uint6 TlsHistogramSharedPtr* tls_histogram = nullptr; if (!shutting_down_ && tls_cache_) { - tls_histogram = &tls_cache_->get().tls_histogram_cache_[id]; + tls_histogram = &(tlsCache().tls_histogram_cache_[id]); if (*tls_histogram != nullptr) { return **tls_histogram; } diff --git a/source/common/stats/thread_local_store.h b/source/common/stats/thread_local_store.h index 410254afeb33..22707accacbd 100644 --- a/source/common/stats/thread_local_store.h +++ b/source/common/stats/thread_local_store.h @@ -477,6 +477,7 @@ class ThreadLocalStoreImpl : Logger::Loggable, public StoreRo void removeRejectedStats(StatMapClass& map, StatListClass& list); bool checkAndRememberRejection(StatName name, StatNameStorageSet& central_rejected_stats, StatNameHashSet* tls_rejected_stats); + TlsCache& tlsCache() { return **tls_cache_; } Allocator& alloc_; Event::Dispatcher* main_thread_dispatcher_{}; diff --git a/source/common/thread_local/thread_local_impl.cc b/source/common/thread_local/thread_local_impl.cc index 6c450449a942..0815236a3195 100644 --- a/source/common/thread_local/thread_local_impl.cc +++ b/source/common/thread_local/thread_local_impl.cc @@ -90,15 +90,6 @@ void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb) { parent_.runOnAllThreads(dataCallback(cb)); } -void InstanceImpl::SlotImpl::runOnAllThreads(const Event::PostCb& cb, - const Event::PostCb& complete_cb) { - parent_.runOnAllThreads(wrapCallback(cb), complete_cb); -} - -void InstanceImpl::SlotImpl::runOnAllThreads(const Event::PostCb& cb) { - parent_.runOnAllThreads(wrapCallback(cb)); -} - void InstanceImpl::SlotImpl::set(InitializeCb cb) { ASSERT(std::this_thread::get_id() == parent_.main_thread_id_); ASSERT(!parent_.shutdown_); diff --git a/source/common/thread_local/thread_local_impl.h b/source/common/thread_local/thread_local_impl.h index 4f5bb1b88125..7abed0499166 100644 --- a/source/common/thread_local/thread_local_impl.h +++ b/source/common/thread_local/thread_local_impl.h @@ -45,8 +45,6 @@ class InstanceImpl : Logger::Loggable, public NonCopyable, pub ThreadLocalObjectSharedPtr get() override; void runOnAllThreads(const UpdateCb& cb) override; void runOnAllThreads(const UpdateCb& cb, const Event::PostCb& complete_cb) override; - void runOnAllThreads(const Event::PostCb& cb) override; - void runOnAllThreads(const Event::PostCb& cb, const Event::PostCb& complete_cb) override; bool currentThreadRegistered() override; void set(InitializeCb cb) override; diff --git a/source/common/upstream/cluster_manager_impl.cc b/source/common/upstream/cluster_manager_impl.cc index ce6ecf509865..08ca0567025d 100644 --- a/source/common/upstream/cluster_manager_impl.cc +++ b/source/common/upstream/cluster_manager_impl.cc @@ -647,17 +647,17 @@ void ClusterManagerImpl::clusterWarmingToActive(const std::string& cluster_name) void ClusterManagerImpl::createOrUpdateThreadLocalCluster(ClusterData& cluster) { tls_.runOnAllThreads([new_cluster = cluster.cluster_->info(), thread_aware_lb_factory = cluster.loadBalancerFactory()]( - ThreadLocalClusterManagerImpl& cluster_manager) { - if (cluster_manager.thread_local_clusters_.count(new_cluster->name()) > 0) { + OptRef cluster_manager) { + if (cluster_manager->thread_local_clusters_.count(new_cluster->name()) > 0) { ENVOY_LOG(debug, "updating TLS cluster {}", new_cluster->name()); } else { ENVOY_LOG(debug, "adding TLS cluster {}", new_cluster->name()); } auto thread_local_cluster = new ThreadLocalClusterManagerImpl::ClusterEntry( - cluster_manager, new_cluster, thread_aware_lb_factory); - cluster_manager.thread_local_clusters_[new_cluster->name()].reset(thread_local_cluster); - for (auto& cb : cluster_manager.update_callbacks_) { + *cluster_manager, new_cluster, thread_aware_lb_factory); + cluster_manager->thread_local_clusters_[new_cluster->name()].reset(thread_local_cluster); + for (auto& cb : cluster_manager->update_callbacks_) { cb->onClusterAddOrUpdate(*thread_local_cluster); } }); @@ -673,13 +673,13 @@ bool ClusterManagerImpl::removeCluster(const std::string& cluster_name) { active_clusters_.erase(existing_active_cluster); ENVOY_LOG(info, "removing cluster {}", cluster_name); - tls_.runOnAllThreads([cluster_name](ThreadLocalClusterManagerImpl& cluster_manager) { - ASSERT(cluster_manager.thread_local_clusters_.count(cluster_name) == 1); + tls_.runOnAllThreads([cluster_name](OptRef cluster_manager) { + ASSERT(cluster_manager->thread_local_clusters_.count(cluster_name) == 1); ENVOY_LOG(debug, "removing TLS cluster {}", cluster_name); - for (auto& cb : cluster_manager.update_callbacks_) { + for (auto& cb : cluster_manager->update_callbacks_) { cb->onClusterRemoval(cluster_name); } - cluster_manager.thread_local_clusters_.erase(cluster_name); + cluster_manager->thread_local_clusters_.erase(cluster_name); }); } @@ -807,7 +807,7 @@ void ClusterManagerImpl::updateClusterCounts() { } ThreadLocalCluster* ClusterManagerImpl::get(absl::string_view cluster) { - ThreadLocalClusterManagerImpl& cluster_manager = tls_.get(); + ThreadLocalClusterManagerImpl& cluster_manager = *tls_; auto entry = cluster_manager.thread_local_clusters_.find(cluster); if (entry != cluster_manager.thread_local_clusters_.end()) { @@ -846,7 +846,7 @@ Http::ConnectionPool::Instance* ClusterManagerImpl::httpConnPoolForCluster(const std::string& cluster, ResourcePriority priority, absl::optional protocol, LoadBalancerContext* context) { - ThreadLocalClusterManagerImpl& cluster_manager = tls_.get(); + ThreadLocalClusterManagerImpl& cluster_manager = *tls_; auto entry = cluster_manager.thread_local_clusters_.find(cluster); if (entry == cluster_manager.thread_local_clusters_.end()) { @@ -872,7 +872,7 @@ ClusterManagerImpl::httpConnPoolForCluster(const std::string& cluster, ResourceP Tcp::ConnectionPool::Instance* ClusterManagerImpl::tcpConnPoolForCluster(const std::string& cluster, ResourcePriority priority, LoadBalancerContext* context) { - ThreadLocalClusterManagerImpl& cluster_manager = tls_.get(); + ThreadLocalClusterManagerImpl& cluster_manager = *tls_; auto entry = cluster_manager.thread_local_clusters_.find(cluster); if (entry == cluster_manager.thread_local_clusters_.end()) { @@ -898,8 +898,8 @@ ClusterManagerImpl::tcpConnPoolForCluster(const std::string& cluster, ResourcePr void ClusterManagerImpl::postThreadLocalDrainConnections(const Cluster& cluster, const HostVector& hosts_removed) { tls_.runOnAllThreads([name = cluster.info()->name(), - hosts_removed](ThreadLocalClusterManagerImpl& cluster_manager) { - cluster_manager.removeHosts(name, hosts_removed); + hosts_removed](OptRef cluster_manager) { + cluster_manager->removeHosts(name, hosts_removed); }); } @@ -912,21 +912,21 @@ void ClusterManagerImpl::postThreadLocalClusterUpdate(const Cluster& cluster, ui update_params = HostSetImpl::updateHostsParams(*host_set), locality_weights = host_set->localityWeights(), hosts_added, hosts_removed, overprovisioning_factor = host_set->overprovisioningFactor()]( - ThreadLocalClusterManagerImpl& cluster_manager) { - cluster_manager.updateClusterMembership(name, priority, update_params, locality_weights, - hosts_added, hosts_removed, overprovisioning_factor); + OptRef cluster_manager) { + cluster_manager->updateClusterMembership(name, priority, update_params, locality_weights, + hosts_added, hosts_removed, overprovisioning_factor); }); } void ClusterManagerImpl::postThreadLocalHealthFailure(const HostSharedPtr& host) { - tls_.runOnAllThreads([host](ThreadLocalClusterManagerImpl& cluster_manager) { - cluster_manager.onHostHealthFailure(host); + tls_.runOnAllThreads([host](OptRef cluster_manager) { + cluster_manager->onHostHealthFailure(host); }); } Host::CreateConnectionData ClusterManagerImpl::tcpConnForCluster(const std::string& cluster, LoadBalancerContext* context) { - ThreadLocalClusterManagerImpl& cluster_manager = tls_.get(); + ThreadLocalClusterManagerImpl& cluster_manager = *tls_; auto entry = cluster_manager.thread_local_clusters_.find(cluster); if (entry == cluster_manager.thread_local_clusters_.end()) { @@ -954,7 +954,7 @@ Host::CreateConnectionData ClusterManagerImpl::tcpConnForCluster(const std::stri } Http::AsyncClient& ClusterManagerImpl::httpAsyncClientForCluster(const std::string& cluster) { - ThreadLocalClusterManagerImpl& cluster_manager = tls_.get(); + ThreadLocalClusterManagerImpl& cluster_manager = *tls_; auto entry = cluster_manager.thread_local_clusters_.find(cluster); if (entry != cluster_manager.thread_local_clusters_.end()) { return entry->second->http_async_client_; @@ -965,7 +965,7 @@ Http::AsyncClient& ClusterManagerImpl::httpAsyncClientForCluster(const std::stri ClusterUpdateCallbacksHandlePtr ClusterManagerImpl::addThreadLocalClusterUpdateCallbacks(ClusterUpdateCallbacks& cb) { - ThreadLocalClusterManagerImpl& cluster_manager = tls_.get(); + ThreadLocalClusterManagerImpl& cluster_manager = *tls_; return std::make_unique(cb, cluster_manager.update_callbacks_); } diff --git a/source/extensions/clusters/aggregate/cluster.cc b/source/extensions/clusters/aggregate/cluster.cc index 52d99b036f9a..c630a580a49a 100644 --- a/source/extensions/clusters/aggregate/cluster.cc +++ b/source/extensions/clusters/aggregate/cluster.cc @@ -91,7 +91,8 @@ void Cluster::startPreInit() { void Cluster::refresh(const std::function& skip_predicate) { // Post the priority set to worker threads. // TODO(mattklein123): Remove "this" capture. - tls_.runOnAllThreads([this, skip_predicate, cluster_name = this->info()->name()]() { + tls_.runOnAllThreads([this, skip_predicate, cluster_name = this->info()->name()]( + OptRef) { PriorityContextPtr priority_context = linearizePrioritySet(skip_predicate); Upstream::ThreadLocalCluster* cluster = cluster_manager_.get(cluster_name); ASSERT(cluster != nullptr); diff --git a/source/extensions/common/dynamic_forward_proxy/dns_cache_impl.cc b/source/extensions/common/dynamic_forward_proxy/dns_cache_impl.cc index 5d5e39e4b68b..ac09ab220f72 100644 --- a/source/extensions/common/dynamic_forward_proxy/dns_cache_impl.cc +++ b/source/extensions/common/dynamic_forward_proxy/dns_cache_impl.cc @@ -55,7 +55,7 @@ DnsCacheImpl::LoadDnsCacheEntryResult DnsCacheImpl::loadDnsCacheEntry(absl::string_view host, uint16_t default_port, LoadDnsCacheEntryCallbacks& callbacks) { ENVOY_LOG(debug, "thread local lookup for host '{}'", host); - ThreadLocalHostInfo& tls_host_info = tls_slot_.get(); + ThreadLocalHostInfo& tls_host_info = *tls_slot_; auto tls_host = tls_host_info.host_map_->find(host); if (tls_host != tls_host_info.host_map_->end()) { ENVOY_LOG(debug, "thread local hit for host '{}'", host); @@ -275,8 +275,8 @@ void DnsCacheImpl::updateTlsHostsMap() { } } - tls_slot_.runOnAllThreads([new_host_map](ThreadLocalHostInfo& local_host_info) { - local_host_info.updateHostMap(new_host_map); + tls_slot_.runOnAllThreads([new_host_map](OptRef local_host_info) { + local_host_info->updateHostMap(new_host_map); }); } diff --git a/source/extensions/common/tap/extension_config_base.cc b/source/extensions/common/tap/extension_config_base.cc index d5538bff21fb..04a6fb73d20b 100644 --- a/source/extensions/common/tap/extension_config_base.cc +++ b/source/extensions/common/tap/extension_config_base.cc @@ -58,15 +58,16 @@ const absl::string_view ExtensionConfigBase::adminId() { void ExtensionConfigBase::clearTapConfig() { tls_slot_.runOnAllThreads( - [](TlsFilterConfig& tls_filter_config) { tls_filter_config.config_ = nullptr; }); + [](OptRef tls_filter_config) { tls_filter_config->config_ = nullptr; }); } void ExtensionConfigBase::installNewTap(const envoy::config::tap::v3::TapConfig& proto_config, Sink* admin_streamer) { TapConfigSharedPtr new_config = config_factory_->createConfigFromProto(proto_config, admin_streamer); - tls_slot_.runOnAllThreads( - [new_config](TlsFilterConfig& tls_filter_config) { tls_filter_config.config_ = new_config; }); + tls_slot_.runOnAllThreads([new_config](OptRef tls_filter_config) { + tls_filter_config->config_ = new_config; + }); } void ExtensionConfigBase::newTapConfig(const envoy::config::tap::v3::TapConfig& proto_config, diff --git a/source/extensions/filters/common/lua/lua.cc b/source/extensions/filters/common/lua/lua.cc index e1fea1f96d38..f23f968c9e7f 100644 --- a/source/extensions/filters/common/lua/lua.cc +++ b/source/extensions/filters/common/lua/lua.cc @@ -65,20 +65,20 @@ ThreadLocalState::ThreadLocalState(const std::string& code, ThreadLocal::SlotAll } int ThreadLocalState::getGlobalRef(uint64_t slot) { - LuaThreadLocal& tls = tls_slot_->get(); + LuaThreadLocal& tls = **tls_slot_; ASSERT(tls.global_slots_.size() > slot); return tls.global_slots_[slot]; } uint64_t ThreadLocalState::registerGlobal(const std::string& global) { - tls_slot_->runOnAllThreads([global](LuaThreadLocal& tls) { - lua_getglobal(tls.state_.get(), global.c_str()); - if (lua_isfunction(tls.state_.get(), -1)) { - tls.global_slots_.push_back(luaL_ref(tls.state_.get(), LUA_REGISTRYINDEX)); + tls_slot_->runOnAllThreads([global](OptRef tls) { + lua_getglobal(tls->state_.get(), global.c_str()); + if (lua_isfunction(tls->state_.get(), -1)) { + tls->global_slots_.push_back(luaL_ref(tls->state_.get(), LUA_REGISTRYINDEX)); } else { ENVOY_LOG(debug, "definition for '{}' not found in script", global); - lua_pop(tls.state_.get(), 1); - tls.global_slots_.push_back(LUA_REFNIL); + lua_pop(tls->state_.get(), 1); + tls->global_slots_.push_back(LUA_REFNIL); } }); @@ -86,7 +86,7 @@ uint64_t ThreadLocalState::registerGlobal(const std::string& global) { } CoroutinePtr ThreadLocalState::createCoroutine() { - lua_State* state = tls_slot_->get().state_.get(); + lua_State* state = tlsState().get(); return std::make_unique(std::make_pair(lua_newthread(state), state)); } diff --git a/source/extensions/filters/common/lua/lua.h b/source/extensions/filters/common/lua/lua.h index ac84ac14af7e..6112df91b0de 100644 --- a/source/extensions/filters/common/lua/lua.h +++ b/source/extensions/filters/common/lua/lua.h @@ -386,22 +386,23 @@ class ThreadLocalState : Logger::Loggable { * all threaded workers. */ template void registerType() { - tls_slot_->runOnAllThreads([](LuaThreadLocal& tls) { T::registerType(tls.state_.get()); }); + tls_slot_->runOnAllThreads( + [](OptRef tls) { T::registerType(tls->state_.get()); }); } /** * Return the number of bytes used by the runtime. */ uint64_t runtimeBytesUsed() { - uint64_t bytes_used = lua_gc(tls_slot_->get().state_.get(), LUA_GCCOUNT, 0) * 1024; - bytes_used += lua_gc(tls_slot_->get().state_.get(), LUA_GCCOUNTB, 0); + uint64_t bytes_used = lua_gc(tlsState().get(), LUA_GCCOUNT, 0) * 1024; + bytes_used += lua_gc(tlsState().get(), LUA_GCCOUNTB, 0); return bytes_used; } /** * Force a full runtime GC. */ - void runtimeGC() { lua_gc(tls_slot_->get().state_.get(), LUA_GCCOLLECT, 0); } + void runtimeGC() { lua_gc(tlsState().get(), LUA_GCCOLLECT, 0); } private: struct LuaThreadLocal : public ThreadLocal::ThreadLocalObject { @@ -411,6 +412,8 @@ class ThreadLocalState : Logger::Loggable { std::vector global_slots_; }; + CSmartPtr& tlsState() { return (*tls_slot_)->state_; } + ThreadLocal::TypedSlotPtr tls_slot_; uint64_t current_global_slot_{}; }; diff --git a/source/extensions/filters/http/admission_control/admission_control.h b/source/extensions/filters/http/admission_control/admission_control.h index 79a903d361d1..7b4e83de80c7 100644 --- a/source/extensions/filters/http/admission_control/admission_control.h +++ b/source/extensions/filters/http/admission_control/admission_control.h @@ -59,7 +59,7 @@ class AdmissionControlFilterConfig { std::shared_ptr response_evaluator); virtual ~AdmissionControlFilterConfig() = default; - virtual ThreadLocalController& getController() const { return tls_->get(); } + virtual ThreadLocalController& getController() const { return **tls_; } Random::RandomGenerator& random() const { return random_; } bool filterEnabled() const { return admission_control_feature_.enabled(); } diff --git a/source/server/overload_manager_impl.cc b/source/server/overload_manager_impl.cc index c128caf41d12..db5ced97e41d 100644 --- a/source/server/overload_manager_impl.cc +++ b/source/server/overload_manager_impl.cc @@ -384,7 +384,7 @@ bool OverloadManagerImpl::registerForAction(const std::string& action, return true; } -ThreadLocalOverloadState& OverloadManagerImpl::getThreadLocalOverloadState() { return tls_.get(); } +ThreadLocalOverloadState& OverloadManagerImpl::getThreadLocalOverloadState() { return *tls_; } Event::ScaledRangeTimerManagerPtr OverloadManagerImpl::createScaledRangeTimerManager(Event::Dispatcher& dispatcher) const { @@ -442,9 +442,9 @@ void OverloadManagerImpl::flushResourceUpdates() { std::swap(*shared_updates, state_updates_to_flush_); tls_.runOnAllThreads( - [updates = std::move(shared_updates)](ThreadLocalOverloadStateImpl& overload_state) { + [updates = std::move(shared_updates)](OptRef overload_state) { for (const auto& [action, state] : *updates) { - overload_state.setState(action, state); + overload_state->setState(action, state); } }); } diff --git a/test/common/common/BUILD b/test/common/common/BUILD index 20392d66be19..8f9ec5324dc8 100644 --- a/test/common/common/BUILD +++ b/test/common/common/BUILD @@ -187,6 +187,11 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "optref_test", + srcs = ["optref_test.cc"], +) + envoy_cc_test( name = "random_generator_test", srcs = ["random_generator_test.cc"], diff --git a/test/common/common/optref_test.cc b/test/common/common/optref_test.cc new file mode 100644 index 000000000000..343d4506bba0 --- /dev/null +++ b/test/common/common/optref_test.cc @@ -0,0 +1,37 @@ +#include + +#include "envoy/common/optref.h" + +#include "gtest/gtest.h" + +namespace Envoy { + +// Helper function for returning the string reference from an OptRef. Calling +// value() inline at the EXPECT_EQ callsites does not compile due to template +// specialization ambiguities, that this wrapper resolves. +static std::string& strref(const OptRef optref) { return optref.value(); } + +TEST(OptRefTest, Empty) { + OptRef optref; + EXPECT_FALSE(optref.has_value()); +} + +TEST(OptRefTest, NonConst) { + std::string str("Hello"); + OptRef optref(str); + EXPECT_TRUE(optref.has_value()); + EXPECT_EQ("Hello", strref(optref)); + EXPECT_EQ(5, optref->size()); + optref->append(", World!"); + EXPECT_EQ("Hello, World!", strref(optref)); +} + +TEST(OptRefTest, Const) { + std::string str("Hello"); + const OptRef optref(str); + EXPECT_TRUE(optref.has_value()); + EXPECT_EQ("Hello", strref(optref)); + EXPECT_EQ(5, optref->size()); +} + +} // namespace Envoy diff --git a/test/common/stats/thread_local_store_test.cc b/test/common/stats/thread_local_store_test.cc index 9e97d323d43d..71bf69be9df7 100644 --- a/test/common/stats/thread_local_store_test.cc +++ b/test/common/stats/thread_local_store_test.cc @@ -55,8 +55,8 @@ class ThreadLocalStoreTestingPeer { const std::function& num_tls_hist_cb) { auto num_tls_histograms = std::make_shared>(0); thread_local_store_impl.tls_cache_->runOnAllThreads( - [num_tls_histograms](ThreadLocalStoreImpl::TlsCache& tls_cache) { - *num_tls_histograms += tls_cache.tls_histogram_cache_.size(); + [num_tls_histograms](OptRef tls_cache) { + *num_tls_histograms += tls_cache->tls_histogram_cache_.size(); }, [num_tls_hist_cb, num_tls_histograms]() { num_tls_hist_cb(*num_tls_histograms); }); } diff --git a/test/common/thread_local/thread_local_impl_test.cc b/test/common/thread_local/thread_local_impl_test.cc index 9aa1d65734be..59bdd6d0080b 100644 --- a/test/common/thread_local/thread_local_impl_test.cc +++ b/test/common/thread_local/thread_local_impl_test.cc @@ -138,32 +138,23 @@ class CallbackNotInvokedAfterDeletionTest : public ThreadLocalInstanceImplTest { ThreadStatus thread_status_; }; -TEST_F(CallbackNotInvokedAfterDeletionTest, WithArg) { +TEST_F(CallbackNotInvokedAfterDeletionTest, WithData) { InSequence s; - slot_->runOnAllThreads([this](ThreadLocalObject&) { + slot_->runOnAllThreads([this](OptRef obj) { + EXPECT_TRUE(obj.has_value()); // Callbacks happen on the main thread but not the workers, so track the total. total_callbacks_++; }); - slot_->runOnAllThreads([this](ThreadLocalObject&) { ++thread_status_.thread_local_calls_; }, - [this]() { - // Callbacks happen on the main thread but not the workers. - EXPECT_EQ(thread_status_.thread_local_calls_, 1); - thread_status_.all_threads_complete_ = true; - }); -} - -TEST_F(CallbackNotInvokedAfterDeletionTest, WithoutArg) { - InSequence s; - slot_->runOnAllThreads([this]() { - // Callbacks happen on the main thread but not the workers, so track the total. - total_callbacks_++; - }); - slot_->runOnAllThreads([this]() { ++thread_status_.thread_local_calls_; }, - [this]() { - // Callbacks happen on the main thread but not the workers. - EXPECT_EQ(thread_status_.thread_local_calls_, 1); - thread_status_.all_threads_complete_ = true; - }); + slot_->runOnAllThreads( + [this](OptRef obj) { + EXPECT_TRUE(obj.has_value()); + ++thread_status_.thread_local_calls_; + }, + [this]() { + // Callbacks happen on the main thread but not the workers. + EXPECT_EQ(thread_status_.thread_local_calls_, 1); + thread_status_.all_threads_complete_ = true; + }); } // Test that the update callback is called as expected, for the worker and main threads. @@ -175,7 +166,7 @@ TEST_F(ThreadLocalInstanceImplTest, UpdateCallback) { uint32_t update_called = 0; TestThreadLocalObject& object_ref = setObject(slot); - auto update_cb = [&update_called](ThreadLocalObject&) { ++update_called; }; + auto update_cb = [&update_called](OptRef) { ++update_called; }; EXPECT_CALL(thread_dispatcher_, post(_)); EXPECT_CALL(object_ref, onDestroy()); slot.runOnAllThreads(update_cb); @@ -192,7 +183,6 @@ struct StringSlotObject : public ThreadLocalObject { TEST_F(ThreadLocalInstanceImplTest, TypedUpdateCallback) { InSequence s; - TypedSlot slot(tls_); uint32_t update_called = 0; @@ -202,16 +192,43 @@ TEST_F(ThreadLocalInstanceImplTest, TypedUpdateCallback) { s->str_ = "hello"; return s; }); - EXPECT_EQ("hello", slot.get().str_); + EXPECT_EQ("hello", slot.get()->str_); + + auto update_cb = [&update_called](OptRef s) { + ++update_called; + EXPECT_TRUE(s.has_value()); + s->str_ = "goodbye"; + }; + EXPECT_CALL(thread_dispatcher_, post(_)); + slot.runOnAllThreads(update_cb); + + // Tests a few different ways of getting at the slot data. + EXPECT_EQ("goodbye", slot.get()->str_); + EXPECT_EQ("goodbye", slot->str_); + EXPECT_EQ("goodbye", (*slot).str_); + EXPECT_EQ(2, update_called); // 1 worker, 1 main thread. + + tls_.shutdownGlobalThreading(); + tls_.shutdownThread(); +} + +TEST_F(ThreadLocalInstanceImplTest, NoDataCallback) { + InSequence s; + TypedSlot slot(tls_); + + uint32_t update_called = 0; + EXPECT_CALL(thread_dispatcher_, post(_)); + slot.set([](Event::Dispatcher&) -> std::shared_ptr { return nullptr; }); + EXPECT_FALSE(slot.get().has_value()); - auto update_cb = [&update_called](StringSlotObject& s) { + auto update_cb = [&update_called](OptRef s) { ++update_called; - s.str_ = "goodbye"; + EXPECT_FALSE(s.has_value()); }; EXPECT_CALL(thread_dispatcher_, post(_)); slot.runOnAllThreads(update_cb); - EXPECT_EQ("goodbye", slot.get().str_); + EXPECT_FALSE(slot.get().has_value()); EXPECT_EQ(2, update_called); // 1 worker, 1 main thread. tls_.shutdownGlobalThreading(); @@ -232,7 +249,7 @@ TEST_F(ThreadLocalInstanceImplTest, RunOnAllThreads) { // Ensure that the thread local call back and all_thread_complete call back are called. ThreadStatus thread_status; tlsptr->runOnAllThreads( - [&thread_status](ThreadLocal::ThreadLocalObject&) { ++thread_status.thread_local_calls_; }, + [&thread_status](OptRef) { ++thread_status.thread_local_calls_; }, [&thread_status]() { EXPECT_EQ(thread_status.thread_local_calls_, 2); thread_status.all_threads_complete_ = true; diff --git a/test/mocks/thread_local/mocks.h b/test/mocks/thread_local/mocks.h index 6c5f7a9a3450..7b3097f5e0fa 100644 --- a/test/mocks/thread_local/mocks.h +++ b/test/mocks/thread_local/mocks.h @@ -66,10 +66,6 @@ class MockInstance : public Instance { void runOnAllThreads(const UpdateCb& cb, const Event::PostCb& main_callback) override { parent_.runOnAllThreads([cb, this]() { cb(parent_.data_[index_]); }, main_callback); } - void runOnAllThreads(const Event::PostCb& cb) override { parent_.runOnAllThreads(cb); } - void runOnAllThreads(const Event::PostCb& cb, const Event::PostCb& main_callback) override { - parent_.runOnAllThreads(cb, main_callback); - } void set(InitializeCb cb) override { if (parent_.defer_data) {