Skip to content

Commit

Permalink
Don't pass derived args to forw pass unless they are references or po…
Browse files Browse the repository at this point in the history
…inters.

``_reverse_forward`` functions only contain the forward pass, which only affects the derivatives on references/pointers. For other types, there is no point in passing the adjoint. Moreover, doing so is often incorrect because derived arguments are generated for the reverse pass, e.g.
```
// forward pass
operator_subscript_reverse_forw(&vec, 0, &_d_vec, _r0); // `_r0` is declared later
...
// reverse pass
size_type _r0 = 0UL;
operator_subscript_pullback(&_t2, 0, 1, &_d_vec, &_r0);
```
In this example, replacing the first occurrence of _r0 with 0 will still be correct.
  • Loading branch information
PetroZarytskyi committed Jan 4, 2025
1 parent 825396d commit 483b0ce
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 74 deletions.
40 changes: 8 additions & 32 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2009,6 +2009,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* call = nullptr;

QualType returnType = FD->getReturnType();
// Stores the dx of the call arguments for the function to be derived
for (std::size_t i = 0, e = CE->getNumArgs() - isMethodOperatorCall; i != e;
++i) {
const Expr* arg = CE->getArg(i + isMethodOperatorCall);
if (!utils::IsReferenceOrPointerArg(arg))
CallArgDx[i] = getZeroInit(arg->getType());
}
if (baseDiff.getExpr_dx() &&
!baseDiff.getExpr_dx()->getType()->isPointerType())
CallArgDx.insert(CallArgDx.begin(), BuildOp(UnaryOperatorKind::UO_AddrOf,
Expand Down Expand Up @@ -2041,38 +2048,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
assert(calleeFnForwPassFD &&
"Clad failed to generate callee function forward pass function");

// FIXME: We are using the derivatives in forward pass here
// If `expr_dx()` is only meant to be used in reverse pass,
// (for example, `clad::pop(...)` expression and a corresponding
// `clad::push(...)` in the forward pass), then this can result in
// incorrect derivative or crash at runtime. Ideally, we should have
// a separate routine to use derivative in the forward pass.

// We cannot reuse the derivatives previously computed because
// they might contain 'clad::pop(..)` expression.
if (baseDiff.getExpr_dx()) {
Expr* derivedBase = baseDiff.getExpr_dx();
// FIXME: We may need this if-block once we support pointers, and
// passing pointers-by-reference if
// (isCladArrayType(derivedBase->getType()))
// CallArgs.push_back(derivedBase);
// else
// Currently derivedBase `*d_this` can never be CladArrayType
CallArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, Loc));
}

for (std::size_t i = static_cast<std::size_t>(isMethodOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
StmtDiff argDiff = Visit(arg);
// Has to be removed once nondifferentiable arguments are handeled
if (argDiff.getStmt_dx())
CallArgs.push_back(argDiff.getExpr_dx());
else
CallArgs.push_back(getZeroInit(arg->getType()));
}
CallArgs.insert(CallArgs.end(), CallArgDx.begin(), CallArgDx.end());
if (Expr* baseE = baseDiff.getExpr()) {
call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(),
CallArgs, Loc);
Expand Down
4 changes: 2 additions & 2 deletions test/Gradient/MemberFunctions.C
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ double fn2(SimpleFunctions& sf, double i) {

// CHECK: void fn2_grad(SimpleFunctions &sf, double i, SimpleFunctions *_d_sf, double *_d_i) {
// CHECK-NEXT: SimpleFunctions _t0 = sf;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = _t0.ref_mem_fn_forw(i, &(*_d_sf), *_d_i);
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = _t0.ref_mem_fn_forw(i, &(*_d_sf), 0.);
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: _t0.ref_mem_fn_pullback(i, 1, &(*_d_sf), &_r0);
Expand All @@ -459,7 +459,7 @@ double fn5(SimpleFunctions& v, double value) {

// CHECK: void fn5_grad(SimpleFunctions &v, double value, SimpleFunctions *_d_v, double *_d_value) {
// CHECK-NEXT: SimpleFunctions _t0 = v;
// CHECK-NEXT: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> _t1 = _t0.operator_plus_equal_forw(value, &(*_d_v), *_d_value);
// CHECK-NEXT: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> _t1 = _t0.operator_plus_equal_forw(value, &(*_d_v), 0.);
// CHECK-NEXT: (*_d_v).x += 1;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
Expand Down
Loading

0 comments on commit 483b0ce

Please sign in to comment.