Skip to content

Commit

Permalink
Merge pull request #152 from DARMA-tasking/151-virtual-allocate
Browse files Browse the repository at this point in the history
#151: Split up virtual serialization logic and better support in-place deserialization
  • Loading branch information
PhilMiller authored Nov 9, 2020
2 parents 318d5bc + e4cf1f8 commit f64650d
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 46 deletions.
18 changes: 14 additions & 4 deletions src/checkpoint/container/unique_ptr_serialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,24 @@ void serialize(Serializer& s, std::unique_ptr<T, Deleter>& ptr) {
s | is_null;
}

if (not is_null) {
T* t = ptr.get();
allocateConstructForPointer(s, t);
if (is_null) {
if (s.isUnpacking()) {
ptr.reset();
}
return;
}

T* t = ptr.get();
reconstructPointedToObjectIfNeeded(s, t);

if (s.isUnpacking()) {
// Support deserialization in place
if (ptr.get() != t) {
ptr = std::unique_ptr<T, Deleter>(t);
}
s | *ptr;
}

s | *ptr;
}

} /* end namespace checkpoint */
Expand Down
19 changes: 16 additions & 3 deletions src/checkpoint/dispatch/dispatch_virtual.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
* ----------------------------- API usage: ----------------------------------
* ---------------------------------------------------------------------------
*
* ----------------------------- Declarations --------------------------------
* === Option 1 ===
*
* - Make your virtual class hierarchy you want to serialize all inherit from
Expand All @@ -71,15 +72,27 @@
* - \c checkpoint_virtual_serialize_root()
* - \c checkpoint_virtual_serialize_derived_from(ParentT)
*
* Invoking the virtual serializer:
* --------------------------- Invocation ------------------------------------
*
* - If you have a \c std::unique_ptr<T>, where T is virtually serializable (by
* using the macros or inheriting from \c SerializableBase<T> and
* \c SerializableDervived<T,BaseT> ), they will automatically be virtually
* serialized.
*
* Example using unique_ptr:
* template <typename T>
* struct MyObjectWithUniquePointer {
* unique_ptr<T> ptr;
*
* template <typename SerializerT>
* void serialize(SerializerT& s) {
* s | ptr;
* }
* };
*
*
* - If you have a raw pointer, \c Teuchos::RCP, or \c std::shared_ptr<T>,
* you must invoke: \c checkpoint::allocateConstructForPointer<SerializerT,T>
* you must invoke \c checkpoint::reconstructPointedToObjectIfNeeded(s, t);
*
* Example with raw pointer:
*
Expand All @@ -94,7 +107,7 @@
* if (!is_null) {
* // During size/pack, save the actual derived type of raw_ptr;
* // During unpack, allocate/construct raw_ptr with correct virtual type
* checkpoint::allocateConstructForPointer(s, raw_ptr);
* checkpoint::reconstructPointedToObjectIfNeeded(s, raw_ptr);
* s | *raw_ptr;
* }
* }
Expand Down
117 changes: 85 additions & 32 deletions src/checkpoint/dispatch/vrt/virtual_serialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,63 @@ void virtualSerialize(T*& base, SerializerT& s) {
namespace checkpoint {

/**
* \struct SerializeAsVirtualIfNeeded
* \struct SerializeVirtualTypeIfNeeded
*
* \brief Do a static trait test on type to check for virtual
* serializability. If virtually serializable, we need to perform some
* extra work to register the type, and serialize its index.
*/
template <typename T, typename SerializerT, typename _enabled = void>
struct SerializeVirtualTypeIfNeeded;

template <typename T, typename SerializerT>
struct SerializeVirtualTypeIfNeeded<
T,
SerializerT,
typename std::enable_if_t<
dispatch::vrt::VirtualSerializeTraits<T>::has_not_virtual_serialize
>
>
{
static dispatch::vrt::TypeIdx apply(SerializerT& s, T* target) {
// no type idx needed in this case
return dispatch::vrt::no_type_idx;
}
};

template <typename T, typename SerializerT>
struct SerializeVirtualTypeIfNeeded<
T,
SerializerT,
typename std::enable_if_t<
dispatch::vrt::VirtualSerializeTraits<T>::has_virtual_serialize
>
>
{
static typename dispatch::vrt::TypeIdx apply(SerializerT& s, T* target) {
dispatch::vrt::TypeIdx entry = dispatch::vrt::no_type_idx;

if (not s.isUnpacking()) {
entry = target->_checkpointDynamicTypeIndex();
}

// entry doesn't count as part of a footprint
if (not s.isFootprinting()) {
s | entry;
}

if (target != nullptr) {
// Support deserialization in place, and make sure it's safe
checkpointAssert(entry == target->_checkpointDynamicTypeIndex(),
"Trying to deserialize in place over a mismatched type");
}

return entry;
}
};

/**
* \struct ReconstructAsVirtualIfNeeded
*
* \brief Do a static trait test on type to check for virtual
* serializability. If virtually serializable, we need to perform some extra
Expand All @@ -80,64 +136,52 @@ namespace checkpoint {
* and serializing what the pointer points to.
*/
template <typename T, typename SerializerT, typename _enabled = void>
struct SerializeAsVirtualIfNeeded;
struct ReconstructAsVirtualIfNeeded;

template <typename T, typename SerializerT>
struct SerializeAsVirtualIfNeeded<
struct ReconstructAsVirtualIfNeeded<
T,
SerializerT,
typename std::enable_if_t<
dispatch::vrt::VirtualSerializeTraits<T>::has_not_virtual_serialize and
not std::is_same<SerializerT, checkpoint::Footprinter>::value
>
> {
static void apply(SerializerT& s, T*& target) {
static T* apply(SerializerT& s, dispatch::vrt::TypeIdx entry) {
// no type idx needed in this case, static construction in default case
auto t = std::allocator<T>{}.allocate(1);
target = dispatch::Reconstructor<T>::construct(t);
return dispatch::Reconstructor<T>::construct(t);
}
};

template <typename T, typename SerializerT>
struct SerializeAsVirtualIfNeeded<
struct ReconstructAsVirtualIfNeeded<
T,
SerializerT,
typename std::enable_if_t<
dispatch::vrt::VirtualSerializeTraits<T>::has_not_virtual_serialize and
std::is_same<SerializerT, checkpoint::Footprinter>::value
>
> {
static void apply(SerializerT& s, T*& target) { }
static T* apply(SerializerT& s, dispatch::vrt::TypeIdx entry) { return nullptr; }
};

template <typename T, typename SerializerT>
struct SerializeAsVirtualIfNeeded<
struct ReconstructAsVirtualIfNeeded<
T,
SerializerT,
typename std::enable_if_t<
dispatch::vrt::VirtualSerializeTraits<T>::has_virtual_serialize
>
> {
static void apply(SerializerT& s, T*& target) {
using dispatch::vrt::TypeIdx;

TypeIdx entry = dispatch::vrt::no_type_idx;

if (not s.isUnpacking()) {
entry = target->_checkpointDynamicTypeIndex();
}

s | entry;

if (s.isUnpacking()) {
using BaseT = ::checkpoint::dispatch::vrt::checkpoint_base_type_t<T>;

// use type idx here, registration needed for proper type re-construction
auto t = dispatch::vrt::objregistry::allocateConcreteType<BaseT>(entry);
target = static_cast<T*>(
dispatch::vrt::objregistry::constructConcreteType<BaseT>(entry, t)
);
}
static T* apply(SerializerT& s, dispatch::vrt::TypeIdx entry) {
using BaseT = ::checkpoint::dispatch::vrt::checkpoint_base_type_t<T>;

// use type idx here, registration needed for proper type re-construction
auto t = dispatch::vrt::objregistry::allocateConcreteType<BaseT>(entry);
return static_cast<T*>(
dispatch::vrt::objregistry::constructConcreteType<BaseT>(entry, t)
);
}
};

Expand All @@ -157,21 +201,30 @@ struct SerializeAsVirtualIfNeeded<
* template <typename SerializerT>
* void serialize(SerializerT& s) {
* T* raw = elm.get();
* checkpoint::allocateConstructForPointer(s, raw);
* checkpoint::reconstructPointedToObjectIfNeeded(s, raw);
* if (s.isUnpacking()) {
* a = std::shared_ptr<T>(raw);
* }
* s | *a;
* }
* };
*
*
* \param[in] s the serializer
* \param[in] target a reference to a pointer to the target object
*/
template <typename SerializerT, typename T>
void allocateConstructForPointer(SerializerT& s, T*& target) {
SerializeAsVirtualIfNeeded<T, SerializerT>::apply(s, target);
void reconstructPointedToObjectIfNeeded(SerializerT& s, T*& target) {
auto entry = SerializeVirtualTypeIfNeeded<T, SerializerT>::apply(s, target);

if (target != nullptr) {
// Support deserialization in place; assumes matching of virtual
// types was checked in serializeDynamicTypeIndex()
return;
}

if (s.isUnpacking()) {
target = ReconstructAsVirtualIfNeeded<T, SerializerT>::apply(s, entry);
}
}

} /* end namespace checkpoint */
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/test_unique_ptr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ struct UserObject2 {
EXPECT_EQ(vec[0], vec_val);
}

private:
int x = 0, y = 0;
std::vector<int> vec;
};
Expand All @@ -101,19 +100,21 @@ struct UserObject1 {

template <typename Serializer>
void serialize(Serializer& s) {
s | z | obj | obj_null;
s | z | obj | obj_null | obj_reset_null;
}

void check() {
EXPECT_EQ(z, z_val);
EXPECT_NE(obj, nullptr);
obj->check();
EXPECT_EQ(obj_null, nullptr);
EXPECT_EQ(obj_reset_null, nullptr);
}

int z = 0;
std::unique_ptr<UserObject2> obj = nullptr;
std::unique_ptr<UserObject2> obj_null = nullptr;
std::unique_ptr<UserObject2> obj_reset_null = nullptr;
};

TEST_F(TestUniquePtr, test_unique_ptr_1) {
Expand All @@ -122,6 +123,13 @@ TEST_F(TestUniquePtr, test_unique_ptr_1) {
auto ret = checkpoint::serialize(t);
auto out = checkpoint::deserialize<UserObject1>(ret->getBuffer());
out->check();

UserObject1 u{UserObject1::MakeTag{}};
u.obj_reset_null = std::make_unique<UserObject2>();
u.obj->x = 1;
u.obj->vec.clear();
checkpoint::deserializeInPlace(ret->getBuffer(), &u);
u.check();
}

}}} // end namespace checkpoint::tests::unit
46 changes: 41 additions & 5 deletions tests/unit/test_virtual_serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ struct TestWrapper {

for (auto&& elm : vec) {
TestBase* base = elm.get();
checkpoint::allocateConstructForPointer(s, base);
checkpoint::reconstructPointedToObjectIfNeeded(s, base);
if (s.isUnpacking()) {
elm = std::shared_ptr<TestBase>(base);
}
Expand All @@ -359,7 +359,7 @@ struct TestWrapper {

for (auto&& elm : vec_derived) {
TestDerived2* derived = elm.get();
checkpoint::allocateConstructForPointer(s, derived);
checkpoint::reconstructPointedToObjectIfNeeded(s, derived);
if (s.isUnpacking()) {
elm = std::shared_ptr<TestDerived2>(derived);
}
Expand Down Expand Up @@ -643,11 +643,11 @@ template <typename ObjT>
struct HolderBasic final : HolderObjBase<ObjT> {
checkpoint_virtual_serialize_derived_from(HolderObjBase<ObjT>)

ObjT* get() override { return obj_; }
ObjT* obj_ = nullptr;
ObjT* get() override { return obj_.get(); }
std::unique_ptr<ObjT> obj_ = nullptr;

template <typename Serializer>
void serialize(Serializer& s) {}
void serialize(Serializer& s) { s | obj_; }
};

TEST_F(TestVirtualSerializeTemplated, test_virtual_serialize_templated) {
Expand All @@ -658,4 +658,40 @@ TEST_F(TestVirtualSerializeTemplated, test_virtual_serialize_templated) {
checkpoint::deserialize<TestType>(std::move(ret));
}

// Test for virtual serialize of a raw pointer

using TestVirtualSerializeRaw = TestHarness;

struct Owner {
Owner(int i) {
HolderBasic<int> *ptr_orig = new HolderBasic<int>;
ptr_orig->obj_ = std::make_unique<int>(i);
ptr_raw_base_ = ptr_orig;
}
Owner() = default;
~Owner() { delete ptr_raw_base_; }

template <typename SerializerT>
void serialize(SerializerT &s) {
// Mimic the snippet given in dispatch_virtual.h
bool is_null = ptr_raw_base_ == nullptr;
s | is_null;
if (!is_null) {
checkpoint::reconstructPointedToObjectIfNeeded(s, ptr_raw_base_);
s | *ptr_raw_base_;
}
}

HolderObjBase<int> *ptr_raw_base_ = nullptr;
};

TEST_F(TestVirtualSerializeRaw, test_virtual_serialize_raw) {
Owner a{10};

auto buf = checkpoint::serialize(a);
auto b = checkpoint::deserialize<Owner>(std::move(buf));

EXPECT_EQ(*(b->ptr_raw_base_->get()), 10);
}

}}} // end namespace checkpoint::tests::unit

0 comments on commit f64650d

Please sign in to comment.