diff --git a/include/plssvm/csvm.hpp b/include/plssvm/csvm.hpp index 2d25d3dce..a36156e45 100644 --- a/include/plssvm/csvm.hpp +++ b/include/plssvm/csvm.hpp @@ -41,11 +41,13 @@ #include "fmt/color.h" // fmt::fg, fmt::color::orange #include "fmt/format.h" // fmt::format +#include "fmt/ranges.h" // fmt::join #include "igor/igor.hpp" // igor::parser #include // std::max_element, std::all_of #include // std::chrono::{time_point, steady_clock, duration_cast} #include // std::size_t +#include // std::int64_t #include // std::numeric_limits::lowest #include // std::unique_ptr #include // std::optional, std::make_optional, std::nullopt @@ -967,10 +969,22 @@ std::tuple, std::vector, std::vector(assembly_end_time - assembly_start_time); if (used_solver != solver_type::cg_implicit) { - detail::log(verbosity_level::full | verbosity_level::timing, - comm_, - "Assembled the kernel matrix in {}.\n", - assembly_duration); + if (comm_.size() > 1) { + // gather kernel matrix assembly runtimes from each MPI rank + const std::vector durations = comm_.gather(assembly_duration); + + detail::log(verbosity_level::full | verbosity_level::timing, + comm_, + "Assembled the kernel matrix in {} ({}).\n", + *std::max_element(durations.cbegin(), durations.cend()), + fmt::join(durations, "|")); + + } else { + detail::log(verbosity_level::full | verbosity_level::timing, + comm_, + "Assembled the kernel matrix in {}.\n", + assembly_duration); + } } PLSSVM_DETAIL_TRACKING_PERFORMANCE_TRACKER_ADD_TRACKING_ENTRY((detail::tracking::tracking_entry{ "kernel_matrix", "kernel_matrix_assembly", assembly_duration }));