Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pointer dereference in fwd mode #975

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,8 +1005,11 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
}
// Is not a variable or is a reference to something unrelated to independent
// variable. Derivative is 0.
auto zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
return StmtDiff(clonedDRE, zero);
// If DRE is of type pointer, then the derivative is a null pointer.
if (clonedDRE->getType()->isPointerType())
return StmtDiff(clonedDRE, nullptr);
return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
Comment on lines +1009 to +1012
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (clonedDRE->getType()->isPointerType())
return StmtDiff(clonedDRE, nullptr);
return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
QualType clonedDreTy = clonedDRE->getType();
if (clonedDreTy->isPointerType())
return StmtDiff(clonedDRE, nullptr);
return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral(
clonedDreTy, m_Context, /*val=*/0));

}

StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
Expand Down Expand Up @@ -1315,7 +1318,10 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
opKind == UnaryOperatorKind::UO_Imag) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_Deref) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
if (Expr* dx = diff.getExpr_dx())
return StmtDiff(op, BuildOp(opKind, dx));
return StmtDiff(op, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
} else if (opKind == UnaryOperatorKind::UO_AddrOf) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_LNot) {
Expand Down
20 changes: 18 additions & 2 deletions test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,15 @@ double fn8(double* params) {
// CHECK-NEXT: return 1. * params[0] + params[0] * 1. + *(double *)_t0.pushforward;
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) {
// CHECK-NEXT: return {const_cast<void *>(Arg), const_cast<void *>(_d_Arg)};
double fn9(double* params, const double *constants) {
double c0 = *constants;
return params[0] * c0;
}

// CHECK: double fn9_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0;
// CHECK-NEXT: double c0 = *constants;
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }


Expand All @@ -218,4 +225,13 @@ int main() {
auto fn8_dx = clad::differentiate(fn8, "params[0]");
double d_param = fn8_dx.execute(params);
printf("{%.2f}\n", d_param); // CHECK-EXEC: {6.00}

double constants[] = {5.0};
auto fn9_dx = clad::differentiate(fn9, "params[0]");
d_param = fn9_dx.execute(params, constants);
printf("{%.2f}\n", d_param); // CHECK-EXEC: {5.00}
}

// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) {
// CHECK-NEXT: return {const_cast<void *>(Arg), const_cast<void *>(_d_Arg)};
// CHECK-NEXT: }
Loading