Skip to content

Commit

Permalink
fix priorities (#999)
Browse files Browse the repository at this point in the history
Fix Cholesky priorities
  • Loading branch information
rasolca authored Oct 2, 2023
1 parent 998c7bd commit fbab08c
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions include/dlaf/factorization/cholesky/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void Cholesky<backend, device, T>::call_L(Matrix<T, device>& mat_a) {
// Cholesky decomposition on mat_a.readwrite(k,k) r/w potrf (lapack operation)
auto kk = LocalTileIndex{k, k};

potrfDiagTile<backend>(thread_priority::normal, mat_a.readwrite(kk));
potrfDiagTile<backend>(thread_priority::high, mat_a.readwrite(kk));

for (SizeType i = k + 1; i < nrtile; ++i) {
// Update panel mat_a.readwrite(i,k) with trsm (blas operation), using data mat_a.read(k,k)
Expand All @@ -160,7 +160,7 @@ void Cholesky<backend, device, T>::call_L(Matrix<T, device>& mat_a) {
for (SizeType i = j + 1; i < nrtile; ++i) {
// Update remaining trailing matrix mat_a.readwrite(i,j), reading
// mat_a.read(i,k) and mat_a.read(j,k), using gemm (blas operation)
gemmTrailingMatrixTile<backend>(thread_priority::normal, mat_a.read(LocalTileIndex{i, k}),
gemmTrailingMatrixTile<backend>(trailing_matrix_priority, mat_a.read(LocalTileIndex{i, k}),
mat_a.read(LocalTileIndex{j, k}),
mat_a.readwrite(LocalTileIndex{i, j}));
}
Expand Down Expand Up @@ -193,7 +193,7 @@ void Cholesky<backend, device, T>::call_L(comm::CommunicatorGrid grid, Matrix<T,

// Factorization of diagonal tile and broadcast it along the k-th column
if (kk_rank == this_rank)
potrfDiagTile<backend>(thread_priority::normal, mat_a.readwrite(kk_idx));
potrfDiagTile<backend>(thread_priority::high, mat_a.readwrite(kk_idx));

// If there is no trailing matrix
const SizeType kt = k + 1;
Expand Down Expand Up @@ -260,9 +260,7 @@ void Cholesky<backend, device, T>::call_L(comm::CommunicatorGrid grid, Matrix<T,
continue;

const auto i = distr.localTileFromGlobalTile<Coord::Row>(i_idx);
// TODO: This was using executor_np. Was that intentional, or should it
// be trailing_matrix_executor/priority?
gemmTrailingMatrixTile<backend>(thread_priority::normal, panel.read({Coord::Row, i}),
gemmTrailingMatrixTile<backend>(trailing_matrix_priority, panel.read({Coord::Row, i}),
panelT.read({Coord::Col, j}),
mat_a.readwrite(LocalTileIndex{i, j}));
}
Expand All @@ -285,7 +283,7 @@ void Cholesky<backend, device, T>::call_U(Matrix<T, device>& mat_a) {
for (SizeType k = 0; k < nrtile; ++k) {
auto kk = LocalTileIndex{k, k};

potrfDiagTile<backend>(thread_priority::normal, mat_a.readwrite(kk));
potrfDiagTile<backend>(thread_priority::high, mat_a.readwrite(kk));

for (SizeType j = k + 1; j < nrtile; ++j) {
trsmPanelTile<backend>(thread_priority::high, mat_a.read(kk),
Expand All @@ -300,7 +298,7 @@ void Cholesky<backend, device, T>::call_U(Matrix<T, device>& mat_a) {
mat_a.readwrite(LocalTileIndex{i, i}));

for (SizeType j = i + 1; j < nrtile; ++j) {
gemmTrailingMatrixTile<backend>(thread_priority::normal, mat_a.read(LocalTileIndex{k, i}),
gemmTrailingMatrixTile<backend>(trailing_matrix_priority, mat_a.read(LocalTileIndex{k, i}),
mat_a.read(LocalTileIndex{k, j}),
mat_a.readwrite(LocalTileIndex{i, j}));
}
Expand Down Expand Up @@ -333,7 +331,7 @@ void Cholesky<backend, device, T>::call_U(comm::CommunicatorGrid grid, Matrix<T,

// Factorization of diagonal tile and broadcast it along the k-th column
if (kk_rank == this_rank) {
potrfDiagTile<backend>(thread_priority::normal, mat_a.readwrite(kk_idx));
potrfDiagTile<backend>(thread_priority::high, mat_a.readwrite(kk_idx));
}

// If there is no trailing matrix
Expand Down Expand Up @@ -401,7 +399,7 @@ void Cholesky<backend, device, T>::call_U(comm::CommunicatorGrid grid, Matrix<T,

const auto j = distr.localTileFromGlobalTile<Coord::Col>(j_idx);

gemmTrailingMatrixTile<backend>(thread_priority::normal, panelT.read({Coord::Row, i}),
gemmTrailingMatrixTile<backend>(trailing_matrix_priority, panelT.read({Coord::Row, i}),
panel.read({Coord::Col, j}),
mat_a.readwrite(LocalTileIndex{i, j}));
}
Expand Down

0 comments on commit fbab08c

Please sign in to comment.