Skip to content

Commit

Permalink
#344 Workaround for NVCC bug
Browse files Browse the repository at this point in the history
For NVCC < 11.7.1, the templated static constexpr member variables give
the wrong results (has_trait_v replaced with has_trait::value).
  • Loading branch information
Matthew-Whitlock committed Sep 24, 2024
1 parent 14b22f3 commit 5bfb72b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 40 deletions.
14 changes: 7 additions & 7 deletions examples/checkpoint_example_user_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ namespace test {

TestObj() {}

template<typename SerT, typename std::enable_if_t<SerT::template has_not_traits<shallow_trait>::value>* = nullptr>
template<typename SerT, typename SerT::template has_not_traits<shallow_trait>::type = nullptr>
void serialize(SerT& s){
if constexpr(SerT::template has_traits<checkpoint_trait>::value){
if constexpr(SerT::template has_traits_v<checkpoint_trait>){
if(s.isSizing()) printf("Customizing serialization for checkpoint\n");
s | a;
} else {
Expand All @@ -26,13 +26,13 @@ namespace test {
}

namespace test {
template<typename SerT, typename std::enable_if_t<SerT::template has_traits<random_trait>::value>* = nullptr>
template<typename SerT, typename SerT::template has_traits<random_trait>::type = nullptr>
void serialize(SerT& s, TestObj& myObj){
if(s.isSizing()) printf("Inserting random extra object serialization step! ");
myObj.serialize(s);
}

template<typename SerT, typename std::enable_if_t<SerT::template has_traits<shallow_trait>::value>* = nullptr>
template<typename SerT, typename SerT::template has_traits<shallow_trait>::type = nullptr>
void serialize(SerT& s, TestObj& myObj){
if(s.isSizing()) printf("Removing shallow trait before passing along!\n");
auto newS = s.template withoutTraits<shallow_trait>();
Expand All @@ -41,23 +41,23 @@ namespace test {
}

namespace misc {
template<typename SerT, typename std::enable_if_t<SerT::template has_traits<test::random_trait>::value>* = nullptr>
template<typename SerT, typename SerT::template has_traits<test::random_trait>::type = nullptr>
void serialize(SerT& s, test::TestObj& myObj){
if(s.isSizing()) printf("Serializers in other namespaces don't usually get found ");
myObj.serialize(s);
}


const struct namespace_trait {} NamespaceTrait;
template<typename SerT, typename std::enable_if_t<SerT::template has_traits<namespace_trait>::value>* = nullptr>
template<typename SerT, typename SerT::template has_traits<namespace_trait>::type = nullptr>
void serialize(SerT& s, test::TestObj& myObj){
if(s.isSizing()) printf("A misc:: trait means we can serialize from misc:: too: ");
myObj.serialize(s);
}


const struct hook_all_trait {} HookAllTrait;
template<typename SerT, typename T, typename std::enable_if_t<SerT::template has_traits<hook_all_trait>::value>* = nullptr>
template<typename SerT, typename T, typename SerT::template has_traits<hook_all_trait>::type = nullptr>
void serialize(SerT& s, T& myObj){
if(s.isSizing()) printf("We can even add on a generic pre-serialize hook: ");
auto newS = s.template withoutTraits<hook_all_trait>();
Expand Down
47 changes: 22 additions & 25 deletions src/checkpoint/serializers/serializer_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@

namespace checkpoint {

namespace {
// Cuda does not play nicely with templated static constexpr
// member variables in the SerializerRef class, so we make a
// helper struct to hold non-templated static constexpr members.
template<bool B>
struct bool_enable_if {
static constexpr bool value = false;
};

template<>
struct bool_enable_if<true> {
static constexpr bool value = true;
using type = void*;
};

}

template<typename SerT, typename UserTraits = UserTraitHolder<>>
struct SerializerRef
{
Expand Down Expand Up @@ -112,42 +129,22 @@ struct SerializerRef

//Big block of helpers for conveniently checking traits in different contexts.
template<typename... Traits>
using has_traits = typename TraitHolder::template has<Traits...>;
using has_traits = bool_enable_if<TraitHolder::template has<Traits...>::value>;
template<typename... Traits>
using has_any_traits = typename TraitHolder::template has_any<Traits...>;
using has_any_traits = bool_enable_if<TraitHolder::template has_any<Traits...>::value>;

template<typename... Traits>
using has_not_traits = std::integral_constant<bool, !(has_traits<Traits...>::value)>;
template<typename... Traits>
using has_not_any_traits = std::integral_constant<bool, !(has_any_traits<Traits...>::value)>;

template<typename... Traits>
static constexpr bool has_traits_v = has_traits<Traits...>::value;
template<typename... Traits>
static constexpr bool has_any_traits_v = has_any_traits<Traits...>::value;
template<typename... Traits>
static constexpr bool has_not_traits_v = has_not_traits<Traits...>::value;
using has_not_traits = bool_enable_if<!(has_traits<Traits...>::value)>;
template<typename... Traits>
static constexpr bool has_not_any_traits_v = has_not_any_traits<Traits...>::value;

template<typename... Traits>
using has_traits_t = std::enable_if_t<has_traits_v<Traits...>>;
template<typename... Traits>
using has_any_traits_t = std::enable_if_t<has_any_traits_v<Traits...>>;
template<typename... Traits>
using has_not_traits_t = std::enable_if_t<has_not_traits_v<Traits...>>;
template<typename... Traits>
using has_not_any_traits_t = std::enable_if_t<has_not_any_traits_v<Traits...>>;


using has_not_any_traits = bool_enable_if<!(has_any_traits<Traits...>::value)>;

//Helpers for converting between traits
using TraitlessT = SerializerRef<SerT>;

//Returns a new reference with traits in addition to this reference's traits.
template<typename Trait, typename... Traits>
auto withTraits(UserTraitHolder<Trait, Traits...> = {}){
using NewTraitHolder = typename TraitHolder::template with<Trait, Traits...>;
//return setTraits(NewTraitHolder {});
return SerializerRef<SerT, NewTraitHolder>(*this);
}

Expand Down
22 changes: 14 additions & 8 deletions tests/unit/test_user_traits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,13 @@ struct UserObjectA {

template <typename S>
void serialize(S& s) {
std::cout << "A: serializing with type "
<< abi::__cxa_demangle(typeid(s).name(), nullptr, nullptr, nullptr)
<< std::endl;
EXPECT_FALSE((S::template has_traits_v<ShallowTrait>));

EXPECT_EQ(
(S::template has_traits_v<TraitPairA>),
(S::template has_traits_v<TraitPairA>),
(S::template has_traits_v<TraitPairB>)
);

Expand All @@ -93,20 +96,20 @@ struct UserObjectA {
};

template<
typename SerT,
typename = typename SerT::template has_traits_t<CheckpointTraitNonintrusive>
typename S,
typename = typename S::template has_traits<CheckpointTraitNonintrusive>::type
>
void serialize(SerT& s, UserObjectA& obj){
void serialize(S& s, UserObjectA& obj){
s | obj.name;
obj.serialize(s);
}

namespace CheckpointNamespace {
template<
typename SerT,
typename = typename SerT::template has_traits_t<CheckpointTraitNamespaced>
typename S,
typename = typename S::template has_traits<CheckpointTraitNamespaced>::type
>
void serialize(SerT& s, UserObjectA& obj){
void serialize(S& s, UserObjectA& obj){
s | obj.name;
obj.serialize(s);
}
Expand Down Expand Up @@ -150,8 +153,11 @@ struct UserObjectB : public UserObjectA {

template <typename S>
void serialize(S& s) {
std::cout << "B: serializing with type "
<< abi::__cxa_demangle(typeid(s).name(), nullptr, nullptr, nullptr)
<< std::endl;
auto new_s = s.template withoutTraits<ShallowTrait>();
if constexpr(S::template has_traits_v<TraitPairA>){
if (S::template has_traits<TraitPairA>::value){
auto newer_s = new_s.template withTraits<TraitPairB>();
UserObjectA::serialize(newer_s);
} else {
Expand Down

0 comments on commit 5bfb72b

Please sign in to comment.