Skip to content

Commit

Permalink
Zoltan2: Update Weights getters
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Jul 26, 2023
1 parent 6298185 commit e066b2f
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 37 deletions.
18 changes: 16 additions & 2 deletions packages/zoltan2/core/src/input/Zoltan2_Adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,29 @@ template <typename User>
\param hostWgts on return a Kokkos view of the weights for this idx
\param idx the weight index, zero or greater
*/
virtual void getWeightsHostView(ConstWeightsHostView1D& hostWgts, int idx = 0) const {
virtual void getWeightsHostView(WeightsHostView1D& hostWgts, int idx = 0) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Provide a Kokkos view (Host side) of the weights.
\param hostWgts on return a Kokkos view of all the weights
*/
virtual void getWeightsHostView(WeightsHostView& hostWgts) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Provide a Kokkos view (Device side) of the weights.
\param deviceWgts on return a Kokkos view of the weights for this idx
\param idx the weight index, zero or greater
*/
virtual void getWeightsDeviceView(ConstWeightsDeviceView1D& deviceWgts, int idx = 0) const {
virtual void getWeightsDeviceView(WeightsDeviceView1D& deviceWgts, int idx = 0) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Provide a Kokkos view (Device side) of the weights.
\param deviceWgts on return a Kokkos view of all the weights
*/
virtual void getWeightsDeviceView(WeightsDeviceView& deviceWgts) const {
Z2_THROW_NOT_IMPLEMENTED
}

Expand Down
58 changes: 52 additions & 6 deletions packages/zoltan2/core/src/input/Zoltan2_GraphAdapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,22 +226,38 @@ class GraphAdapter : public AdapterWithCoordsWrapper<User, UserCoord> {
\param idx ranges from zero to one less than getNumWeightsPerVertex().
*/
virtual void
getVertexWeightsDeviceView(typename Base::ConstWeightsDeviceView1D &weights,
getVertexWeightsDeviceView(typename Base::WeightsDeviceView1D &weights,
int /* idx */ = 0) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Provide a device view of the vertex weights, if any.
\param weights is the view of all the weights for the vertices returned in getVertexIDsView().
*/
virtual void
getVertexWeightsDeviceView(typename Base::WeightsDeviceView &weights) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Provide a host view of the vertex weights, if any.
\param weights is the list of weights of the given index for
the vertices returned in getVertexIDsView().
\param idx ranges from zero to one less than getNumWeightsPerVertex().
*/
virtual void
getVertexWeightsHostView(typename Base::ConstWeightsHostView1D &weights,
getVertexWeightsHostView(typename Base::WeightsHostView1D &weights,
int /* idx */ = 0) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Provide a host view of the vertex weights, if any.
\param weights is the list of all the weights for the vertices returned in getVertexIDsView()
*/
virtual void
getVertexWeightsHostView(typename Base::WeightsHostView &weights) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Indicate whether vertex weight with index idx should be the
* global degree of the vertex
*/
Expand Down Expand Up @@ -270,22 +286,38 @@ class GraphAdapter : public AdapterWithCoordsWrapper<User, UserCoord> {
\param idx ranges from zero to one less than getNumWeightsPerEdge().
*/
virtual void
getEdgeWeightsDeviceView(typename Base::ConstWeightsDeviceView1D &weights,
getEdgeWeightsDeviceView(typename Base::WeightsDeviceView1D &weights,
int /* idx */ = 0) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Provide a device view of the edge weights, if any.
\param weights is the list of weights for the edges returned in getEdgeView().
*/
virtual void
getEdgeWeightsDeviceView(typename Base::WeightsDeviceView &weights) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Provide a host view of the edge weights, if any.
\param weights is the list of weights of the given index for
the edges returned in getEdgeView().
\param idx ranges from zero to one less than getNumWeightsPerEdge().
*/
virtual void
getEdgeWeightsHostView(typename Base::ConstWeightsHostView1D &weights,
getEdgeWeightsHostView(typename Base::WeightsHostView1D &weights,
int /* idx */ = 0) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Provide a host view of the edge weights, if any.
\param weights is the list of weights for the edges returned in getEdgeView().
*/
virtual void
getEdgeWeightsHostView(typename Base::WeightsHostView &weights) const {
Z2_THROW_NOT_IMPLEMENTED
}

/*! \brief Allow user to provide additional data that contains coordinate
* info associated with the MatrixAdapter's primaryEntityType_.
* Associated data must have the same parallel distribution and
Expand Down Expand Up @@ -410,22 +442,36 @@ class GraphAdapter : public AdapterWithCoordsWrapper<User, UserCoord> {
getVertexWeightsView(wgt, stride, idx);
}

void getWeightsHostView(typename Base::ConstWeightsHostView1D &hostWgts,
void getWeightsHostView(typename Base::WeightsHostView1D &hostWgts,
int idx = 0) const override {
AssertCondition(getPrimaryEntityType() == GRAPH_VERTEX,
"getWeightsHostView not yet supported for graph edges.");

getVertexWeightsHostView(hostWgts, idx);
}

void getWeightsDeviceView(typename Base::ConstWeightsDeviceView1D &deviceWgts,
void getWeightsHostView(typename Base::WeightsHostView &hostWgts) const override {
AssertCondition(getPrimaryEntityType() == GRAPH_VERTEX,
"getWeightsHostView not yet supported for graph edges.");

getVertexWeightsHostView(hostWgts);
}

void getWeightsDeviceView(typename Base::WeightsDeviceView1D &deviceWgts,
int idx = 0) const override {
AssertCondition(getPrimaryEntityType() == GRAPH_VERTEX,
"getWeightsDeviceView not yet supported for graph edges.");

getVertexWeightsDeviceView(deviceWgts, idx);
}

void getWeightsDeviceView(typename Base::WeightsDeviceView &deviceWgts) const override {
AssertCondition(getPrimaryEntityType() == GRAPH_VERTEX,
"getWeightsDeviceView not yet supported for graph edges.");

getVertexWeightsDeviceView(deviceWgts);
}

bool useDegreeAsWeight(int idx) const {
AssertCondition(this->getPrimaryEntityType() == GRAPH_VERTEX,
"useDegreeAsWeight not yet supported for graph edges.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ TpetraCrsGraphAdapter<User, UserCoord>::TpetraCrsGraphAdapter(

if (this->nWeightsPerVertex_ > 0) {

// should we create underlying Views aswell?
this->vertexWeightsDevice_.resize(this->nWeightsPerVertex_);
this->vertexWeightsDevice_ = typename Base::WeightsDeviceView(
"vertexWeightsDevice_", graph->getLocalNumRows(),
this->nWeightsPerVertex_);

this->vertexDegreeWeightsHost_ = typename Base::VtxDegreeHostView(
"vertexDegreeWeightsHost_", this->nWeightsPerVertex_);
Expand All @@ -152,7 +153,10 @@ TpetraCrsGraphAdapter<User, UserCoord>::TpetraCrsGraphAdapter(
}
}

this->edgeWeightsDevice_.resize(this->nWeightsPerEdge_);
if (this->nWeightsPerEdge_) {
this->edgeWeightsDevice_ = typename Base::WeightsDeviceView(
"nWeightsPerEdge_", graph->getLocalNumRows(), this->nWeightsPerEdge_);
}
}

////////////////////////////////////////////////////////////////////////////
Expand Down
107 changes: 85 additions & 22 deletions packages/zoltan2/core/src/input/Zoltan2_TpetraRowGraphAdapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#ifndef _ZOLTAN2_TPETRAROWGRAPHADAPTER_HPP_
#define _ZOLTAN2_TPETRAROWGRAPHADAPTER_HPP_

#include "Kokkos_DualView.hpp"
#include "Kokkos_UnorderedMap.hpp"
#include <Tpetra_RowGraph.hpp>
#include <Zoltan2_GraphAdapter.hpp>
Expand Down Expand Up @@ -273,27 +274,37 @@ class TpetraRowGraphAdapter : public GraphAdapter<User, UserCoord> {
void getVertexWeightsView(const scalar_t *&weights, int &stride,
int idx) const override;

void
getVertexWeightsDeviceView(typename Base::ConstWeightsDeviceView1D &weights,
int idx = 0) const override;
void getVertexWeightsDeviceView(typename Base::WeightsDeviceView1D &weights,
int idx = 0) const override;

void getVertexWeightsDeviceView(
typename Base::WeightsDeviceView &weights) const override;

void getVertexWeightsHostView(typename Base::ConstWeightsHostView1D &weights,
void getVertexWeightsHostView(typename Base::WeightsHostView1D &weights,
int idx = 0) const override;

void getVertexWeightsHostView(
typename Base::WeightsHostView &weights) const override;

bool useDegreeAsVertexWeight(int idx) const override;

int getNumWeightsPerEdge() const override;

void getEdgeWeightsView(const scalar_t *&weights, int &stride,
int idx) const override;

void
getEdgeWeightsDeviceView(typename Base::ConstWeightsDeviceView1D &weights,
int idx = 0) const override;
void getEdgeWeightsDeviceView(typename Base::WeightsDeviceView1D &weights,
int idx = 0) const override;

void getEdgeWeightsHostView(typename Base::ConstWeightsHostView1D &weights,
void getEdgeWeightsDeviceView(
typename Base::WeightsDeviceView &weights) const override;

void getEdgeWeightsHostView(typename Base::WeightsHostView1D &weights,
int idx = 0) const override;

void getEdgeWeightsHostView(
typename Base::WeightsHostView &weights) const override;

template <typename Adapter>
void applyPartitioningSolution(
const User &in, User *&out,
Expand Down Expand Up @@ -321,12 +332,12 @@ class TpetraRowGraphAdapter : public GraphAdapter<User, UserCoord> {

int nWeightsPerVertex_;
ArrayRCP<StridedData<lno_t, scalar_t>> vertexWeights_;
std::vector<typename Base::ConstWeightsDeviceView1D> vertexWeightsDevice_;
typename Base::WeightsDeviceView vertexWeightsDevice_;
typename Base::VtxDegreeHostView vertexDegreeWeightsHost_;

int nWeightsPerEdge_;
ArrayRCP<StridedData<lno_t, scalar_t>> edgeWeights_;
std::vector<typename Base::ConstWeightsDeviceView1D> edgeWeightsDevice_;
typename Base::WeightsDeviceView edgeWeightsDevice_;

virtual RCP<User> doMigration(const User &from, size_t numLocalRows,
const gno_t *myNewRows) const;
Expand Down Expand Up @@ -378,7 +389,8 @@ TpetraRowGraphAdapter<User, UserCoord>::TpetraRowGraphAdapter(
vertexWeights_ =
arcp(new strided_t[nWeightsPerVertex_], 0, nWeightsPerVertex_, true);

vertexWeightsDevice_.resize(nWeightsPerVertex_);
vertexWeightsDevice_ = typename Base::WeightsDeviceView(
"vertexWeightsDevice_", nvtx, nWeightsPerVertex_);

vertexDegreeWeightsHost_ = typename Base::VtxDegreeHostView(
"vertexDegreeWeightsHost_", nWeightsPerVertex_);
Expand All @@ -392,7 +404,8 @@ TpetraRowGraphAdapter<User, UserCoord>::TpetraRowGraphAdapter(
edgeWeights_ =
arcp(new strided_t[nWeightsPerEdge_], 0, nWeightsPerEdge_, true);

edgeWeightsDevice_.resize(nWeightsPerEdge_);
edgeWeightsDevice_ = typename Base::WeightsDeviceView(
"nWeightsPerEdge_", graph_->getLocalNumRows(), nWeightsPerEdge_);
}
}

Expand Down Expand Up @@ -446,7 +459,12 @@ void TpetraRowGraphAdapter<User, UserCoord>::setVertexWeightsDevice(
AssertCondition((idx >= 0) and (idx < nWeightsPerVertex_),
"Invalid vertex weight index: " + std::to_string(idx));

vertexWeightsDevice_[idx] = weights;
const auto nVtx = getLocalNumVertices();
auto weightsSub = Kokkos::subview(vertexWeightsDevice_, Kokkos::ALL, idx);
Kokkos::parallel_for(
nVtx, KOKKOS_LAMBDA(const int vertexID) {
weightsSub(vertexID) = weights(vertexID);
});
}

////////////////////////////////////////////////////////////////////////////
Expand All @@ -458,7 +476,14 @@ void TpetraRowGraphAdapter<User, UserCoord>::setVertexWeightsHost(

auto weightsDevice = Kokkos::create_mirror_view_and_copy(
typename Base::device_t(), weightsHost);
vertexWeightsDevice_[idx] = weightsDevice;
// vertexWeightsDevice_[idx] = weightsDevice;

const auto nVtx = getLocalNumVertices();
Kokkos::parallel_for(
Kokkos::RangePolicy<Kokkos::HostSpace::execution_space>(0, nVtx),
KOKKOS_LAMBDA(const int vertexID) {
vertexWeightsDevice_(vertexID, idx) = weightsDevice(vertexID);
});
}

////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -619,23 +644,42 @@ void TpetraRowGraphAdapter<User, UserCoord>::getVertexWeightsView(
////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowGraphAdapter<User, UserCoord>::getVertexWeightsDeviceView(
typename Base::ConstWeightsDeviceView1D &weights, int idx) const {
typename Base::WeightsDeviceView1D &weights, int idx) const {
AssertCondition((idx >= 0) and (idx < nWeightsPerVertex_),
"Invalid vertex weight index.");
weights = vertexWeightsDevice_.at(idx);

weights = Kokkos::subview(vertexWeightsDevice_, Kokkos::ALL, idx);
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowGraphAdapter<User, UserCoord>::getVertexWeightsDeviceView(
typename Base::WeightsDeviceView &weights) const {

weights = vertexWeightsDevice_;
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowGraphAdapter<User, UserCoord>::getVertexWeightsHostView(
typename Base::ConstWeightsHostView1D &weights, int idx) const {
typename Base::WeightsHostView1D &weights, int idx) const {
AssertCondition((idx >= 0) and (idx < nWeightsPerVertex_),
"Invalid vertex weight index.");
const auto weightsDevice = vertexWeightsDevice_.at(idx);

auto weightsDevice = Kokkos::subview(vertexWeightsDevice_, Kokkos::ALL, idx);
weights = Kokkos::create_mirror_view(weightsDevice);
Kokkos::deep_copy(weights, weightsDevice);
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowGraphAdapter<User, UserCoord>::getVertexWeightsHostView(
typename Base::WeightsHostView &weights) const {

weights = Kokkos::create_mirror_view(vertexWeightsDevice_);
Kokkos::deep_copy(weights, vertexWeightsDevice_);
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
bool TpetraRowGraphAdapter<User, UserCoord>::useDegreeAsVertexWeight(
Expand Down Expand Up @@ -663,19 +707,38 @@ void TpetraRowGraphAdapter<User, UserCoord>::getEdgeWeightsView(
////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowGraphAdapter<User, UserCoord>::getEdgeWeightsDeviceView(
typename Base::ConstWeightsDeviceView1D &weights, int idx) const {
weights = edgeWeightsDevice_.at(idx);
typename Base::WeightsDeviceView1D &weights, int idx) const {

weights = Kokkos::subview(edgeWeightsDevice_, Kokkos::ALL, idx);
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowGraphAdapter<User, UserCoord>::getEdgeWeightsDeviceView(
typename Base::WeightsDeviceView &weights) const {

weights = edgeWeightsDevice_;
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowGraphAdapter<User, UserCoord>::getEdgeWeightsHostView(
typename Base::ConstWeightsHostView1D &weights, int idx) const {
const auto weightsDevice = edgeWeightsDevice_.at(idx);
typename Base::WeightsHostView1D &weights, int idx) const {

auto weightsDevice = Kokkos::subview(edgeWeightsDevice_, Kokkos::ALL, idx);
weights = Kokkos::create_mirror_view(weightsDevice);
Kokkos::deep_copy(weights, weightsDevice);
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
void TpetraRowGraphAdapter<User, UserCoord>::getEdgeWeightsHostView(
typename Base::WeightsHostView &weights) const {

weights = Kokkos::create_mirror_view(edgeWeightsDevice_);
Kokkos::deep_copy(weights, edgeWeightsDevice_);
}

////////////////////////////////////////////////////////////////////////////
template <typename User, typename UserCoord>
template <typename Adapter>
Expand Down
Loading

0 comments on commit e066b2f

Please sign in to comment.