Skip to content

Commit

Permalink
fix DifferentiateVarDecl for constructors in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Sep 21, 2023
1 parent 4bc0d63 commit 3acc860
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
12 changes: 9 additions & 3 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2440,11 +2440,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType) {
Expr* derivedE = BuildDeclRef(VDDerived);
initDiff = VD->getInit() ? Visit(VD->getInit(), derivedE) : StmtDiff{};
initDiff = StmtDiff{};
if (VD->getInit()) {
if (isa<CXXConstructExpr>(VD->getInit()))
initDiff = Visit(VD->getInit());
else
initDiff = Visit(VD->getInit(), derivedE);
}

// If we are differentiating `VarDecl` corresponding to a local variable
// inside a loop, then we need to reset it to 0 at each iteration.
//
//
// for example, if defined inside a loop,
// ```
// double localVar = i;
Expand All @@ -2454,7 +2460,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// {
// *_d_i += _d_localVar;
// _d_localVar = 0;
// }
// }
if (isInsideLoop) {
Stmt* assignToZero = BuildOp(BinaryOperatorKind::BO_Assign,
BuildDeclRef(VDDerived),
Expand Down
33 changes: 33 additions & 0 deletions test/Gradient/MemberFunctions.C
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,10 @@ double fn2(SimpleFunctions& sf, double i) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn3(double x, double y, double i, double j) {
SimpleFunctions sf(x, y);
return sf.mem_fn(i, j);
}

int main() {
auto d_mem_fn = clad::gradient(&SimpleFunctions::mem_fn);
Expand Down Expand Up @@ -880,4 +884,33 @@ int main() {
// CHECK-NEXT: * _d_j += _r3;
// CHECK-NEXT: }
// CHECK-NEXT: }

auto d_fn3 = clad::gradient(fn3, "i,j");
result[0] = result[1] = 0;
d_fn3.execute(2, 3, 4, 5, &result[0], &result[1]);
printf("%.2f %.2f", result[0], result[1]); // CHECK-EXEC: 10.00 4.00

// CHECK: void fn3_grad_2_3(double x, double y, double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: double _d_x = 0;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: SimpleFunctions _d_sf({});
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: SimpleFunctions _t2;
// CHECK-NEXT: SimpleFunctions sf(x, y);
// CHECK-NEXT: _t0 = i;
// CHECK-NEXT: _t1 = j;
// CHECK-NEXT: _t2 = sf;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: double _grad1 = 0.;
// CHECK-NEXT: _t2.mem_fn_pullback(_t0, _t1, 1, &_d_sf, &_grad0, &_grad1);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_i += _r0;
// CHECK-NEXT: double _r1 = _grad1;
// CHECK-NEXT: * _d_j += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }
}

0 comments on commit 3acc860

Please sign in to comment.