diff --git a/packages/zoltan2/core/src/input/Zoltan2_Adapter.hpp b/packages/zoltan2/core/src/input/Zoltan2_Adapter.hpp index 8cd7453d574d..b4d1aee75a96 100644 --- a/packages/zoltan2/core/src/input/Zoltan2_Adapter.hpp +++ b/packages/zoltan2/core/src/input/Zoltan2_Adapter.hpp @@ -259,7 +259,7 @@ template /*! \brief Provide a Kokkos view (Host side) of the weights. \param hostWgts on return a Kokkos view of all the weights */ - virtual void getWeightsHostView(ConstWeightsHostView1D& hostWgts, int idx = 0) const { + virtual void getWeightsHostView(WeightsHostView& hostWgts) const { Z2_THROW_NOT_IMPLEMENTED } @@ -274,7 +274,7 @@ template /*! \brief Provide a Kokkos view (Device side) of the weights. \param deviceWgts on return a Kokkos view of all the weights */ - virtual void getWeightsDeviceView(ConstWeightsDeviceView1D& deviceWgts, int idx = 0) const { + virtual void getWeightsDeviceView(WeightsDeviceView& deviceWgts) const { Z2_THROW_NOT_IMPLEMENTED } diff --git a/packages/zoltan2/core/src/input/Zoltan2_BasicIdentifierAdapter.hpp b/packages/zoltan2/core/src/input/Zoltan2_BasicIdentifierAdapter.hpp index 21adc1b9f2fd..b0be959a4e91 100644 --- a/packages/zoltan2/core/src/input/Zoltan2_BasicIdentifierAdapter.hpp +++ b/packages/zoltan2/core/src/input/Zoltan2_BasicIdentifierAdapter.hpp @@ -201,7 +201,7 @@ template lno_t localNumIDs_; const gno_t *idList_; ArrayRCP > weights_; - size_t numWeightsPerID_; + size_t numWeightsPerID_ = 0; Kokkos::View idsView_; Kokkos::View weightsView_; diff --git a/packages/zoltan2/core/src/input/Zoltan2_MatrixAdapter.hpp b/packages/zoltan2/core/src/input/Zoltan2_MatrixAdapter.hpp index 6938afb92628..5c4ba2d1306a 100644 --- a/packages/zoltan2/core/src/input/Zoltan2_MatrixAdapter.hpp +++ b/packages/zoltan2/core/src/input/Zoltan2_MatrixAdapter.hpp @@ -121,6 +121,8 @@ template using user_t = User; using userCoord_t = UserCoord; using base_adapter_t = MatrixAdapter; + using Base = AdapterWithCoordsWrapper; + using device_t = typename node_t::device_type; #endif enum BaseAdapterType adapterType() const override {return MatrixAdapterType;} @@ -160,12 +162,12 @@ template Z2_THROW_NOT_IMPLEMENTED } - virtual void getRowIDsHostView(typename BaseAdapter::ConstIdsHostView& rowIds) const + virtual void getRowIDsHostView(typename Base::ConstIdsHostView& rowIds) const { Z2_THROW_NOT_IMPLEMENTED } - virtual void getRowIDsDeviceView(typename BaseAdapter::ConstIdsDeviceView& rowIds) const + virtual void getRowIDsDeviceView(typename Base::ConstIdsDeviceView& rowIds) const { Z2_THROW_NOT_IMPLEMENTED } @@ -189,14 +191,14 @@ template Z2_THROW_NOT_IMPLEMENTED } - virtual void getCRSHostView(typename BaseAdapter::ConstOffsetsHostView& offsets, - typename BaseAdapter::ConstIdsHostView& colIds) const + virtual void getCRSHostView(typename Base::ConstOffsetsHostView& offsets, + typename Base::ConstIdsHostView& colIds) const { Z2_THROW_NOT_IMPLEMENTED } - virtual void getCRSDeviceView(typename BaseAdapter::ConstOffsetsDeviceView& offsets, - typename BaseAdapter::ConstIdsDeviceView& colIds) const + virtual void getCRSDeviceView(typename Base::ConstOffsetsDeviceView& offsets, + typename Base::ConstIdsDeviceView& colIds) const { Z2_THROW_NOT_IMPLEMENTED } @@ -227,16 +229,16 @@ template Z2_THROW_NOT_IMPLEMENTED } - virtual void getCRSHostView(typename BaseAdapter::ConstOffsetsHostView& offsets, - typename BaseAdapter::ConstIdsHostView& colIds, - typename BaseAdapter::ConstScalarsHostView& values) const + virtual void getCRSHostView(typename Base::ConstOffsetsHostView& offsets, + typename Base::ConstIdsHostView& colIds, + typename Base::ConstScalarsHostView& values) const { Z2_THROW_NOT_IMPLEMENTED } - virtual void getCRSDeviceView(typename BaseAdapter::ConstOffsetsDeviceView& offsets, - typename BaseAdapter::ConstIdsDeviceView& colIds, - typename BaseAdapter::ConstScalarsDeviceView& values) const + virtual void getCRSDeviceView(typename Base::ConstOffsetsDeviceView& offsets, + typename Base::ConstIdsDeviceView& colIds, + typename Base::ConstScalarsDeviceView& values) const { Z2_THROW_NOT_IMPLEMENTED } @@ -262,18 +264,28 @@ template Z2_THROW_NOT_IMPLEMENTED } - virtual void getRowWeightsHostView(typename BaseAdapter::ConstWeightsHostView1D& weights, + virtual void getRowWeightsHostView(typename Base::WeightsHostView1D& weights, int /* idx */ = 0) const { Z2_THROW_NOT_IMPLEMENTED } - virtual void getRowWeightsDeviceView(typename BaseAdapter::ConstWeightsDeviceView1D& weights, + virtual void + getRowWeightsHostView(typename Base::WeightsHostView &weights) const { + Z2_THROW_NOT_IMPLEMENTED + } + + virtual void getRowWeightsDeviceView(typename Base::WeightsDeviceView1D& weights, int /* idx */ = 0) const { Z2_THROW_NOT_IMPLEMENTED } + virtual void + getRowWeightsDeviceView(typename Base::WeightsDeviceView &weights) const { + Z2_THROW_NOT_IMPLEMENTED + } + /*! \brief Indicate whether row weight with index idx should be the * global number of nonzeros in the row. */ @@ -298,12 +310,12 @@ template Z2_THROW_NOT_IMPLEMENTED } - virtual void getColumnIDsHostView(typename BaseAdapter::ConstIdsHostView& colIds) const + virtual void getColumnIDsHostView(typename Base::ConstIdsHostView& colIds) const { Z2_THROW_NOT_IMPLEMENTED } - virtual void getColumnIDsDeviceView(typename BaseAdapter::ConstIdsDeviceView& colIds) const + virtual void getColumnIDsDeviceView(typename Base::ConstIdsDeviceView& colIds) const { Z2_THROW_NOT_IMPLEMENTED } @@ -373,12 +385,12 @@ template Z2_THROW_NOT_IMPLEMENTED } - virtual void getColumnWeightsHostView(typename BaseAdapter::ConstWeightsHostView1D& weights) const + virtual void getColumnWeightsHostView(typename Base::WeightsHostView1D& weights) const { Z2_THROW_NOT_IMPLEMENTED } - virtual void getColumnWeightsDeviceView(typename BaseAdapter::ConstWeightsDeviceView1D& weights) const + virtual void getColumnWeightsDeviceView(typename Base::WeightsDeviceView1D& weights) const { Z2_THROW_NOT_IMPLEMENTED } @@ -497,7 +509,7 @@ template } } - void getIDsHostView(typename BaseAdapter::ConstIdsHostView& ids) const override { + void getIDsHostView(typename Base::ConstIdsHostView& ids) const override { switch (getPrimaryEntityType()) { case MATRIX_ROW: getRowIDsHostView(ids); @@ -519,7 +531,7 @@ template } } - void getIDsDeviceView(typename BaseAdapter::ConstIdsDeviceView& ids) const override { + void getIDsDeviceView(typename Base::ConstIdsDeviceView& ids) const override { switch (getPrimaryEntityType()) { case MATRIX_ROW: getRowIDsDeviceView(ids); @@ -580,7 +592,7 @@ template } } - virtual void getWeightsHostView(typename BaseAdapter::ConstWeightsHostView1D &hostWgts, + void getWeightsHostView(typename Base::WeightsHostView1D &hostWgts, int idx = 0) const override { switch (getPrimaryEntityType()) { case MATRIX_ROW: @@ -603,7 +615,7 @@ template break; } } - virtual void getWeightsDeviceView(typename BaseAdapter::ConstWeightsDeviceView1D& deviceWgts, + void getWeightsDeviceView(typename Base::WeightsDeviceView1D& deviceWgts, int idx = 0) const override { switch (getPrimaryEntityType()) { case MATRIX_ROW: diff --git a/packages/zoltan2/core/src/input/Zoltan2_TpetraCrsMatrixAdapter.hpp b/packages/zoltan2/core/src/input/Zoltan2_TpetraCrsMatrixAdapter.hpp index eee2bc20855b..ee2092e676a2 100644 --- a/packages/zoltan2/core/src/input/Zoltan2_TpetraCrsMatrixAdapter.hpp +++ b/packages/zoltan2/core/src/input/Zoltan2_TpetraCrsMatrixAdapter.hpp @@ -156,7 +156,9 @@ template if (this->nWeightsPerRow_ > 0) { - this->rowWeightsDevice_.resize(this->nWeightsPerRow_); + this->rowWeightsDevice_ = typename Base::WeightsDeviceView( + "rowWeightsDevice_", inmatrix->getLocalNumRows(), + this->nWeightsPerRow_); this->numNzWeight_ = Kokkos::View( "numNzWeight_", this->nWeightsPerRow_); diff --git a/packages/zoltan2/core/src/input/Zoltan2_TpetraRowMatrixAdapter.hpp b/packages/zoltan2/core/src/input/Zoltan2_TpetraRowMatrixAdapter.hpp index bb2ba55201f2..f6eb73c22d66 100644 --- a/packages/zoltan2/core/src/input/Zoltan2_TpetraRowMatrixAdapter.hpp +++ b/packages/zoltan2/core/src/input/Zoltan2_TpetraRowMatrixAdapter.hpp @@ -250,11 +250,17 @@ class TpetraRowMatrixAdapter : public MatrixAdapter { void getRowWeightsView(const scalar_t *&weights, int &stride, int idx = 0) const; - void getRowWeightsDeviceView(typename Base::ConstWeightsDeviceView1D &weights, - int idx) const; + void getRowWeightsDeviceView(typename Base::WeightsDeviceView1D &weights, + int idx = 0) const; - void getRowWeightsHostView(typename Base::ConstWeightsHostView1D &weights, - int idx) const; + void getRowWeightsDeviceView( + typename Base::WeightsDeviceView &weights) const override; + + void getRowWeightsHostView(typename Base::WeightsHostView1D &weights, + int idx = 0) const; + + void getRowWeightsHostView( + typename Base::WeightsHostView &weights) const override; bool useNumNonzerosAsRowWeight(int idx) const; @@ -290,7 +296,7 @@ class TpetraRowMatrixAdapter : public MatrixAdapter { int nWeightsPerRow_; ArrayRCP> rowWeights_; - std::vector rowWeightsDevice_; + typename Base::WeightsDeviceView rowWeightsDevice_; Kokkos::View numNzWeight_; bool mayHaveDiagonalEntries; @@ -350,7 +356,8 @@ TpetraRowMatrixAdapter::TpetraRowMatrixAdapter( rowWeights_ = arcp(new strided_t[nWeightsPerRow_], 0, nWeightsPerRow_, true); - rowWeightsDevice_.resize(nWeightsPerRow_); + rowWeightsDevice_ = typename Base::WeightsDeviceView( + "rowWeightsDevice_", nrows, nWeightsPerRow_); numNzWeight_ = Kokkos::View( "numNzWeight_", nWeightsPerRow_); @@ -430,7 +437,12 @@ void TpetraRowMatrixAdapter::setRowWeightsDevice( AssertCondition((idx >= 0) and (idx < nWeightsPerRow_), "Invalid row weight index: " + std::to_string(idx)); - rowWeightsDevice_[idx] = weights; + const auto nrows = getLocalNumRows(); + auto weightsSub = Kokkos::subview(rowWeightsDevice_, Kokkos::ALL, idx); + Kokkos::parallel_for( + nrows, KOKKOS_LAMBDA(const int rowID) { + weightsSub(rowID) = weights(rowID); + }); } //////////////////////////////////////////////////////////////////////////// @@ -442,7 +454,13 @@ void TpetraRowMatrixAdapter::setRowWeightsHost( auto weightsDevice = Kokkos::create_mirror_view_and_copy( typename Base::device_t(), weightsHost); - rowWeightsDevice_[idx] = weightsDevice; + + const auto nrows = getLocalNumRows(); + Kokkos::parallel_for( + Kokkos::RangePolicy(0, nrows), + KOKKOS_LAMBDA(const int rowID) { + rowWeightsDevice_(rowID, idx) = weightsDevice(rowID); + }); } //////////////////////////////////////////////////////////////////////////// @@ -620,24 +638,43 @@ void TpetraRowMatrixAdapter::getRowWeightsView(const scalar_t * //////////////////////////////////////////////////////////////////////////// template -void TpetraRowMatrixAdapter::getRowWeightsDeviceView(typename Base::ConstWeightsDeviceView1D &weights, - int idx) const { +void TpetraRowMatrixAdapter::getRowWeightsDeviceView( + typename Base::WeightsDeviceView1D &weights, int idx) const { AssertCondition((idx >= 0) and (idx < nWeightsPerRow_), "Invalid row weight index."); - weights = rowWeightsDevice_.at(idx); + + weights = Kokkos::subview(rowWeightsDevice_, Kokkos::ALL, idx); +} + +//////////////////////////////////////////////////////////////////////////// +template +void TpetraRowMatrixAdapter::getRowWeightsDeviceView( + typename Base::WeightsDeviceView &weights) const { + + weights = rowWeightsDevice_; } //////////////////////////////////////////////////////////////////////////// template -void TpetraRowMatrixAdapter::getRowWeightsHostView(typename Base::ConstWeightsHostView1D &weights, - int idx) const { +void TpetraRowMatrixAdapter::getRowWeightsHostView( + typename Base::WeightsHostView1D &weights, int idx) const { AssertCondition((idx >= 0) and (idx < nWeightsPerRow_), - "Invalid row index."); - const auto weightsDevice = rowWeightsDevice_.at(idx); + "Invalid row weight index."); + + auto weightsDevice = Kokkos::subview(rowWeightsDevice_, Kokkos::ALL, idx); weights = Kokkos::create_mirror_view(weightsDevice); Kokkos::deep_copy(weights, weightsDevice); } +//////////////////////////////////////////////////////////////////////////// +template +void TpetraRowMatrixAdapter::getRowWeightsHostView( + typename Base::WeightsHostView &weights) const { + + weights = Kokkos::create_mirror_view(rowWeightsDevice_); + Kokkos::deep_copy(weights, rowWeightsDevice_); +} + //////////////////////////////////////////////////////////////////////////// template bool TpetraRowMatrixAdapter::useNumNonzerosAsRowWeight(int idx) const { return numNzWeight_[idx]; } diff --git a/packages/zoltan2/test/core/unit/input/MatrixAdapter.cpp b/packages/zoltan2/test/core/unit/input/MatrixAdapter.cpp index 4acc750b58bb..4f0fe9e600d2 100644 --- a/packages/zoltan2/test/core/unit/input/MatrixAdapter.cpp +++ b/packages/zoltan2/test/core/unit/input/MatrixAdapter.cpp @@ -275,10 +275,10 @@ int main(int narg, char *arg[]) { // TEST of getRowWeightsView, getRowWeightsHost0View and // getRowWeightsDeviceView - Zoltan2::BaseAdapter::ConstWeightsHostView1D weightsHost0; - Zoltan2::BaseAdapter::ConstWeightsDeviceView1D weightsDevice0; - Zoltan2::BaseAdapter::ConstWeightsHostView1D weightsHost1; - Zoltan2::BaseAdapter::ConstWeightsDeviceView1D weightsDevice1; + Zoltan2::BaseAdapter::WeightsHostView1D weightsHost0; + Zoltan2::BaseAdapter::WeightsDeviceView1D weightsDevice0; + Zoltan2::BaseAdapter::WeightsHostView1D weightsHost1; + Zoltan2::BaseAdapter::WeightsDeviceView1D weightsDevice1; tmi.getRowWeightsHostView(weightsHost0, 0); tmi.getRowWeightsDeviceView(weightsDevice0, 0); diff --git a/packages/zoltan2/test/core/unit/input/TpetraCrsMatrixInput.cpp b/packages/zoltan2/test/core/unit/input/TpetraCrsMatrixInput.cpp index 1eb5d7347e8b..5e71024dd462 100644 --- a/packages/zoltan2/test/core/unit/input/TpetraCrsMatrixInput.cpp +++ b/packages/zoltan2/test/core/unit/input/TpetraCrsMatrixInput.cpp @@ -182,7 +182,7 @@ void verifyInputAdapter(adapter_t &ia, matrix_t &matrix) { //// setRowWeightsDevice ///////////////////////////////// Z2_TEST_THROW(ia.setRowWeightsDevice( - typename adapter_t::ConstWeightsDeviceView1D{}, 50), + typename adapter_t::WeightsDeviceView1D{}, 50), std::runtime_error); weightsDevice_t wgts0("wgts0", nrows); @@ -203,10 +203,10 @@ void verifyInputAdapter(adapter_t &ia, matrix_t &matrix) { //// getRowWeightsDevice ///////////////////////////////// { - constWeightsDevice_t weightsDevice; + weightsDevice_t weightsDevice; Z2_TEST_NOTHROW(ia.getRowWeightsDeviceView(weightsDevice, 0)); - constWeightsHost_t weightsHost; + weightsHost_t weightsHost; Z2_TEST_NOTHROW(ia.getRowWeightsHostView(weightsHost, 0)); TestDeviceHostView(weightsDevice, weightsHost); @@ -214,10 +214,10 @@ void verifyInputAdapter(adapter_t &ia, matrix_t &matrix) { TestDeviceHostView(wgts0, weightsHost); } { - constWeightsDevice_t weightsDevice; + weightsDevice_t weightsDevice; Z2_TEST_NOTHROW(ia.getRowWeightsDeviceView(weightsDevice, 1)); - constWeightsHost_t weightsHost; + weightsHost_t weightsHost; Z2_TEST_NOTHROW(ia.getRowWeightsHostView(weightsHost, 1)); TestDeviceHostView(weightsDevice, weightsHost); @@ -225,11 +225,11 @@ void verifyInputAdapter(adapter_t &ia, matrix_t &matrix) { TestDeviceHostView(wgts1, weightsHost); } { - constWeightsDevice_t wgtsDevice; + weightsDevice_t wgtsDevice; Z2_TEST_THROW(ia.getRowWeightsDeviceView(wgtsDevice, 2), std::runtime_error); - constWeightsHost_t wgtsHost; + weightsHost_t wgtsHost; Z2_TEST_THROW(ia.getRowWeightsHostView(wgtsHost, 2), std::runtime_error); } diff --git a/packages/zoltan2/test/core/unit/input/TpetraRowMatrixInput.cpp b/packages/zoltan2/test/core/unit/input/TpetraRowMatrixInput.cpp index 196cb444dc09..6c9fd64ce9e1 100644 --- a/packages/zoltan2/test/core/unit/input/TpetraRowMatrixInput.cpp +++ b/packages/zoltan2/test/core/unit/input/TpetraRowMatrixInput.cpp @@ -184,7 +184,7 @@ void verifyInputAdapter(adapter_t &ia, matrix_t &matrix) { //// setRowWeightsDevice ///////////////////////////////// Z2_TEST_THROW(ia.setRowWeightsDevice( - typename adapter_t::ConstWeightsDeviceView1D{}, 50), + typename adapter_t::WeightsDeviceView1D{}, 50), std::runtime_error); weightsDevice_t wgts0("wgts0", nrows); @@ -205,10 +205,10 @@ void verifyInputAdapter(adapter_t &ia, matrix_t &matrix) { //// getRowWeightsDevice ///////////////////////////////// { - constWeightsDevice_t weightsDevice; + weightsDevice_t weightsDevice; Z2_TEST_NOTHROW(ia.getRowWeightsDeviceView(weightsDevice, 0)); - constWeightsHost_t weightsHost; + weightsHost_t weightsHost; Z2_TEST_NOTHROW(ia.getRowWeightsHostView(weightsHost, 0)); TestDeviceHostView(weightsDevice, weightsHost); @@ -216,10 +216,10 @@ void verifyInputAdapter(adapter_t &ia, matrix_t &matrix) { TestDeviceHostView(wgts0, weightsHost); } { - constWeightsDevice_t weightsDevice; + weightsDevice_t weightsDevice; Z2_TEST_NOTHROW(ia.getRowWeightsDeviceView(weightsDevice, 1)); - constWeightsHost_t weightsHost; + weightsHost_t weightsHost; Z2_TEST_NOTHROW(ia.getRowWeightsHostView(weightsHost, 1)); TestDeviceHostView(weightsDevice, weightsHost); @@ -227,11 +227,11 @@ void verifyInputAdapter(adapter_t &ia, matrix_t &matrix) { TestDeviceHostView(wgts1, weightsHost); } { - constWeightsDevice_t wgtsDevice; + weightsDevice_t wgtsDevice; Z2_TEST_THROW(ia.getRowWeightsDeviceView(wgtsDevice, 2), std::runtime_error); - constWeightsHost_t wgtsHost; + weightsHost_t wgtsHost; Z2_TEST_THROW(ia.getRowWeightsHostView(wgtsHost, 2), std::runtime_error); }