From b0fa87aaee317e01eb155e0de6c8568d9022bfa4 Mon Sep 17 00:00:00 2001 From: Baidyanath Kundu Date: Fri, 20 Aug 2021 13:57:21 +0530 Subject: [PATCH] Modify MoveData in Tape.h to make it CUDA compatible * Add addressof function make the call to std::address CUDA compatible * Change fprintf to printf * Add trap function to replace exit function --- include/clad/Differentiator/CladConfig.h | 36 ++++++++++++++++++++++++ include/clad/Differentiator/Tape.h | 6 ++-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/include/clad/Differentiator/CladConfig.h b/include/clad/Differentiator/CladConfig.h index 21ddd1216..7d2fa028f 100644 --- a/include/clad/Differentiator/CladConfig.h +++ b/include/clad/Differentiator/CladConfig.h @@ -3,10 +3,46 @@ #ifndef CLAD_CONFIG_H #define CLAD_CONFIG_H +#include +#include + +// Define CUDA_HOST_DEVICE attribute for adding CUDA support to +// clad functions #ifdef __CUDACC__ #define CUDA_HOST_DEVICE __host__ __device__ #else #define CUDA_HOST_DEVICE #endif +// Define trap function that is a CUDA compatible replacement for +// exit(int code) function +#ifdef __CUDACC__ +__device__ void trap(int code) { + asm("trap;"); +} +__host__ void trap(int code) { + exit(code); +} +#else +void trap(int code) { + exit(code); +} +#endif + +#ifdef __CUDACC__ +template +__device__ T* addressof(T& r) { + return __builtin_addressof(r); +} +template +__host__ T* addressof(T& r) { + return std::addressof(r); +} +#else +template +T* addressof(T& r) { + return std::addressof(r); +} +#endif + #endif // CLAD_CONFIG_H diff --git a/include/clad/Differentiator/Tape.h b/include/clad/Differentiator/Tape.h index f421b8b8c..beb715e75 100644 --- a/include/clad/Differentiator/Tape.h +++ b/include/clad/Differentiator/Tape.h @@ -100,13 +100,13 @@ namespace clad { // allocation properly. for (; first != last; ++first, (void)++current) { auto new_data = ::new (const_cast( - static_cast(std::addressof(*current)))) + static_cast(addressof(current)))) T(std::move(*first)); if (!new_data) { // clean up the memory mess just in case! destroy(d_first, current); - fprintf(stderr, "Allocation failure during tape resize! Aborting."); - exit(EXIT_FAILURE); + printf("Allocation failure during tape resize! Aborting."); + trap(EXIT_FAILURE); } } }