From b9f8f31964898f46b5d335df90241141d323b86b Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 22 Feb 2024 16:37:53 +0100 Subject: [PATCH] Add support for C-style memory alloc and free in forward mode AD --- .../clad/Differentiator/BuiltinDerivatives.h | 26 +++++++++++ lib/Differentiator/CladUtils.cpp | 4 +- test/ForwardMode/Pointer.C | 43 +++++++++++++++++++ 3 files changed, 70 insertions(+), 3 deletions(-) diff --git a/include/clad/Differentiator/BuiltinDerivatives.h b/include/clad/Differentiator/BuiltinDerivatives.h index fac077f00..50bc60880 100644 --- a/include/clad/Differentiator/BuiltinDerivatives.h +++ b/include/clad/Differentiator/BuiltinDerivatives.h @@ -196,6 +196,32 @@ CUDA_HOST_DEVICE void clamp_pullback(const T& v, const T& lo, const T& hi, #endif } // namespace std + +// NOLINTBEGIN(cppcoreguidelines-no-malloc) +ValueAndPushforward inline malloc_pushforward(size_t sz, + size_t d_sz) { + return {malloc(sz), malloc(sz)}; +} + +ValueAndPushforward inline calloc_pushforward(size_t n, size_t sz, + size_t d_n, + size_t d_sz) { + return {calloc(n, sz), calloc(n, sz)}; +} + +ValueAndPushforward inline realloc_pushforward(void* ptr, + size_t sz, + void* d_ptr, + size_t d_sz) { + return {realloc(ptr, sz), realloc(d_ptr, sz)}; +} + +void free_pushforward(void* ptr, void* d_ptr) { + free(ptr); + free(d_ptr); +} +// NOLINTEND(cppcoreguidelines-no-malloc) + // These are required because C variants of mathematical functions are // defined in global namespace. using std::abs_pushforward; diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index cf68b10f4..e78bb146e 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -653,9 +653,7 @@ namespace clad { } bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD) { - if (FD->getNameAsString() == "free") - return true; - return false; + return FD->getNameAsString() == "free"; } } // namespace utils } // namespace clad diff --git a/test/ForwardMode/Pointer.C b/test/ForwardMode/Pointer.C index 25fce39b3..b5f278e70 100644 --- a/test/ForwardMode/Pointer.C +++ b/test/ForwardMode/Pointer.C @@ -133,6 +133,47 @@ double fn6 (double i) { // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +double fn7(double i) { + double *p = (double*)malloc(8UL /*sizeof(double)*/); + *p = i; + T *t = (T*)calloc(1, sizeof(T)); + t->i = i; + double res = *p + t->i; + p = (double*)realloc(p, 2*sizeof(double)); + p[1] = 2*i; + res += p[1]; + free(t); + free(p); + return res; +} + +// CHECK: double fn7_darg0(double i) { +// CHECK-NEXT: double _d_i = 1; +// CHECK-NEXT: clad::ValueAndPushforward _t0 = clad::custom_derivatives::malloc_pushforward(8UL, 0UL); +// CHECK-NEXT: double *_d_p = (double *)_t0.pushforward; +// CHECK-NEXT: double *p = (double *)_t0.value; +// CHECK-NEXT: *_d_p = _d_i; +// CHECK-NEXT: *p = i; +// CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives::calloc_pushforward(1, sizeof(T), 0, sizeof(T)); +// CHECK-NEXT: T *_d_t = (T *)_t1.pushforward; +// CHECK-NEXT: T *t = (T *)_t1.value; +// CHECK-NEXT: _d_t->i = _d_i; +// CHECK-NEXT: t->i = i; +// CHECK-NEXT: double _d_res = *_d_p + _d_t->i; +// CHECK-NEXT: double res = *p + t->i; +// CHECK-NEXT: unsigned long _t2 = sizeof(double); +// CHECK-NEXT: clad::ValueAndPushforward _t3 = clad::custom_derivatives::realloc_pushforward(p, 2 * _t2, _d_p, 0 * _t2 + 2 * sizeof(double)); +// CHECK-NEXT: _d_p = (double *)_t3.pushforward; +// CHECK-NEXT: p = (double *)_t3.value; +// CHECK-NEXT: _d_p[1] = 0 * i + 2 * _d_i; +// CHECK-NEXT: p[1] = 2 * i; +// CHECK-NEXT: _d_res += _d_p[1]; +// CHECK-NEXT: res += p[1]; +// CHECK-NEXT: clad::custom_derivatives::free_pushforward(t, _d_t); +// CHECK-NEXT: clad::custom_derivatives::free_pushforward(p, _d_p); +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + int main() { INIT_DIFFERENTIATE(fn1, "i"); INIT_DIFFERENTIATE(fn2, "i"); @@ -140,6 +181,7 @@ int main() { INIT_DIFFERENTIATE(fn4, "i"); INIT_DIFFERENTIATE(fn5, "i"); INIT_DIFFERENTIATE(fn6, "i"); + INIT_DIFFERENTIATE(fn7, "i"); TEST_DIFFERENTIATE(fn1, 3, 5); // CHECK-EXEC: {5.00} TEST_DIFFERENTIATE(fn2, 3, 5); // CHECK-EXEC: {5.00} @@ -147,4 +189,5 @@ int main() { TEST_DIFFERENTIATE(fn4, 3, 5); // CHECK-EXEC: {16.00} TEST_DIFFERENTIATE(fn5, 3, 5); // CHECK-EXEC: {57.00} TEST_DIFFERENTIATE(fn6, 3); // CHECK-EXEC: {1.00} + TEST_DIFFERENTIATE(fn7, 3); // CHECK-EXEC: {4.00} }