Skip to content

Commit

Permalink
#26: fix compilation errors after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
cwschilly committed Jul 31, 2023
1 parent 931871e commit 1429716
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 59 deletions.
4 changes: 2 additions & 2 deletions packages/zoltan2/core/src/input/Zoltan2_Adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ template <typename User>
/*! \brief Provide a Kokkos view (Host side) of the weights.
\param hostWgts on return a Kokkos view of all the weights
*/
virtual void getWeightsHostView(ConstWeightsHostView1D& hostWgts, int idx = 0) const {
virtual void getWeightsHostView(WeightsHostView& hostWgts) const {
Z2_THROW_NOT_IMPLEMENTED
}

Expand All @@ -274,7 +274,7 @@ template <typename User>
/*! \brief Provide a Kokkos view (Device side) of the weights.
\param deviceWgts on return a Kokkos view of all the weights
*/
virtual void getWeightsDeviceView(ConstWeightsDeviceView1D& deviceWgts, int idx = 0) const {
virtual void getWeightsDeviceView(WeightsDeviceView& deviceWgts) const {
Z2_THROW_NOT_IMPLEMENTED
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ template <typename User>
lno_t localNumIDs_;
const gno_t *idList_;
ArrayRCP<StridedData<lno_t, scalar_t> > weights_;
size_t numWeightsPerID_;
size_t numWeightsPerID_ = 0;

Kokkos::View<gno_t *, device_t> idsView_;
Kokkos::View<scalar_t **, device_t> weightsView_;
Expand Down
56 changes: 34 additions & 22 deletions packages/zoltan2/core/src/input/Zoltan2_MatrixAdapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ template <typename User, typename UserCoord=User>
using user_t = User;
using userCoord_t = UserCoord;
using base_adapter_t = MatrixAdapter<User,UserCoord>;
using Base = AdapterWithCoordsWrapper<User, UserCoord>;
using device_t = typename node_t::device_type;
#endif

enum BaseAdapterType adapterType() const override {return MatrixAdapterType;}
Expand Down Expand Up @@ -160,12 +162,12 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getRowIDsHostView(typename BaseAdapter<User>::ConstIdsHostView& rowIds) const
virtual void getRowIDsHostView(typename Base::ConstIdsHostView& rowIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getRowIDsDeviceView(typename BaseAdapter<User>::ConstIdsDeviceView& rowIds) const
virtual void getRowIDsDeviceView(typename Base::ConstIdsDeviceView& rowIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}
Expand All @@ -189,14 +191,14 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getCRSHostView(typename BaseAdapter<User>::ConstOffsetsHostView& offsets,
typename BaseAdapter<User>::ConstIdsHostView& colIds) const
virtual void getCRSHostView(typename Base::ConstOffsetsHostView& offsets,
typename Base::ConstIdsHostView& colIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getCRSDeviceView(typename BaseAdapter<User>::ConstOffsetsDeviceView& offsets,
typename BaseAdapter<User>::ConstIdsDeviceView& colIds) const
virtual void getCRSDeviceView(typename Base::ConstOffsetsDeviceView& offsets,
typename Base::ConstIdsDeviceView& colIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -227,16 +229,16 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getCRSHostView(typename BaseAdapter<User>::ConstOffsetsHostView& offsets,
typename BaseAdapter<User>::ConstIdsHostView& colIds,
typename BaseAdapter<User>::ConstScalarsHostView& values) const
virtual void getCRSHostView(typename Base::ConstOffsetsHostView& offsets,
typename Base::ConstIdsHostView& colIds,
typename Base::ConstScalarsHostView& values) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getCRSDeviceView(typename BaseAdapter<User>::ConstOffsetsDeviceView& offsets,
typename BaseAdapter<User>::ConstIdsDeviceView& colIds,
typename BaseAdapter<User>::ConstScalarsDeviceView& values) const
virtual void getCRSDeviceView(typename Base::ConstOffsetsDeviceView& offsets,
typename Base::ConstIdsDeviceView& colIds,
typename Base::ConstScalarsDeviceView& values) const
{
Z2_THROW_NOT_IMPLEMENTED
}
Expand All @@ -262,18 +264,28 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getRowWeightsHostView(typename BaseAdapter<User>::ConstWeightsHostView1D& weights,
virtual void getRowWeightsHostView(typename Base::WeightsHostView1D& weights,
int /* idx */ = 0) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getRowWeightsDeviceView(typename BaseAdapter<User>::ConstWeightsDeviceView1D& weights,
virtual void
getRowWeightsHostView(typename Base::WeightsHostView &weights) const {
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getRowWeightsDeviceView(typename Base::WeightsDeviceView1D& weights,
int /* idx */ = 0) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void
getRowWeightsDeviceView(typename Base::WeightsDeviceView &weights) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Indicate whether row weight with index idx should be the
* global number of nonzeros in the row.
*/
Expand All @@ -298,12 +310,12 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getColumnIDsHostView(typename BaseAdapter<User>::ConstIdsHostView& colIds) const
virtual void getColumnIDsHostView(typename Base::ConstIdsHostView& colIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getColumnIDsDeviceView(typename BaseAdapter<User>::ConstIdsDeviceView& colIds) const
virtual void getColumnIDsDeviceView(typename Base::ConstIdsDeviceView& colIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -373,12 +385,12 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getColumnWeightsHostView(typename BaseAdapter<User>::ConstWeightsHostView1D& weights) const
virtual void getColumnWeightsHostView(typename Base::WeightsHostView1D& weights) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getColumnWeightsDeviceView(typename BaseAdapter<User>::ConstWeightsDeviceView1D& weights) const
virtual void getColumnWeightsDeviceView(typename Base::WeightsDeviceView1D& weights) const
{
Z2_THROW_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -497,7 +509,7 @@ template <typename User, typename UserCoord=User>
}
}

void getIDsHostView(typename BaseAdapter<User>::ConstIdsHostView& ids) const override {
void getIDsHostView(typename Base::ConstIdsHostView& ids) const override {
switch (getPrimaryEntityType()) {
case MATRIX_ROW:
getRowIDsHostView(ids);
Expand All @@ -519,7 +531,7 @@ template <typename User, typename UserCoord=User>
}
}

void getIDsDeviceView(typename BaseAdapter<User>::ConstIdsDeviceView& ids) const override {
void getIDsDeviceView(typename Base::ConstIdsDeviceView& ids) const override {
switch (getPrimaryEntityType()) {
case MATRIX_ROW:
getRowIDsDeviceView(ids);
Expand Down Expand Up @@ -580,7 +592,7 @@ template <typename User, typename UserCoord=User>
}
}

virtual void getWeightsHostView(typename BaseAdapter<User>::ConstWeightsHostView1D &hostWgts,
void getWeightsHostView(typename Base::WeightsHostView1D &hostWgts,
int idx = 0) const override {
switch (getPrimaryEntityType()) {
case MATRIX_ROW:
Expand All @@ -603,7 +615,7 @@ template <typename User, typename UserCoord=User>
break;
} }

virtual void getWeightsDeviceView(typename BaseAdapter<User>::ConstWeightsDeviceView1D& deviceWgts,
void getWeightsDeviceView(typename Base::WeightsDeviceView1D& deviceWgts,
int idx = 0) const override {
switch (getPrimaryEntityType()) {
case MATRIX_ROW:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ template <typename User, typename UserCoord = User>

if (this->nWeightsPerRow_ > 0) {

this->rowWeightsDevice_.resize(this->nWeightsPerRow_);
this->rowWeightsDevice_ = typename Base::WeightsDeviceView(
"rowWeightsDevice_", inmatrix->getLocalNumRows(),
this->nWeightsPerRow_);

this->numNzWeight_ = Kokkos::View<bool *, host_t>(
"numNzWeight_", this->nWeightsPerRow_);
Expand Down
67 changes: 52 additions & 15 deletions packages/zoltan2/core/src/input/Zoltan2_TpetraRowMatrixAdapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,17 @@ class TpetraRowMatrixAdapter : public MatrixAdapter<User, UserCoord> {
void getRowWeightsView(const scalar_t *&weights, int &stride,
int idx = 0) const;

void getRowWeightsDeviceView(typename Base::ConstWeightsDeviceView1D &weights,
int idx) const;
void getRowWeightsDeviceView(typename Base::WeightsDeviceView1D &weights,
int idx = 0) const;

void getRowWeightsHostView(typename Base::ConstWeightsHostView1D &weights,
int idx) const;
void getRowWeightsDeviceView(
typename Base::WeightsDeviceView &weights) const override;

void getRowWeightsHostView(typename Base::WeightsHostView1D &weights,
int idx = 0) const;

void getRowWeightsHostView(
typename Base::WeightsHostView &weights) const override;

bool useNumNonzerosAsRowWeight(int idx) const;

Expand Down Expand Up @@ -290,7 +296,7 @@ class TpetraRowMatrixAdapter : public MatrixAdapter<User, UserCoord> {

int nWeightsPerRow_;
ArrayRCP<StridedData<lno_t, scalar_t>> rowWeights_;
std::vector<typename Base::ConstWeightsDeviceView1D> rowWeightsDevice_;
typename Base::WeightsDeviceView rowWeightsDevice_;
Kokkos::View<bool *, host_t> numNzWeight_;

bool mayHaveDiagonalEntries;
Expand Down Expand Up @@ -350,7 +356,8 @@ TpetraRowMatrixAdapter<User, UserCoord>::TpetraRowMatrixAdapter(
rowWeights_ =
arcp(new strided_t[nWeightsPerRow_], 0, nWeightsPerRow_, true);

rowWeightsDevice_.resize(nWeightsPerRow_);
rowWeightsDevice_ = typename Base::WeightsDeviceView(
"rowWeightsDevice_", nrows, nWeightsPerRow_);

numNzWeight_ = Kokkos::View<bool *, host_t>(
"numNzWeight_", nWeightsPerRow_);
Expand Down Expand Up @@ -430,7 +437,12 @@ void TpetraRowMatrixAdapter<User, UserCoord>::setRowWeightsDevice(
AssertCondition((idx >= 0) and (idx < nWeightsPerRow_),
"Invalid row weight index: " + std::to_string(idx));

rowWeightsDevice_[idx] = weights;
const auto nrows = getLocalNumRows();
auto weightsSub = Kokkos::subview(rowWeightsDevice_, Kokkos::ALL, idx);
Kokkos::parallel_for(
nrows, KOKKOS_LAMBDA(const int rowID) {
weightsSub(rowID) = weights(rowID);
});
}

////////////////////////////////////////////////////////////////////////////
Expand All @@ -442,7 +454,13 @@ void TpetraRowMatrixAdapter<User, UserCoord>::setRowWeightsHost(

auto weightsDevice = Kokkos::create_mirror_view_and_copy(
typename Base::device_t(), weightsHost);
rowWeightsDevice_[idx] = weightsDevice;

const auto nrows = getLocalNumRows();
Kokkos::parallel_for(
Kokkos::RangePolicy<Kokkos::HostSpace::execution_space>(0, nrows),
KOKKOS_LAMBDA(const int rowID) {
rowWeightsDevice_(rowID, idx) = weightsDevice(rowID);
});
}

////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -620,24 +638,43 @@ void TpetraRowMatrixAdapter<User, UserCoord>::getRowWeightsView(const scalar_t *

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowMatrixAdapter<User, UserCoord>::getRowWeightsDeviceView(typename Base::ConstWeightsDeviceView1D &weights,
int idx) const {
void TpetraRowMatrixAdapter<User, UserCoord>::getRowWeightsDeviceView(
typename Base::WeightsDeviceView1D &weights, int idx) const {
AssertCondition((idx >= 0) and (idx < nWeightsPerRow_),
"Invalid row weight index.");
weights = rowWeightsDevice_.at(idx);

weights = Kokkos::subview(rowWeightsDevice_, Kokkos::ALL, idx);
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowMatrixAdapter<User, UserCoord>::getRowWeightsDeviceView(
typename Base::WeightsDeviceView &weights) const {

weights = rowWeightsDevice_;
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowMatrixAdapter<User, UserCoord>::getRowWeightsHostView(typename Base::ConstWeightsHostView1D &weights,
int idx) const {
void TpetraRowMatrixAdapter<User, UserCoord>::getRowWeightsHostView(
typename Base::WeightsHostView1D &weights, int idx) const {
AssertCondition((idx >= 0) and (idx < nWeightsPerRow_),
"Invalid row index.");
const auto weightsDevice = rowWeightsDevice_.at(idx);
"Invalid row weight index.");

auto weightsDevice = Kokkos::subview(rowWeightsDevice_, Kokkos::ALL, idx);
weights = Kokkos::create_mirror_view(weightsDevice);
Kokkos::deep_copy(weights, weightsDevice);
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowMatrixAdapter<User, UserCoord>::getRowWeightsHostView(
typename Base::WeightsHostView &weights) const {

weights = Kokkos::create_mirror_view(rowWeightsDevice_);
Kokkos::deep_copy(weights, rowWeightsDevice_);
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
bool TpetraRowMatrixAdapter<User, UserCoord>::useNumNonzerosAsRowWeight(int idx) const { return numNzWeight_[idx]; }
Expand Down
8 changes: 4 additions & 4 deletions packages/zoltan2/test/core/unit/input/MatrixAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,10 @@ int main(int narg, char *arg[]) {
// TEST of getRowWeightsView, getRowWeightsHost0View and
// getRowWeightsDeviceView

Zoltan2::BaseAdapter<trowMatrix_t>::ConstWeightsHostView1D weightsHost0;
Zoltan2::BaseAdapter<trowMatrix_t>::ConstWeightsDeviceView1D weightsDevice0;
Zoltan2::BaseAdapter<trowMatrix_t>::ConstWeightsHostView1D weightsHost1;
Zoltan2::BaseAdapter<trowMatrix_t>::ConstWeightsDeviceView1D weightsDevice1;
Zoltan2::BaseAdapter<trowMatrix_t>::WeightsHostView1D weightsHost0;
Zoltan2::BaseAdapter<trowMatrix_t>::WeightsDeviceView1D weightsDevice0;
Zoltan2::BaseAdapter<trowMatrix_t>::WeightsHostView1D weightsHost1;
Zoltan2::BaseAdapter<trowMatrix_t>::WeightsDeviceView1D weightsDevice1;

tmi.getRowWeightsHostView(weightsHost0, 0);
tmi.getRowWeightsDeviceView(weightsDevice0, 0);
Expand Down
14 changes: 7 additions & 7 deletions packages/zoltan2/test/core/unit/input/TpetraCrsMatrixInput.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ void verifyInputAdapter(adapter_t &ia, matrix_t &matrix) {
//// setRowWeightsDevice
/////////////////////////////////
Z2_TEST_THROW(ia.setRowWeightsDevice(
typename adapter_t::ConstWeightsDeviceView1D{}, 50),
typename adapter_t::WeightsDeviceView1D{}, 50),
std::runtime_error);

weightsDevice_t wgts0("wgts0", nrows);
Expand All @@ -203,33 +203,33 @@ void verifyInputAdapter(adapter_t &ia, matrix_t &matrix) {
//// getRowWeightsDevice
/////////////////////////////////
{
constWeightsDevice_t weightsDevice;
weightsDevice_t weightsDevice;
Z2_TEST_NOTHROW(ia.getRowWeightsDeviceView(weightsDevice, 0));

constWeightsHost_t weightsHost;
weightsHost_t weightsHost;
Z2_TEST_NOTHROW(ia.getRowWeightsHostView(weightsHost, 0));

TestDeviceHostView(weightsDevice, weightsHost);

TestDeviceHostView(wgts0, weightsHost);
}
{
constWeightsDevice_t weightsDevice;
weightsDevice_t weightsDevice;
Z2_TEST_NOTHROW(ia.getRowWeightsDeviceView(weightsDevice, 1));

constWeightsHost_t weightsHost;
weightsHost_t weightsHost;
Z2_TEST_NOTHROW(ia.getRowWeightsHostView(weightsHost, 1));

TestDeviceHostView(weightsDevice, weightsHost);

TestDeviceHostView(wgts1, weightsHost);
}
{
constWeightsDevice_t wgtsDevice;
weightsDevice_t wgtsDevice;
Z2_TEST_THROW(ia.getRowWeightsDeviceView(wgtsDevice, 2),
std::runtime_error);

constWeightsHost_t wgtsHost;
weightsHost_t wgtsHost;
Z2_TEST_THROW(ia.getRowWeightsHostView(wgtsHost, 2), std::runtime_error);
}

Expand Down
Loading

0 comments on commit 1429716

Please sign in to comment.