Skip to content

Commit

Permalink
Merge pull request #49 from NexGenAnalytics/26-zoltan2-matrixadapter-…
Browse files Browse the repository at this point in the history
…refactor-api

#26: Zoltan2: MatrixAdapter refactor API
  • Loading branch information
stmcgovern authored Aug 1, 2023
2 parents 35b8f66 + adbbb49 commit a24bce4
Show file tree
Hide file tree
Showing 11 changed files with 1,237 additions and 425 deletions.
4 changes: 2 additions & 2 deletions packages/zoltan2/core/src/input/Zoltan2_Adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ template <typename User>
using ConstScalarsDeviceView = Kokkos::View<const scalar_t *, device_t>;
using ConstScalarsHostView = typename ConstScalarsDeviceView::HostMirror;

using scalarsDeviceView = Kokkos::View<scalar_t *, device_t>;
using scalarsHostView = typename scalarsDeviceView::HostMirror;
using ScalarsDeviceView = Kokkos::View<scalar_t *, device_t>;
using ScalarsHostView = typename ScalarsDeviceView::HostMirror;

using ConstWeightsDeviceView1D = Kokkos::View<const scalar_t *, device_t>;
using ConstWeightsHostView1D = typename ConstWeightsDeviceView1D::HostMirror;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ template <typename User>
lno_t localNumIDs_;
const gno_t *idList_;
ArrayRCP<StridedData<lno_t, scalar_t> > weights_;
size_t numWeightsPerID_;
size_t numWeightsPerID_ = 0;

Kokkos::View<gno_t *, device_t> idsView_;
Kokkos::View<scalar_t **, device_t> weightsView_;
Expand Down
63 changes: 40 additions & 23 deletions packages/zoltan2/core/src/input/Zoltan2_MatrixAdapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ template <typename User, typename UserCoord=User>
using user_t = User;
using userCoord_t = UserCoord;
using base_adapter_t = MatrixAdapter<User,UserCoord>;
using Base = AdapterWithCoordsWrapper<User, UserCoord>;
using device_t = typename node_t::device_type;
#endif

enum BaseAdapterType adapterType() const override {return MatrixAdapterType;}
Expand Down Expand Up @@ -160,12 +162,12 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getRowIDsHostView(typename BaseAdapter<User>::ConstIdsHostView& rowIds) const
virtual void getRowIDsHostView(typename Base::ConstIdsHostView& rowIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getRowIDsDeviceView(typename BaseAdapter<User>::ConstIdsDeviceView& rowIds) const
virtual void getRowIDsDeviceView(typename Base::ConstIdsDeviceView& rowIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}
Expand All @@ -189,14 +191,14 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getCRSHostView(typename BaseAdapter<User>::ConstOffsetsHostView& offsets,
typename BaseAdapter<User>::ConstIdsHostView& colIds) const
virtual void getCRSHostView(typename Base::ConstOffsetsHostView& offsets,
typename Base::ConstIdsHostView& colIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getCRSDeviceView(typename BaseAdapter<User>::ConstOffsetsDeviceView& offsets,
typename BaseAdapter<User>::ConstIdsDeviceView& colIds) const
virtual void getCRSDeviceView(typename Base::ConstOffsetsDeviceView& offsets,
typename Base::ConstIdsDeviceView& colIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -227,16 +229,16 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getCRSHostView(typename BaseAdapter<User>::ConstOffsetsHostView& offsets,
typename BaseAdapter<User>::ConstIdsHostView& colIds,
typename BaseAdapter<User>::ConstScalarsHostView& values) const
virtual void getCRSHostView(typename Base::ConstOffsetsHostView& offsets,
typename Base::ConstIdsHostView& colIds,
typename Base::ConstScalarsHostView& values) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getCRSDeviceView(typename BaseAdapter<User>::ConstOffsetsDeviceView& offsets,
typename BaseAdapter<User>::ConstIdsDeviceView& colIds,
typename BaseAdapter<User>::ConstScalarsDeviceView& values) const
virtual void getCRSDeviceView(typename Base::ConstOffsetsDeviceView& offsets,
typename Base::ConstIdsDeviceView& colIds,
typename Base::ConstScalarsDeviceView& values) const
{
Z2_THROW_NOT_IMPLEMENTED
}
Expand All @@ -262,16 +264,28 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getRowWeightsHostView(typename BaseAdapter<User>::WeightsHostView& weights) const
virtual void getRowWeightsHostView(typename Base::WeightsHostView1D& weights,
int /* idx */ = 0) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getRowWeightsDeviceView(typename BaseAdapter<User>::WeightsDeviceView& weights) const
virtual void
getRowWeightsHostView(typename Base::WeightsHostView &weights) const {
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getRowWeightsDeviceView(typename Base::WeightsDeviceView1D& weights,
int /* idx */ = 0) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void
getRowWeightsDeviceView(typename Base::WeightsDeviceView &weights) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Indicate whether row weight with index idx should be the
* global number of nonzeros in the row.
*/
Expand All @@ -296,12 +310,12 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getColumnIDsHostView(typename BaseAdapter<User>::ConstIdsHostView& colIds) const
virtual void getColumnIDsHostView(typename Base::ConstIdsHostView& colIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getColumnIDsDeviceView(typename BaseAdapter<User>::ConstIdsDeviceView& colIds) const
virtual void getColumnIDsDeviceView(typename Base::ConstIdsDeviceView& colIds) const
{
Z2_THROW_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -371,12 +385,12 @@ template <typename User, typename UserCoord=User>
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getColumnWeightsHostView(typename BaseAdapter<User>::WeightsHostView& weights) const
virtual void getColumnWeightsHostView(typename Base::WeightsHostView1D& weights) const
{
Z2_THROW_NOT_IMPLEMENTED
}

virtual void getColumnWeightsDeviceView(typename BaseAdapter<User>::WeightsDeviceView& weights) const
virtual void getColumnWeightsDeviceView(typename Base::WeightsDeviceView1D& weights) const
{
Z2_THROW_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -495,7 +509,7 @@ template <typename User, typename UserCoord=User>
}
}

void getIDsHostView(typename BaseAdapter<User>::ConstIdsHostView& ids) const override {
void getIDsHostView(typename Base::ConstIdsHostView& ids) const override {
switch (getPrimaryEntityType()) {
case MATRIX_ROW:
getRowIDsHostView(ids);
Expand All @@ -517,7 +531,7 @@ template <typename User, typename UserCoord=User>
}
}

void getIDsDeviceView(typename BaseAdapter<User>::ConstIdsDeviceView& ids) const override {
void getIDsDeviceView(typename Base::ConstIdsDeviceView& ids) const override {
switch (getPrimaryEntityType()) {
case MATRIX_ROW:
getRowIDsDeviceView(ids);
Expand Down Expand Up @@ -553,7 +567,8 @@ template <typename User, typename UserCoord=User>
}
}

void getWeightsView(const scalar_t *&wgt, int &stride, int idx = 0) const override
void getWeightsView(const scalar_t *&wgt, int &stride,
int idx = 0) const override
{
switch (getPrimaryEntityType()) {
case MATRIX_ROW:
Expand All @@ -577,7 +592,8 @@ template <typename User, typename UserCoord=User>
}
}

virtual void getWeightsHostView(typename BaseAdapter<User>::WeightsHostView& hostWgts) const {
void getWeightsHostView(typename Base::WeightsHostView1D &hostWgts,
int idx = 0) const override {
switch (getPrimaryEntityType()) {
case MATRIX_ROW:
getRowWeightsHostView(hostWgts);
Expand All @@ -599,7 +615,8 @@ template <typename User, typename UserCoord=User>
break;
} }

virtual void getWeightsDeviceView(typename BaseAdapter<User>::WeightsDeviceView& deviceWgts) const {
void getWeightsDeviceView(typename Base::WeightsDeviceView1D& deviceWgts,
int idx = 0) const override {
switch (getPrimaryEntityType()) {
case MATRIX_ROW:
getRowWeightsDeviceView(deviceWgts);
Expand Down
Loading

0 comments on commit a24bce4

Please sign in to comment.