Skip to content

Commit

Permalink
Fix pointer dereference in fwd mode
Browse files Browse the repository at this point in the history
fixes #972
  • Loading branch information
vaithak committed Jul 11, 2024
1 parent f309b61 commit ec62356
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
12 changes: 9 additions & 3 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,8 +1005,11 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
}
// Is not a variable or is a reference to something unrelated to independent
// variable. Derivative is 0.
auto zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
return StmtDiff(clonedDRE, zero);
// If DRE is of type pointer, then the derivative is a null pointer.
if (clonedDRE->getType()->isPointerType())
return StmtDiff(clonedDRE, nullptr);
return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
}

StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
Expand Down Expand Up @@ -1315,7 +1318,10 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
opKind == UnaryOperatorKind::UO_Imag) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_Deref) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
if (Expr* dx = diff.getExpr_dx())
return StmtDiff(op, BuildOp(opKind, dx));
return StmtDiff(op, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
} else if (opKind == UnaryOperatorKind::UO_AddrOf) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_LNot) {
Expand Down
20 changes: 18 additions & 2 deletions test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,15 @@ double fn8(double* params) {
// CHECK-NEXT: return 1. * params[0] + params[0] * 1. + *(double *)_t0.pushforward;
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) {
// CHECK-NEXT: return {const_cast<void *>(Arg), const_cast<void *>(_d_Arg)};
double fn9(double* params, const double *constants) {
double c0 = *constants;
return params[0] * c0;
}

// CHECK: double fn9_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0;
// CHECK-NEXT: double c0 = *constants;
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }


Expand All @@ -218,4 +225,13 @@ int main() {
auto fn8_dx = clad::differentiate(fn8, "params[0]");
double d_param = fn8_dx.execute(params);
printf("{%.2f}\n", d_param); // CHECK-EXEC: {6.00}

double constants[] = {5.0};
auto fn9_dx = clad::differentiate(fn9, "params[0]");
d_param = fn9_dx.execute(params, constants);
printf("{%.2f}\n", d_param); // CHECK-EXEC: {5.00}
}

// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) {
// CHECK-NEXT: return {const_cast<void *>(Arg), const_cast<void *>(_d_Arg)};
// CHECK-NEXT: }

0 comments on commit ec62356

Please sign in to comment.