Skip to content

Commit

Permalink
Differentiate for loop conditions in reverse mode (vgvassilev#746)
Browse files Browse the repository at this point in the history
Currently loop conditions are not differentiated in reverse mode.
This change differentiates the loop condition expression.
  • Loading branch information
rohanjulka19 committed Jul 2, 2024
1 parent 645d2b6 commit c5407a1
Show file tree
Hide file tree
Showing 14 changed files with 1,601 additions and 603 deletions.
9 changes: 6 additions & 3 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ namespace clad {
/// Output variable of vector-valued function
std::string outputArrayStr;
std::vector<Stmts> m_LoopBlock;
clang::Expr* m_CurrentBreakFlagExpr;

unsigned outputArrayCursor = 0;
unsigned numParams = 0;
// FIXME: Should we make this an object instead of a pointer?
Expand Down Expand Up @@ -561,9 +563,6 @@ namespace clad {

ReverseModeVisitor& m_RMV;

const bool m_IsInvokedBySwitchStmt = false;
/// Builds and returns a literal expression of type `std::size_t` with
/// `value` as value.
clang::Expr* CreateSizeTLiteralExpr(std::size_t value);

/// Initialise the `m_ControlFlowTape`.
Expand All @@ -576,6 +575,8 @@ namespace clad {
clang::Expr* CreateCFTapePushExpr(std::size_t value);

public:
const bool m_IsInvokedBySwitchStmt = false;

BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false)
: m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {}

Expand All @@ -598,6 +599,8 @@ namespace clad {
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();

clang::Expr* CreateCFTapeBackExprForCurrentCase();

/// Does final modifications on forward and reverse blocks
/// so that `break` and `continue` statements are handled
/// accurately.
Expand Down
116 changes: 98 additions & 18 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff thenDiff = VisitBranch(If->getThen());
StmtDiff elseDiff = VisitBranch(If->getElse());

Stmt* Forward = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr, /*Var=*/nullptr,
condDiffStored, noLoc, noLoc, thenDiff.getStmt(), noLoc,
Expand Down Expand Up @@ -992,6 +991,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Scope::ContinueScope);
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;
const Stmt* init = FS->getInit();
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingLoopInitStmt();
Expand All @@ -1000,7 +1002,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Save the isInsideLoop value (we may be inside another loop).
llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop);
isInsideLoop = true;

StmtDiff condVarRes;
VarDecl* condVarClone = nullptr;
if (FS->getConditionVariable()) {
Expand All @@ -1011,11 +1012,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

// FIXME: for now we assume that cond has no differentiable effects,
// but it is not generally true, e.g. for (...; (x = y); ...)...
StmtDiff cond;
if (FS->getCond())
cond = Visit(FS->getCond());
StmtDiff condDiff;
StmtDiff condExprDiff;
if (FS->getCond()) {
std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond());
}

const auto* IDRE = dyn_cast<DeclRefExpr>(FS->getInc());
const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc();

Expand Down Expand Up @@ -1054,16 +1057,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

const Stmt* body = FS->getBody();
StmtDiff BodyDiff = DifferentiateLoopBody(body, loopCounter,
condVarRes.getStmt_dx(),
incDiff.getStmt_dx(),
/*isForLoop=*/true);
StmtDiff BodyDiff = DifferentiateLoopBody(
body, loopCounter, condVarRes.getStmt_dx(), incDiff.getStmt_dx(),
/*isForLoop=*/true);

/// FIXME: This part in necessary to replace local variables inside loops
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
Expr* forwardCond = cond.getExpr();
Expr* forwardCond = condExprDiff.getExpr();
/// If there is a declaration in the condition, `cond` will be
/// a DeclRefExpr of the declared variable. There is no point in
/// inserting it since condVarRes.getExpr() represents an assignment with
Expand All @@ -1073,8 +1075,34 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (condVarRes.getExpr() != nullptr && isa<Expr>(condVarRes.getExpr()))
forwardCond = cast<Expr>(condVarRes.getExpr());

Stmt* breakStmt = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get();

beginBlock();
if (utils::unwrapIfSingleStmt(condDiff.getStmt())) {
addToCurrentBlock(condDiff.getStmt());
}

Stmt* IfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/
BuildOp(clang::UnaryOperatorKind::UO_LNot, BuildParens(forwardCond)),
/*LPL=*/noLoc, /*RPL=*/noLoc,
/*Then=*/breakStmt,
/*EL=*/noLoc,
/*Else=*/nullptr);
addToCurrentBlock(IfStmt);

Stmt* forwardCondStmts = endBlock();
if (BodyDiff.getStmt()) {
BodyDiff.updateStmt(utils::PrependAndCreateCompoundStmt(
m_Context, BodyDiff.getStmt(), forwardCondStmts));
} else {
BodyDiff.updateStmt(utils::unwrapIfSingleStmt(forwardCondStmts));
}

Stmt* Forward = new (m_Context)
ForStmt(m_Context, initResult.getStmt(), forwardCond, condVarClone,
ForStmt(m_Context, initResult.getStmt(), nullptr, condVarClone,
incResult, BodyDiff.getStmt(), noLoc, noLoc, noLoc);

// Create a condition testing counter for being zero, and its decrement.
Expand All @@ -1084,12 +1112,44 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CounterCondition = loopCounter.getCounterConditionResult().get().second;
Expr* CounterDecrement = loopCounter.getCounterDecrement();

Stmt* ReverseResult = BodyDiff.getStmt_dx();
if (!ReverseResult)
ReverseResult = new (m_Context) NullStmt(noLoc);
beginBlock(direction::reverse);
Stmt* RevIfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/BuildOp(clang::UnaryOperatorKind::UO_LNot, CounterCondition),
/*LPL=*/noLoc, /*RPL=*/noLoc,
/*Then=*/Clone(breakStmt),
/*EL=*/noLoc,
/*Else=*/nullptr);
addToCurrentBlock(RevIfStmt, direction::reverse);

if (condDiff.getStmt_dx()) {
if (m_CurrentBreakFlagExpr) {
Expr* loopBreakFlagCond =
BuildOp(BinaryOperatorKind::BO_LOr,
BuildOp(UnaryOperatorKind::UO_LNot, CounterCondition),
m_CurrentBreakFlagExpr);
auto* RevIfStmt = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, loopBreakFlagCond, noLoc,
noLoc, condDiff.getStmt_dx(), noLoc, nullptr);
addToCurrentBlock(RevIfStmt, direction::reverse);
} else {
addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse);
}
}

Stmt* revPassCondStmts = endBlock(direction::reverse);
if (BodyDiff.getStmt_dx()) {
BodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
m_Context, BodyDiff.getStmt_dx(), revPassCondStmts));
} else {
BodyDiff.updateStmtDx(utils::unwrapIfSingleStmt(revPassCondStmts));
}

Stmt* Reverse = new (m_Context)
ForStmt(m_Context, nullptr, CounterCondition, nullptr, CounterDecrement,
ReverseResult, noLoc, noLoc, noLoc);
ForStmt(m_Context, nullptr, nullptr, nullptr, CounterDecrement,
BodyDiff.getStmt_dx(), noLoc, noLoc, noLoc);

addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
Expand Down Expand Up @@ -2391,8 +2451,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == BO_Comma) {
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Ldiff = Visit(L, zero);
Rdiff = Visit(R, dfdx());
Ldiff = Visit(L, zero);
valueForRevPass = Ldiff.getRevSweepAsExpr();
ResultRef = Ldiff.getExpr();
} else if (opCode == BO_LAnd) {
Expand Down Expand Up @@ -3546,6 +3606,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt();
Stmt* pushExprToCurrentCase = activeBreakContHandler
->CreateCFTapePushExprToCurrentCase();
if (isInsideLoop && !activeBreakContHandler->m_IsInvokedBySwitchStmt) {
Expr* tapeBackExprForCurrentCase =
activeBreakContHandler->CreateCFTapeBackExprForCurrentCase();
if (m_CurrentBreakFlagExpr) {
m_CurrentBreakFlagExpr =
BuildOp(BinaryOperatorKind::BO_LOr, m_CurrentBreakFlagExpr,
tapeBackExprForCurrentCase);

} else {
m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase;
}
}
addToCurrentBlock(pushExprToCurrentCase);
addToCurrentBlock(newBS);
return {endBlock(direction::forward), CFCaseStmt};
Expand Down Expand Up @@ -3607,6 +3679,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CS;
}

Expr* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapeBackExprForCurrentCase() {
return m_RMV.BuildOp(
BinaryOperatorKind::BO_NE, m_ControlFlowTape->Last(),
ConstantFolder::synthesizeLiteral(m_RMV.m_Context.IntTy,
m_RMV.m_Context, m_CaseCounter));
}

Stmt* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapePushExprToCurrentCase() {
if (!m_ControlFlowTape)
Expand Down
Loading

0 comments on commit c5407a1

Please sign in to comment.