Skip to content

Commit

Permalink
Add support for C-style memory alloc and free in forward mode AD
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 22, 2024
1 parent dbcf6f7 commit a1f9cd2
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 3 deletions.
23 changes: 23 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,29 @@ CUDA_HOST_DEVICE void clamp_pullback(const T& v, const T& lo, const T& hi,
#endif

} // namespace std

// NOLINTBEGIN(cppcoreguidelines-no-malloc)
CUDA_HOST_DEVICE ValueAndPushforward<void*, void*>
malloc_pushforward(size_t sz, size_t d_sz) {
return {malloc(sz), malloc(sz)};
}

CUDA_HOST_DEVICE ValueAndPushforward<void*, void*>
calloc_pushforward(size_t n, size_t sz, size_t d_n, size_t d_sz) {
return {calloc(n, sz), calloc(n, sz)};
}

CUDA_HOST_DEVICE ValueAndPushforward<void*, void*>
realloc_pushforward(void* ptr, size_t sz, void* d_ptr, size_t d_sz) {
return {realloc(ptr, sz), realloc(d_ptr, sz)};
}

CUDA_HOST_DEVICE void free_pushforward(void* ptr, void* d_ptr) {
free(ptr);
free(d_ptr);
}
// NOLINTEND(cppcoreguidelines-no-malloc)

// These are required because C variants of mathematical functions are
// defined in global namespace.
using std::abs_pushforward;
Expand Down
4 changes: 1 addition & 3 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,9 +653,7 @@ namespace clad {
}

bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD) {
if (FD->getNameAsString() == "free")
return true;
return false;
return FD->getNameAsString() == "free";
}
} // namespace utils
} // namespace clad
43 changes: 43 additions & 0 deletions test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,61 @@ double fn6 (double i) {
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }

double fn7(double i) {
double *p = (double*)malloc(8UL /*sizeof(double)*/);
*p = i;
T *t = (T*)calloc(1, sizeof(T));
t->i = i;
double res = *p + t->i;
p = (double*)realloc(p, 2*sizeof(double));
p[1] = 2*i;
res += p[1];
free(t);
free(p);
return res;
}

// CHECK: double fn7_darg0(double i) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: clad::ValueAndPushforward<void *, void *> _t0 = clad::custom_derivatives::malloc_pushforward(8UL, 0UL);
// CHECK-NEXT: double *_d_p = (double *)_t0.pushforward;
// CHECK-NEXT: double *p = (double *)_t0.value;
// CHECK-NEXT: *_d_p = _d_i;
// CHECK-NEXT: *p = i;
// CHECK-NEXT: clad::ValueAndPushforward<void *, void *> _t1 = clad::custom_derivatives::calloc_pushforward(1, sizeof(T), 0, sizeof(T));
// CHECK-NEXT: T *_d_t = (T *)_t1.pushforward;
// CHECK-NEXT: T *t = (T *)_t1.value;
// CHECK-NEXT: _d_t->i = _d_i;
// CHECK-NEXT: t->i = i;
// CHECK-NEXT: double _d_res = *_d_p + _d_t->i;
// CHECK-NEXT: double res = *p + t->i;
// CHECK-NEXT: unsigned long _t2 = sizeof(double);
// CHECK-NEXT: clad::ValueAndPushforward<void *, void *> _t3 = clad::custom_derivatives::realloc_pushforward(p, 2 * _t2, _d_p, 0 * _t2 + 2 * sizeof(double));
// CHECK-NEXT: _d_p = (double *)_t3.pushforward;
// CHECK-NEXT: p = (double *)_t3.value;
// CHECK-NEXT: _d_p[1] = 0 * i + 2 * _d_i;
// CHECK-NEXT: p[1] = 2 * i;
// CHECK-NEXT: _d_res += _d_p[1];
// CHECK-NEXT: res += p[1];
// CHECK-NEXT: clad::custom_derivatives::free_pushforward(t, _d_t);
// CHECK-NEXT: clad::custom_derivatives::free_pushforward(p, _d_p);
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }

int main() {
INIT_DIFFERENTIATE(fn1, "i");
INIT_DIFFERENTIATE(fn2, "i");
INIT_DIFFERENTIATE(fn3, "i");
INIT_DIFFERENTIATE(fn4, "i");
INIT_DIFFERENTIATE(fn5, "i");
INIT_DIFFERENTIATE(fn6, "i");
INIT_DIFFERENTIATE(fn7, "i");

TEST_DIFFERENTIATE(fn1, 3, 5); // CHECK-EXEC: {5.00}
TEST_DIFFERENTIATE(fn2, 3, 5); // CHECK-EXEC: {5.00}
TEST_DIFFERENTIATE(fn3, 3, 5); // CHECK-EXEC: {6.00}
TEST_DIFFERENTIATE(fn4, 3, 5); // CHECK-EXEC: {16.00}
TEST_DIFFERENTIATE(fn5, 3, 5); // CHECK-EXEC: {57.00}
TEST_DIFFERENTIATE(fn6, 3); // CHECK-EXEC: {1.00}
TEST_DIFFERENTIATE(fn7, 3); // CHECK-EXEC: {4.00}
}

0 comments on commit a1f9cd2

Please sign in to comment.