Skip to content

Commit

Permalink
Merge pull request #248 from DARMA-tasking/247-dynrankview-fails-to-s…
Browse files Browse the repository at this point in the history
…erialize-when-rank-is-0

247 dynrankview fails to serialize when rank is 0
  • Loading branch information
PhilMiller authored Jul 18, 2022
2 parents 594584f + c9d6f9f commit 6c3f519
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 10 deletions.
16 changes: 12 additions & 4 deletions src/checkpoint/container/view_serialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,17 @@ inline void serialize_impl(SerializerT& s, Kokkos::DynRankView<T,Args...>& view)
serializeLayout<SerializerT>(s, 7, layout_cur);
}

// Serialize the total number of elements in the Kokkos::View
size_t num_elms = view.size();
s | num_elms;
bool is_uninitialized = dims == 0 and num_elms == 0;
// Construct a view with the layout and use operator= to propagate out
if (s.isUnpacking()) {
if (dims == 1) {
if (is_uninitialized) {
view = ViewType{};
} else if (dims == 0) {
view = constructRankedView<ViewType,0>(label, std::make_tuple(layout));
} else if (dims == 1) {
view = constructRankedView<ViewType,1>(label, std::make_tuple(layout));
} else if (dims == 2) {
view = constructRankedView<ViewType,2>(label, std::make_tuple(layout));
Expand All @@ -337,9 +345,9 @@ inline void serialize_impl(SerializerT& s, Kokkos::DynRankView<T,Args...>& view)
}
}

// Serialize the total number of elements in the Kokkos::View
size_t num_elms = view.size();
s | num_elms;
if (is_uninitialized) {
return;
}

// Serialize whether the view is contiguous or not. Is this required?
bool is_contig = view.span_is_contiguous();
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/test_commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,20 @@ struct TestFactory {
namespace {
template <typename T>
std::unique_ptr<T> serializeAny(
T& view, std::function<void(T const&,T const&)> compare
T& view, std::function<void(T const&,T const&)> compare = nullptr
) {
using namespace checkpoint;

auto ret = serialize<T>(view);
auto out_view = deserialize<T>(ret->getBuffer());
auto const& out_view_ref = *out_view;
#if CHECKPOINT_USE_ND_COMPARE
compareND(view, out_view_ref);
#else
compare(view, out_view_ref);
#endif
if (compare) {
#if CHECKPOINT_USE_ND_COMPARE
compareND(view, out_view_ref);
#else
compare(view, out_view_ref);
#endif
}
return out_view;
}
} //end namespace
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_kokkos_serialize_dynrankview.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,19 @@

#include "test_commons.h"
#include "test_harness.h"
#include "test_kokkos_0d_commons.h"
#include "test_kokkos_1d_commons.h"
#include "test_kokkos_2d_commons.h"
#include "test_kokkos_3d_commons.h"

#include <Kokkos_DynRankView.hpp>

template <typename ParamT>
struct KokkosDynRankViewTestEmpty : KokkosViewTest<ParamT> { };

template <typename ParamT>
struct KokkosDynRankViewTest0D : KokkosViewTest<ParamT> { };

template <typename ParamT>
struct KokkosDynRankViewTest1D : KokkosViewTest<ParamT> { };

Expand All @@ -59,10 +66,41 @@ struct KokkosDynRankViewTest2D : KokkosViewTest<ParamT> { };
template <typename ParamT>
struct KokkosDynRankViewTest3D : KokkosViewTest<ParamT> { };

TYPED_TEST_CASE_P(KokkosDynRankViewTestEmpty);
TYPED_TEST_CASE_P(KokkosDynRankViewTest0D);
TYPED_TEST_CASE_P(KokkosDynRankViewTest1D);
TYPED_TEST_CASE_P(KokkosDynRankViewTest2D);
TYPED_TEST_CASE_P(KokkosDynRankViewTest3D);

TYPED_TEST_P(KokkosDynRankViewTestEmpty, test_empty_any) {
using namespace checkpoint;

using DataType = TypeParam;
using ViewType = Kokkos::DynRankView<DataType>;

ViewType in_view{};
auto out_view = serializeAny<ViewType>(in_view);
EXPECT_EQ(out_view->rank(), unsigned(0));
EXPECT_EQ(out_view->size(), unsigned(0));
}

TYPED_TEST_P(KokkosDynRankViewTest0D, test_0d_any) {
using namespace checkpoint;

using DataType = TypeParam;
using ViewType = Kokkos::DynRankView<DataType>;

static constexpr size_t const N = 1;

ViewType in_view("test");

EXPECT_EQ(in_view.size(), N);

init1d(in_view);
auto out_view = serializeAny<ViewType>(in_view, &compare1d<ViewType>);
EXPECT_EQ(out_view->rank(), unsigned(0));
}

TYPED_TEST_P(KokkosDynRankViewTest1D, test_1d_any) {
using namespace checkpoint;

Expand Down Expand Up @@ -111,12 +149,16 @@ TYPED_TEST_P(KokkosDynRankViewTest3D, test_3d_any) {
EXPECT_EQ(out_view->rank(), unsigned(3));
}

REGISTER_TYPED_TEST_CASE_P(KokkosDynRankViewTestEmpty, test_empty_any);
REGISTER_TYPED_TEST_CASE_P(KokkosDynRankViewTest0D, test_0d_any);
REGISTER_TYPED_TEST_CASE_P(KokkosDynRankViewTest1D, test_1d_any);
REGISTER_TYPED_TEST_CASE_P(KokkosDynRankViewTest2D, test_2d_any);
REGISTER_TYPED_TEST_CASE_P(KokkosDynRankViewTest3D, test_3d_any);

#if DO_UNIT_TESTS_FOR_VIEW

INSTANTIATE_TYPED_TEST_CASE_P(test_dynrank_empty , KokkosDynRankViewTestEmpty, DynRankViewTestTypes, );
INSTANTIATE_TYPED_TEST_CASE_P(test_dynrank_0, KokkosDynRankViewTest0D, DynRankViewTestTypes, );
INSTANTIATE_TYPED_TEST_CASE_P(test_dynrank_1, KokkosDynRankViewTest1D, DynRankViewTestTypes, );
INSTANTIATE_TYPED_TEST_CASE_P(test_dynrank_2, KokkosDynRankViewTest2D, DynRankViewTestTypes, );
INSTANTIATE_TYPED_TEST_CASE_P(test_dynrank_3, KokkosDynRankViewTest3D, DynRankViewTestTypes, );
Expand Down

0 comments on commit 6c3f519

Please sign in to comment.