diff --git a/include/dlaf/eigensolver/band_to_tridiag/mc.h b/include/dlaf/eigensolver/band_to_tridiag/mc.h index 49d5576adf..198b034132 100644 --- a/include/dlaf/eigensolver/band_to_tridiag/mc.h +++ b/include/dlaf/eigensolver/band_to_tridiag/mc.h @@ -427,8 +427,6 @@ template SizeType b, std::shared_ptr> a_block, SizeType j, DepSender&& dep) { using dlaf::comm::internal::transformMPI; - using dlaf::internal::whenAllLift; - namespace ex = pika::execution::experimental; auto send = [dest, tag, b, j](const comm::Communicator& comm, @@ -436,7 +434,8 @@ template DLAF_MPI_CHECK_ERROR(MPI_Isend(a_block->ptr(0, j), to_int(2 * b), dlaf::comm::mpi_datatype::type, dest, tag, comm, req)); }; - return whenAllLift(std::forward(pcomm), std::move(a_block), std::forward(dep)) | + return ex::when_all(std::forward(pcomm), ex::just(std::move(a_block)), + std::forward(dep)) | transformMPI(send); } @@ -445,7 +444,7 @@ template SizeType b, std::shared_ptr> a_block, SizeType j, DepSender&& dep) { using dlaf::comm::internal::transformMPI; - using dlaf::internal::whenAllLift; + namespace ex = pika::execution::experimental; auto recv = [src, tag, b, j](const comm::Communicator& comm, std::shared_ptr>& a_block, MPI_Request* req) { @@ -453,7 +452,8 @@ template src, tag, comm, req)); }; - return whenAllLift(std::forward(pcomm), std::move(a_block), std::forward(dep)) | + return ex::when_all(std::forward(pcomm), ex::just(std::move(a_block)), + std::forward(dep)) | transformMPI(recv); } @@ -622,13 +622,13 @@ template [[nodiscard]] auto schedule_send_worker(CommSender&& pcomm, comm::IndexT_MPI dest, comm::IndexT_MPI tag, PromiseSender&& worker) { using dlaf::comm::internal::transformMPI; - using dlaf::internal::whenAllLift; + namespace ex = pika::execution::experimental; auto send = [dest, tag](const comm::Communicator& comm, const auto& worker, MPI_Request* req) { worker.send(comm, dest, tag, req); }; - return whenAllLift(std::forward(pcomm), std::forward(worker)) | + return ex::when_all(std::forward(pcomm), std::forward(worker)) | transformMPI(send); } @@ -637,13 +637,13 @@ template comm::IndexT_MPI src, comm::IndexT_MPI tag, PromiseSender&& worker) { using dlaf::comm::internal::transformMPI; - using dlaf::internal::whenAllLift; + namespace ex = pika::execution::experimental; auto recv = [sweep, step, src, tag](const comm::Communicator& comm, auto& worker, MPI_Request* req) { worker.recv(sweep, step, comm, src, tag, req); }; - return whenAllLift(std::forward(pcomm), std::forward(worker)) | + return ex::when_all(std::forward(pcomm), std::forward(worker)) | transformMPI(recv); } @@ -852,7 +852,6 @@ struct VAccessHelper { // Rank 0 Rank1 Rank2 Rank3 Rank 0 Rank1 Rank2 Rank3 // A0 A4 A8 A12 B0 B4 B8 B12 // A1 A5 A9 A13 B1 B5 B9 B13 - // -- -- --- --- -- -- --- --- // A2 A6 A10 A14 B2 B6 B10 B14 // A3 A7 A11 A15 B3 B7 B11 B15 // --- --- --- --- --- --- --- --- @@ -871,136 +870,107 @@ struct VAccessHelper { // A4 B3 | C2 D1 | E0 .. // .. .. | .. .. | .. .. // - // the communication of a tile of a panel might be splitted in two parts (e.g. B2 B3 tile). - // If copyIsSplitted() is true the copy/communication has to happen in two parts Top and Bottom, - // otherwise only Top is set. + // the communication of a tile of a panel might be splitted in multiple parts (e.g. B2 B3 tile). - VAccessHelper(const comm::CommunicatorGrid& grid, const SizeType band, const SizeType sweeps, - const SizeType sweep0, const SizeType step0, const matrix::Distribution& dist_band, - const matrix::Distribution& dist_panel, const matrix::Distribution& dist_v) noexcept { - rank_panel_ = rankPanel(band, step0, dist_band); - const auto rank = grid.rankFullCommunicator(grid.rank()); - if (rank == rank_panel_) - index_panel_ = indexPanel(band, step0, dist_band, dist_panel); + VAccessHelper(const comm::CommunicatorGrid& grid, const SizeType sweeps, const SizeType sweep0, + const SizeType step0, const matrix::Distribution& dist_panel, + const matrix::Distribution& dist_v) noexcept + : grid_(grid), dist_v_(dist_v) { + const SizeType b = dist_panel.baseTileSize().cols(); + const SizeType nb = dist_v.baseTileSize().rows(); - const GlobalElementIndex id_v{(sweep0 / band + step0) * band, sweep0}; - index_v_ = dist_v.globalTileIndex(id_v); - index_element_tile_v_ = dist_v.tileElementIndex(id_v); + DLAF_ASSERT_HEAVY(sweep0 % b == 0, sweep0, b); + DLAF_ASSERT_HEAVY(step0 * b % dist_panel.baseTileSize().rows() == 0, step0, b, + dist_panel.baseTileSize().rows()); - rank_v_top_ = grid.rankFullCommunicator(dist_v.rankGlobalTile(index_v_)); + index_panel_ = dist_panel.globalTileIndex(GlobalElementIndex{step0 * b, 0}); + rank_panel_ = dist_panel.rankGlobalTile(index_panel_).row(); + + DLAF_ASSERT_HEAVY(dist_panel.tileSize(index_panel_).rows() > 0, + dist_panel.tileSize(index_panel_).rows()); + + const GlobalElementIndex ij_el_v_top{(sweep0 / b + step0) * b, sweep0}; + ij_v_top_ = dist_v.globalTileIndex(ij_el_v_top); + offset_v_top_ = dist_v.tileElementIndex(ij_el_v_top); const SizeType rows_panel = - std::min(dist_panel.blockSize().rows(), dist_v.size().rows() - 1 - id_v.row()); - const SizeType rows_v_top = dist_v.tileSize(index_v_).rows() - index_element_tile_v_.row(); + std::min(dist_panel.tileSize(index_panel_).rows(), dist_v.size().rows() - 1 - ij_el_v_top.row()); - const SizeType cols = std::min(rows_panel, std::min(band, sweeps - sweep0)); + const SizeType cols = std::min(rows_panel, std::min(b, sweeps - sweep0)); - if (rows_v_top < rows_panel) { - rank_v_bottom_ = grid.rankFullCommunicator(dist_v.rankGlobalTile(indexVBottomInternal(index_v_))); - size_top_ = TileElementSize{rows_v_top, cols}; - size_bottom_ = TileElementSize{rows_panel - rows_v_top, cols}; - } - else { - rank_v_bottom_ = -1; - size_top_ = TileElementSize{rows_panel, cols}; - size_bottom_ = TileElementSize{0, 0}; - } - } + size_top_ = {std::min(nb - offset_v_top_.row(), rows_panel), cols}; + size_middle_ = {nb, cols}; + + nr_tiles_ = 1 + util::ceilDiv(rows_panel - size_top_.rows(), nb); - bool copyIsSplitted() const noexcept { - return size_bottom_.rows() > 0; + size_bottom_ = {rows_panel - size_top_.rows() - (nr_tiles_ - 2) * size_middle_.rows(), cols}; } - LocalTileIndex indexPanel() const noexcept { + GlobalTileIndex index_panel() const noexcept { DLAF_ASSERT_HEAVY(index_panel_.isValid(), index_panel_); return index_panel_; } - comm::IndexT_MPI rankPanel() const noexcept { + comm::IndexT_MPI rank_panel() const noexcept { return rank_panel_; } - matrix::SubTileSpec specPanelTop() const noexcept { - return {{0, 0}, size_top_}; - } - matrix::SubTileSpec specPanelBottom() const noexcept { - DLAF_ASSERT_HEAVY(copyIsSplitted(), size_bottom_); - return {{size_top_.rows(), 0}, size_bottom_}; + SizeType nr_tiles() const noexcept { + return nr_tiles_; } - GlobalTileIndex indexVTop() const noexcept { - return index_v_; - } - matrix::SubTileSpec specVTop() const noexcept { - return {index_element_tile_v_, size_top_}; - } - comm::IndexT_MPI rankVTop() const noexcept { - return rank_v_top_; - } - GlobalTileIndex indexVBottom() const noexcept { - DLAF_ASSERT_HEAVY(copyIsSplitted(), size_bottom_); - return indexVBottomInternal(index_v_); - } - matrix::SubTileSpec specVBottom() const noexcept { - DLAF_ASSERT_HEAVY(copyIsSplitted(), size_bottom_); - return {{0, index_element_tile_v_.col()}, size_bottom_}; - } - comm::IndexT_MPI rankVBottom() const noexcept { - DLAF_ASSERT_HEAVY(copyIsSplitted(), size_bottom_); - return rank_v_bottom_; + TileElementIndex spec_panel_origin(SizeType i) const noexcept { + DLAF_ASSERT_HEAVY(0 <= i && i < nr_tiles_, i, nr_tiles_); + if (i == 0) + return {0, 0}; + + return {size_top_.rows() + (i - 1) * size_middle_.rows(), 0}; } - static comm::IndexT_MPI rankPanel(const SizeType band, const SizeType step, - const matrix::Distribution& dist_band) noexcept { - // Need to use dist_band to identify the rank - // as the panel distribution is local (mismatch between tile size and distribution block-size). - const GlobalElementIndex id{0, step * band}; - const GlobalTileIndex index = dist_band.globalTileIndex(id); + TileElementSize spec_size(SizeType i) const noexcept { + DLAF_ASSERT_HEAVY(0 <= i && i < nr_tiles_, i, nr_tiles_); + if (i == 0) + return size_top_; + else if (i == nr_tiles_ - 1) + return size_bottom_; - return dist_band.rankGlobalTile(index).col(); + return size_middle_; } - static LocalTileIndex indexPanel(const SizeType band, const SizeType step, - const matrix::Distribution& dist_band, - const matrix::Distribution& dist_panel) noexcept { - // Need to use dist_band to compute the local element index - // as dist_panel is local (mismatch between tile size and distribution block-size). - // Then dist_panel is used to compute the local tile index. - const GlobalElementIndex id{0, step * band}; - const GlobalTileIndex index = dist_band.globalTileIndex(id); - - DLAF_ASSERT_HEAVY(dist_band.rankIndex() == dist_band.rankGlobalTile(index), dist_band.rankIndex(), - dist_band.rankGlobalTile(index)); + GlobalTileIndex index_v(SizeType i) const noexcept { + DLAF_ASSERT_HEAVY(0 <= i && i < nr_tiles_, i, nr_tiles_); + return ij_v_top_ + GlobalTileSize{i, 0}; + ; + } - const SizeType local_row_panel_v = - dist_band.localTileIndex(index).col() * dist_band.blockSize().cols() + - dist_band.tileElementIndex(id).col(); + comm::IndexT_MPI rank_v(SizeType i) const noexcept { + return grid_.rankFullCommunicator(dist_v_.rankGlobalTile(index_v(i))); + } - DLAF_ASSERT_HEAVY(dist_panel.tileElementIndex(GlobalElementIndex{local_row_panel_v, 0}) == - TileElementIndex(0, 0), - local_row_panel_v, dist_panel.blockSize().rows()); + TileElementIndex spec_v_origin(SizeType i) const noexcept { + DLAF_ASSERT_HEAVY(0 <= i && i < nr_tiles_, i, nr_tiles_); + if (i == 0) + return offset_v_top_; - return dist_panel.localTileIndex(dist_panel.globalTileIndex(GlobalElementIndex{local_row_panel_v, - 0})); + return {0, offset_v_top_.col()}; } private: - static GlobalTileIndex indexVBottomInternal(const GlobalTileIndex& index_v_top) noexcept { - return index_v_top + GlobalTileSize{1, 0}; - } - - LocalTileIndex index_panel_; + const comm::CommunicatorGrid& grid_; + const matrix::Distribution& dist_v_; + GlobalTileIndex index_panel_; comm::IndexT_MPI rank_panel_; - GlobalTileIndex index_v_; - TileElementIndex index_element_tile_v_; - comm::IndexT_MPI rank_v_top_; - comm::IndexT_MPI rank_v_bottom_; + SizeType nr_tiles_; + GlobalTileIndex ij_v_top_; + TileElementIndex offset_v_top_; TileElementSize size_top_{0, 0}; + TileElementSize size_middle_{0, 0}; TileElementSize size_bottom_{0, 0}; }; template TridiagResult BandToTridiag::call_L( comm::CommunicatorGrid grid, const SizeType b, Matrix& mat_a) noexcept { + // TODO rewrite // Note on the algorithm, data distribution and dependency tracking: // The band matrix is redistribuited in 1D block cyclic. The new block size is a multiple of the // block_size of mat_a. As sweeps are performed the matrix is shifted one column to the left (The @@ -1038,6 +1008,9 @@ TridiagResult BandToTridiag::call_L( using util::ceilDiv; using pika::resource::get_num_threads; + using SemaphorePtr = std::shared_ptr>; + using Tile = matrix::Tile; + using TilePtr = std::shared_ptr; namespace ex = pika::execution::experimental; @@ -1046,7 +1019,7 @@ TridiagResult BandToTridiag::call_L( // note: A is square and has square blocksize SizeType size = mat_a.size().cols(); - SizeType nrtile = mat_a.nrTiles().cols(); + SizeType n = mat_a.nrTiles().cols(); SizeType nb = mat_a.blockSize().cols(); auto& dist_a = mat_a.distribution(); @@ -1068,6 +1041,8 @@ TridiagResult BandToTridiag::call_L( const auto prev_rank = (rank == 0 ? ranks - 1 : rank - 1); const auto next_rank = (rank + 1 == ranks ? 0 : rank + 1); + auto policy_hp = dlaf::internal::Policy(pika::execution::thread_priority::high); + const SizeType nb_band = get1DBlockSize(nb); const SizeType tiles_per_block = nb_band / nb; matrix::Distribution dist({1, size}, {1, nb_band}, {1, ranks}, {0, rank}, {0, 0}); @@ -1087,7 +1062,7 @@ TridiagResult BandToTridiag::call_L( return static_cast(2 * j + (is_offdiag ? 1 : 0)); }; // The offset is set to the first unused tag by compute_copy_tag. - const comm::IndexT_MPI offset_v_tag = compute_copy_tag(nrtile, false); + const comm::IndexT_MPI offset_v_tag = compute_copy_tag(n, false); auto compute_v_tag = [offset_v_tag](SizeType i, bool is_bottom) { // only the row index is needed as dependencies are added to avoid @@ -1096,17 +1071,17 @@ TridiagResult BandToTridiag::call_L( }; // The offset is set to the first unused tag by compute_v_tag. - const comm::IndexT_MPI offset_col_tag = compute_v_tag(nrtile, false); + const comm::IndexT_MPI offset_col_tag = compute_v_tag(n, false); - auto compute_col_tag = [offset_col_tag, ranks](SizeType block_id, bool last_col) { + auto compute_col_tag = [offset_col_tag, ranks](SizeType id_block, bool last_col) { // By construction the communication from block j+1 to block j are dependent, therefore a tag per - // block is enough. Moreover block_id is divided by the number of ranks as only the local index is + // block is enough. Moreover id_block is divided by the number of ranks as only the local index is // needed. // When the last column ((size-1)-th column) is communicated the tag is incremented by 1 as in // some case it can mix with the (size-2)-th columnn. - // Note: Passing the local_block_id is not an option as the sender local index might be different + // Note: Passing the id_block_local is not an option as the sender local index might be different // from the receiver index. - return offset_col_tag + static_cast(block_id / ranks) + (last_col ? 1 : 0); + return offset_col_tag + static_cast(id_block / ranks) + (last_col ? 1 : 0); }; // Same offset if ranks > 2, otherwise add the first unused tag of compute_col_tag. @@ -1115,26 +1090,30 @@ TridiagResult BandToTridiag::call_L( ; auto compute_worker_tag = [offset_worker_tag, workers_per_block, ranks](SizeType sweep, - SizeType block_id) { + SizeType id_block) { // As only workers_per_block are available a dependency is introduced by reusing it, therefore // a different tag for all sweeps is not needed. - // Moreover block_id is divided by the number of ranks as only the local index is needed. - // Note: Passing the local_block_id is not an option as the sender local index might be different + // Moreover id_block is divided by the number of ranks as only the local index is needed. + // Note: Passing the id_block_local is not an option as the sender local index might be different // from the receiver index. return offset_worker_tag + static_cast(sweep % workers_per_block + - block_id / ranks * workers_per_block); + id_block / ranks * workers_per_block); }; // Need shared pointer to keep the allocation until all the tasks are executed. vector>> a_ws; + a_ws.reserve(dist.localNrTiles().cols()); + + vector sems; + sems.reserve(dist.localNrTiles().cols()); + for (SizeType i = 0; i < dist.localNrTiles().cols(); ++i) { a_ws.emplace_back(std::make_shared>(size, b, rank + i * ranks, nb_band)); + sems.emplace_back(std::make_shared>(0)); } - vector>> deps(dist.localNrTiles().cols()); - for (auto& dep : deps) { - dep.reserve(nb_band / nb); - } + vector> tiles_v(dist.localNrTiles().cols()); + vector> deps(dist.localNrTiles().cols()); { constexpr std::size_t n_workspaces = 4; @@ -1152,64 +1131,79 @@ TridiagResult BandToTridiag::call_L( }; // Copy the band matrix - for (SizeType k = 0; k < nrtile; ++k) { - const auto id_block = k / tiles_per_block; - const GlobalTileIndex index_diag{k, k}; - const GlobalTileIndex index_offdiag{k + 1, k}; - const auto rank_block = dist.rankGlobalTile(id_block); - const auto rank_diag = grid.rankFullCommunicator(dist_a.rankGlobalTile(index_diag)); - const auto rank_offdiag = - (k == nrtile - 1 ? -1 : grid.rankFullCommunicator(dist_a.rankGlobalTile(index_offdiag))); - const auto tag_diag = compute_copy_tag(k, false); - const auto tag_offdiag = compute_copy_tag(k, true); - - if (rank == rank_block) { - ex::any_sender<> dep; - const auto id_block_local = dist.localTileFromGlobalTile(id_block); - - if (rank == rank_diag) { - dep = copy_diag(a_ws[id_block_local], k * nb, mat_a.read(index_diag)) | ex::split(); - } - else { - auto& temp = temps.nextResource(); - auto diag_tile = comm::scheduleRecv(ex::make_unique_any_sender(comm), rank_diag, tag_diag, - splitTile(temp.readwrite(LocalTileIndex{0, 0}), - {{0, 0}, dist_a.tileSize(index_diag)})); - dep = copy_diag(a_ws[id_block_local], k * nb, std::move(diag_tile)) | ex::split(); - } - - if (k < nrtile - 1) { - if (rank == rank_offdiag) { - dep = copy_offdiag(a_ws[id_block_local], k * nb, - ex::when_all(std::move(dep), mat_a.read(index_offdiag))) | - ex::split(); + for (SizeType k_block = 0; k_block < dist.nrTiles().cols(); ++k_block) { + const SizeType k_start = k_block * tiles_per_block; + const SizeType k_end = std::min(k_start + tiles_per_block, n); + const auto rank_block = dist.rankGlobalTile(k_block); + + ex::unique_any_sender<> prev_dep = ex::just(); + for (SizeType k = k_start; k < k_end; ++k) { + const GlobalTileIndex index_diag{k, k}; + const GlobalTileIndex index_offdiag{k + 1, k}; + const auto rank_diag = grid.rankFullCommunicator(dist_a.rankGlobalTile(index_diag)); + const auto rank_offdiag = + (k == n - 1 ? -1 : grid.rankFullCommunicator(dist_a.rankGlobalTile(index_offdiag))); + const auto tag_diag = compute_copy_tag(k, false); + const auto tag_offdiag = compute_copy_tag(k, true); + + if (rank == rank_block) { + SizeType nr_release = nb / b; + ex::unique_any_sender<> dep; + const auto k_block_local = dist.localTileFromGlobalTile(k_block); + + if (rank == rank_diag) { + dep = copy_diag(a_ws[k_block_local], k * nb, mat_a.read(index_diag)); } else { auto& temp = temps.nextResource(); - auto offdiag_tile = - comm::scheduleRecv(ex::make_unique_any_sender(comm), rank_offdiag, tag_offdiag, - splitTile(temp.readwrite(LocalTileIndex{0, 0}), - {{0, 0}, dist_a.tileSize(index_offdiag)})); - dep = copy_offdiag(a_ws[id_block_local], k * nb, - ex::when_all(std::move(dep), std::move(offdiag_tile))) | - ex::split(); + auto diag_tile = comm::scheduleRecv(ex::make_unique_any_sender(comm), rank_diag, tag_diag, + splitTile(temp.readwrite(LocalTileIndex{0, 0}), + {{0, 0}, dist_a.tileSize(index_diag)})); + dep = ex::ensure_started(copy_diag(a_ws[k_block_local], k * nb, std::move(diag_tile))); } - } - deps[id_block_local].push_back(std::move(dep)); - } - else { - if (rank == rank_diag) { - ex::start_detached(comm::scheduleSend(ex::make_unique_any_sender(comm), rank_block, tag_diag, - mat_a.read(index_diag))); + if (k < n - 1) { + if (rank == rank_offdiag) { + dep = copy_offdiag(a_ws[k_block_local], k * nb, + ex::when_all(std::move(dep), mat_a.read(index_offdiag))); + } + else { + auto& temp = temps.nextResource(); + auto offdiag_tile = + comm::scheduleRecv(ex::make_unique_any_sender(comm), rank_offdiag, tag_offdiag, + splitTile(temp.readwrite(LocalTileIndex{0, 0}), + {{0, 0}, dist_a.tileSize(index_offdiag)})); + dep = ex::ensure_started(copy_offdiag( + a_ws[k_block_local], k * nb, ex::when_all(std::move(dep), std::move(offdiag_tile)))); + } + } + else { + // Add one to make sure to unlock the last step of the first sweep. + nr_release = ceilDiv(size - k * nb, b) + 1; + } + + prev_dep = + ex::when_all(ex::just(nr_release, sems[k_block_local]), std::move(prev_dep), + std::move(dep)) | + dlaf::internal::transform(policy_hp, [](SizeType nr, auto&& sem) { sem->release(nr); }); } - if (k < nrtile - 1) { - if (rank == rank_offdiag) { - ex::start_detached(comm::scheduleSend(ex::make_unique_any_sender(comm), rank_block, - tag_offdiag, mat_a.read(index_offdiag))); + else { + if (rank == rank_diag) { + ex::start_detached(comm::scheduleSend(ex::make_unique_any_sender(comm), rank_block, tag_diag, + mat_a.read(index_diag))); + } + if (k < n - 1) { + if (rank == rank_offdiag) { + ex::start_detached(comm::scheduleSend(ex::make_unique_any_sender(comm), rank_block, + tag_offdiag, mat_a.read(index_offdiag))); + } } } } + if (rank == rank_block) { + const auto k_block_local = dist.localTileFromGlobalTile(k_block); + deps[k_block_local] = ex::ensure_started(std::move(prev_dep)); + } } } @@ -1221,149 +1215,170 @@ TridiagResult BandToTridiag::call_L( } constexpr std::size_t n_workspaces = 2; - // As the panel has tiles of size (nb x b), while it should be distributed with a row block-size - // of nb * tiles_per_block, we use a local distribution and we manage the computation of the - // local panel index with VAccessHelper. - matrix::Distribution dist_panel({dist.localSize().cols(), b}, {nb, b}); + matrix::Distribution dist_panel({size, b}, {nb_band, b}, {ranks, 1}, {rank, 0}, {0, 0}); common::RoundRobin> v_panels(n_workspaces, dist_panel); - auto init_sweep = [](std::shared_ptr> a_block, SizeType sweep, - SweepWorkerDist& worker) { worker.start_sweep(sweep, *a_block); }; - auto cont_sweep = [b](std::shared_ptr> a_block, SizeType nr_steps, - SweepWorkerDist& worker, matrix::Tile&& tile_v, - TileElementIndex index) { - for (SizeType j = 0; j < nr_steps; ++j) { - worker.compact_copy_to_tile(tile_v, index + TileElementSize(j * b, 0)); - worker.do_step(*a_block); + auto run_steps = [b](std::shared_ptr>&& a_bl, SemaphorePtr&& sem, + SemaphorePtr&& sem_next, SizeType nr_steps, bool last_step, + SweepWorkerDist& worker, const TilePtr& tile_v, SizeType j_el_tl) { + for (SizeType step = 0; step < nr_steps; ++step) { + worker.compact_copy_to_tile(*tile_v, TileElementIndex(step * b, j_el_tl)); + sem->acquire(); + worker.do_step(*a_bl); + sem_next->release(1); + } + if (last_step) { + // Make sure to unlock the last step of the next sweep + sem_next->release(1); } }; - auto policy_hp = dlaf::internal::Policy(pika::execution::thread_priority::high); - auto copy_tridiag = [policy_hp, &mat_trid](std::shared_ptr> a_block, SizeType sweep, - auto&& dep) { - auto copy_tridiag_task = [](std::shared_ptr> a_block, SizeType start, - SizeType n_d, SizeType n_e, auto tile_t) { - DLAF_ASSERT_HEAVY(n_e >= 0 && (n_e == n_d || n_e == n_d - 1), n_e, n_d); - DLAF_ASSERT_HEAVY(tile_t.size().cols() == 2, tile_t); - DLAF_ASSERT_HEAVY(tile_t.size().rows() >= n_d, tile_t, n_d); - - auto inc = a_block->ld() + 1; - if (isComplex_v) - // skip imaginary part if Complex. - inc *= 2; - - common::internal::SingleThreadedBlasScope single; - - if (auto n1 = a_block->next_split(start); n1 < n_d) { - blas::copy(n1, (BaseType*) a_block->ptr(0, start), inc, tile_t.ptr({0, 0}), 1); - blas::copy(n_d - n1, (BaseType*) a_block->ptr(0, start + n1), inc, tile_t.ptr({n1, 0}), 1); - blas::copy(n1, (BaseType*) a_block->ptr(1, start), inc, tile_t.ptr({0, 1}), 1); - blas::copy(n_e - n1, (BaseType*) a_block->ptr(1, start + n1), inc, tile_t.ptr({n1, 1}), 1); - } - else { - blas::copy(n_d, (BaseType*) a_block->ptr(0, start), inc, tile_t.ptr({0, 0}), 1); - blas::copy(n_e, (BaseType*) a_block->ptr(1, start), inc, tile_t.ptr({0, 1}), 1); - } - }; + auto copy_tridiag_task = [](std::shared_ptr>&& a_bl, SizeType start, SizeType n_d, + SizeType n_e, const matrix::Tile, Device::CPU>& tile_t) { + DLAF_ASSERT_HEAVY(n_e >= 0 && (n_e == n_d || n_e == n_d - 1), n_e, n_d); + DLAF_ASSERT_HEAVY(tile_t.size().cols() == 2, tile_t); + DLAF_ASSERT_HEAVY(tile_t.size().rows() >= n_d, tile_t, n_d); - const auto size = mat_trid.size().rows(); - const auto nb = mat_trid.blockSize().rows(); - if (sweep % nb == nb - 1 || sweep == size - 1) { - const auto tile_index = sweep / nb; - const auto start = tile_index * nb; - dlaf::internal::whenAllLift(std::move(a_block), start, std::min(nb, size - start), - std::min(nb, size - 1 - start), - mat_trid.readwrite(GlobalTileIndex{tile_index, 0}), - std::forward(dep)) | - dlaf::internal::transformDetach(policy_hp, copy_tridiag_task); + auto inc = a_bl->ld() + 1; + if (isComplex_v) + // skip imaginary part if Complex. + inc *= 2; + + common::internal::SingleThreadedBlasScope single; + + if (auto n1 = a_bl->next_split(start); n1 < n_d) { + blas::copy(n1, (BaseType*) a_bl->ptr(0, start), inc, tile_t.ptr({0, 0}), 1); + blas::copy(n_d - n1, (BaseType*) a_bl->ptr(0, start + n1), inc, tile_t.ptr({n1, 0}), 1); + blas::copy(n1, (BaseType*) a_bl->ptr(1, start), inc, tile_t.ptr({0, 1}), 1); + blas::copy(n_e - n1, (BaseType*) a_bl->ptr(1, start + n1), inc, tile_t.ptr({n1, 1}), 1); } else { - ex::start_detached(std::forward(dep)); + blas::copy(n_d, (BaseType*) a_bl->ptr(0, start), inc, tile_t.ptr({0, 0}), 1); + blas::copy(n_e, (BaseType*) a_bl->ptr(1, start), inc, tile_t.ptr({0, 1}), 1); } }; + auto init_sweep = [](std::shared_ptr>&& a_bl, SemaphorePtr&& sem, SizeType sweep, + SweepWorkerDist& worker) { + sem->acquire(); + worker.start_sweep(sweep, *a_bl); + return std::move(sem); + }; + + auto init_sweep_copy_tridiag = [copy_tridiag_task, + nb](std::shared_ptr>&& a_bl, SemaphorePtr&& sem, + SizeType sweep, SweepWorkerDist& worker, + const matrix::Tile, Device::CPU>& tile_t) { + sem->acquire(); + worker.start_sweep(sweep, *a_bl); + copy_tridiag_task(std::move(a_bl), sweep - (nb - 1), nb, nb, tile_t); + return std::move(sem); + }; + const SizeType steps_per_block = nb_band / b; - const SizeType steps_per_task = nb / b; const SizeType sweeps = nrSweeps(size); for (SizeType sweep = 0; sweep < sweeps; ++sweep) { const SizeType steps = nrStepsForSweep(sweep, size, b); - auto& v_panel = sweep % b == 0 ? v_panels.nextResource() : v_panels.currentResource(); + auto& panel_v = sweep % b == 0 ? v_panels.nextResource() : v_panels.currentResource(); - SizeType last_step = 0; + ex::any_sender<> send_col_dep; for (SizeType init_step = 0; init_step < steps; init_step += steps_per_block) { - const auto block_id = dist.globalTileIndex(GlobalElementIndex{0, init_step * b}); - const auto rank_block = dist.rankGlobalTile(block_id).col(); - const SizeType block_steps = std::min(steps_per_block, steps - init_step); + const auto id_block = dist.globalTileIndex(GlobalElementIndex{0, init_step * b}); + const auto rank_block = dist.rankGlobalTile(id_block).col(); if (prev_rank == rank_block) { const SizeType next_j = sweep + (init_step + steps_per_block) * b; if (next_j < size) { - const auto block_local_id = dist.localTileIndex(block_id + GlobalTileSize{0, 1}).col(); - auto a_block = a_ws[block_local_id]; - auto& deps_block = deps[block_local_id]; + const auto id_block_local = dist.localTileIndex(id_block + GlobalTileSize{0, 1}).col(); + auto& a_block = a_ws[id_block_local]; + auto& sem = sems[id_block_local]; + + send_col_dep = + ex::just(sem) | + dlaf::internal::transform(policy_hp, [](SemaphorePtr&& sem) { sem->acquire(); }) | + ex::split(); ex::start_detached(schedule_send_col(comm, prev_rank, - compute_col_tag(block_id.col(), next_j == size - 1), b, - a_block, next_j, deps_block[0])); + compute_col_tag(id_block.col(), next_j == size - 1), b, + a_block, next_j, send_col_dep)); } } else if (rank == rank_block) { - const auto block_local_id = dist.localTileIndex(block_id).col(); - auto a_block = a_ws[block_local_id]; - auto& w_pipeline = workers[block_local_id][sweep % workers_per_block]; - auto& deps_block = deps[block_local_id]; + const auto id_block_local = dist.localTileIndex(id_block).col(); + auto& a_block = a_ws[id_block_local]; + auto& sem = sems[id_block_local]; + auto& w_pipeline = workers[id_block_local][sweep % workers_per_block]; + auto& dep_block = deps[id_block_local]; + auto& tile_v = tiles_v[id_block_local]; + + if (sweep % b == 0) { + tile_v = panel_v.readwrite(LocalTileIndex{id_block_local, 0}) | + ex::then([](Tile&& tile) { return std::make_shared(std::move(tile)); }) | + ex::split(); + } + + ex::unique_any_sender sem_sender; + + // Index of the first column (currently) after this block (if exists). + const SizeType next_j = sweep + (init_step + steps_per_block) * b; + if (next_j < size) { + // Sweep 0: + // As copy are independent we need to make sure all that the band matrix copy is finished + // before releasing the semaphore. + // Sweeps 1, 2, ...: + // The dependency on the operation of the previous sweep is real as the Worker cannot be sent + // before dep_block gets ready, and the Worker is needed in the next rank to update the + // column before is sent here. + // Therefore sem can be released without other dependencies + ex::start_detached( + ex::when_all(ex::just(sem), + schedule_recv_col(comm, next_rank, + compute_col_tag(id_block.col(), next_j == size - 1), b, + a_block, next_j, std::move(dep_block))) | + ex::then([](SemaphorePtr&& sem) { sem->release(1); })); + } // Sweep initialization if (init_step == 0) { - auto dep = dlaf::internal::whenAllLift(a_block, sweep, w_pipeline(), deps_block[0]) | - dlaf::internal::transform(policy_hp, init_sweep); - - copy_tridiag(a_block, sweep, std::move(dep)); + if ((sweep + 1) % nb != 0) { + sem_sender = + ex::ensure_started(ex::when_all(ex::just(a_block, std::move(sem), sweep), w_pipeline()) | + dlaf::internal::transform(policy_hp, init_sweep)); + } + else { + const auto tile_index = sweep / nb; + sem_sender = + ex::ensure_started(ex::when_all(ex::just(a_block, std::move(sem), sweep), w_pipeline(), + mat_trid.readwrite(GlobalTileIndex{tile_index, 0})) | + dlaf::internal::transform(policy_hp, init_sweep_copy_tridiag)); + } } else { ex::start_detached(schedule_recv_worker(sweep, init_step, comm, prev_rank, - compute_worker_tag(sweep, block_id.col()), + compute_worker_tag(sweep, id_block.col()), w_pipeline())); - } - // Index of the first column (currently) after this block (if exists). - const SizeType next_j = sweep + (init_step + steps_per_block) * b; - if (next_j < size) { - // The dependency on the operation of the previous sweep is real as the Worker cannot be sent - // before deps_block.back() gets ready, and the Worker is needed in the next rank to update the - // column before is sent here. - deps_block.push_back(schedule_recv_col(comm, next_rank, - compute_col_tag(block_id.col(), next_j == size - 1), b, - a_block, next_j, deps_block.back()) | - ex::split()); + // SendCol already acquired once the semaphore, so no need to acquire it here. + sem_sender = ex::when_all(ex::just(std::move(sem)), std::move(send_col_dep)); } - for (SizeType block_step = 0; block_step < block_steps; block_step += steps_per_task) { - // Last task only applies the remaining steps to the block boundary - const SizeType nr_steps = std::min(steps_per_task, block_steps - block_step); - - auto dep_index = - std::min(ceilDiv(block_step + nr_steps, steps_per_task), deps_block.size() - 1); + auto sem_next = std::make_shared>(0); - const auto local_index_tile_panel_v = - VAccessHelper::indexPanel(b, init_step + block_step, dist, dist_panel); + // When doing the last step of the sweep that receive the col size-2 the extra semaphore + // release shouldn't occour, as it will be done by the receive of the last column. + bool last_step = steps - init_step <= steps_per_block && next_j != size - 2; + const SizeType block_steps = std::min(steps_per_block, steps - init_step); - deps_block[ceilDiv(block_step, steps_per_task)] = - dlaf::internal::whenAllLift(a_block, nr_steps, w_pipeline(), - v_panel.readwrite(local_index_tile_panel_v), - TileElementIndex{0, sweep % b}, deps_block[dep_index]) | - dlaf::internal::transform(policy_hp, cont_sweep) | ex::split(); - - last_step = block_step; - } - - // Shrink the dependency vector to only include the senders generated by this block in this sweep. - deps_block.resize(ceilDiv(last_step, steps_per_task) + 1); + dep_block = ex::ensure_started(ex::when_all(ex::just(a_block), std::move(sem_sender), + ex::just(sem_next, block_steps, last_step), + w_pipeline(), tile_v, ex::just(sweep % b)) | + dlaf::internal::transform(policy_hp, run_steps)); + sem = std::move(sem_next); if (init_step + block_steps < steps) { ex::start_detached(schedule_send_worker( - comm, next_rank, compute_worker_tag(sweep, block_id.col() + 1), w_pipeline())); + comm, next_rank, compute_worker_tag(sweep, id_block.col() + 1), w_pipeline())); } } } @@ -1373,66 +1388,65 @@ TridiagResult BandToTridiag::call_L( const SizeType base_sweep_steps = nrStepsForSweep(base_sweep, size, b); for (SizeType init_step = 0; init_step < base_sweep_steps; init_step += steps_per_block) { - const SizeType base_sweep_block_steps = std::min(steps_per_block, base_sweep_steps - init_step); - - for (SizeType block_step = 0; block_step < base_sweep_block_steps; - block_step += steps_per_task) { - VAccessHelper helper(grid, b, sweeps, base_sweep, init_step + block_step, dist, dist_panel, - dist_v); - - if (rank == helper.rankPanel()) { - auto copy_or_send = - [&comm, rank, &v_panel, &mat_v, - &compute_v_tag](const LocalTileIndex index_panel, const matrix::SubTileSpec spec_panel, - const comm::IndexT_MPI rank_v, const GlobalTileIndex index_v, - const matrix::SubTileSpec spec_v, const bool bottom) { - auto tile_v_panel = splitTile(v_panel.read(index_panel), spec_panel); - if (rank == rank_v) { - auto tile_v = splitTile(mat_v.readwrite(index_v), spec_v); - ex::start_detached(ex::when_all(std::move(tile_v_panel), std::move(tile_v)) | - copy(Policy>{})); - } - else { - ex::start_detached(comm::scheduleSend(ex::make_unique_any_sender(comm), rank_v, - compute_v_tag(index_v.row(), bottom), - std::move(tile_v_panel))); - } - }; - - copy_or_send(helper.indexPanel(), helper.specPanelTop(), helper.rankVTop(), - helper.indexVTop(), helper.specVTop(), false); - if (helper.copyIsSplitted()) { - copy_or_send(helper.indexPanel(), helper.specPanelBottom(), helper.rankVBottom(), - helper.indexVBottom(), helper.specVBottom(), true); - } + VAccessHelper helper(grid, sweeps, base_sweep, init_step, dist_panel, dist_v); + + const auto id_panel = helper.index_panel(); + const auto rank_panel = helper.rank_panel(); + + if (rank == rank_panel) { + const auto id_panel_local = dist_panel.localTileIndex(id_panel); + tiles_v[id_panel_local.row()] = ex::any_sender{}; + + auto copy_or_send = + [&comm, rank, &panel_v, &mat_v, + &compute_v_tag](const LocalTileIndex index_panel, + const TileElementIndex spec_panel_origin, const TileElementSize spec_size, + const comm::IndexT_MPI rank_v, const GlobalTileIndex index_v, + const TileElementIndex spec_v_origin, const bool bottom) { + auto tile_v_panel = splitTile(panel_v.read(index_panel), {spec_panel_origin, spec_size}); + if (rank == rank_v) { + auto tile_v = splitTile(mat_v.readwrite(index_v), {spec_v_origin, spec_size}); + ex::start_detached(ex::when_all(std::move(tile_v_panel), std::move(tile_v)) | + copy(Policy>{})); + } + else { + ex::start_detached(comm::scheduleSend(ex::make_unique_any_sender(comm), rank_v, + compute_v_tag(index_v.row(), bottom), + std::move(tile_v_panel))); + } + }; + + for (SizeType i = 0; i < helper.nr_tiles(); ++i) { + copy_or_send(id_panel_local, helper.spec_panel_origin(i), helper.spec_size(i), + helper.rank_v(i), helper.index_v(i), helper.spec_v_origin(i), (i != 0)); } - else { - auto recv = [&comm, rank, &dist_v, &mat_v, - &compute_v_tag](const comm::IndexT_MPI rank_panel, - const comm::IndexT_MPI rank_v, const GlobalTileIndex index_v, - const matrix::SubTileSpec spec_v, const bool bottom) { - if (rank == rank_v) { - auto tile_v = splitTile(mat_v.readwrite(index_v), spec_v); - auto local_index_v = dist_v.localTileIndex(index_v); - - ex::any_sender<> dep; - if (local_index_v.col() == 0) - dep = ex::just(); - else - dep = ex::drop_value(mat_v.read(local_index_v - LocalTileSize{0, 1})); - - ex::start_detached(comm::scheduleRecv( - ex::make_unique_any_sender(comm), rank_panel, compute_v_tag(index_v.row(), bottom), - matrix::ReadWriteTileSender(ex::when_all(std::move(tile_v), - std::move(dep))))); - } - }; - - recv(helper.rankPanel(), helper.rankVTop(), helper.indexVTop(), helper.specVTop(), false); - if (helper.copyIsSplitted()) { - recv(helper.rankPanel(), helper.rankVBottom(), helper.indexVBottom(), helper.specVBottom(), - true); + } + else { + auto recv = [&comm, rank, &dist_v, &mat_v, + &compute_v_tag](const comm::IndexT_MPI rank_panel, const comm::IndexT_MPI rank_v, + const GlobalTileIndex index_v, + const TileElementIndex spec_v_origin, + const TileElementSize spec_size, const bool bottom) { + if (rank == rank_v) { + auto tile_v = splitTile(mat_v.readwrite(index_v), {spec_v_origin, spec_size}); + auto local_index_v = dist_v.localTileIndex(index_v); + + ex::any_sender<> dep; + if (local_index_v.col() == 0) + dep = ex::just(); + else + dep = ex::drop_value(mat_v.read(local_index_v - LocalTileSize{0, 1})); + + ex::start_detached(comm::scheduleRecv( + ex::make_unique_any_sender(comm), rank_panel, compute_v_tag(index_v.row(), bottom), + matrix::ReadWriteTileSender(ex::when_all(std::move(tile_v), + std::move(dep))))); } + }; + + for (SizeType i = 0; i < helper.nr_tiles(); ++i) { + recv(rank_panel, helper.rank_v(i), helper.index_v(i), helper.spec_v_origin(i), + helper.spec_size(i), i != 0); } } } @@ -1441,15 +1455,30 @@ TridiagResult BandToTridiag::call_L( // Rank 0 (owner of the first band matrix block) copies the last parts of the tridiag matrix. if (rank == 0) { + auto copy_tridiag = [policy_hp, size, nb, &mat_trid, ©_tridiag_task]( + std::shared_ptr> a_block, SizeType i, auto&& dep) { + const auto tile_index = (i - 1) / nb; + const auto start = tile_index * nb; + ex::when_all(ex::just(std::move(a_block), start, std::min(nb, size - start), + std::min(nb, size - 1 - start)), + mat_trid.readwrite(GlobalTileIndex{tile_index, 0}), + std::forward(dep)) | + dlaf::internal::transformDetach(policy_hp, copy_tridiag_task); + }; + + auto dep = ex::just(std::move(sems[0])) | + dlaf::internal::transform(policy_hp, [](SemaphorePtr&& sem) { sem->acquire(); }) | + ex::split(); + // copy the last elements of the diagonals if constexpr (!isComplex_v) { // only needed for real types as they don't perform sweep size-2 - copy_tridiag(a_ws[0], size - 2, deps[0][0]); + copy_tridiag(a_ws[0], size - 1, dep); } - copy_tridiag(a_ws[0], size - 1, std::move(deps[0][0])); + copy_tridiag(a_ws[0], size, std::move(dep)); } - // only rank0 has mat_trid -> bcast to everyone. + // only Rank 0 has mat_trid -> bcast to everyone. for (const auto& index : iterate_range2d(mat_trid.nrTiles())) { if (rank == 0) ex::start_detached(comm::scheduleSendBcast(comm_bcast(), mat_trid.read(index)));