From 81d6eca0f7a07e2b1fff9d2946d86e57e9a3be1c Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 22 Feb 2024 15:42:11 +0100 Subject: [PATCH] Add support for C-style memory allocations in reverse mode AD --- include/clad/Differentiator/CladUtils.h | 2 ++ lib/Differentiator/CladUtils.cpp | 10 +++++++ lib/Differentiator/ReverseModeVisitor.cpp | 19 ++++++++++++ test/Gradient/Pointers.C | 35 +++++++++++++++++++++++ 4 files changed, 66 insertions(+) diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 5690c3913..1b3a149e1 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -328,6 +328,8 @@ namespace clad { void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt); bool IsLiteral(const clang::Expr* E); + + bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD); } // namespace utils } // namespace clad diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index fbddd535b..b350aa6d6 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -641,5 +641,15 @@ namespace clad { isa(E) || isa(E) || isa(E); } + + bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD) { + if (FD->getNameAsString() == "malloc") + return true; + if (FD->getNameAsString() == "calloc") + return true; + if (FD->getNameAsString() == "realloc") + return true; + return false; + } } // namespace utils } // namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index b792c29a8..49c8c2c80 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1441,6 +1441,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Stores tape decl and pushes for multiarg numerically differentiated // calls. llvm::SmallVector NumericalDiffMultiArg{}; + + // For calls to C-style memory allocation functions, we do not need to + // differentiate the call. We just need to visit the arguments to the + // function. + if (utils::IsMemoryAllocationFunction(FD)) { + for (const Expr* Arg : CE->arguments()) { + StmtDiff ArgDiff = Visit(Arg, dfdx()); + CallArgs.push_back(ArgDiff.getExpr()); + } + Expr* call = m_Sema + .ActOnCallExpr(getCurrentScope(), + Clone(CE->getCallee()), + noLoc, + llvm::MutableArrayRef(CallArgs), + noLoc) + .get(); + return StmtDiff(call, call); + } + // If the result does not depend on the result of the call, just clone // the call and visit arguments (since they may contain side-effects like // f(x = y)) diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index cb2b66ee9..dd488b0fe 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -430,6 +430,35 @@ double structPointer (double x) { // CHECK-NEXT: delete _d_t; // CHECK-NEXT: } +double cStyleMemoryAlloc(double x, size_t n) { + T* t = (T*)malloc(n * sizeof(T)); + t->x = x; + double res = t->x; + return res; +} + +// CHECK: void cStyleMemoryAlloc_grad_0(double x, size_t n, clad::array_ref _d_x) { +// CHECK-NEXT: size_t _d_n = 0; +// CHECK-NEXT: T *_d_t = 0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: _d_t = (T *)malloc(n * sizeof(T)); +// CHECK-NEXT: T *t = (T *)malloc(n * sizeof(T)); +// CHECK-NEXT: _t0 = t->x; +// CHECK-NEXT: t->x = x; +// CHECK-NEXT: double res = t->x; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: _d_t->x += _d_res; +// CHECK-NEXT: { +// CHECK-NEXT: t->x = _t0; +// CHECK-NEXT: double _r_d0 = _d_t->x; +// CHECK-NEXT: _d_t->x -= _r_d0; +// CHECK-NEXT: * _d_x += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } + #define NON_MEM_FN_TEST(var)\ res[0]=0;\ var.execute(5,res);\ @@ -533,4 +562,10 @@ int main() { auto d_structPointer = clad::gradient(structPointer); double d_x = 0; d_structPointer.execute(5, &d_x); + printf("%.2f\n", d_x); // CHECK-EXEC: 1.00 + + auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x"); + d_x = 0; + d_cStyleMemoryAlloc.execute(5, 7, &d_x); + printf("%.2f\n", d_x); // CHECK-EXEC: 1.00 } \ No newline at end of file