Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove additional copies from tridiagonal eigensolver #819

Merged
merged 3 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 36 additions & 38 deletions include/dlaf/eigensolver/tridiag_solver/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,25 +194,21 @@ void TridiagSolver<B, D, T>::call(Matrix<T, Device::CPU>& tridiag, Matrix<T, D>&
const matrix::Distribution& distr = evecs.distribution();
LocalElementSize vec_size(distr.size().rows(), 1);
TileElementSize vec_tile_size(distr.blockSize().rows(), 1);
WorkSpace<T, D> ws{Matrix<T, D>(distr), // mat1
Matrix<T, D>(distr), // mat2
Matrix<T, D>(vec_size, vec_tile_size), // dtmp
Matrix<T, D>(vec_size, vec_tile_size), // z
Matrix<T, D>(vec_size, vec_tile_size), // ztmp
Matrix<SizeType, D>(vec_size, vec_tile_size), // i1
Matrix<SizeType, D>(vec_size, vec_tile_size), // i2
Matrix<SizeType, D>(vec_size, vec_tile_size), // i3
Matrix<ColType, D>(vec_size, vec_tile_size)}; // c
WorkSpace<T, D> ws{Matrix<T, D>(distr), // mat1
Matrix<T, D>(distr), // mat2
Matrix<T, D>(vec_size, vec_tile_size), // z
Matrix<T, D>(vec_size, vec_tile_size), // ztmp
Matrix<SizeType, D>(vec_size, vec_tile_size)}; // i2

// Mirror workspace on host memory for CPU-only kernels
WorkSpaceHostMirror<T, D> ws_h{initMirrorMatrix(evals), initMirrorMatrix(ws.mat1),
initMirrorMatrix(ws.dtmp), initMirrorMatrix(ws.z),
initMirrorMatrix(ws.ztmp), initMirrorMatrix(ws.i2),
initMirrorMatrix(ws.c),
// TODO: Not needed: for local version (appease warning)
initMirrorMatrix(evecs), initMirrorMatrix(ws.mat2)
WorkSpaceHost<T> ws_h{Matrix<T, Device::CPU>(vec_size, vec_tile_size), // dtmp
Matrix<SizeType, Device::CPU>(vec_size, vec_tile_size), // i1
Matrix<SizeType, Device::CPU>(vec_size, vec_tile_size), // i3
Matrix<ColType, Device::CPU>(vec_size, vec_tile_size)}; // c

};
// Mirror workspace on host memory for CPU-only kernels
WorkSpaceHostMirror<T, D> ws_hm{initMirrorMatrix(evals), initMirrorMatrix(ws.mat1),
initMirrorMatrix(ws.z), initMirrorMatrix(ws.ztmp),
initMirrorMatrix(ws.i2)};

// Set `evecs` to `zero` (needed for Given's rotation to make sure no random values are picked up)
matrix::util::set0<B, T, D>(pika::execution::thread_priority::normal, evecs);
Expand All @@ -226,17 +222,19 @@ void TridiagSolver<B, D, T>::call(Matrix<T, Device::CPU>& tridiag, Matrix<T, D>&
solveLeaf(tridiag, evecs);
}
else {
solveLeaf(tridiag, evecs, ws_h.evecs);
solveLeaf(tridiag, evecs, ws_hm.mat1);
}

// Offload the diagonal from `tridiag` to `evals`
offloadDiagonal(tridiag, evals);
offloadDiagonal(tridiag, ws_hm.evals);

// Each triad represents two subproblems to be merged
for (auto [i_begin, i_split, i_end] : generateSubproblemIndices(distr.nrTiles().rows())) {
mergeSubproblems<B>(i_begin, i_split, i_end, offdiag_vals[to_sizet(i_split)], ws, ws_h, evals,
mergeSubproblems<B>(i_begin, i_split, i_end, offdiag_vals[to_sizet(i_split)], ws, ws_h, ws_hm, evals,
evecs);
}

copy({0, 0}, evals.distribution().localNrTiles(), ws_hm.evals, evals);
}

// Overload which provides the eigenvector matrix as complex values where the imaginery part is set to zero.
Expand Down Expand Up @@ -331,22 +329,20 @@ void TridiagSolver<B, D, T>::call(comm::CommunicatorGrid grid, Matrix<T, Device:
const matrix::Distribution& dist_evecs = evecs.distribution();
const matrix::Distribution& dist_evals = evals.distribution();

WorkSpace<T, D> ws{Matrix<T, D>(dist_evecs), // mat1
Matrix<T, D>(dist_evecs), // mat2
Matrix<T, D>(dist_evals), // dtmp
Matrix<T, D>(dist_evals), // z
Matrix<T, D>(dist_evals), // ztmp
Matrix<SizeType, D>(dist_evals), // i1
Matrix<SizeType, D>(dist_evals), // i2
Matrix<SizeType, D>(dist_evals), // i3
Matrix<ColType, D>(dist_evals)}; // c
DistWorkSpace<T, D> ws{Matrix<T, D>(dist_evecs), // mat1
Matrix<T, D>(dist_evecs), // mat2
Matrix<T, D>(dist_evals), // z
Matrix<T, D>(dist_evals)}; // ztmp

// Mirror workspace on host memory for CPU-only kernels
WorkSpaceHostMirror<T, D> ws_h{initMirrorMatrix(evals), initMirrorMatrix(ws.mat1),
initMirrorMatrix(ws.dtmp), initMirrorMatrix(ws.z),
initMirrorMatrix(ws.ztmp), initMirrorMatrix(ws.i2),
initMirrorMatrix(ws.c), initMirrorMatrix(evecs),
initMirrorMatrix(ws.mat2)};
DistWorkSpaceHost<T> ws_h{Matrix<T, Device::CPU>(dist_evals), // dtmp
Matrix<SizeType, Device::CPU>(dist_evals), // i1
Matrix<SizeType, Device::CPU>(dist_evals), // i2
Matrix<SizeType, Device::CPU>(dist_evals), // i3
Matrix<ColType, Device::CPU>(dist_evals)}; // c

DistWorkSpaceHostMirror<T, D> ws_hm{initMirrorMatrix(evals), initMirrorMatrix(evecs),
initMirrorMatrix(ws.mat1), initMirrorMatrix(ws.mat2),
initMirrorMatrix(ws.z), initMirrorMatrix(ws.ztmp)};

// Set `evecs` to `zero` (needed for Given's rotation to make sure no random values are picked up)
matrix::util::set0<B, T, D>(pika::execution::thread_priority::normal, evecs);
Expand All @@ -364,18 +360,20 @@ void TridiagSolver<B, D, T>::call(comm::CommunicatorGrid grid, Matrix<T, Device:
solveDistLeaf(grid, full_task_chain, tridiag, evecs);
}
else {
solveDistLeaf(grid, full_task_chain, tridiag, evecs, ws_h.evecs);
solveDistLeaf(grid, full_task_chain, tridiag, evecs, ws_hm.evecs);
}

// Offload the diagonal from `tridiag` to `evals`
offloadDiagonal(tridiag, evals);
offloadDiagonal(tridiag, ws_hm.evals);

// Each triad represents two subproblems to be merged
SizeType nrtiles = dist_evecs.nrTiles().rows();
for (auto [i_begin, i_split, i_end] : generateSubproblemIndices(nrtiles)) {
mergeDistSubproblems<B>(grid, full_task_chain, row_task_chain, col_task_chain, i_begin, i_split,
i_end, offdiag_vals[to_sizet(i_split)], ws, ws_h, evals, evecs);
i_end, offdiag_vals[to_sizet(i_split)], ws, ws_h, ws_hm, evals, evecs);
}

copy({0, 0}, evals.distribution().localNrTiles(), ws_hm.evals, evals);
}

// \overload TridiagSolver<B, D, T>::call()
Expand Down
Loading