From ec62356bf415e59cbd9b4bfd67191f5eb6614207 Mon Sep 17 00:00:00 2001
From: Vaibhav Thakkar <vaibhav.thakkar.22.12.99@gmail.com>
Date: Thu, 11 Jul 2024 09:51:44 +0200
Subject: [PATCH] Fix pointer dereference in fwd mode

fixes #972
---
 lib/Differentiator/BaseForwardModeVisitor.cpp | 12 ++++++++---
 test/ForwardMode/Pointer.C                    | 20 +++++++++++++++++--
 2 files changed, 27 insertions(+), 5 deletions(-)

diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp
index 465dfa128..559125b8a 100644
--- a/lib/Differentiator/BaseForwardModeVisitor.cpp
+++ b/lib/Differentiator/BaseForwardModeVisitor.cpp
@@ -1005,8 +1005,11 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
   }
   // Is not a variable or is a reference to something unrelated to independent
   // variable. Derivative is 0.
-  auto zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
-  return StmtDiff(clonedDRE, zero);
+  // If DRE is of type pointer, then the derivative is a null pointer.
+  if (clonedDRE->getType()->isPointerType())
+    return StmtDiff(clonedDRE, nullptr);
+  return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral(
+                                 m_Context.IntTy, m_Context, /*val=*/0));
 }
 
 StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
@@ -1315,7 +1318,10 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
            opKind == UnaryOperatorKind::UO_Imag) {
     return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
   } else if (opKind == UnaryOperatorKind::UO_Deref) {
-    return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
+    if (Expr* dx = diff.getExpr_dx())
+      return StmtDiff(op, BuildOp(opKind, dx));
+    return StmtDiff(op, ConstantFolder::synthesizeLiteral(
+                            m_Context.IntTy, m_Context, /*val=*/0));
   } else if (opKind == UnaryOperatorKind::UO_AddrOf) {
     return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
   } else if (opKind == UnaryOperatorKind::UO_LNot) {
diff --git a/test/ForwardMode/Pointer.C b/test/ForwardMode/Pointer.C
index 317cd28a5..69f050516 100644
--- a/test/ForwardMode/Pointer.C
+++ b/test/ForwardMode/Pointer.C
@@ -192,8 +192,15 @@ double fn8(double* params) {
 // CHECK-NEXT:     return 1. * params[0] + params[0] * 1. + *(double *)_t0.pushforward;
 // CHECK-NEXT: }
 
-// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) {
-// CHECK-NEXT:     return {const_cast<void *>(Arg), const_cast<void *>(_d_Arg)};
+double fn9(double* params, const double *constants) {
+  double c0 = *constants;
+  return params[0] * c0;
+}
+
+// CHECK: double fn9_darg0_0(double *params, const double *constants) {
+// CHECK-NEXT:     double _d_c0 = 0;
+// CHECK-NEXT:     double c0 = *constants;
+// CHECK-NEXT:     return 1. * c0 + params[0] * _d_c0;
 // CHECK-NEXT: }
 
 
@@ -218,4 +225,13 @@ int main() {
   auto fn8_dx = clad::differentiate(fn8, "params[0]");
   double d_param = fn8_dx.execute(params);
   printf("{%.2f}\n", d_param); // CHECK-EXEC: {6.00}
+
+  double constants[] = {5.0};
+  auto fn9_dx = clad::differentiate(fn9, "params[0]");
+  d_param = fn9_dx.execute(params, constants);
+  printf("{%.2f}\n", d_param); // CHECK-EXEC: {5.00}
 }
+
+// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) {
+// CHECK-NEXT:     return {const_cast<void *>(Arg), const_cast<void *>(_d_Arg)};
+// CHECK-NEXT: }