From ec62356bf415e59cbd9b4bfd67191f5eb6614207 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar <vaibhav.thakkar.22.12.99@gmail.com> Date: Thu, 11 Jul 2024 09:51:44 +0200 Subject: [PATCH] Fix pointer dereference in fwd mode fixes #972 --- lib/Differentiator/BaseForwardModeVisitor.cpp | 12 ++++++++--- test/ForwardMode/Pointer.C | 20 +++++++++++++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 465dfa128..559125b8a 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1005,8 +1005,11 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { } // Is not a variable or is a reference to something unrelated to independent // variable. Derivative is 0. - auto zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); - return StmtDiff(clonedDRE, zero); + // If DRE is of type pointer, then the derivative is a null pointer. + if (clonedDRE->getType()->isPointerType()) + return StmtDiff(clonedDRE, nullptr); + return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral( + m_Context.IntTy, m_Context, /*val=*/0)); } StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) { @@ -1315,7 +1318,10 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { opKind == UnaryOperatorKind::UO_Imag) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_Deref) { - return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); + if (Expr* dx = diff.getExpr_dx()) + return StmtDiff(op, BuildOp(opKind, dx)); + return StmtDiff(op, ConstantFolder::synthesizeLiteral( + m_Context.IntTy, m_Context, /*val=*/0)); } else if (opKind == UnaryOperatorKind::UO_AddrOf) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_LNot) { diff --git a/test/ForwardMode/Pointer.C b/test/ForwardMode/Pointer.C index 317cd28a5..69f050516 100644 --- a/test/ForwardMode/Pointer.C +++ b/test/ForwardMode/Pointer.C @@ -192,8 +192,15 @@ double fn8(double* params) { // CHECK-NEXT: return 1. * params[0] + params[0] * 1. + *(double *)_t0.pushforward; // CHECK-NEXT: } -// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) { -// CHECK-NEXT: return {const_cast<void *>(Arg), const_cast<void *>(_d_Arg)}; +double fn9(double* params, const double *constants) { + double c0 = *constants; + return params[0] * c0; +} + +// CHECK: double fn9_darg0_0(double *params, const double *constants) { +// CHECK-NEXT: double _d_c0 = 0; +// CHECK-NEXT: double c0 = *constants; +// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0; // CHECK-NEXT: } @@ -218,4 +225,13 @@ int main() { auto fn8_dx = clad::differentiate(fn8, "params[0]"); double d_param = fn8_dx.execute(params); printf("{%.2f}\n", d_param); // CHECK-EXEC: {6.00} + + double constants[] = {5.0}; + auto fn9_dx = clad::differentiate(fn9, "params[0]"); + d_param = fn9_dx.execute(params, constants); + printf("{%.2f}\n", d_param); // CHECK-EXEC: {5.00} } + +// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) { +// CHECK-NEXT: return {const_cast<void *>(Arg), const_cast<void *>(_d_Arg)}; +// CHECK-NEXT: }