diff --git a/src/vt/topos/index/dense/dense_array.h b/src/vt/topos/index/dense/dense_array.h index 51ce06ff66..a40a05d215 100644 --- a/src/vt/topos/index/dense/dense_array.h +++ b/src/vt/topos/index/dense/dense_array.h @@ -102,7 +102,7 @@ struct DenseIndexArray : BaseIndex, serialization::ByteCopyTrait { DenseIndexArray(std::array in_array); DenseIndexArray(DenseIndexArraySingleInitTag, IndexType const& init_value); - NumDimensionsType ndims() const { return ndim; } + static constexpr NumDimensionsType ndims() { return ndim; } IndexType& operator[](IndexType const& index); IndexType const& operator[](IndexType const& index) const; diff --git a/src/vt/topos/index/index.h b/src/vt/topos/index/index.h index fa064bb2a8..4be63a31a5 100644 --- a/src/vt/topos/index/index.h +++ b/src/vt/topos/index/index.h @@ -60,6 +60,7 @@ template using Index1D = DenseIndexArray; template using Index2D = DenseIndexArray; template using Index3D = DenseIndexArray; template using IdxType = DenseIndexArray; +template using IndexND = DenseIndexArray; static_assert(IndexTraits>::is_index, "Does not conform"); static_assert(IndexTraits>::is_index, "Does not conform"); @@ -76,11 +77,14 @@ using IdxBase = index::IdxBase; using Index1D = index::Index1D; using Index2D = index::Index2D; using Index3D = index::Index3D; +template +using IndexND = index::IndexND; template using IdxType = index::IdxType; template using IdxType1D = index::Index1D; template using IdxType2D = index::Index2D; template using IdxType3D = index::Index3D; +template using IdxTypeND = index::IndexND; } // end namespace vt diff --git a/src/vt/topos/mapping/dense/dense.h b/src/vt/topos/mapping/dense/dense.h index 83598acfab..f78c8dec3a 100644 --- a/src/vt/topos/mapping/dense/dense.h +++ b/src/vt/topos/mapping/dense/dense.h @@ -74,7 +74,7 @@ using IdxPtr = Index*; template using Idx1DPtr = IdxType1D*; template using Idx2DPtr = IdxType2D*; template using Idx3DPtr = IdxType3D*; -template using IdxNDPtr = vt::index::IdxType*; +template using IdxNDPtr = vt::index::IdxType*; template NodeType denseBlockMap(IdxPtr idx, IdxPtr max_idx, NodeType nnodes); @@ -85,8 +85,8 @@ template NodeType defaultDenseIndex2DMap(Idx2DPtr idx, Idx2DPtr max, NodeType n); template NodeType defaultDenseIndex3DMap(Idx3DPtr idx, Idx3DPtr max, NodeType n); -template -NodeType defaultDenseIndex3DMap(IdxNDPtr idx, IdxNDPtr max, NodeType n); +template +NodeType defaultDenseIndexNDMap(IdxNDPtr idx, IdxNDPtr max, NodeType n); template NodeType dense1DRoundRobinMap( Idx1DPtr idx, Idx1DPtr max, NodeType n); @@ -94,6 +94,8 @@ template NodeType dense2DRoundRobinMap( Idx2DPtr idx, Idx2DPtr max, NodeType n); template NodeType dense3DRoundRobinMap( Idx3DPtr idx, Idx3DPtr max, NodeType n); +template +NodeType denseNDRoundRobinMap( IdxNDPtr idx, IdxNDPtr max, NodeType n); template NodeType dense1DBlockMap( Idx1DPtr idx, Idx1DPtr max, NodeType n); @@ -101,10 +103,13 @@ template NodeType dense2DBlockMap( Idx2DPtr idx, Idx2DPtr max, NodeType n); template NodeType dense3DBlockMap( Idx3DPtr idx, Idx3DPtr max, NodeType n); +template +NodeType denseNDBlockMap( IdxNDPtr idx, IdxNDPtr max, NodeType n); template using i1D = IdxType1D; template using i2D = IdxType2D; template using i3D = IdxType3D; +template using iND = IdxTypeND; template using Adapt = MapFunctorAdapt; @@ -115,18 +120,26 @@ template using dense2DMapFn = Adapt>, defaultDenseIndex2DMap, i2D >; template using dense3DMapFn = Adapt>, defaultDenseIndex3DMap, i3D >; +template +using denseNDMapFn = Adapt>, defaultDenseIndexNDMap, iND >; + template using dense1DRRMapFn = Adapt>, dense1DRoundRobinMap, i1D>; template using dense2DRRMapFn = Adapt>, dense2DRoundRobinMap, i2D>; template using dense3DRRMapFn = Adapt>, dense3DRoundRobinMap, i3D>; +template +using denseNDRRMapFn = Adapt>, denseNDRoundRobinMap, iND >; + template using dense1DBlkMapFn = Adapt>, dense1DBlockMap, i1D>; template using dense2DBlkMapFn = Adapt>, dense2DBlockMap, i2D>; template using dense3DBlkMapFn = Adapt>, dense3DBlockMap, i3D>; +template +using denseNDBlkMapFn = Adapt>, denseNDBlockMap, iND >; }} // end namespace vt::mapping diff --git a/src/vt/topos/mapping/dense/dense.impl.h b/src/vt/topos/mapping/dense/dense.impl.h index c608104ec7..2d6d83b8c1 100644 --- a/src/vt/topos/mapping/dense/dense.impl.h +++ b/src/vt/topos/mapping/dense/dense.impl.h @@ -68,6 +68,11 @@ NodeType defaultDenseIndex3DMap(Idx3DPtr idx, Idx3DPtr max, NodeType nx) { return dense3DBlockMap(idx, max, nx); } +template +NodeType defaultDenseIndexNDMap(IdxNDPtr idx, IdxNDPtr max, NodeType nx) { + return denseNDBlockMap(idx, max, nx); +} + // Default round robin mappings template NodeType dense1DRoundRobinMap(Idx1DPtr idx, Idx1DPtr max, NodeType nx) { @@ -104,6 +109,11 @@ NodeType dense3DBlockMap(Idx3DPtr idx, Idx3DPtr max, NodeType nx) { return denseBlockMap, 3>(idx, max, nx); } +template +NodeType denseNDBlockMap(IdxNDPtr idx, IdxNDPtr max, NodeType nx) { + return denseBlockMap, N>(idx, max, nx); +} + template inline NodeType blockMapDenseFlatIndex( IndexElmType* flat_idx_ptr, IndexElmType* num_elems_ptr, diff --git a/src/vt/vrt/collection/defaults/default_map.h b/src/vt/vrt/collection/defaults/default_map.h index 36b2b5979f..af3091f4f2 100644 --- a/src/vt/vrt/collection/defaults/default_map.h +++ b/src/vt/vrt/collection/defaults/default_map.h @@ -61,8 +61,17 @@ struct DefaultMapBase { using MapParamPackType = std::tuple; }; -template -struct DefaultMap; +template +struct DefaultMap : DefaultMapBase { + using BaseType = typename CollectionT::IndexType::DenseIndexType; + using BlockMapType = + ::vt::mapping::denseNDMapFn; + using RRMapType = + ::vt::mapping::denseNDRRMapFn; + using DefaultMapType = + ::vt::mapping::denseNDMapFn; + using MapType = DefaultMapType; +}; /* * Default mappings for Index1D: RR, Block, etc.