Skip to content

Commit

Permalink
#151: Split up virtual serialization logic and better support in-plac…
Browse files Browse the repository at this point in the history
…e deserialization
  • Loading branch information
PhilMiller committed Oct 30, 2020
1 parent 318d5bc commit 0347b36
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 42 deletions.
19 changes: 15 additions & 4 deletions src/checkpoint/container/unique_ptr_serialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,25 @@ void serialize(Serializer& s, std::unique_ptr<T, Deleter>& ptr) {
s | is_null;
}

if (not is_null) {
T* t = ptr.get();
allocateConstructForPointer(s, t);
if (is_null) {
if (s.isUnpacking()) {
ptr.reset();
}
return;
}

T* t = ptr.get();
auto entry = serializeDynamicTypeIndex(s, t);
allocateConstructForPointer(s, t, entry);

if (s.isUnpacking()) {
// Support deserialization in place
if (ptr.get() != t) {
ptr = std::unique_ptr<T, Deleter>(t);
}
s | *ptr;
}

s | *ptr;
}

} /* end namespace checkpoint */
Expand Down
30 changes: 27 additions & 3 deletions src/checkpoint/dispatch/dispatch_virtual.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
* ----------------------------- API usage: ----------------------------------
* ---------------------------------------------------------------------------
*
* ----------------------------- Declarations --------------------------------
* === Option 1 ===
*
* - Make your virtual class hierarchy you want to serialize all inherit from
Expand All @@ -71,15 +72,37 @@
* - \c checkpoint_virtual_serialize_root()
* - \c checkpoint_virtual_serialize_derived_from(ParentT)
*
* Invoking the virtual serializer:
* --------------------------- Invocation ------------------------------------
*
* - If you have a \c std::unique_ptr<T>, where T is virtually serializable (by
* using the macros or inheriting from \c SerializableBase<T> and
* \c SerializableDervived<T,BaseT> ), they will automatically be virtually
* serialized.
*
* Example using unique_ptr:
* template <typename T>
* struct MyObjectWithUniquePointer {
* unique_ptr<T> ptr;
*
* template <typename SerializerT>
* void serialize(SerializerT& s) {
* s | ptr;
* }
* };
*
*
* - If you have a raw pointer, \c Teuchos::RCP, or \c std::shared_ptr<T>,
* you must invoke: \c checkpoint::allocateConstructForPointer<SerializerT,T>
* you must manually serialize the dynamic type and call for its
* reconstruction
*
* Serializing the dynamic type's registered index:
*
* T* t;
* auto entry = checkpoint::serializeDynamicTypeIndex(s, t);
*
* Invoking the virtual class reconstructor:
*
* checkpoint::allocateConstructForPointer(s, t, entry);
*
* Example with raw pointer:
*
Expand All @@ -94,7 +117,8 @@
* if (!is_null) {
* // During size/pack, save the actual derived type of raw_ptr;
* // During unpack, allocate/construct raw_ptr with correct virtual type
* checkpoint::allocateConstructForPointer(s, raw_ptr);
* auto entry = checkpoint::serializeDynamicTypeIndex(s, raw_ptr);
* checkpoint::allocateConstructForPointer(s, raw_ptr, entry);
* s | *raw_ptr;
* }
* }
Expand Down
119 changes: 88 additions & 31 deletions src/checkpoint/dispatch/vrt/virtual_serialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,63 @@ void virtualSerialize(T*& base, SerializerT& s) {
namespace checkpoint {

/**
* \struct SerializeAsVirtualIfNeeded
* \struct SerializeVirtualTypeIfNeeded
*
* \brief Do a static trait test on type to check for virtual
* serializability. If virtually serializable, we need to perform some
* extra work to register the type, and serialize its index.
*/
template <typename T, typename SerializerT, typename _enabled = void>
struct SerializeVirtualTypeIfNeeded;

template <typename T, typename SerializerT>
struct SerializeVirtualTypeIfNeeded<
T,
SerializerT,
typename std::enable_if_t<
dispatch::vrt::VirtualSerializeTraits<T>::has_not_virtual_serialize
>
>
{
static dispatch::vrt::TypeIdx apply(SerializerT& s, T* target) {
// no type idx needed in this case
return dispatch::vrt::no_type_idx;
}
};

template <typename T, typename SerializerT>
struct SerializeVirtualTypeIfNeeded<
T,
SerializerT,
typename std::enable_if_t<
dispatch::vrt::VirtualSerializeTraits<T>::has_virtual_serialize
>
>
{
static typename dispatch::vrt::TypeIdx apply(SerializerT& s, T* target) {
dispatch::vrt::TypeIdx entry = dispatch::vrt::no_type_idx;

if (not s.isUnpacking()) {
entry = target->_checkpointDynamicTypeIndex();
} else {
if (target != nullptr) {
// Support deserialization in place, and make sure it's safe
checkpointAssert(entry == target->_checkpointDynamicTypeIndex(),
"Trying to deserialize in place over a mismatched type");
}
}

// entry doesn't count as part of a footprint
if (not s.isFootprinting()) {
s | entry;
}

return entry;
}
};

/**
* \struct ReconstructAsVirtualIfNeeded
*
* \brief Do a static trait test on type to check for virtual
* serializability. If virtually serializable, we need to perform some extra
Expand All @@ -80,64 +136,52 @@ namespace checkpoint {
* and serializing what the pointer points to.
*/
template <typename T, typename SerializerT, typename _enabled = void>
struct SerializeAsVirtualIfNeeded;
struct ReconstructAsVirtualIfNeeded;

template <typename T, typename SerializerT>
struct SerializeAsVirtualIfNeeded<
struct ReconstructAsVirtualIfNeeded<
T,
SerializerT,
typename std::enable_if_t<
dispatch::vrt::VirtualSerializeTraits<T>::has_not_virtual_serialize and
not std::is_same<SerializerT, checkpoint::Footprinter>::value
>
> {
static void apply(SerializerT& s, T*& target) {
static void apply(SerializerT& s, T*& target, dispatch::vrt::TypeIdx entry) {
// no type idx needed in this case, static construction in default case
auto t = std::allocator<T>{}.allocate(1);
target = dispatch::Reconstructor<T>::construct(t);
}
};

template <typename T, typename SerializerT>
struct SerializeAsVirtualIfNeeded<
struct ReconstructAsVirtualIfNeeded<
T,
SerializerT,
typename std::enable_if_t<
dispatch::vrt::VirtualSerializeTraits<T>::has_not_virtual_serialize and
std::is_same<SerializerT, checkpoint::Footprinter>::value
>
> {
static void apply(SerializerT& s, T*& target) { }
static void apply(SerializerT& s, T*& target, dispatch::vrt::TypeIdx entry) { }
};

template <typename T, typename SerializerT>
struct SerializeAsVirtualIfNeeded<
struct ReconstructAsVirtualIfNeeded<
T,
SerializerT,
typename std::enable_if_t<
dispatch::vrt::VirtualSerializeTraits<T>::has_virtual_serialize
>
> {
static void apply(SerializerT& s, T*& target) {
using dispatch::vrt::TypeIdx;

TypeIdx entry = dispatch::vrt::no_type_idx;

if (not s.isUnpacking()) {
entry = target->_checkpointDynamicTypeIndex();
}

s | entry;

if (s.isUnpacking()) {
using BaseT = ::checkpoint::dispatch::vrt::checkpoint_base_type_t<T>;

// use type idx here, registration needed for proper type re-construction
auto t = dispatch::vrt::objregistry::allocateConcreteType<BaseT>(entry);
target = static_cast<T*>(
dispatch::vrt::objregistry::constructConcreteType<BaseT>(entry, t)
);
}
static void apply(SerializerT& s, T*& target, dispatch::vrt::TypeIdx entry) {
using BaseT = ::checkpoint::dispatch::vrt::checkpoint_base_type_t<T>;

// use type idx here, registration needed for proper type re-construction
auto t = dispatch::vrt::objregistry::allocateConcreteType<BaseT>(entry);
target = static_cast<T*>(
dispatch::vrt::objregistry::constructConcreteType<BaseT>(entry, t)
);
}
};

Expand All @@ -157,21 +201,34 @@ struct SerializeAsVirtualIfNeeded<
* template <typename SerializerT>
* void serialize(SerializerT& s) {
* T* raw = elm.get();
* checkpoint::allocateConstructForPointer(s, raw);
* auto entry = serializeDynamicTypeIndex(s, raw);
* checkpoint::allocateConstructForPointer(s, raw, entry);
* if (s.isUnpacking()) {
* a = std::shared_ptr<T>(raw);
* }
* s | *a;
* }
* };
*
*
* \param[in] s the serializer
* \param[in] target a reference to a pointer to the target object
*/
template <typename SerializerT, typename T>
void allocateConstructForPointer(SerializerT& s, T*& target) {
SerializeAsVirtualIfNeeded<T, SerializerT>::apply(s, target);
void allocateConstructForPointer(SerializerT& s, T*& target, const dispatch::vrt::TypeIdx entry) {
if (target != nullptr) {
// Support deserialization in place; assumes matching of virtual
// types was checked in serializeDynamicTypeIndex()
return;
}

if (s.isUnpacking()) {
ReconstructAsVirtualIfNeeded<T, SerializerT>::apply(s, target, entry);
}
}

template <typename SerializerT, typename T>
dispatch::vrt::TypeIdx serializeDynamicTypeIndex(SerializerT& s, T* target) {
return SerializeVirtualTypeIfNeeded<T, SerializerT>::apply(s, target);
}

} /* end namespace checkpoint */
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/test_unique_ptr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ struct UserObject2 {
EXPECT_EQ(vec[0], vec_val);
}

private:
int x = 0, y = 0;
std::vector<int> vec;
};
Expand All @@ -101,19 +100,21 @@ struct UserObject1 {

template <typename Serializer>
void serialize(Serializer& s) {
s | z | obj | obj_null;
s | z | obj | obj_null | obj_reset_null;
}

void check() {
EXPECT_EQ(z, z_val);
EXPECT_NE(obj, nullptr);
obj->check();
EXPECT_EQ(obj_null, nullptr);
EXPECT_EQ(obj_reset_null, nullptr);
}

int z = 0;
std::unique_ptr<UserObject2> obj = nullptr;
std::unique_ptr<UserObject2> obj_null = nullptr;
std::unique_ptr<UserObject2> obj_reset_null = nullptr;
};

TEST_F(TestUniquePtr, test_unique_ptr_1) {
Expand All @@ -122,6 +123,13 @@ TEST_F(TestUniquePtr, test_unique_ptr_1) {
auto ret = checkpoint::serialize(t);
auto out = checkpoint::deserialize<UserObject1>(ret->getBuffer());
out->check();

UserObject1 u{UserObject1::MakeTag{}};
u.obj_reset_null = std::make_unique<UserObject2>();
u.obj->x = 1;
u.obj->vec.clear();
checkpoint::deserializeInPlace(ret->getBuffer(), &u);
u.check();
}

}}} // end namespace checkpoint::tests::unit
6 changes: 4 additions & 2 deletions tests/unit/test_virtual_serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ struct TestWrapper {

for (auto&& elm : vec) {
TestBase* base = elm.get();
checkpoint::allocateConstructForPointer(s, base);
auto entry = checkpoint::serializeDynamicTypeIndex(s, base);
checkpoint::allocateConstructForPointer(s, base, entry);
if (s.isUnpacking()) {
elm = std::shared_ptr<TestBase>(base);
}
Expand All @@ -359,7 +360,8 @@ struct TestWrapper {

for (auto&& elm : vec_derived) {
TestDerived2* derived = elm.get();
checkpoint::allocateConstructForPointer(s, derived);
auto entry = checkpoint::serializeDynamicTypeIndex(s, derived);
checkpoint::allocateConstructForPointer(s, derived, entry);
if (s.isUnpacking()) {
elm = std::shared_ptr<TestDerived2>(derived);
}
Expand Down

0 comments on commit 0347b36

Please sign in to comment.