diff --git a/src/vt/vrt/collection/balance/baselb/baselb.cc b/src/vt/vrt/collection/balance/baselb/baselb.cc index b10ac48603..f13963d4da 100644 --- a/src/vt/vrt/collection/balance/baselb/baselb.cc +++ b/src/vt/vrt/collection/balance/baselb/baselb.cc @@ -108,9 +108,9 @@ void BaseLB::importProcessorData( vt_debug_print_verbose( lb, node, - "\t {}: importProcessorData: this_load={}, obj={}, load={}, " + "\t {}: importProcessorData: this_load={}, obj={}, home={}, load={}, " "load_milli={}, bin={}\n", - this_node, this_load, obj, load, load_milli, bin + this_node, this_load, obj.id, obj.home_node, load, load_milli, bin ); } @@ -191,8 +191,8 @@ void BaseLB::applyMigrations(TransferVecType const &transfers) { vt_debug_print( lb, node, - "migrateObjectTo, obj_id={}, from={}, to={}, found={}\n", - obj_id, from, to, has_object + "migrateObjectTo, obj_id={}, home={}, from={}, to={}, found={}\n", + obj_id.id, obj_id.home_node, from, to, has_object ); local_migration_count_++; diff --git a/src/vt/vrt/collection/balance/baselb/baselb.h b/src/vt/vrt/collection/balance/baselb/baselb.h index cf3e546ff9..1a9bd24583 100644 --- a/src/vt/vrt/collection/balance/baselb/baselb.h +++ b/src/vt/vrt/collection/balance/baselb/baselb.h @@ -64,7 +64,7 @@ namespace vt { namespace vrt { namespace collection { namespace lb { static constexpr int32_t const default_bin_size = 10; struct BaseLB { - using ObjIDType = balance::ElementIDType; + using ObjIDType = balance::ElementIDStruct; using ObjBinType = int32_t; using ObjBinListType = std::list; using ObjSampleType = std::map; diff --git a/src/vt/vrt/collection/balance/elm_stats.cc b/src/vt/vrt/collection/balance/elm_stats.cc index 98ba37582d..7861a6eb3f 100644 --- a/src/vt/vrt/collection/balance/elm_stats.cc +++ b/src/vt/vrt/collection/balance/elm_stats.cc @@ -88,26 +88,26 @@ void ElementStats::recvComm( } void ElementStats::recvObjData( - ElementIDType pto, ElementIDType tto, - ElementIDType pfrom, ElementIDType tfrom, double bytes, bool bcast + ElementIDStruct pto, + ElementIDStruct pfrom, double bytes, bool bcast ) { - LBCommKey key(LBCommKey::CollectionTag{}, pfrom, tfrom, pto, tto, bcast); + LBCommKey key(LBCommKey::CollectionTag{}, pfrom, pto, bcast); recvComm(key, bytes); } void ElementStats::recvFromNode( - ElementIDType pto, ElementIDType tto, NodeType from, + ElementIDStruct pto, NodeType from, double bytes, bool bcast ) { - LBCommKey key(LBCommKey::NodeToCollectionTag{}, from, pto, tto, bcast); + LBCommKey key(LBCommKey::NodeToCollectionTag{}, from, pto, bcast); recvComm(key, bytes); } void ElementStats::recvToNode( - NodeType to, ElementIDType pfrom, ElementIDType tfrom, + NodeType to, ElementIDStruct pfrom, double bytes, bool bcast ) { - LBCommKey key(LBCommKey::CollectionToNodeTag{}, pfrom, tfrom, to, bcast); + LBCommKey key(LBCommKey::CollectionToNodeTag{}, pfrom, to, bcast); recvComm(key, bytes); } diff --git a/src/vt/vrt/collection/balance/elm_stats.h b/src/vt/vrt/collection/balance/elm_stats.h index e9fb9bacae..f81a8a17f7 100644 --- a/src/vt/vrt/collection/balance/elm_stats.h +++ b/src/vt/vrt/collection/balance/elm_stats.h @@ -71,15 +71,15 @@ struct ElementStats { void addTime(TimeType const& time); void recvComm(LBCommKey key, double bytes); void recvObjData( - ElementIDType to_perm, ElementIDType to_temp, - ElementIDType from_perm, ElementIDType from_temp, double bytes, bool bcast + ElementIDStruct to_perm, + ElementIDStruct from_perm, double bytes, bool bcast ); void recvFromNode( - ElementIDType to_perm, ElementIDType to_temp, NodeType from, + ElementIDStruct to_perm, NodeType from, double bytes, bool bcast ); void recvToNode( - NodeType to, ElementIDType from_perm, ElementIDType from_temp, + NodeType to, ElementIDStruct from_perm, double bytes, bool bcast ); void updatePhase(PhaseType const& inc = 1); diff --git a/src/vt/vrt/collection/balance/gossiplb/gossiplb.cc b/src/vt/vrt/collection/balance/gossiplb/gossiplb.cc index 23d46baaeb..6294951c76 100644 --- a/src/vt/vrt/collection/balance/gossiplb/gossiplb.cc +++ b/src/vt/vrt/collection/balance/gossiplb/gossiplb.cc @@ -420,12 +420,13 @@ void GossipLB::decide() { vt_debug_print( gossiplb, node, "GossipLB::decide: under.size()={}, selected_node={}, selected_load={}," - "obj_id={:x}, obj_load={}, avg={}, this_new_load_={}, " + "obj_id={:x}, home={}, obj_load={}, avg={}, this_new_load_={}, " "criterion={}\n", under.size(), selected_node, selected_load, - obj_id, + obj_id.id, + obj_id.home_node, obj_load, avg, this_new_load_, diff --git a/src/vt/vrt/collection/balance/greedylb/greedylb.cc b/src/vt/vrt/collection/balance/greedylb/greedylb.cc index dff057e621..26075ab14c 100644 --- a/src/vt/vrt/collection/balance/greedylb/greedylb.cc +++ b/src/vt/vrt/collection/balance/greedylb/greedylb.cc @@ -210,8 +210,9 @@ void GreedyLB::runBalancer( GreedyLB::ObjIDType GreedyLB::objSetNode( NodeType const& node, ObjIDType const& id ) { - auto const new_id = id & 0xFFFFFFFF0000000; - return new_id | node; + auto new_id = id; + new_id.curr_node = node; + return new_id; } void GreedyLB::recvObjsDirect(GreedyLBTypes::ObjIDType* objs) { @@ -223,7 +224,7 @@ void GreedyLB::recvObjsDirect(GreedyLBTypes::ObjIDType* objs) { "recvObjsDirect: num_recs={}\n", num_recs ); - for (decltype(+num_recs) i = 0; i < num_recs; i++) { + for (decltype(+num_recs.id) i = 0; i < num_recs.id; i++) { auto const to_node = objGetNode(recs[i]); auto const new_obj_id = objSetNode(this_node,recs[i]); vt_debug_print( @@ -274,7 +275,7 @@ void GreedyLB::transferObjs(std::vector&& in_load) { auto ptr_out = reinterpret_cast(ptr); auto const& proc = node_transfer[node]; auto const& rec_size = proc.size(); - *ptr_out = rec_size; + ptr_out->id = rec_size; for (size_t i = 0; i < rec_size; i++) { *(ptr_out + i + 1) = proc[i]; } diff --git a/src/vt/vrt/collection/balance/greedylb/greedylb_types.h b/src/vt/vrt/collection/balance/greedylb/greedylb_types.h index 2b5252100f..38e7415554 100644 --- a/src/vt/vrt/collection/balance/greedylb/greedylb_types.h +++ b/src/vt/vrt/collection/balance/greedylb/greedylb_types.h @@ -56,7 +56,7 @@ namespace vt { namespace vrt { namespace collection { namespace lb { struct GreedyLBTypes { - using ObjIDType = balance::ElementIDType; + using ObjIDType = balance::ElementIDStruct; using ObjBinType = int32_t; using ObjBinListType = std::list; using ObjSampleType = std::map; @@ -76,7 +76,9 @@ struct GreedyRecord { ObjType getObj() const { return obj_; } private: - GreedyLBTypes::ObjIDType obj_ = 0; + GreedyLBTypes::ObjIDType obj_ = { + 0, uninitialized_destination, uninitialized_destination + }; LoadType load_ = 0.0f; }; diff --git a/src/vt/vrt/collection/balance/hierarchicallb/hierlb_types.h b/src/vt/vrt/collection/balance/hierarchicallb/hierlb_types.h index 98c9efeb12..cc970b0974 100644 --- a/src/vt/vrt/collection/balance/hierarchicallb/hierlb_types.h +++ b/src/vt/vrt/collection/balance/hierarchicallb/hierlb_types.h @@ -56,7 +56,7 @@ namespace vt { namespace vrt { namespace collection { namespace lb { struct HierLBTypes { - using ObjIDType = balance::ElementIDType; + using ObjIDType = balance::ElementIDStruct; using ObjBinType = int32_t; using ObjBinListType = std::list; using ObjSampleType = std::map; diff --git a/src/vt/vrt/collection/balance/lb_comm.h b/src/vt/vrt/collection/balance/lb_comm.h index c71a43ebec..15ac2c68f8 100644 --- a/src/vt/vrt/collection/balance/lb_comm.h +++ b/src/vt/vrt/collection/balance/lb_comm.h @@ -62,8 +62,8 @@ enum struct CommCategory : int8_t { CollectiveToCollectionBcast = 7 }; -inline NodeType objGetNode(ElementIDType const id) { - return id & 0x0000000FFFFFFFF; +inline NodeType objGetNode(ElementIDStruct const id) { + return id.curr_node; } struct LBCommKey { @@ -79,40 +79,36 @@ struct LBCommKey { LBCommKey( CollectionTag, - ElementIDType from, ElementIDType from_temp, - ElementIDType to, ElementIDType to_temp, + ElementIDStruct from, ElementIDStruct to, bool bcast - ) : from_(from), from_temp_(from_temp), to_(to), to_temp_(to_temp), + ) : from_(from), to_(to), cat_(bcast ? CommCategory::Broadcast : CommCategory::SendRecv) { } LBCommKey( CollectionToNodeTag, - ElementIDType from, ElementIDType from_temp, NodeType to, + ElementIDStruct from, NodeType to, bool bcast - ) : from_(from), from_temp_(from_temp), nto_(to), + ) : from_(from), nto_(to), cat_(bcast ? CommCategory::CollectionToNodeBcast : CommCategory::CollectionToNode) { } LBCommKey( NodeToCollectionTag, - NodeType from, ElementIDType to, ElementIDType to_temp, + NodeType from, ElementIDStruct to, bool bcast - ) : to_(to), to_temp_(to_temp), nfrom_(from), + ) : to_(to), nfrom_(from), cat_(bcast ? CommCategory::NodeToCollectionBcast : CommCategory::NodeToCollection) { } - ElementIDType from_ = no_element_id; - ElementIDType from_temp_ = no_element_id; - ElementIDType to_ = no_element_id; - ElementIDType to_temp_ = no_element_id; + ElementIDStruct from_ = { no_element_id, uninitialized_destination, uninitialized_destination }; + ElementIDStruct to_ = { no_element_id, uninitialized_destination, uninitialized_destination }; + ElementIDType edge_id_ = no_element_id; NodeType nfrom_ = uninitialized_destination; NodeType nto_ = uninitialized_destination; CommCategory cat_ = CommCategory::SendRecv; - ElementIDType fromObj() const { return from_; } - ElementIDType toObj() const { return to_; } - ElementIDType fromObjTemp() const { return from_temp_; } - ElementIDType toObjTemp() const { return to_temp_; } + ElementIDStruct fromObj() const { return from_; } + ElementIDStruct toObj() const { return to_; } ElementIDType fromNode() const { return nfrom_; } ElementIDType toNode() const { return nto_; } ElementIDType edgeID() const { return edge_id_; } @@ -120,11 +116,11 @@ struct LBCommKey { bool selfEdge() const { return cat_ == CommCategory::SendRecv and from_ == to_; } bool offNode() const { if (cat_ == CommCategory::SendRecv) { - return objGetNode(from_temp_) != objGetNode(to_temp_); + return objGetNode(from_) != objGetNode(to_); } else if (cat_ == CommCategory::CollectionToNode) { - return objGetNode(from_temp_) != nto_; + return objGetNode(from_) != nto_; } else if (cat_ == CommCategory::NodeToCollection) { - return objGetNode(to_temp_) != nfrom_; + return objGetNode(to_) != nfrom_; } else { return true; } @@ -140,7 +136,7 @@ struct LBCommKey { template void serialize(SerializerT& s) { - s | from_ | to_ | from_temp_ | to_temp_ | nfrom_ | nto_ | cat_ | edge_id_; + s | from_ | to_ | nfrom_ | nto_ | cat_ | edge_id_; } }; @@ -174,8 +170,9 @@ using CommMapType = std::unordered_map; namespace std { -using CommCategoryType = vt::vrt::collection::balance::CommCategory; -using LBCommKeyType = vt::vrt::collection::balance::LBCommKey; +using CommCategoryType = vt::vrt::collection::balance::CommCategory; +using LBCommKeyType = vt::vrt::collection::balance::LBCommKey; +using ElementIDStructType = vt::vrt::collection::balance::ElementIDStruct; template <> struct hash { @@ -189,7 +186,10 @@ struct hash { template <> struct hash { size_t operator()(LBCommKeyType const& in) const { - return std::hash()(in.from_ ^ in.to_ ^ in.nfrom_ ^ in.nto_); + return std::hash()( + std::hash()(in.from_) ^ + std::hash()(in.to_) ^ in.nfrom_ ^ in.nto_ + ); } }; diff --git a/src/vt/vrt/collection/balance/lb_common.cc b/src/vt/vrt/collection/balance/lb_common.cc new file mode 100644 index 0000000000..de38d2b2bc --- /dev/null +++ b/src/vt/vrt/collection/balance/lb_common.cc @@ -0,0 +1,59 @@ +/* +//@HEADER +// ***************************************************************************** +// +// lb_common.cc +// DARMA Toolkit v. 1.0.0 +// DARMA/vt => Virtual Transport +// +// Copyright 2019 National Technology & Engineering Solutions of Sandia, LLC +// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. +// Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact darma@sandia.gov +// +// ***************************************************************************** +//@HEADER +*/ + +#include "vt/config.h" +#include "vt/vrt/collection/balance/lb_common.h" + +#include + +namespace vt { namespace vrt { namespace collection { namespace balance { + +std::ostream& operator<<( + std::ostream& os, const ::vt::vrt::collection::balance::ElementIDStruct& id +) { + os << "(" << id.id << "," << id.home_node << "," << id.curr_node << ")"; + return os; +} + +}}}} /* end namespace vt::vrt::collection::balance */ diff --git a/src/vt/vrt/collection/balance/lb_common.h b/src/vt/vrt/collection/balance/lb_common.h index 421c78d1e9..703d5c4a97 100644 --- a/src/vt/vrt/collection/balance/lb_common.h +++ b/src/vt/vrt/collection/balance/lb_common.h @@ -50,6 +50,7 @@ #include #include +#include namespace vt { namespace vrt { namespace collection { namespace balance { @@ -78,8 +79,8 @@ std::ostream& operator<<( static constexpr ElementIDType const no_element_id = 0; -using LoadMapType = std::unordered_map; -using SubphaseLoadMapType = std::unordered_map>; +using LoadMapType = std::unordered_map; +using SubphaseLoadMapType = std::unordered_map>; } /* end namespace balance */ namespace lb { diff --git a/src/vt/vrt/collection/balance/model/comm_overhead.cc b/src/vt/vrt/collection/balance/model/comm_overhead.cc index 52b07e27dc..b4cfa7609b 100644 --- a/src/vt/vrt/collection/balance/model/comm_overhead.cc +++ b/src/vt/vrt/collection/balance/model/comm_overhead.cc @@ -62,7 +62,7 @@ void CommOverhead::setLoads(std::unordered_map const* pr ComposedModel::setLoads(proc_load, proc_subphase_load, proc_comm); } -TimeType CommOverhead::getWork(ElementIDType object, PhaseOffset offset) { +TimeType CommOverhead::getWork(ElementIDStruct object, PhaseOffset offset) { auto work = ComposedModel::getWork(object, offset); auto phase = getNumCompletedPhases() + offset.phases; @@ -71,7 +71,7 @@ TimeType CommOverhead::getWork(ElementIDType object, PhaseOffset offset) { TimeType overhead = 0.; for (auto&& c : comm) { // find messages that go off-node and are sent to this object - if (c.first.offNode() and c.first.toObjTemp() == object) { + if (c.first.offNode() and c.first.toObj() == object) { overhead += per_msg_weight_ * c.second.messages; overhead += per_byte_weight_ * c.second.bytes; } diff --git a/src/vt/vrt/collection/balance/model/comm_overhead.h b/src/vt/vrt/collection/balance/model/comm_overhead.h index 35ad9360f1..590c9c68a3 100644 --- a/src/vt/vrt/collection/balance/model/comm_overhead.h +++ b/src/vt/vrt/collection/balance/model/comm_overhead.h @@ -70,7 +70,7 @@ struct CommOverhead : public ComposedModel { std::unordered_map const* proc_subphase_load, std::unordered_map const* proc_comm) override; - TimeType getWork(ElementIDType object, PhaseOffset when) override; + TimeType getWork(ElementIDStruct object, PhaseOffset when) override; private: std::unordered_map const* proc_comm_; /**< Underlying comm data */ diff --git a/src/vt/vrt/collection/balance/model/composed_model.cc b/src/vt/vrt/collection/balance/model/composed_model.cc index 0efc7df909..64d2162220 100644 --- a/src/vt/vrt/collection/balance/model/composed_model.cc +++ b/src/vt/vrt/collection/balance/model/composed_model.cc @@ -56,7 +56,7 @@ void ComposedModel::updateLoads(PhaseType last_completed_phase) { base_->updateLoads(last_completed_phase); } -TimeType ComposedModel::getWork(ElementIDType object, PhaseOffset when) { +TimeType ComposedModel::getWork(ElementIDStruct object, PhaseOffset when) { return base_->getWork(object, when); } diff --git a/src/vt/vrt/collection/balance/model/composed_model.h b/src/vt/vrt/collection/balance/model/composed_model.h index fd99e27f78..e6cf67785d 100644 --- a/src/vt/vrt/collection/balance/model/composed_model.h +++ b/src/vt/vrt/collection/balance/model/composed_model.h @@ -71,7 +71,7 @@ class ComposedModel : public LoadModel void updateLoads(PhaseType last_completed_phase) override; - TimeType getWork(ElementIDType object, PhaseOffset when) override; + TimeType getWork(ElementIDStruct object, PhaseOffset when) override; unsigned int getNumPastPhasesNeeded(unsigned int look_back) override; ObjectIterator begin() override; diff --git a/src/vt/vrt/collection/balance/model/linear_model.cc b/src/vt/vrt/collection/balance/model/linear_model.cc index 7a5b56f180..99a78b3f18 100644 --- a/src/vt/vrt/collection/balance/model/linear_model.cc +++ b/src/vt/vrt/collection/balance/model/linear_model.cc @@ -49,7 +49,7 @@ namespace vt { namespace vrt { namespace collection { namespace balance { -TimeType LinearModel::getWork(ElementIDType object, PhaseOffset when) { +TimeType LinearModel::getWork(ElementIDStruct object, PhaseOffset when) { using util::stats::LinearRegression; // Retrospective queries don't call for a prediction diff --git a/src/vt/vrt/collection/balance/model/linear_model.h b/src/vt/vrt/collection/balance/model/linear_model.h index 18554d41ca..19f4532cb7 100644 --- a/src/vt/vrt/collection/balance/model/linear_model.h +++ b/src/vt/vrt/collection/balance/model/linear_model.h @@ -70,7 +70,7 @@ struct LinearModel : ComposedModel { past_len_(in_past_len) { } - TimeType getWork(ElementIDType object, PhaseOffset when) override; + TimeType getWork(ElementIDStruct object, PhaseOffset when) override; unsigned int getNumPastPhasesNeeded(unsigned int look_back) override; private: diff --git a/src/vt/vrt/collection/balance/model/load_model.h b/src/vt/vrt/collection/balance/model/load_model.h index caf47d254b..dacef8d328 100644 --- a/src/vt/vrt/collection/balance/model/load_model.h +++ b/src/vt/vrt/collection/balance/model/load_model.h @@ -134,7 +134,7 @@ class LoadModel * The `updateLoads` method must have been called before any call to * this. */ - virtual TimeType getWork(ElementIDType object, PhaseOffset when) = 0; + virtual TimeType getWork(ElementIDStruct object, PhaseOffset when) = 0; /** * \brief Compute how many phases of past load statistics need to be diff --git a/src/vt/vrt/collection/balance/model/multiple_phases.cc b/src/vt/vrt/collection/balance/model/multiple_phases.cc index e4c94f242c..224b33b85f 100644 --- a/src/vt/vrt/collection/balance/model/multiple_phases.cc +++ b/src/vt/vrt/collection/balance/model/multiple_phases.cc @@ -46,7 +46,7 @@ namespace vt { namespace vrt { namespace collection { namespace balance { -TimeType MultiplePhases::getWork(ElementIDType object, PhaseOffset when) { +TimeType MultiplePhases::getWork(ElementIDStruct object, PhaseOffset when) { // Retrospective queries don't call for a prediction if (when.phases < 0) return ComposedModel::getWork(object, when); diff --git a/src/vt/vrt/collection/balance/model/multiple_phases.h b/src/vt/vrt/collection/balance/model/multiple_phases.h index 6b1d6ed5d9..085d6d5467 100644 --- a/src/vt/vrt/collection/balance/model/multiple_phases.h +++ b/src/vt/vrt/collection/balance/model/multiple_phases.h @@ -80,7 +80,7 @@ struct MultiplePhases : ComposedModel { , future_phase_block_size_(in_future_phase_block_size) { } - TimeType getWork(ElementIDType object, PhaseOffset when) override; + TimeType getWork(ElementIDStruct object, PhaseOffset when) override; private: int future_phase_block_size_ = 0; diff --git a/src/vt/vrt/collection/balance/model/naive_persistence.cc b/src/vt/vrt/collection/balance/model/naive_persistence.cc index d45596d7e7..f79c75773b 100644 --- a/src/vt/vrt/collection/balance/model/naive_persistence.cc +++ b/src/vt/vrt/collection/balance/model/naive_persistence.cc @@ -51,7 +51,7 @@ NaivePersistence::NaivePersistence(std::shared_ptr base) : ComposedModel(base) { } -TimeType NaivePersistence::getWork(ElementIDType object, PhaseOffset offset) +TimeType NaivePersistence::getWork(ElementIDStruct object, PhaseOffset offset) { if (offset.phases >= 0) offset.phases = -1; diff --git a/src/vt/vrt/collection/balance/model/naive_persistence.h b/src/vt/vrt/collection/balance/model/naive_persistence.h index 251e51b2b9..9ecfc50f29 100644 --- a/src/vt/vrt/collection/balance/model/naive_persistence.h +++ b/src/vt/vrt/collection/balance/model/naive_persistence.h @@ -61,7 +61,7 @@ struct NaivePersistence : public ComposedModel { * \param[in] base: The source of underlying load numbers to return; must not be null */ explicit NaivePersistence(std::shared_ptr base); - TimeType getWork(ElementIDType object, PhaseOffset when) override; + TimeType getWork(ElementIDStruct object, PhaseOffset when) override; unsigned int getNumPastPhasesNeeded(unsigned int look_back) override; }; // class NaivePersistence diff --git a/src/vt/vrt/collection/balance/model/norm.cc b/src/vt/vrt/collection/balance/model/norm.cc index 79de55c980..1407d04fef 100644 --- a/src/vt/vrt/collection/balance/model/norm.cc +++ b/src/vt/vrt/collection/balance/model/norm.cc @@ -56,7 +56,7 @@ Norm::Norm(std::shared_ptr base, double power) vtAssert(power >= 0.0, "Reciprocal loads make no sense"); } -TimeType Norm::getWork(ElementIDType object, PhaseOffset offset) +TimeType Norm::getWork(ElementIDStruct object, PhaseOffset offset) { if (offset.subphase != PhaseOffset::WHOLE_PHASE) return ComposedModel::getWork(object, offset); diff --git a/src/vt/vrt/collection/balance/model/norm.h b/src/vt/vrt/collection/balance/model/norm.h index 9fdc9c189b..67ae770c35 100644 --- a/src/vt/vrt/collection/balance/model/norm.h +++ b/src/vt/vrt/collection/balance/model/norm.h @@ -65,7 +65,7 @@ class Norm : public ComposedModel { */ Norm(std::shared_ptr base, double power); - TimeType getWork(ElementIDType object, PhaseOffset when) override; + TimeType getWork(ElementIDStruct object, PhaseOffset when) override; private: const double power_; diff --git a/src/vt/vrt/collection/balance/model/per_collection.cc b/src/vt/vrt/collection/balance/model/per_collection.cc index 34f4bec8bf..68d994b8e4 100644 --- a/src/vt/vrt/collection/balance/model/per_collection.cc +++ b/src/vt/vrt/collection/balance/model/per_collection.cc @@ -70,7 +70,7 @@ void PerCollection::updateLoads(PhaseType last_completed_phase) { ComposedModel::updateLoads(last_completed_phase); } -TimeType PerCollection::getWork(ElementIDType object, PhaseOffset when) { +TimeType PerCollection::getWork(ElementIDStruct object, PhaseOffset when) { // See if some specific model has been given for the object in question auto mi = models_.find(theNodeStats()->getCollectionProxyForElement(object)); if (mi != models_.end()) diff --git a/src/vt/vrt/collection/balance/model/per_collection.h b/src/vt/vrt/collection/balance/model/per_collection.h index 33c843c324..54a434c5f7 100644 --- a/src/vt/vrt/collection/balance/model/per_collection.h +++ b/src/vt/vrt/collection/balance/model/per_collection.h @@ -80,7 +80,7 @@ struct PerCollection : public ComposedModel void updateLoads(PhaseType last_completed_phase) override; - TimeType getWork(ElementIDType object, PhaseOffset when) override; + TimeType getWork(ElementIDStruct object, PhaseOffset when) override; unsigned int getNumPastPhasesNeeded(unsigned int look_back) override; private: diff --git a/src/vt/vrt/collection/balance/model/persistence_median_last_n.cc b/src/vt/vrt/collection/balance/model/persistence_median_last_n.cc index 4a06f7d56d..845074918b 100644 --- a/src/vt/vrt/collection/balance/model/persistence_median_last_n.cc +++ b/src/vt/vrt/collection/balance/model/persistence_median_last_n.cc @@ -55,7 +55,7 @@ PersistenceMedianLastN::PersistenceMedianLastN(std::shared_ptr base, vtAssert(n > 0, "Cannot take a median over no phases"); } -TimeType PersistenceMedianLastN::getWork(ElementIDType object, PhaseOffset when) +TimeType PersistenceMedianLastN::getWork(ElementIDStruct object, PhaseOffset when) { // Retrospective queries don't call for a prospective calculation if (when.phases < 0) diff --git a/src/vt/vrt/collection/balance/model/persistence_median_last_n.h b/src/vt/vrt/collection/balance/model/persistence_median_last_n.h index 877a49cdcd..a1af7b6220 100644 --- a/src/vt/vrt/collection/balance/model/persistence_median_last_n.h +++ b/src/vt/vrt/collection/balance/model/persistence_median_last_n.h @@ -66,7 +66,7 @@ struct PersistenceMedianLastN : public ComposedModel */ PersistenceMedianLastN(std::shared_ptr base, unsigned int n); - TimeType getWork(ElementIDType object, PhaseOffset when) override; + TimeType getWork(ElementIDStruct object, PhaseOffset when) override; unsigned int getNumPastPhasesNeeded(unsigned int look_back) override; private: diff --git a/src/vt/vrt/collection/balance/model/raw_data.cc b/src/vt/vrt/collection/balance/model/raw_data.cc index 85e388358b..783aa6bf6d 100644 --- a/src/vt/vrt/collection/balance/model/raw_data.cc +++ b/src/vt/vrt/collection/balance/model/raw_data.cc @@ -83,7 +83,7 @@ int RawData::getNumSubphases() { return subphases.size(); } -TimeType RawData::getWork(ElementIDType object, PhaseOffset offset) +TimeType RawData::getWork(ElementIDStruct object, PhaseOffset offset) { vtAssert(offset.phases < 0, "RawData makes no predictions. Compose with NaivePersistence or some longer-range forecasting model as needed"); diff --git a/src/vt/vrt/collection/balance/model/raw_data.h b/src/vt/vrt/collection/balance/model/raw_data.h index 4537b3efc1..a54fe04a09 100644 --- a/src/vt/vrt/collection/balance/model/raw_data.h +++ b/src/vt/vrt/collection/balance/model/raw_data.h @@ -60,7 +60,7 @@ namespace vt { namespace vrt { namespace collection { namespace balance { struct RawData : public LoadModel { RawData() = default; void updateLoads(PhaseType last_completed_phase) override; - TimeType getWork(ElementIDType object, PhaseOffset when) override; + TimeType getWork(ElementIDStruct object, PhaseOffset when) override; void setLoads(std::unordered_map const* proc_load, std::unordered_map const* proc_subphase_load, diff --git a/src/vt/vrt/collection/balance/model/select_subphases.cc b/src/vt/vrt/collection/balance/model/select_subphases.cc index 0e2bde0e2f..606b46caff 100644 --- a/src/vt/vrt/collection/balance/model/select_subphases.cc +++ b/src/vt/vrt/collection/balance/model/select_subphases.cc @@ -59,7 +59,7 @@ SelectSubphases::SelectSubphases(std::shared_ptr base, std::vector base, std::vector subphases); - TimeType getWork(ElementIDType object, PhaseOffset when) override; + TimeType getWork(ElementIDStruct object, PhaseOffset when) override; int getNumSubphases() override; std::vector subphases_; diff --git a/src/vt/vrt/collection/balance/node_stats.cc b/src/vt/vrt/collection/balance/node_stats.cc index 3e34e05061..216e79a698 100644 --- a/src/vt/vrt/collection/balance/node_stats.cc +++ b/src/vt/vrt/collection/balance/node_stats.cc @@ -70,28 +70,12 @@ void NodeStats::setProxy(objgroup::proxy::Proxy in_proxy) { return ptr; } -ElementIDType NodeStats::tempToPerm(ElementIDType temp_id) const { - auto iter = node_temp_to_perm_.find(temp_id); - if (iter == node_temp_to_perm_.end()) { - return no_element_id; - } - return iter->second; -} - -ElementIDType NodeStats::permToTemp(ElementIDType perm_id) const { - auto iter = node_perm_to_temp_.find(perm_id); - if (iter == node_perm_to_temp_.end()) { - return no_element_id; - } - return iter->second; -} - -bool NodeStats::hasObjectToMigrate(ElementIDType obj_id) const { +bool NodeStats::hasObjectToMigrate(ElementIDStruct obj_id) const { auto iter = node_migrate_.find(obj_id); return iter != node_migrate_.end(); } -bool NodeStats::migrateObjTo(ElementIDType obj_id, NodeType to_node) { +bool NodeStats::migrateObjTo(ElementIDStruct obj_id, NodeType to_node) { auto iter = node_migrate_.find(obj_id); if (iter == node_migrate_.end()) { return false; @@ -126,26 +110,10 @@ void NodeStats::clearStats() { NodeStats::node_data_.clear(); NodeStats::node_subphase_data_.clear(); NodeStats::node_migrate_.clear(); - NodeStats::node_temp_to_perm_.clear(); - NodeStats::node_perm_to_temp_.clear(); next_elm_ = 1; } void NodeStats::startIterCleanup(PhaseType phase, unsigned int look_back) { - // TODO: Add in subphase support here too - - // Convert the temp ID node_data_ for the last iteration into perm ID for - // stats output - auto const prev_data = std::move(node_data_[phase]); - std::unordered_map new_data; - for (auto& elm : prev_data) { - auto iter = node_temp_to_perm_.find(elm.first); - vtAssert(iter != node_temp_to_perm_.end(), "Temp ID must exist"); - auto perm_id = iter->second; - new_data[perm_id] = elm.second; - } - node_data_[phase] = std::move(new_data); - if (phase >= look_back) { node_data_.erase(phase - look_back); node_subphase_data_.erase(phase - look_back); @@ -153,17 +121,16 @@ void NodeStats::startIterCleanup(PhaseType phase, unsigned int look_back) { node_subphase_comm_.erase(phase - look_back); } - // Create migrate lambdas and temp to perm map since LB is complete + // Clear migrate lambdas and proxy lookup since LB is complete NodeStats::node_migrate_.clear(); - NodeStats::node_temp_to_perm_.clear(); - NodeStats::node_perm_to_temp_.clear(); node_collection_lookup_.clear(); } -ElementIDType NodeStats::getNextElm() { +ElementIDStruct NodeStats::getNextElm() { auto const& this_node = theContext()->getNode(); - auto elm = next_elm_++; - return (elm << 32) | this_node; + auto id = (next_elm_++ << 32) | this_node; + ElementIDStruct elm{id, this_node, this_node}; + return elm; } void NodeStats::initialize() { @@ -227,15 +194,15 @@ getRecvSendDirection(CommKeyType const& comm) { switch (comm.cat_) { case CommCategory::SendRecv: case CommCategory::Broadcast: - return std::make_pair(comm.toObj(), comm.fromObj()); + return std::make_pair(comm.toObj().id, comm.fromObj().id); case CommCategory::NodeToCollection: case CommCategory::NodeToCollectionBcast: - return std::make_pair(comm.toObj(), comm.fromNode()); + return std::make_pair(comm.toObj().id, comm.fromNode()); case CommCategory::CollectionToNode: case CommCategory::CollectionToNodeBcast: - return std::make_pair(comm.toNode(), comm.fromObj()); + return std::make_pair(comm.toNode(), comm.fromObj().id); // Comm stats are not recorded for collective bcast // this case is just to avoid warning of not handled enum @@ -258,12 +225,12 @@ void NodeStats::outputStatsForPhase(PhaseType phase) { vt_print(lb, "NodeStats::outputStatsForPhase: phase={}\n", phase); for (auto&& elm : node_data_.at(phase)) { - ElementIDType id = elm.first; + ElementIDStruct id = elm.first; TimeType time = elm.second; const auto& subphase_times = node_subphase_data_.at(phase)[id]; size_t subphases = subphase_times.size(); - auto obj_str = fmt::format("{},{},{},{},[", phase, id, time, subphases); + auto obj_str = fmt::format("{},{},{},{},[", phase, id.id, time, subphases); for (size_t s = 0; s < subphases; s++) { if (s > 0) { @@ -294,38 +261,37 @@ void NodeStats::outputStatsForPhase(PhaseType phase) { fflush(stats_file_); } -ElementIDType NodeStats::addNodeStats( +ElementIDStruct NodeStats::addNodeStats( Migratable* col_elm, PhaseType const& phase, TimeType const& time, std::vector const& subphase_time, CommMapType const& comm, std::vector const& subphase_comm ) { - // A new temp ID gets assigned when a object is migrated into a node + // The ID struct is modified when a object is migrated into a node - auto const temp_id = col_elm->temp_elm_id_; - auto const perm_id = col_elm->stats_elm_id_; + auto const obj_id = col_elm->elm_id_; vt_debug_print( lb, node, - "NodeStats::addNodeStats: temp_id={}, perm_id={}, phase={}, subphases={}, load={}\n", - temp_id, perm_id, phase, subphase_time.size(), time + "NodeStats::addNodeStats: obj_id={}, phase={}, subphases={}, load={}\n", + obj_id, phase, subphase_time.size(), time ); auto &phase_data = node_data_[phase]; - auto elm_iter = phase_data.find(temp_id); + auto elm_iter = phase_data.find(obj_id); vtAssert(elm_iter == phase_data.end(), "Must not exist"); phase_data.emplace( std::piecewise_construct, - std::forward_as_tuple(temp_id), + std::forward_as_tuple(obj_id), std::forward_as_tuple(time) ); auto &subphase_data = node_subphase_data_[phase]; - auto elm_subphase_iter = subphase_data.find(temp_id); + auto elm_subphase_iter = subphase_data.find(obj_id); vtAssert(elm_subphase_iter == subphase_data.end(), "Must not exist"); subphase_data.emplace( std::piecewise_construct, - std::forward_as_tuple(temp_id), + std::forward_as_tuple(obj_id), std::forward_as_tuple(subphase_time) ); @@ -341,14 +307,11 @@ ElementIDType NodeStats::addNodeStats( } } - node_temp_to_perm_[temp_id] = perm_id; - node_perm_to_temp_[perm_id] = temp_id; - - auto migrate_iter = node_migrate_.find(temp_id); + auto migrate_iter = node_migrate_.find(obj_id); if (migrate_iter == node_migrate_.end()) { node_migrate_.emplace( std::piecewise_construct, - std::forward_as_tuple(temp_id), + std::forward_as_tuple(obj_id), std::forward_as_tuple([col_elm](NodeType node){ col_elm->migrate(node); }) @@ -356,15 +319,15 @@ ElementIDType NodeStats::addNodeStats( } auto const col_proxy = col_elm->getProxy(); - node_collection_lookup_[temp_id] = col_proxy; + node_collection_lookup_[obj_id] = col_proxy; - return temp_id; + return obj_id; } VirtualProxyType NodeStats::getCollectionProxyForElement( - ElementIDType temp_id + ElementIDStruct obj_id ) const { - auto iter = node_collection_lookup_.find(temp_id); + auto iter = node_collection_lookup_.find(obj_id); if (iter == node_collection_lookup_.end()) { return no_vrt_proxy; } diff --git a/src/vt/vrt/collection/balance/node_stats.h b/src/vt/vrt/collection/balance/node_stats.h index 0e3689071f..fc569dbd82 100644 --- a/src/vt/vrt/collection/balance/node_stats.h +++ b/src/vt/vrt/collection/balance/node_stats.h @@ -109,9 +109,9 @@ struct NodeStats : runtime::component::Component { * \param[in] time the time the object took * \param[in] comm the comm graph for the object * - * \return the temporary ID for the object assigned for this phase + * \return the ID struct for the object assigned for this phase */ - ElementIDType addNodeStats( + ElementIDStruct addNodeStats( Migratable* col_elm, PhaseType const& phase, TimeType const& time, std::vector const& subphase_time, @@ -124,7 +124,7 @@ struct NodeStats : runtime::component::Component { void clearStats(); /** - * \internal \brief Cleanup after LB runs; convert temporary to permanent IDs + * \internal \brief Cleanup after LB runs */ void startIterCleanup(PhaseType phase, unsigned int look_back); @@ -162,7 +162,7 @@ struct NodeStats : runtime::component::Component { /** * \internal \brief Generate the next object element ID for LB */ - ElementIDType getNextElm(); + ElementIDStruct getNextElm(); /** * \internal \brief Get stored object loads @@ -195,50 +195,31 @@ struct NodeStats : runtime::component::Component { /** * \internal \brief Test if this node has an object to migrate * - * \param[in] obj_id the object temporary ID + * \param[in] obj_id the object ID struct * * \return whether this node has the object */ - bool hasObjectToMigrate(ElementIDType obj_id) const; + bool hasObjectToMigrate(ElementIDStruct obj_id) const; /** * \internal \brief Migrate an local object to another node * - * \param[in] obj_id the object temporary ID + * \param[in] obj_id the object ID struct * \param[in] to_node the node to migrate to * * \return whether this node has the object */ - bool migrateObjTo(ElementIDType obj_id, NodeType to_node); - - /** - * \internal \brief Convert temporary element ID to permanent Returns - * \c no_element_id if not found. - * \param[in] temp_id temporary ID - * - * \return permanent ID - */ - ElementIDType tempToPerm(ElementIDType temp_id) const; - - /** - * \internal \brief Convert permanent element ID to temporary. Returns - * \c no_element_id if not found. - * - * \param[in] perm_id permanent ID - * - * \return temporary ID - */ - ElementIDType permToTemp(ElementIDType perm_id) const; + bool migrateObjTo(ElementIDStruct obj_id, NodeType to_node); /** * \internal \brief Get the collection proxy for a given element ID * - * \param[in] temp_id the temporary ID for the element for a given phase + * \param[in] obj_id the ID struct for the element * * \return the virtual proxy if the element is part of the collection; * otherwise \c no_vrt_proxy */ - VirtualProxyType getCollectionProxyForElement(ElementIDType temp_id) const; + VirtualProxyType getCollectionProxyForElement(ElementIDStruct obj_id) const; void initialize() override; void finalize() override; @@ -262,13 +243,9 @@ struct NodeStats : runtime::component::Component { /// Node subphase timings for each local object std::unordered_map node_subphase_data_; /// Local migration type-free lambdas for each object - std::unordered_map node_migrate_; - /// Map of temporary ID to permanent ID - std::unordered_map node_temp_to_perm_; - /// Map of permanent ID to temporary ID - std::unordered_map node_perm_to_temp_; - /// Map from element temporary ID to the collection's virtual proxy (untyped) - std::unordered_map node_collection_lookup_; + std::unordered_map node_migrate_; + /// Map from element ID to the collection's virtual proxy (untyped) + std::unordered_map node_collection_lookup_; /// Node communication graph for each local object std::unordered_map node_comm_; /// Node communication graph for each subphase diff --git a/src/vt/vrt/collection/balance/randomlb/randomlb.cc b/src/vt/vrt/collection/balance/randomlb/randomlb.cc index 598b7e7780..4251b7486c 100644 --- a/src/vt/vrt/collection/balance/randomlb/randomlb.cc +++ b/src/vt/vrt/collection/balance/randomlb/randomlb.cc @@ -95,8 +95,8 @@ void RandomLB::runLB() { if (to_node != this_node) { vt_debug_print( lb, node, - "RandomLB: migrating obj={:x} from={} to={}\n", - *it, this_node, to_node + "RandomLB: migrating obj={:x} home={} from={} to={}\n", + it->id, it->home_node, this_node, to_node ); migrateObjectTo(*it, to_node); } diff --git a/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc b/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc index 4848c32d45..2b285d34c8 100644 --- a/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc +++ b/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc @@ -58,11 +58,9 @@ void StatsMapLB::init(objgroup::proxy::Proxy in_proxy) { void StatsMapLB::runLB() { auto const& myNewList = theStatsReader()->getMoveList(phase_); for (size_t in = 0; in < myNewList.size(); in += 2) { - auto temp_id = theNodeStats()->permToTemp(myNewList[in]); - - vtAssert(temp_id != balance::no_element_id, "Must have valid ID here"); - - migrateObjectTo(temp_id, myNewList[in+1]); + auto this_node = theContext()->getNode(); + ObjIDType id{myNewList[in], this_node, this_node}; + migrateObjectTo(id, myNewList[in+1]); } theStatsReader()->clearMoveList(phase_); diff --git a/src/vt/vrt/collection/manager.cc b/src/vt/vrt/collection/manager.cc index 0ffc8ad153..62ec78e37b 100644 --- a/src/vt/vrt/collection/manager.cc +++ b/src/vt/vrt/collection/manager.cc @@ -84,19 +84,14 @@ getDispatcher(auto_registry::AutoHandlerType const han) { return theCollection()->getDispatcher(han); } -balance::ElementIDType CollectionManager::getCurrentContextPerm() const { - return cur_context_perm_elm_id_; -} - -balance::ElementIDType CollectionManager::getCurrentContextTemp() const { - return cur_context_temp_elm_id_; +balance::ElementIDStruct CollectionManager::getCurrentContext() const { + return cur_context_elm_id_; } void CollectionManager::setCurrentContext( - balance::ElementIDType perm, balance::ElementIDType temp + balance::ElementIDStruct elm ) { - cur_context_perm_elm_id_ = perm; - cur_context_temp_elm_id_ = temp; + cur_context_elm_id_ = elm; } void CollectionManager::schedule(ActionType action) { diff --git a/src/vt/vrt/collection/manager.h b/src/vt/vrt/collection/manager.h index 3035eb8b23..450d68de6c 100644 --- a/src/vt/vrt/collection/manager.h +++ b/src/vt/vrt/collection/manager.h @@ -1769,29 +1769,20 @@ struct CollectionManager static BcastBufferType broadcasts_; /** - * \internal \brief Get the current LB temporary element ID based on handler + * \internal \brief Get the current LB element ID struct based on handler * context * * \return the element ID */ - balance::ElementIDType getCurrentContextTemp() const; - - /** - * \internal \brief Get the current LB permanent element ID based on handler - * context - * - * \return the element ID - */ - balance::ElementIDType getCurrentContextPerm() const; + balance::ElementIDStruct getCurrentContext() const; /** * \internal \brief Set the current LB element ID * - * \param[in] elm_perm permanent ID - * \param[in] elm_temp temporary ID + * \param[in] elm ID struct */ void setCurrentContext( - balance::ElementIDType elm_perm, balance::ElementIDType elm_temp + balance::ElementIDStruct elm ); private: @@ -1953,8 +1944,9 @@ struct CollectionManager std::unordered_map user_insert_action_ = {}; std::unordered_map dist_tag_id_ = {}; std::unordered_map release_lb_ = {}; - balance::ElementIDType cur_context_temp_elm_id_ = balance::no_element_id; - balance::ElementIDType cur_context_perm_elm_id_ = balance::no_element_id; + balance::ElementIDStruct cur_context_elm_id_ = { + balance::no_element_id, uninitialized_destination, uninitialized_destination + }; }; // These are static variables in class templates because Intel 18 diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index fc366d1c9d..2c7b08e5b8 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -470,24 +470,22 @@ template // Set the current context (element ID) that is executing (having a message // delivered). This is used for load balancing to build the communication // graph - auto const perm_elm_id = base->getElmID(); - auto const temp_elm_id = base->getTempID(); - auto const perm_prev_elm = theCollection()->getCurrentContextPerm(); - auto const temp_prev_elm = theCollection()->getCurrentContextTemp(); + auto const elm_id = base->getElmID(); + auto const prev_elm = theCollection()->getCurrentContext(); - theCollection()->setCurrentContext(perm_elm_id, temp_elm_id); + theCollection()->setCurrentContext(elm_id); vt_debug_print( vrt_coll, node, - "collectionBcastHandler: current context: perm={}, temp={}\n", - perm_elm_id, temp_elm_id + "collectionBcastHandler: current context: elm={}\n", + elm_id ); std::unique_ptr listener = std::make_unique( [&](NodeType dest, MsgSizeType size, bool bcast){ auto& stats = base->getStats(); - stats.recvToNode(dest, perm_elm_id, temp_elm_id, size, bcast); + stats.recvToNode(dest, elm_id, size, bcast); } ); theMsg()->addSendListener(std::move(listener)); @@ -508,7 +506,7 @@ template theMsg()->clearListeners(); // Unset the element ID context - theCollection()->setCurrentContext(perm_prev_elm, temp_prev_elm); + theCollection()->setCurrentContext(prev_elm); if (msg->lbLiteInstrument()) { auto& stats = base->getStats(); @@ -695,24 +693,22 @@ template // Set the current context (element ID) that is executing (having a message // delivered). This is used for load balancing to build the communication // graph - auto const perm_elm_id = col_ptr->getElmID(); - auto const temp_elm_id = col_ptr->getTempID(); - auto const perm_prev_elm = theCollection()->getCurrentContextPerm(); - auto const temp_prev_elm = theCollection()->getCurrentContextTemp(); + auto const elm_id = col_ptr->getElmID(); + auto const prev_elm = theCollection()->getCurrentContext(); - theCollection()->setCurrentContext(perm_elm_id, temp_elm_id); + theCollection()->setCurrentContext(elm_id); vt_debug_print( vrt_coll, node, - "collectionMsgTypedHandler: current context: perm={}, temp={}\n", - perm_elm_id, temp_elm_id + "collectionMsgTypedHandler: current context: elm={}\n", + elm_id ); std::unique_ptr listener = std::make_unique( [&](NodeType dest, MsgSizeType size, bool bcast){ auto& stats = col_ptr->getStats(); - stats.recvToNode(dest, perm_elm_id, temp_elm_id, size, bcast); + stats.recvToNode(dest, elm_id, size, bcast); } ); theMsg()->addSendListener(std::move(listener)); @@ -734,7 +730,7 @@ template #if vt_check_enabled(lblite) theMsg()->clearListeners(); - theCollection()->setCurrentContext(perm_prev_elm, temp_prev_elm); + theCollection()->setCurrentContext(prev_elm); if (col_msg->lbLiteInstrument()) { auto& stats = col_ptr->getStats(); @@ -746,33 +742,31 @@ template template /*static*/ void CollectionManager::recordStats(ColT* col_ptr, MsgT* msg) { auto const pto = col_ptr->getElmID(); - auto const tto = col_ptr->getTempID(); auto const pfrom = msg->getElm(); - auto const tfrom = msg->getElmTemp(); auto& stats = col_ptr->getStats(); auto const msg_size = serialization::MsgSizer::get(msg); auto const cat = msg->getCat(); vt_debug_print( vrt_coll, node, - "recordStats: receive msg: perm(to={}, from={}), temp(to={}, from={})" + "recordStats: receive msg: elm(to={}, from={})," " no={}, size={}, category={}\n", - pto, pfrom, tto, tfrom, balance::no_element_id, msg_size, + pto, pfrom, balance::no_element_id, msg_size, static_cast::type>(cat) ); if ( cat == balance::CommCategory::SendRecv or cat == balance::CommCategory::Broadcast ) { - vtAssert(pfrom != balance::no_element_id, "Must not be no element ID"); + vtAssert(pfrom.id != balance::no_element_id, "Must not be no element ID"); bool bcast = cat == balance::CommCategory::SendRecv ? false : true; - stats.recvObjData(pto, tto, pfrom, tfrom, msg_size, bcast); + stats.recvObjData(pto, pfrom, msg_size, bcast); } else if ( cat == balance::CommCategory::NodeToCollection or cat == balance::CommCategory::NodeToCollectionBcast ) { bool bcast = cat == balance::CommCategory::NodeToCollection ? false : true; auto nfrom = msg->getFromNode(); - stats.recvFromNode(pto, tto, nfrom, msg_size, bcast); + stats.recvFromNode(pto, nfrom, msg_size, bcast); } } @@ -1070,17 +1064,16 @@ messaging::PendingSend CollectionManager::broadcastMsgUntypedHandler( # if vt_check_enabled(lblite) msg->setLBLiteInstrument(instrument); - auto const temp_elm_id = getCurrentContextTemp(); - auto const perm_elm_id = getCurrentContextPerm(); + auto const elm_id = getCurrentContext(); vt_debug_print( vrt_coll, node, - "broadcasting msg: LB current elm context perm={}, temp={}\n", - perm_elm_id, temp_elm_id + "broadcasting msg: LB current elm context elm={}\n", + elm_id ); - if (perm_elm_id != balance::no_element_id) { - msg->setElm(perm_elm_id, temp_elm_id); + if (elm_id.id != balance::no_element_id) { + msg->setElm(elm_id); msg->setCat(balance::CommCategory::Broadcast); } else { msg->setCat(balance::CommCategory::NodeToCollection); @@ -1377,17 +1370,16 @@ messaging::PendingSend CollectionManager::sendMsgUntypedHandler( # if vt_check_enabled(lblite) msg->setLBLiteInstrument(true); - auto const temp_elm_id = getCurrentContextTemp(); - auto const perm_elm_id = getCurrentContextPerm(); + auto const elm_id = getCurrentContext(); vt_debug_print( vrt_coll, node, - "sending msg: LB current elm context perm={}, temp={}\n", - perm_elm_id, temp_elm_id + "sending msg: LB current elm context elm={}\n", + elm_id ); - if (perm_elm_id != balance::no_element_id) { - msg->setElm(perm_elm_id, temp_elm_id); + if (elm_id.id != balance::no_element_id) { + msg->setElm(elm_id); msg->setCat(balance::CommCategory::SendRecv); } else { msg->setCat(balance::CommCategory::NodeToCollection); @@ -2730,8 +2722,9 @@ MigrateStatus CollectionManager::migrateIn( CollectionProxy(proxy).operator()(idx); - // Always assign a new temp element ID for LB statistic tracking - vrt_elm_ptr->temp_elm_id_ = theNodeStats()->getNextElm(); + // Always update the element ID struct for LB statistic tracking + auto const& this_node = theContext()->getNode(); + vrt_elm_ptr->elm_id_.curr_node = this_node; bool const is_static = ColT::isStaticSized(); diff --git a/src/vt/vrt/collection/messages/user.h b/src/vt/vrt/collection/messages/user.h index 17d07c1bf9..bf143e9c8e 100644 --- a/src/vt/vrt/collection/messages/user.h +++ b/src/vt/vrt/collection/messages/user.h @@ -115,9 +115,8 @@ struct CollectionMessage : RoutedMessageType { #if vt_check_enabled(lblite) bool lbLiteInstrument() const; void setLBLiteInstrument(bool const& val); - balance::ElementIDType getElm() const; - balance::ElementIDType getElmTemp() const; - void setElm(balance::ElementIDType perm, balance::ElementIDType temp); + balance::ElementIDStruct getElm() const; + void setElm(balance::ElementIDStruct elm); balance::CommCategory getCat() const; void setCat(balance::CommCategory cat); #endif @@ -142,8 +141,9 @@ struct CollectionMessage : RoutedMessageType { * (sendMsg,broadcastMsg) they are automatically instrumented */ bool lb_lite_instrument_ = false; - balance::ElementIDType elm_perm_ = 0; - balance::ElementIDType elm_temp_ = 0; + balance::ElementIDStruct elm_ = { + 0, uninitialized_destination, uninitialized_destination + }; balance::CommCategory cat_ = balance::CommCategory::SendRecv; #endif diff --git a/src/vt/vrt/collection/messages/user.impl.h b/src/vt/vrt/collection/messages/user.impl.h index 57f756067c..2f592a5397 100644 --- a/src/vt/vrt/collection/messages/user.impl.h +++ b/src/vt/vrt/collection/messages/user.impl.h @@ -124,8 +124,7 @@ void CollectionMessage::serialize(SerializerT& s) { #if vt_check_enabled(lblite) s | lb_lite_instrument_; - s | elm_perm_; - s | elm_temp_; + s | elm_; s | cat_; #endif @@ -156,21 +155,15 @@ void CollectionMessage::setLBLiteInstrument(bool const& val) { } template -balance::ElementIDType CollectionMessage::getElm() const { - return elm_perm_; +balance::ElementIDStruct CollectionMessage::getElm() const { + return elm_; } template void CollectionMessage::setElm( - balance::ElementIDType perm, balance::ElementIDType temp + balance::ElementIDStruct elm ) { - elm_perm_ = perm; - elm_temp_ = temp; -} - -template -balance::ElementIDType CollectionMessage::getElmTemp() const { - return elm_temp_; + elm_ = elm; } template diff --git a/src/vt/vrt/collection/types/migratable.cc b/src/vt/vrt/collection/types/migratable.cc index bad9d47b3b..c340245c11 100644 --- a/src/vt/vrt/collection/types/migratable.cc +++ b/src/vt/vrt/collection/types/migratable.cc @@ -50,8 +50,7 @@ namespace vt { namespace vrt { namespace collection { Migratable::Migratable() - : stats_elm_id_(theNodeStats()->getNextElm()), - temp_elm_id_(theNodeStats()->getNextElm()) + : elm_id_(theNodeStats()->getNextElm()) { } /*virtual*/ void Migratable::destroy() { diff --git a/src/vt/vrt/collection/types/migratable.h b/src/vt/vrt/collection/types/migratable.h index 014e3c9db7..e64bcb900d 100644 --- a/src/vt/vrt/collection/types/migratable.h +++ b/src/vt/vrt/collection/types/migratable.h @@ -99,8 +99,7 @@ struct Migratable : MigrateHookBase { */ virtual void destroy(); - balance::ElementIDType getElmID() const { return stats_elm_id_; } - balance::ElementIDType getTempID() const { return temp_elm_id_; } + balance::ElementIDStruct getElmID() const { return elm_id_; } protected: template @@ -113,8 +112,7 @@ struct Migratable : MigrateHookBase { public: balance::ElementStats& getStats() { return stats_; } protected: - balance::ElementIDType stats_elm_id_ = 0; - balance::ElementIDType temp_elm_id_ = 0; + balance::ElementIDStruct elm_id_ = {0, uninitialized_destination, uninitialized_destination}; }; }}} /* end namespace vt::vrt::collection */ diff --git a/src/vt/vrt/collection/types/migratable.impl.h b/src/vt/vrt/collection/types/migratable.impl.h index a050bfc9fa..810561a65d 100644 --- a/src/vt/vrt/collection/types/migratable.impl.h +++ b/src/vt/vrt/collection/types/migratable.impl.h @@ -54,8 +54,7 @@ template void Migratable::serialize(Serializer& s) { MigrateHookBase::serialize(s); s | stats_; - s | stats_elm_id_; - s | temp_elm_id_; + s | elm_id_; } }}} /* end namespace vt::vrt::collection */