Skip to content

Commit

Permalink
Don't pass derived args to forw pass unless they are references
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Jan 3, 2025
1 parent 825396d commit 5c77547
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 65 deletions.
31 changes: 8 additions & 23 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2009,6 +2009,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* call = nullptr;

QualType returnType = FD->getReturnType();
// Stores the dx of the call arguments for the function to be derived
for (std::size_t i = 0, e = CE->getNumArgs() - isMethodOperatorCall; i != e;
++i) {
const Expr* arg = CE->getArg(i + isMethodOperatorCall);
if (!utils::IsReferenceOrPointerArg(arg))
CallArgDx[i] = getZeroInit(arg->getType());
}
if (baseDiff.getExpr_dx() &&
!baseDiff.getExpr_dx()->getType()->isPointerType())
CallArgDx.insert(CallArgDx.begin(), BuildOp(UnaryOperatorKind::UO_AddrOf,
Expand Down Expand Up @@ -2050,29 +2057,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// We cannot reuse the derivatives previously computed because
// they might contain 'clad::pop(..)` expression.
if (baseDiff.getExpr_dx()) {
Expr* derivedBase = baseDiff.getExpr_dx();
// FIXME: We may need this if-block once we support pointers, and
// passing pointers-by-reference if
// (isCladArrayType(derivedBase->getType()))
// CallArgs.push_back(derivedBase);
// else
// Currently derivedBase `*d_this` can never be CladArrayType
CallArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, Loc));
}

for (std::size_t i = static_cast<std::size_t>(isMethodOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
StmtDiff argDiff = Visit(arg);
// Has to be removed once nondifferentiable arguments are handeled
if (argDiff.getStmt_dx())
CallArgs.push_back(argDiff.getExpr_dx());
else
CallArgs.push_back(getZeroInit(arg->getType()));
}
CallArgs.insert(CallArgs.end(), CallArgDx.begin(), CallArgDx.end());
if (Expr* baseE = baseDiff.getExpr()) {
call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(),
CallArgs, Loc);
Expand Down
4 changes: 2 additions & 2 deletions test/Gradient/MemberFunctions.C
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ double fn2(SimpleFunctions& sf, double i) {

// CHECK: void fn2_grad(SimpleFunctions &sf, double i, SimpleFunctions *_d_sf, double *_d_i) {
// CHECK-NEXT: SimpleFunctions _t0 = sf;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = _t0.ref_mem_fn_forw(i, &(*_d_sf), *_d_i);
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = _t0.ref_mem_fn_forw(i, &(*_d_sf), 0.);
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: _t0.ref_mem_fn_pullback(i, 1, &(*_d_sf), &_r0);
Expand All @@ -459,7 +459,7 @@ double fn5(SimpleFunctions& v, double value) {

// CHECK: void fn5_grad(SimpleFunctions &v, double value, SimpleFunctions *_d_v, double *_d_value) {
// CHECK-NEXT: SimpleFunctions _t0 = v;
// CHECK-NEXT: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> _t1 = _t0.operator_plus_equal_forw(value, &(*_d_v), *_d_value);
// CHECK-NEXT: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> _t1 = _t0.operator_plus_equal_forw(value, &(*_d_v), 0.);
// CHECK-NEXT: (*_d_v).x += 1;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
Expand Down
Loading

0 comments on commit 5c77547

Please sign in to comment.