Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign of loop's body in reverse pass #835

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/build
/.vscode
/inst
27 changes: 19 additions & 8 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ namespace clad {
std::set<clang::SourceLocation> m_ToBeRecorded;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool isInsideLoop = false;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool hasContStmt = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'hasContStmt' has protected visibility [cppcoreguidelines-non-private-member-variables-in-classes]

    bool hasContStmt = false;
         ^

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: invalid case style for protected member 'hasContStmt' [readability-identifier-naming]

Suggested change
bool hasContStmt = false;
bool m_hasContStmt = false;

/// Output variable of vector-valued function
std::string outputArrayStr;
std::vector<Stmts> m_LoopBlock;
Expand Down Expand Up @@ -501,6 +503,15 @@ namespace clad {
}
};

/// Helper function to bring the cases created by a continue or break stmt
/// foward to the loop's body and append them correctly.
/// The statements that belong to the main body of the loop are added directly
/// to the current block, while the cases followed by with their corresponding stmts
/// are stored in a separate vector.
void AppendCaseStmts(llvm::SmallVectorImpl<clang::Stmt*>& curBlock,
llvm::SmallVectorImpl<clang::Stmt*>& cases, clang::Stmt* S,
bool& afterCase);

/// Helper function to differentiate a loop body.
///
///\param[in] body body of the loop
Expand Down Expand Up @@ -542,12 +553,6 @@ namespace clad {
/// PopBreakContStmtHandler();
/// ```
class BreakContStmtHandler {
/// Keeps track of all the created switch cases. It is required
/// because we need to register all the switch cases later with the
/// switch statement that will be used to manage the control flow in
/// the reverse block.
llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases;

/// `m_ControlFlowTape` tape keeps track of which `break`/`continue`
/// statement was hit in which iteration.
/// \note `m_ControlFlowTape` is only initialized if the body contains
Expand All @@ -560,8 +565,6 @@ namespace clad {
/// `break`/`continue` statement.
std::size_t m_CaseCounter = 0;

ReverseModeVisitor& m_RMV;

const bool m_IsInvokedBySwitchStmt = false;
/// Builds and returns a literal expression of type `std::size_t` with
/// `value` as value.
Expand All @@ -577,6 +580,14 @@ namespace clad {
clang::Expr* CreateCFTapePushExpr(std::size_t value);

public:
/// Keeps track of all the created switch cases. It is required
/// because we need to register all the switch cases later with the
/// switch statement that will be used to manage the control flow in
/// the reverse block.
llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'm_SwitchCases' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes]

      llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases;
                                               ^


ReverseModeVisitor& m_RMV;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'm_RMV' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes]

      ReverseModeVisitor& m_RMV;
                          ^


BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false)
: m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {}

Expand Down
229 changes: 189 additions & 40 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -897,9 +897,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Forward, Reverse);
};

llvm::SaveAndRestore<bool> SaveHasContStmt(hasContStmt);
hasContStmt = false;
StmtDiff thenDiff = VisitBranch(If->getThen());
llvm::SaveAndRestore<bool> SaveHasContStmtThen(hasContStmt);
hasContStmt = false;
StmtDiff elseDiff = VisitBranch(If->getElse());


// It is problematic to specify both condVarDecl and cond thorugh
// Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor.
Stmt* Forward = clad_compat::IfStmt_Create(m_Context,
Expand All @@ -920,18 +925,48 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(PushCond, direction::forward);
reverseCond = PopCond;
}
Stmt* Reverse = clad_compat::IfStmt_Create(m_Context,
noLoc,
If->isConstexpr(),
initResult.getStmt_dx(),
condVarClone,
reverseCond,
noLoc,
noLoc,
thenDiff.getStmt_dx(),
noLoc,
elseDiff.getStmt_dx());
addToCurrentBlock(Reverse, direction::reverse);

// if neither then nor else block contains a continue statement,
// we can add the reverse block to the current block.
if (!SaveHasContStmtThen.get() && !hasContStmt){
Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(),
condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc,
elseDiff.getStmt_dx());

addToCurrentBlock(Reverse, direction::reverse);
}
// if both then and else block contain a continue statement,
// we need to add their cases to the current block.
else if (SaveHasContStmtThen.get() && hasContStmt){
addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse);
addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse);
}
// if only then block contains a continue statement, we need to add
// the then block to the current block and create an if stmt for the else block
// afterwards to ensure that in the reverse pass it will be included in the prior case
else if (SaveHasContStmtThen.get()) {
addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse);
if (elseDiff.getStmt_dx()){
Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(),
condVarClone, reverseCond, noLoc, noLoc,
m_Sema.ActOnNullStmt(noLoc).get(), noLoc, elseDiff.getStmt_dx());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we invert the condition instead of using a null stmt?
e.g.

 if(!cond)
   *else stmt*

instead of

 if(cond)
   ;
 else
   *else stmt*

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this approach before but the way I believe this would turn out is checking for each binary condition operator (<=, >=, <, >, ==, !=) and replacing it accordingly. I will look into it more though, thanks for the comment.

addToCurrentBlock(Reverse, direction::reverse);
}
}
// if only else block contains a continue statement, we need to add
// the else block to the current block and create an if stmt for the then block
// afterwards to ensure that in the reverse pass it will be included in the prior case
else if (hasContStmt) {
addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse);
Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(),
condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc,
nullptr);
addToCurrentBlock(Reverse, direction::reverse);
}

CompoundStmt* ForwardBlock = endBlock(direction::forward);
CompoundStmt* ReverseBlock = endBlock(direction::reverse);
endScope();
Expand Down Expand Up @@ -3487,6 +3522,88 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

void ReverseModeVisitor::AppendCaseStmts(llvm::SmallVectorImpl<Stmt*>& curBlock,
llvm::SmallVectorImpl<Stmt*>& cases,
Stmt* S, bool& afterCase) {
if (auto CS = dyn_cast_or_null<CompoundStmt>(S)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto CS' can be declared as 'auto *CS' [llvm-qualified-auto]

Suggested change
if (auto CS = dyn_cast_or_null<CompoundStmt>(S)) {
if (auto *CS = dyn_cast_or_null<CompoundStmt>(S)) {

// create a new list to store the nested stmts
Stmts newBlock;
// This stmts is a compound and not a case
// so its nested stmts do not come immediately after a case.
// The whole compound though may belong to a case stmt,
// hence, we store the original flag's value
SaveAndRestore<bool> SaveAfterCase(afterCase);
afterCase = false;
for (auto stmt : CS->body())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto stmt' can be declared as 'auto *stmt' [llvm-qualified-auto]

Suggested change
for (auto stmt : CS->body())
for (auto *stmt : CS->body())

AppendCaseStmts(newBlock, cases, stmt, afterCase);
if (!newBlock.empty()){
auto Stmts_ref = clad_compat::makeArrayRef(newBlock.data(), newBlock.size());
auto newCS = clad_compat::CompoundStmt_Create(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto newCS' can be declared as 'auto *newCS' [llvm-qualified-auto]

Suggested change
auto newCS = clad_compat::CompoundStmt_Create(
auto *newCS = clad_compat::CompoundStmt_Create(

m_Context, Stmts_ref /**/
CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2(
FPOptionsOverride()),
noLoc, noLoc);
// if the compound belongs to a case, add it to the `cases` vector
// else add it to the main body of the for loop
if (SaveAfterCase.get())
cases.push_back(newCS);
else{
curBlock.push_back(newCS);
}
}
} else if (isa<CaseStmt>(S)) {
afterCase = true;
cases.push_back(S);
} else if (auto If = dyn_cast_or_null<IfStmt>(S)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto If' can be declared as 'auto *If' [llvm-qualified-auto]

Suggested change
} else if (auto If = dyn_cast_or_null<IfStmt>(S)) {
} else if (auto *If = dyn_cast_or_null<IfStmt>(S)) {

if (auto IfThenCS = dyn_cast_or_null<CompoundStmt>(If->getThen())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto IfThenCS' can be declared as 'auto *IfThenCS' [llvm-qualified-auto]

Suggested change
if (auto IfThenCS = dyn_cast_or_null<CompoundStmt>(If->getThen())) {
if (auto *IfThenCS = dyn_cast_or_null<CompoundStmt>(If->getThen())) {

Stmts thenBlock;
SaveAndRestore<bool> SaveAfterCase(afterCase);
afterCase = false;
for (auto stmt : IfThenCS->body())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto stmt' can be declared as 'auto *stmt' [llvm-qualified-auto]

Suggested change
for (auto stmt : IfThenCS->body())
for (auto *stmt : IfThenCS->body())

AppendCaseStmts(thenBlock, cases, stmt, afterCase);
auto Stmts_ref =
clad_compat::makeArrayRef(thenBlock.data(),
thenBlock.size());
auto newThenCS = clad_compat::CompoundStmt_Create(
m_Context,
Stmts_ref /**/
CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2(
FPOptionsOverride()),
noLoc, noLoc);
If->setThen(newThenCS);
}
if (auto IfElseCS = dyn_cast_or_null<CompoundStmt>(If->getElse())) {
Stmts elseBlock;
SaveAndRestore<bool> SaveAfterCase(afterCase);
afterCase = false;
for (auto stmt : IfElseCS->body())
AppendCaseStmts(elseBlock, cases, stmt, afterCase);
auto Stmts_ref =
clad_compat::makeArrayRef(elseBlock.data(),
elseBlock.size());
auto newElseCS = clad_compat::CompoundStmt_Create(
m_Context,
Stmts_ref /**/
CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2(
FPOptionsOverride()),
noLoc, noLoc);
If->setElse(newElseCS);
}
if (afterCase)
cases.push_back(If);
else
curBlock.push_back(If);
} else if (S) {
if (afterCase)
cases.push_back(S);
else
curBlock.push_back(S);
}
// No need to check fo other stmts that have a body,
// as while and for loops as well as do stmts have their own switch.
// Functions and class objects are also independent.
}

StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body,
LoopCounter& loopCounter,
Stmt* condVarDiff,
Expand Down Expand Up @@ -3523,15 +3640,69 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// for forward-pass loop statement body
endScope();
}

Stmts revLoopBlock = m_LoopBlock.back();
utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx());
if (!revLoopBlock.empty())

if (!activeBreakContHandler->m_SwitchCases.empty()) {
// Add case statement in the beginning of the reverse block
// and corresponding push expression for this case statement
// at the end of the forward block to cover the case when no
// `break`/`continue` statements are hit.
auto* lastSC = activeBreakContHandler->GetNextCFCaseStmt();
auto* pushExprToCurrentCase =
activeBreakContHandler->CreateCFTapePushExprToCurrentCase();

Stmt* forwBlock = nullptr;
forwBlock = utils::AppendAndCreateCompoundStmt(
activeBreakContHandler->m_RMV.m_Context, bodyDiff.getStmt(),
pushExprToCurrentCase);
bodyDiff.updateStmt(forwBlock);

bool afterCase = false;
Stmts cases;
AppendCaseStmts(revLoopBlock, cases, bodyDiff.getStmt_dx(), afterCase);
revLoopBlock.append(cases.begin(), cases.end());
revLoopBlock.insert(revLoopBlock.begin(), lastSC);
Stmts revLoopBlockIndexed; // stores the correctly indexed version of the loop's body
bool betweenCase = false;
Stmts curBlockStmts;

// Add the Stmts between cases as SubStmt of the first CaseStmt
if (!revLoopBlock.empty()) {
CaseStmt* curCaseStmt = nullptr;
for (auto revLoopStmt : revLoopBlock) {
if (auto caseStmt = dyn_cast_or_null<CaseStmt>(revLoopStmt)) {
if (!betweenCase) {
betweenCase = true;
} else {
curBlockStmts.push_back(new (m_Context)
BreakStmt(Stmt::EmptyShell())); // compatible with all clang versions
curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts));
curBlockStmts.clear();
}
curCaseStmt = caseStmt;
revLoopBlockIndexed.push_back(caseStmt);
} else {
curBlockStmts.push_back(revLoopStmt);
}
}
curBlockStmts.push_back(new (m_Context) BreakStmt(Stmt::EmptyShell()));
curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts));
bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlockIndexed));
}
}
else{
utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx());
bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock));
}
m_LoopBlock.pop_back();

// Increment statement in the for-loop is only executed if the iteration
// did not end with a break/continue statement. Therefore, forLoopIncDiff
// should be inside the last switch case in the reverse pass.
activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();

// Increment statement in the for-loop should be executed in the beginning for
// every case, hence it should be added prior to the switch statement.
if (forLoopIncDiff) {
if (bodyDiff.getStmt_dx()) {
bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
Expand All @@ -3541,10 +3712,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();

Expr* counterDecrement = loopCounter.getCounterDecrement();

// Create reverse pass loop body statements by arranging various
Expand All @@ -3568,6 +3735,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitContinueStmt(const ContinueStmt* CS) {
hasContStmt = true;
beginBlock(direction::forward);
Stmt* newCS = m_Sema.ActOnContinueStmt(noLoc, getCurrentScope()).get();
auto* activeBreakContHandler = GetActiveBreakContStmtHandler();
Expand Down Expand Up @@ -3659,25 +3827,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt)
return;

// Add case statement in the beginning of the reverse block
// and corresponding push expression for this case statement
// at the end of the forward block to cover the case when no
// `break`/`continue` statements are hit.
auto* lastSC = GetNextCFCaseStmt();
auto* pushExprToCurrentCase = CreateCFTapePushExprToCurrentCase();

Stmt* forwBlock = nullptr;
Stmt* revBlock = nullptr;

forwBlock = utils::AppendAndCreateCompoundStmt(m_RMV.m_Context,
bodyDiff.getStmt(),
pushExprToCurrentCase);
revBlock = utils::PrependAndCreateCompoundStmt(m_RMV.m_Context,
bodyDiff.getStmt_dx(),
lastSC);

bodyDiff = {forwBlock, revBlock};

auto condResult = m_RMV.m_Sema.ActOnCondition(m_RMV.getCurrentScope(),
noLoc, m_ControlFlowTape->Pop,
Sema::ConditionKind::Switch);
Expand Down
Loading