Skip to content

Commit

Permalink
Differentiate for loop condition expression (vgvassilev#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohanjulka19 committed Jun 5, 2024
1 parent bfe6d65 commit b6cbe7f
Showing 4 changed files with 596 additions and 30 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
@@ -518,6 +518,7 @@ namespace clad {
LoopCounter& loopCounter,
clang::Stmt* condVarDifff = nullptr,
clang::Stmt* forLoopIncDiff = nullptr,
clang::Stmt* condDiff = nullptr,
bool isForLoop = false);

/// This class modifies forward and reverse blocks of the loop/switch
75 changes: 53 additions & 22 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
@@ -985,11 +985,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

auto CommaJoin = [this](Expr* Acc, Stmt* S) {
Expr* E = cast<Expr>(S);
return BuildOp(BO_Comma, E, BuildParens(Acc));
};

// FIXME: for now we assume that cond has no differentiable effects,
// but it is not generally true, e.g. for (...; (x = y); ...)...
StmtDiff cond;
if (FS->getCond())
cond = Visit(FS->getCond());
StmtDiff condDiff;
StmtDiff condExprDiff;
StmtDiff condDiffOuter;
StmtDiff condExprDiffOuter;
if (FS->getCond()) {
std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond());
std::tie(condDiffOuter, condExprDiffOuter) =
DifferentiateSingleExpr(FS->getCond());
}

const auto* IDRE = dyn_cast<DeclRefExpr>(FS->getInc());
const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc();

@@ -1017,27 +1029,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
// Otherwise, join all exprs by comma operator.
else if (incExprDiff.getExpr()) {
auto CommaJoin = [this](Expr* Acc, Stmt* S) {
Expr* E = cast<Expr>(S);
return BuildOp(BO_Comma, E, BuildParens(Acc));
};
incResult = std::accumulate(Additional->body_rbegin(),
Additional->body_rend(),
incExprDiff.getExpr(),
CommaJoin);
}

const Stmt* body = FS->getBody();
StmtDiff BodyDiff = DifferentiateLoopBody(body, loopCounter,
condVarRes.getStmt_dx(),
incDiff.getStmt_dx(),
/*isForLoop=*/true);
StmtDiff BodyDiff =
DifferentiateLoopBody(body, loopCounter, condVarRes.getStmt_dx(),
incDiff.getStmt_dx(), condDiff.getStmt_dx(),
/*isForLoop=*/true);

/// FIXME: This part in necessary to replace local variables inside loops
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
Expr* forwardCond = cond.getExpr();
Expr* forwardCond = condExprDiff.getExpr();
/// If there is a declaration in the condition, `cond` will be
/// a DeclRefExpr of the declared variable. There is no point in
/// inserting it since condVarRes.getExpr() represents an assignment with
@@ -1047,8 +1055,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (condVarRes.getExpr() != nullptr && isa<Expr>(condVarRes.getExpr()))
forwardCond = cast<Expr>(condVarRes.getExpr());

auto* AdditionalStmts = cast<CompoundStmt>(condDiff.getStmt());
Expr* condResult =
std::accumulate(AdditionalStmts->body_rbegin(),
AdditionalStmts->body_rend(), forwardCond, CommaJoin);

Stmt* Forward = new (m_Context)
ForStmt(m_Context, initResult.getStmt(), forwardCond, condVarClone,
ForStmt(m_Context, initResult.getStmt(), condResult, condVarClone,
incResult, BodyDiff.getStmt(), noLoc, noLoc, noLoc);

// Create a condition testing counter for being zero, and its decrement.
@@ -1061,10 +1074,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Stmt* ReverseResult = BodyDiff.getStmt_dx();
if (!ReverseResult)
ReverseResult = new (m_Context) NullStmt(noLoc);
Stmt* Reverse = new (m_Context)
ForStmt(m_Context, nullptr, CounterCondition, nullptr, CounterDecrement,
ReverseResult, noLoc, noLoc, noLoc);
Stmt* Reverse = new (m_Context) ForStmt(m_Context,
nullptr,
CounterCondition,
nullptr,
CounterDecrement,
ReverseResult,
noLoc,
noLoc,
noLoc);
addToCurrentBlock(unwrapIfSingleStmt(condDiffOuter.getStmt()));
addToCurrentBlock(Forward, direction::forward);
Forward = endBlock(direction::forward);
addToCurrentBlock(loopCounter.getPop(), direction::reverse);
addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
addToCurrentBlock(condDiffOuter.getStmt_dx(), direction::reverse);
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
endScope();
@@ -2359,8 +2383,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == BO_Comma) {
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Ldiff = Visit(L, zero);
Rdiff = Visit(R, dfdx());
Ldiff = Visit(L, zero);
valueForRevPass = Ldiff.getRevSweepAsExpr();
ResultRef = Ldiff.getExpr();
} else {
@@ -3348,11 +3372,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body,
LoopCounter& loopCounter,
Stmt* condVarDiff,
Stmt* forLoopIncDiff,
bool isForLoop) {
StmtDiff ReverseModeVisitor::DifferentiateLoopBody(
const Stmt* body, LoopCounter& loopCounter, Stmt* condVarDiff,
Stmt* forLoopIncDiff, Stmt* condDiff, bool isForLoop) {
Expr* counterIncrement = loopCounter.getCounterIncrement();
auto* activeBreakContHandler = PushBreakContStmtHandler();
activeBreakContHandler->BeginCFSwitchStmtScope();
@@ -3402,6 +3424,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

if (condDiff) {
if (bodyDiff.getStmt_dx()) {
bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
m_Context, bodyDiff.getStmt_dx(), condDiff));
} else {
bodyDiff.updateStmtDx(condDiff);
}
}

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();
12 changes: 6 additions & 6 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
@@ -858,12 +858,12 @@ double f21 (double x, double y) {
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_y += 1;
//CHECK-NEXT: {
//CHECK-NEXT: y = _t0;
//CHECK-NEXT: double _r_d0 = *_d_y;
//CHECK-NEXT: *_d_y -= _r_d0;
//CHECK-NEXT: *_d_y += 0;
//CHECK-NEXT: y--;
//CHECK-NEXT: *_d_x += _r_d0;
// CHECK-NEXT: y = _t0;
// CHECK-NEXT: double _r_d0 = * _d_y;
// CHECK-NEXT: * _d_y -= _r_d0;
// CHECK-NEXT: * _d_x += _r_d0;
// CHECK-NEXT: * _d_y += 0;
// CHECK-NEXT: y--;
//CHECK-NEXT: }
//CHECK-NEXT: }

Loading

0 comments on commit b6cbe7f

Please sign in to comment.