diff --git a/src/checkpoint/container/unique_ptr_serialize.h b/src/checkpoint/container/unique_ptr_serialize.h index af97a53f..bfcb95af 100644 --- a/src/checkpoint/container/unique_ptr_serialize.h +++ b/src/checkpoint/container/unique_ptr_serialize.h @@ -60,14 +60,25 @@ void serialize(Serializer& s, std::unique_ptr& ptr) { s | is_null; } - if (not is_null) { - T* t = ptr.get(); - allocateConstructForPointer(s, t); + if (is_null) { if (s.isUnpacking()) { + ptr.reset(); + } + return; + } + + T* t = ptr.get(); + auto entry = serializeDynamicTypeIndex(s, t); + allocateConstructForPointer(s, t, entry); + + if (s.isUnpacking()) { + // Support deserialization in place + if (ptr.get() != t) { ptr = std::unique_ptr(t); } - s | *ptr; } + + s | *ptr; } } /* end namespace checkpoint */ diff --git a/src/checkpoint/dispatch/dispatch_virtual.h b/src/checkpoint/dispatch/dispatch_virtual.h index 660803bd..75ad7951 100644 --- a/src/checkpoint/dispatch/dispatch_virtual.h +++ b/src/checkpoint/dispatch/dispatch_virtual.h @@ -58,6 +58,7 @@ * ----------------------------- API usage: ---------------------------------- * --------------------------------------------------------------------------- * + * ----------------------------- Declarations -------------------------------- * === Option 1 === * * - Make your virtual class hierarchy you want to serialize all inherit from @@ -71,15 +72,37 @@ * - \c checkpoint_virtual_serialize_root() * - \c checkpoint_virtual_serialize_derived_from(ParentT) * - * Invoking the virtual serializer: + * --------------------------- Invocation ------------------------------------ * * - If you have a \c std::unique_ptr, where T is virtually serializable (by * using the macros or inheriting from \c SerializableBase and * \c SerializableDervived ), they will automatically be virtually * serialized. * + * Example using unique_ptr: + * template + * struct MyObjectWithUniquePointer { + * unique_ptr ptr; + * + * template + * void serialize(SerializerT& s) { + * s | ptr; + * } + * }; + * + * * - If you have a raw pointer, \c Teuchos::RCP, or \c std::shared_ptr, - * you must invoke: \c checkpoint::allocateConstructForPointer + * you must manually serialize the dynamic type and call for its + * reconstruction + * + * Serializing the dynamic type's registered index: + * + * T* t; + * auto entry = checkpoint::serializeDynamicTypeIndex(s, t); + * + * Invoking the virtual class reconstructor: + * + * checkpoint::allocateConstructForPointer(s, t, entry); * * Example with raw pointer: * @@ -94,7 +117,8 @@ * if (!is_null) { * // During size/pack, save the actual derived type of raw_ptr; * // During unpack, allocate/construct raw_ptr with correct virtual type - * checkpoint::allocateConstructForPointer(s, raw_ptr); + * auto entry = checkpoint::serializeDynamicTypeIndex(s, raw_ptr); + * checkpoint::allocateConstructForPointer(s, raw_ptr, entry); * s | *raw_ptr; * } * } diff --git a/src/checkpoint/dispatch/vrt/virtual_serialize.h b/src/checkpoint/dispatch/vrt/virtual_serialize.h index 8d0de571..c99a7eee 100644 --- a/src/checkpoint/dispatch/vrt/virtual_serialize.h +++ b/src/checkpoint/dispatch/vrt/virtual_serialize.h @@ -71,7 +71,63 @@ void virtualSerialize(T*& base, SerializerT& s) { namespace checkpoint { /** - * \struct SerializeAsVirtualIfNeeded + * \struct SerializeVirtualTypeIfNeeded + * + * \brief Do a static trait test on type to check for virtual + * serializability. If virtually serializable, we need to perform some + * extra work to register the type, and serialize its index. + */ +template +struct SerializeVirtualTypeIfNeeded; + +template +struct SerializeVirtualTypeIfNeeded< + T, + SerializerT, + typename std::enable_if_t< + dispatch::vrt::VirtualSerializeTraits::has_not_virtual_serialize + > +> +{ + static dispatch::vrt::TypeIdx apply(SerializerT& s, T* target) { + // no type idx needed in this case + return dispatch::vrt::no_type_idx; + } +}; + +template +struct SerializeVirtualTypeIfNeeded< + T, + SerializerT, + typename std::enable_if_t< + dispatch::vrt::VirtualSerializeTraits::has_virtual_serialize + > +> +{ + static typename dispatch::vrt::TypeIdx apply(SerializerT& s, T* target) { + dispatch::vrt::TypeIdx entry = dispatch::vrt::no_type_idx; + + if (not s.isUnpacking()) { + entry = target->_checkpointDynamicTypeIndex(); + } else { + if (target != nullptr) { + // Support deserialization in place, and make sure it's safe + checkpointAssert(entry == target->_checkpointDynamicTypeIndex(), + "Trying to deserialize in place over a mismatched type"); + } + } + + // entry doesn't count as part of a footprint + if (not s.isFootprinting()) { + s | entry; + } + + return entry; + } +}; + +/** + * \struct ReconstructAsVirtualIfNeeded * * \brief Do a static trait test on type to check for virtual * serializability. If virtually serializable, we need to perform some extra @@ -80,10 +136,10 @@ namespace checkpoint { * and serializing what the pointer points to. */ template -struct SerializeAsVirtualIfNeeded; +struct ReconstructAsVirtualIfNeeded; template -struct SerializeAsVirtualIfNeeded< +struct ReconstructAsVirtualIfNeeded< T, SerializerT, typename std::enable_if_t< @@ -91,7 +147,7 @@ struct SerializeAsVirtualIfNeeded< not std::is_same::value > > { - static void apply(SerializerT& s, T*& target) { + static void apply(SerializerT& s, T*& target, dispatch::vrt::TypeIdx entry) { // no type idx needed in this case, static construction in default case auto t = std::allocator{}.allocate(1); target = dispatch::Reconstructor::construct(t); @@ -99,7 +155,7 @@ struct SerializeAsVirtualIfNeeded< }; template -struct SerializeAsVirtualIfNeeded< +struct ReconstructAsVirtualIfNeeded< T, SerializerT, typename std::enable_if_t< @@ -107,37 +163,25 @@ struct SerializeAsVirtualIfNeeded< std::is_same::value > > { - static void apply(SerializerT& s, T*& target) { } + static void apply(SerializerT& s, T*& target, dispatch::vrt::TypeIdx entry) { } }; template -struct SerializeAsVirtualIfNeeded< +struct ReconstructAsVirtualIfNeeded< T, SerializerT, typename std::enable_if_t< dispatch::vrt::VirtualSerializeTraits::has_virtual_serialize > > { - static void apply(SerializerT& s, T*& target) { - using dispatch::vrt::TypeIdx; - - TypeIdx entry = dispatch::vrt::no_type_idx; - - if (not s.isUnpacking()) { - entry = target->_checkpointDynamicTypeIndex(); - } - - s | entry; - - if (s.isUnpacking()) { - using BaseT = ::checkpoint::dispatch::vrt::checkpoint_base_type_t; - - // use type idx here, registration needed for proper type re-construction - auto t = dispatch::vrt::objregistry::allocateConcreteType(entry); - target = static_cast( - dispatch::vrt::objregistry::constructConcreteType(entry, t) - ); - } + static void apply(SerializerT& s, T*& target, dispatch::vrt::TypeIdx entry) { + using BaseT = ::checkpoint::dispatch::vrt::checkpoint_base_type_t; + + // use type idx here, registration needed for proper type re-construction + auto t = dispatch::vrt::objregistry::allocateConcreteType(entry); + target = static_cast( + dispatch::vrt::objregistry::constructConcreteType(entry, t) + ); } }; @@ -157,7 +201,8 @@ struct SerializeAsVirtualIfNeeded< * template * void serialize(SerializerT& s) { * T* raw = elm.get(); - * checkpoint::allocateConstructForPointer(s, raw); + * auto entry = serializeDynamicTypeIndex(s, raw); + * checkpoint::allocateConstructForPointer(s, raw, entry); * if (s.isUnpacking()) { * a = std::shared_ptr(raw); * } @@ -165,13 +210,25 @@ struct SerializeAsVirtualIfNeeded< * } * }; * - * * \param[in] s the serializer * \param[in] target a reference to a pointer to the target object */ template -void allocateConstructForPointer(SerializerT& s, T*& target) { - SerializeAsVirtualIfNeeded::apply(s, target); +void allocateConstructForPointer(SerializerT& s, T*& target, const dispatch::vrt::TypeIdx entry) { + if (target != nullptr) { + // Support deserialization in place; assumes matching of virtual + // types was checked in serializeDynamicTypeIndex() + return; + } + + if (s.isUnpacking()) { + ReconstructAsVirtualIfNeeded::apply(s, target, entry); + } +} + +template +dispatch::vrt::TypeIdx serializeDynamicTypeIndex(SerializerT& s, T* target) { + return SerializeVirtualTypeIfNeeded::apply(s, target); } } /* end namespace checkpoint */ diff --git a/tests/unit/test_unique_ptr.cc b/tests/unit/test_unique_ptr.cc index 9476eae0..1aabb184 100644 --- a/tests/unit/test_unique_ptr.cc +++ b/tests/unit/test_unique_ptr.cc @@ -84,7 +84,6 @@ struct UserObject2 { EXPECT_EQ(vec[0], vec_val); } -private: int x = 0, y = 0; std::vector vec; }; @@ -101,7 +100,7 @@ struct UserObject1 { template void serialize(Serializer& s) { - s | z | obj | obj_null; + s | z | obj | obj_null | obj_reset_null; } void check() { @@ -109,11 +108,13 @@ struct UserObject1 { EXPECT_NE(obj, nullptr); obj->check(); EXPECT_EQ(obj_null, nullptr); + EXPECT_EQ(obj_reset_null, nullptr); } int z = 0; std::unique_ptr obj = nullptr; std::unique_ptr obj_null = nullptr; + std::unique_ptr obj_reset_null = nullptr; }; TEST_F(TestUniquePtr, test_unique_ptr_1) { @@ -122,6 +123,13 @@ TEST_F(TestUniquePtr, test_unique_ptr_1) { auto ret = checkpoint::serialize(t); auto out = checkpoint::deserialize(ret->getBuffer()); out->check(); + + UserObject1 u{UserObject1::MakeTag{}}; + u.obj_reset_null = std::make_unique(); + u.obj->x = 1; + u.obj->vec.clear(); + checkpoint::deserializeInPlace(ret->getBuffer(), &u); + u.check(); } }}} // end namespace checkpoint::tests::unit diff --git a/tests/unit/test_virtual_serialize.cc b/tests/unit/test_virtual_serialize.cc index b9c58296..7faa6bc3 100644 --- a/tests/unit/test_virtual_serialize.cc +++ b/tests/unit/test_virtual_serialize.cc @@ -350,7 +350,8 @@ struct TestWrapper { for (auto&& elm : vec) { TestBase* base = elm.get(); - checkpoint::allocateConstructForPointer(s, base); + auto entry = checkpoint::serializeDynamicTypeIndex(s, base); + checkpoint::allocateConstructForPointer(s, base, entry); if (s.isUnpacking()) { elm = std::shared_ptr(base); } @@ -359,7 +360,8 @@ struct TestWrapper { for (auto&& elm : vec_derived) { TestDerived2* derived = elm.get(); - checkpoint::allocateConstructForPointer(s, derived); + auto entry = checkpoint::serializeDynamicTypeIndex(s, derived); + checkpoint::allocateConstructForPointer(s, derived, entry); if (s.isUnpacking()) { elm = std::shared_ptr(derived); }