diff --git a/src/vt-tv/api/info.h b/src/vt-tv/api/info.h index c3231499f..31908c411 100644 --- a/src/vt-tv/api/info.h +++ b/src/vt-tv/api/info.h @@ -245,37 +245,38 @@ struct Info { /** * \brief Returns a getter to a specified object QOI */ - std::function - getObjectQOIGetter(const std::string& object_qoi) const { - std::function qoi_getter; + template + std::function + getObjectQOIGetter(std::string const& object_qoi) const { + std::function qoi_getter; if (object_qoi == "load") { qoi_getter = [&](ObjectWork obj) { - return convertQOIVariantTypeToT_(getObjectLoad(obj)); + return convertQOIVariantTypeToT_(getObjectLoad(obj)); }; } else if (object_qoi == "received_volume") { qoi_getter = [&](ObjectWork obj) { - return convertQOIVariantTypeToT_(getObjectReceivedVolume(obj)); + return convertQOIVariantTypeToT_(getObjectReceivedVolume(obj)); }; } else if (object_qoi == "sent_volume") { qoi_getter = [&](ObjectWork obj) { - return convertQOIVariantTypeToT_(getObjectSentVolume(obj)); + return convertQOIVariantTypeToT_(getObjectSentVolume(obj)); }; } else if (object_qoi == "max_volume") { qoi_getter = [&](ObjectWork obj) { - return convertQOIVariantTypeToT_(getObjectMaxVolume(obj)); + return convertQOIVariantTypeToT_(getObjectMaxVolume(obj)); }; } else if (object_qoi == "id") { qoi_getter = [&](ObjectWork obj) { - return convertQOIVariantTypeToT_(getObjectID(obj)); + return convertQOIVariantTypeToT_(getObjectID(obj)); }; } else if (object_qoi == "rank_id") { qoi_getter = [&](ObjectWork obj) { - return convertQOIVariantTypeToT_(getObjectRankID(obj)); + return convertQOIVariantTypeToT_(getObjectRankID(obj)); }; } else { // Look in attributes and user_defined (will throw an error if QOI doesn't exist) qoi_getter = [&](ObjectWork obj) { - return convertQOIVariantTypeToT_( + return convertQOIVariantTypeToT_( getObjectAttributeOrUserDefined(obj, object_qoi)); }; } @@ -383,11 +384,23 @@ struct Info { * * \return the object QOI */ - double getObjectQOIAtPhase( - ElementIDType obj_id, PhaseType phase, std::string obj_qoi) const { - auto qoi_getter = getObjectQOIGetter(obj_qoi); + template + T getObjectQOIAtPhase( + ElementIDType obj_id, PhaseType phase, std::string const& obj_qoi + ) const { auto const& objects = this->getPhaseObjects(phase); auto const& obj = objects.at(obj_id); + auto const& ud = obj.getUserDefined(); + + if (auto it = ud.find(obj_qoi); it != ud.end()) { + if (std::holds_alternative(it->second)) { + return static_cast(std::get(it->second)); + } else if (std::holds_alternative(it->second)) { + return static_cast(std::get(it->second)); + } + } + + auto qoi_getter = getObjectQOIGetter(obj_qoi); return qoi_getter(obj); } diff --git a/src/vt-tv/render/render.cc b/src/vt-tv/render/render.cc index 7f0b864ae..199b11e7f 100644 --- a/src/vt-tv/render/render.cc +++ b/src/vt-tv/render/render.cc @@ -190,14 +190,13 @@ Render::computeObjectQOIRange_() { // Initialize object QOI range attributes double oq_max = -1 * std::numeric_limits::infinity(); double oq_min = std::numeric_limits::infinity(); - double oq; std::set> oq_all; // Update the QOI range auto updateQOIRange = [&](auto const& objects, PhaseType phase) { for (auto const& [obj_id, obj_work] : objects) { // Update maximum object qoi - oq = info_.getObjectQOIAtPhase(obj_id, phase, this->object_qoi_); + auto oq = info_.getObjectQOIAtPhase(obj_id, phase, this->object_qoi_); if (!continuous_object_qoi_) { // Allow for integer categorical QOI (i.e. rank_id) if (oq == static_cast(oq)) { @@ -538,7 +537,7 @@ vtkNew Render::createObjectMesh_(PhaseType phase) { // Set object attributes ElementIDType obj_id = objectWork.getID(); - auto oq = this->info_.getObjectQOIAtPhase(obj_id, phase, this->object_qoi_); + auto oq = info_.getObjectQOIAtPhase(obj_id, phase, object_qoi_); q_arr->SetTuple1(point_index, oq); b_arr->SetTuple1(point_index, migratable); if (this->object_qoi_ != "load") { diff --git a/tests/unit/api/test_info.cc b/tests/unit/api/test_info.cc index de5e0bd5b..cdd3d6a4a 100644 --- a/tests/unit/api/test_info.cc +++ b/tests/unit/api/test_info.cc @@ -131,7 +131,7 @@ TEST_P(InfoTest, test_get_object_qoi_getter) { "rank_id", "non-existent"}); for (auto const& qoi : qoi_list) { - auto qoi_getter = info.getObjectQOIGetter(qoi); + auto qoi_getter = info.getObjectQOIGetter(qoi); } } @@ -409,9 +409,9 @@ TEST_F(InfoTest, test_get_object_qoi) { "non-existent"}); for (auto const& qoi : qoi_list) { if (qoi == "non-existent") { - EXPECT_THROW(info.getObjectQOIAtPhase(0, 0, qoi), std::runtime_error); + EXPECT_THROW(info.getObjectQOIAtPhase(0, 0, qoi), std::runtime_error); } else { - ASSERT_NO_THROW(info.getObjectQOIAtPhase(0, 0, qoi)); + ASSERT_NO_THROW(info.getObjectQOIAtPhase(0, 0, qoi)); } } }