Skip to content

Commit

Permalink
Differentiate for loop condition expression (#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohanjulka19 committed Mar 15, 2024
1 parent 5736df6 commit 728c53e
Show file tree
Hide file tree
Showing 4 changed files with 452 additions and 25 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ namespace clad {
LoopCounter& loopCounter,
clang::Stmt* condVarDifff = nullptr,
clang::Stmt* forLoopIncDiff = nullptr,
clang::Stmt* condDiff = nullptr,
bool isForLoop = false);

/// This class modifies forward and reverse blocks of the loop/switch
Expand Down
60 changes: 41 additions & 19 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1056,11 +1056,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

auto CommaJoin = [this](Expr* Acc, Stmt* S) {
Expr* E = cast<Expr>(S);
return BuildOp(BO_Comma, E, BuildParens(Acc));
};

// FIXME: for now we assume that cond has no differentiable effects,
// but it is not generally true, e.g. for (...; (x = y); ...)...
StmtDiff cond;
if (FS->getCond())
cond = Visit(FS->getCond());
StmtDiff condDiff;
StmtDiff condExprDiff;
StmtDiff condDiffOuter;
StmtDiff condExprDiffOuter;
if (FS->getCond()) {
std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond());
std::tie(condDiffOuter, condExprDiffOuter) =
DifferentiateSingleExpr(FS->getCond());
}

const auto* IDRE = dyn_cast<DeclRefExpr>(FS->getInc());
const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc();

Expand Down Expand Up @@ -1088,27 +1100,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
// Otherwise, join all exprs by comma operator.
else if (incExprDiff.getExpr()) {
auto CommaJoin = [this](Expr* Acc, Stmt* S) {
Expr* E = cast<Expr>(S);
return BuildOp(BO_Comma, E, BuildParens(Acc));
};
incResult = std::accumulate(Additional->body_rbegin(),
Additional->body_rend(),
incExprDiff.getExpr(),
CommaJoin);
}

const Stmt* body = FS->getBody();
StmtDiff BodyDiff = DifferentiateLoopBody(body, loopCounter,
condVarRes.getStmt_dx(),
incDiff.getStmt_dx(),
/*isForLoop=*/true);
StmtDiff BodyDiff =
DifferentiateLoopBody(body, loopCounter, condVarRes.getStmt_dx(),
incDiff.getStmt_dx(), condDiff.getStmt_dx(),
/*isForLoop=*/true);

/// FIXME: This part in necessary to replace local variables inside loops
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
Expr* forwardCond = cond.getExpr();
Expr* forwardCond = condExprDiff.getExpr();
/// If there is a declaration in the condition, `cond` will be
/// a DeclRefExpr of the declared variable. There is no point in
/// inserting it since condVarRes.getExpr() represents an assignment with
Expand All @@ -1118,8 +1126,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (condVarRes.getExpr() != nullptr && isa<Expr>(condVarRes.getExpr()))
forwardCond = cast<Expr>(condVarRes.getExpr());

auto* AdditionalStmts = cast<CompoundStmt>(condDiff.getStmt());
Expr* condResult =
std::accumulate(AdditionalStmts->body_rbegin(),
AdditionalStmts->body_rend(), forwardCond, CommaJoin);

Stmt* Forward = new (m_Context)
ForStmt(m_Context, initResult.getStmt(), forwardCond, condVarClone,
ForStmt(m_Context, initResult.getStmt(), condResult, condVarClone,
incResult, BodyDiff.getStmt(), noLoc, noLoc, noLoc);

// Create a condition testing counter for being zero, and its decrement.
Expand All @@ -1141,10 +1154,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
noLoc,
noLoc,
noLoc);
addToCurrentBlock(unwrapIfSingleStmt(condDiffOuter.getStmt()));
addToCurrentBlock(Forward, direction::forward);
Forward = endBlock(direction::forward);
addToCurrentBlock(loopCounter.getPop(), direction::reverse);
addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
addToCurrentBlock(condDiffOuter.getStmt_dx(), direction::reverse);
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
endScope();
Expand Down Expand Up @@ -2580,8 +2595,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == BO_Comma) {
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Ldiff = Visit(L, zero);
Rdiff = Visit(R, dfdx());
Ldiff = Visit(L, zero);
valueForRevPass = Ldiff.getRevSweepAsExpr();
ResultRef = Ldiff.getExpr();
} else {
Expand Down Expand Up @@ -3602,11 +3617,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body,
LoopCounter& loopCounter,
Stmt* condVarDiff,
Stmt* forLoopIncDiff,
bool isForLoop) {
StmtDiff ReverseModeVisitor::DifferentiateLoopBody(
const Stmt* body, LoopCounter& loopCounter, Stmt* condVarDiff,
Stmt* forLoopIncDiff, Stmt* condDiff, bool isForLoop) {
Expr* counterIncrement = loopCounter.getCounterIncrement();
auto* activeBreakContHandler = PushBreakContStmtHandler();
activeBreakContHandler->BeginCFSwitchStmtScope();
Expand Down Expand Up @@ -3656,6 +3669,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

if (condDiff) {
if (bodyDiff.getStmt_dx()) {
bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
m_Context, bodyDiff.getStmt_dx(), condDiff));
} else {
bodyDiff.updateStmtDx(condDiff);
}

Check warning on line 3678 in lib/Differentiator/ReverseModeVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/ReverseModeVisitor.cpp#L3678

Added line #L3678 was not covered by tests
}

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();
Expand Down
12 changes: 6 additions & 6 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -852,12 +852,12 @@ double f21 (double x, double y) {
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_y += 1;
//CHECK-NEXT: {
//CHECK-NEXT: y = _t0;
//CHECK-NEXT: double _r_d0 = * _d_y;
//CHECK-NEXT: * _d_y -= _r_d0;
//CHECK-NEXT: * _d_y += 0;
//CHECK-NEXT: y--;
//CHECK-NEXT: * _d_x += _r_d0;
// CHECK-NEXT: y = _t0;
// CHECK-NEXT: double _r_d0 = * _d_y;
// CHECK-NEXT: * _d_y -= _r_d0;
// CHECK-NEXT: * _d_x += _r_d0;
// CHECK-NEXT: * _d_y += 0;
// CHECK-NEXT: y--;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down
Loading

0 comments on commit 728c53e

Please sign in to comment.