Skip to content

Commit

Permalink
Differentiate for loop conditions in reverse mode (#746)
Browse files Browse the repository at this point in the history
This change differentiates the loop condition expression.
Additionaly if in forward pass a loop terminates pre-maturely due to break stmt.
The reverse pass should start differentiation with break statment and
not the loop condition differentiation. This change keeps track of whether the break
was called in forward pass and based on that in reverse mode it is decided
whether the loop differentiation is skipped for the first iteration or not.
  • Loading branch information
rohanjulka19 committed Jul 13, 2024
1 parent ef668c7 commit 1f4a7e6
Show file tree
Hide file tree
Showing 14 changed files with 1,721 additions and 509 deletions.
13 changes: 12 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ namespace clad {
/// Output variable of vector-valued function
std::string outputArrayStr;
std::vector<Stmts> m_LoopBlock;
/// This expression checks if the forward pass loop was terminted due to
/// break. It is used to determine whether to run the loop cond
/// differentiation. One additional time.
clang::Expr* m_CurrentBreakFlagExpr;

unsigned outputArrayCursor = 0;
unsigned numParams = 0;
// FIXME: Should we make this an object instead of a pointer?
Expand Down Expand Up @@ -561,7 +566,6 @@ namespace clad {

ReverseModeVisitor& m_RMV;

const bool m_IsInvokedBySwitchStmt = false;
/// Builds and returns a literal expression of type `std::size_t` with
/// `value` as value.
clang::Expr* CreateSizeTLiteralExpr(std::size_t value);
Expand All @@ -576,6 +580,8 @@ namespace clad {
clang::Expr* CreateCFTapePushExpr(std::size_t value);

public:
bool m_IsInvokedBySwitchStmt = false;

BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false)
: m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {}

Expand All @@ -598,6 +604,11 @@ namespace clad {
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();

/// Builds and return `clad::back(TapeRef) != m_CaseCounter`
/// expression, where `TapeRef` and `m_CaseCounter` are replaced
/// by their actual values respectively
clang::Expr* CreateCFTapeBackExprForCurrentCase();

/// Does final modifications on forward and reverse blocks
/// so that `break` and `continue` statements are handled
/// accurately.
Expand Down
109 changes: 96 additions & 13 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff thenDiff = VisitBranch(If->getThen());
StmtDiff elseDiff = VisitBranch(If->getElse());

Stmt* Forward = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr, /*Var=*/nullptr,
condDiffStored, noLoc, noLoc, thenDiff.getStmt(), noLoc,
Expand Down Expand Up @@ -992,6 +991,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Scope::ContinueScope);
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;
const Stmt* init = FS->getInit();
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingLoopInitStmt();
Expand All @@ -1000,7 +1002,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Save the isInsideLoop value (we may be inside another loop).
llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop);
isInsideLoop = true;

StmtDiff condVarRes;
VarDecl* condVarClone = nullptr;
if (FS->getConditionVariable()) {
Expand All @@ -1011,11 +1012,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

// FIXME: for now we assume that cond has no differentiable effects,
// but it is not generally true, e.g. for (...; (x = y); ...)...
StmtDiff cond;
StmtDiff condDiff;
StmtDiff condExprDiff;
if (FS->getCond())
cond = Visit(FS->getCond());
std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond());

const auto* IDRE = dyn_cast<DeclRefExpr>(FS->getInc());
const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc();

Expand Down Expand Up @@ -1063,7 +1065,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
Expr* forwardCond = cond.getExpr();
Expr* forwardCond = condExprDiff.getExpr();
/// If there is a declaration in the condition, `cond` will be
/// a DeclRefExpr of the declared variable. There is no point in
/// inserting it since condVarRes.getExpr() represents an assignment with
Expand All @@ -1073,8 +1075,36 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (condVarRes.getExpr() != nullptr && isa<Expr>(condVarRes.getExpr()))
forwardCond = cast<Expr>(condVarRes.getExpr());

Stmt* breakStmt = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get();

/// This part adds the forward pass of loop condition stmt in the body
/// In this first loop condition diff stmts execute then loop condition
/// is checked if and loop is terminated.
beginBlock();
if (utils::unwrapIfSingleStmt(condDiff.getStmt()))
addToCurrentBlock(condDiff.getStmt());

Stmt* IfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/
BuildOp(clang::UnaryOperatorKind::UO_LNot, BuildParens(forwardCond)),
/*LPL=*/noLoc, /*RPL=*/noLoc,
/*Then=*/breakStmt,
/*EL=*/noLoc,
/*Else=*/nullptr);
addToCurrentBlock(IfStmt);

Stmt* forwardCondStmts = endBlock();
if (BodyDiff.getStmt()) {
BodyDiff.updateStmt(utils::PrependAndCreateCompoundStmt(
m_Context, BodyDiff.getStmt(), forwardCondStmts));
} else {
BodyDiff.updateStmt(utils::unwrapIfSingleStmt(forwardCondStmts));
}

Stmt* Forward = new (m_Context)
ForStmt(m_Context, initResult.getStmt(), forwardCond, condVarClone,
ForStmt(m_Context, initResult.getStmt(), nullptr, condVarClone,
incResult, BodyDiff.getStmt(), noLoc, noLoc, noLoc);

// Create a condition testing counter for being zero, and its decrement.
Expand All @@ -1084,12 +1114,45 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CounterCondition = loopCounter.getCounterConditionResult().get().second;
Expr* CounterDecrement = loopCounter.getCounterDecrement();

Stmt* ReverseResult = BodyDiff.getStmt_dx();
if (!ReverseResult)
ReverseResult = new (m_Context) NullStmt(noLoc);
/// This part adds the reverse pass of loop condition stmt in the body
beginBlock(direction::reverse);
Stmt* RevIfStmt = clad_compat::IfStmt_Create(
/*Ctx=*/m_Context, /*IL=*/noLoc, /*IsConstexpr=*/false,
/*Init=*/nullptr, /*Var=*/nullptr,
/*Cond=*/BuildOp(clang::UnaryOperatorKind::UO_LNot, CounterCondition),
/*LPL=*/noLoc, /*RPL=*/noLoc,
/*Then=*/Clone(breakStmt),
/*EL=*/noLoc,
/*Else=*/nullptr);
addToCurrentBlock(RevIfStmt, direction::reverse);

if (condDiff.getStmt_dx()) {
if (m_CurrentBreakFlagExpr) {
Expr* loopBreakFlagCond =
BuildOp(BinaryOperatorKind::BO_LOr,
BuildOp(UnaryOperatorKind::UO_LNot, CounterCondition),
m_CurrentBreakFlagExpr);
auto* RevIfStmt = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, loopBreakFlagCond, noLoc,
noLoc, condDiff.getStmt_dx(), noLoc, nullptr);
addToCurrentBlock(RevIfStmt, direction::reverse);
} else {
addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse);
}
}

Stmt* revPassCondStmts = endBlock(direction::reverse);
if (BodyDiff.getStmt_dx()) {
BodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
m_Context, BodyDiff.getStmt_dx(), revPassCondStmts));
} else {
BodyDiff.updateStmtDx(utils::unwrapIfSingleStmt(revPassCondStmts));
}

Stmt* Reverse = new (m_Context)
ForStmt(m_Context, nullptr, CounterCondition, nullptr, CounterDecrement,
ReverseResult, noLoc, noLoc, noLoc);
ForStmt(m_Context, nullptr, nullptr, nullptr, CounterDecrement,
BodyDiff.getStmt_dx(), noLoc, noLoc, noLoc);

addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
Expand Down Expand Up @@ -2391,8 +2454,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == BO_Comma) {
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Ldiff = Visit(L, zero);
Rdiff = Visit(R, dfdx());
Ldiff = Visit(L, zero);
valueForRevPass = Ldiff.getRevSweepAsExpr();
ResultRef = Ldiff.getExpr();
} else if (opCode == BO_LAnd) {
Expand Down Expand Up @@ -3546,6 +3609,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt();
Stmt* pushExprToCurrentCase = activeBreakContHandler
->CreateCFTapePushExprToCurrentCase();
if (isInsideLoop && !activeBreakContHandler->m_IsInvokedBySwitchStmt) {
Expr* tapeBackExprForCurrentCase =
activeBreakContHandler->CreateCFTapeBackExprForCurrentCase();
if (m_CurrentBreakFlagExpr) {
m_CurrentBreakFlagExpr =
BuildOp(BinaryOperatorKind::BO_LOr, m_CurrentBreakFlagExpr,
tapeBackExprForCurrentCase);

} else {
m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase;
}
}
addToCurrentBlock(pushExprToCurrentCase);
addToCurrentBlock(newBS);
return {endBlock(direction::forward), CFCaseStmt};
Expand Down Expand Up @@ -3607,6 +3682,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CS;
}

Expr* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapeBackExprForCurrentCase() {
return m_RMV.BuildOp(
BinaryOperatorKind::BO_NE, m_ControlFlowTape->Last(),
ConstantFolder::synthesizeLiteral(m_RMV.m_Context.IntTy,
m_RMV.m_Context, m_CaseCounter));
}

Stmt* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapePushExprToCurrentCase() {
if (!m_ControlFlowTape)
Expand Down
Loading

0 comments on commit 1f4a7e6

Please sign in to comment.