Skip to content

Commit

Permalink
#128: render: finish object qoi selection from user defined
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander committed Dec 9, 2024
1 parent 6bf446e commit eac8a4e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 19 deletions.
39 changes: 26 additions & 13 deletions src/vt-tv/api/info.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,37 +245,38 @@ struct Info {
/**
* \brief Returns a getter to a specified object QOI
*/
std::function<double(ObjectWork)>
getObjectQOIGetter(const std::string& object_qoi) const {
std::function<double(ObjectWork)> qoi_getter;
template <typename T>
std::function<T(ObjectWork)>
getObjectQOIGetter(std::string const& object_qoi) const {
std::function<T(ObjectWork)> qoi_getter;
if (object_qoi == "load") {
qoi_getter = [&](ObjectWork obj) {
return convertQOIVariantTypeToT_<double>(getObjectLoad(obj));
return convertQOIVariantTypeToT_<T>(getObjectLoad(obj));
};
} else if (object_qoi == "received_volume") {
qoi_getter = [&](ObjectWork obj) {
return convertQOIVariantTypeToT_<double>(getObjectReceivedVolume(obj));
return convertQOIVariantTypeToT_<T>(getObjectReceivedVolume(obj));
};
} else if (object_qoi == "sent_volume") {
qoi_getter = [&](ObjectWork obj) {
return convertQOIVariantTypeToT_<double>(getObjectSentVolume(obj));
return convertQOIVariantTypeToT_<T>(getObjectSentVolume(obj));
};
} else if (object_qoi == "max_volume") {
qoi_getter = [&](ObjectWork obj) {
return convertQOIVariantTypeToT_<double>(getObjectMaxVolume(obj));
return convertQOIVariantTypeToT_<T>(getObjectMaxVolume(obj));
};
} else if (object_qoi == "id") {
qoi_getter = [&](ObjectWork obj) {
return convertQOIVariantTypeToT_<double>(getObjectID(obj));
return convertQOIVariantTypeToT_<T>(getObjectID(obj));
};
} else if (object_qoi == "rank_id") {
qoi_getter = [&](ObjectWork obj) {
return convertQOIVariantTypeToT_<double>(getObjectRankID(obj));
return convertQOIVariantTypeToT_<T>(getObjectRankID(obj));
};
} else {
// Look in attributes and user_defined (will throw an error if QOI doesn't exist)
qoi_getter = [&](ObjectWork obj) {
return convertQOIVariantTypeToT_<double>(
return convertQOIVariantTypeToT_<T>(
getObjectAttributeOrUserDefined(obj, object_qoi));
};
}
Expand Down Expand Up @@ -383,11 +384,23 @@ struct Info {
*
* \return the object QOI
*/
double getObjectQOIAtPhase(
ElementIDType obj_id, PhaseType phase, std::string obj_qoi) const {
auto qoi_getter = getObjectQOIGetter(obj_qoi);
template <typename T>
T getObjectQOIAtPhase(
ElementIDType obj_id, PhaseType phase, std::string const& obj_qoi
) const {
auto const& objects = this->getPhaseObjects(phase);
auto const& obj = objects.at(obj_id);
auto const& ud = obj.getUserDefined();

if (auto it = ud.find(obj_qoi); it != ud.end()) {
if (std::holds_alternative<double>(it->second)) {
return static_cast<T>(std::get<double>(it->second));
} else if (std::holds_alternative<int>(it->second)) {
return static_cast<T>(std::get<int>(it->second));
}
}

auto qoi_getter = getObjectQOIGetter<T>(obj_qoi);
return qoi_getter(obj);
}

Expand Down
5 changes: 2 additions & 3 deletions src/vt-tv/render/render.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,13 @@ Render::computeObjectQOIRange_() {
// Initialize object QOI range attributes
double oq_max = -1 * std::numeric_limits<double>::infinity();
double oq_min = std::numeric_limits<double>::infinity();
double oq;
std::set<std::variant<double, int>> oq_all;

// Update the QOI range
auto updateQOIRange = [&](auto const& objects, PhaseType phase) {
for (auto const& [obj_id, obj_work] : objects) {
// Update maximum object qoi
oq = info_.getObjectQOIAtPhase(obj_id, phase, this->object_qoi_);
auto oq = info_.getObjectQOIAtPhase<double>(obj_id, phase, this->object_qoi_);
if (!continuous_object_qoi_) {
// Allow for integer categorical QOI (i.e. rank_id)
if (oq == static_cast<int>(oq)) {
Expand Down Expand Up @@ -538,7 +537,7 @@ vtkNew<vtkPolyData> Render::createObjectMesh_(PhaseType phase) {

// Set object attributes
ElementIDType obj_id = objectWork.getID();
auto oq = this->info_.getObjectQOIAtPhase(obj_id, phase, this->object_qoi_);
auto oq = info_.getObjectQOIAtPhase<double>(obj_id, phase, object_qoi_);
q_arr->SetTuple1(point_index, oq);
b_arr->SetTuple1(point_index, migratable);
if (this->object_qoi_ != "load") {
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/api/test_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ TEST_P(InfoTest, test_get_object_qoi_getter) {
"rank_id",
"non-existent"});
for (auto const& qoi : qoi_list) {
auto qoi_getter = info.getObjectQOIGetter(qoi);
auto qoi_getter = info.getObjectQOIGetter<double>(qoi);
}
}

Expand Down Expand Up @@ -409,9 +409,9 @@ TEST_F(InfoTest, test_get_object_qoi) {
"non-existent"});
for (auto const& qoi : qoi_list) {
if (qoi == "non-existent") {
EXPECT_THROW(info.getObjectQOIAtPhase(0, 0, qoi), std::runtime_error);
EXPECT_THROW(info.getObjectQOIAtPhase<double>(0, 0, qoi), std::runtime_error);
} else {
ASSERT_NO_THROW(info.getObjectQOIAtPhase(0, 0, qoi));
ASSERT_NO_THROW(info.getObjectQOIAtPhase<double>(0, 0, qoi));
}
}
}
Expand Down

0 comments on commit eac8a4e

Please sign in to comment.