-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Redesign of loop's body in reverse pass #835
base: master
Are you sure you want to change the base?
Changes from all commits
06c9a9f
88ef7e8
a825e10
b00e030
215e7f7
bdd422a
7185b6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
/build | ||
/.vscode | ||
/inst |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -58,6 +58,8 @@ namespace clad { | |||||
std::set<clang::SourceLocation> m_ToBeRecorded; | ||||||
/// A flag indicating if the Stmt we are currently visiting is inside loop. | ||||||
bool isInsideLoop = false; | ||||||
/// A flag indicating if the Stmt we are currently visiting is inside loop. | ||||||
bool hasContStmt = false; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: invalid case style for protected member 'hasContStmt' [readability-identifier-naming]
Suggested change
|
||||||
/// Output variable of vector-valued function | ||||||
std::string outputArrayStr; | ||||||
std::vector<Stmts> m_LoopBlock; | ||||||
|
@@ -501,6 +503,15 @@ namespace clad { | |||||
} | ||||||
}; | ||||||
|
||||||
/// Helper function to bring the cases created by a continue or break stmt | ||||||
/// foward to the loop's body and append them correctly. | ||||||
/// The statements that belong to the main body of the loop are added directly | ||||||
/// to the current block, while the cases followed by with their corresponding stmts | ||||||
/// are stored in a separate vector. | ||||||
void AppendCaseStmts(llvm::SmallVectorImpl<clang::Stmt*>& curBlock, | ||||||
llvm::SmallVectorImpl<clang::Stmt*>& cases, clang::Stmt* S, | ||||||
bool& afterCase); | ||||||
|
||||||
/// Helper function to differentiate a loop body. | ||||||
/// | ||||||
///\param[in] body body of the loop | ||||||
|
@@ -542,12 +553,6 @@ namespace clad { | |||||
/// PopBreakContStmtHandler(); | ||||||
/// ``` | ||||||
class BreakContStmtHandler { | ||||||
/// Keeps track of all the created switch cases. It is required | ||||||
/// because we need to register all the switch cases later with the | ||||||
/// switch statement that will be used to manage the control flow in | ||||||
/// the reverse block. | ||||||
llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases; | ||||||
|
||||||
/// `m_ControlFlowTape` tape keeps track of which `break`/`continue` | ||||||
/// statement was hit in which iteration. | ||||||
/// \note `m_ControlFlowTape` is only initialized if the body contains | ||||||
|
@@ -560,8 +565,6 @@ namespace clad { | |||||
/// `break`/`continue` statement. | ||||||
std::size_t m_CaseCounter = 0; | ||||||
|
||||||
ReverseModeVisitor& m_RMV; | ||||||
|
||||||
const bool m_IsInvokedBySwitchStmt = false; | ||||||
/// Builds and returns a literal expression of type `std::size_t` with | ||||||
/// `value` as value. | ||||||
|
@@ -577,6 +580,14 @@ namespace clad { | |||||
clang::Expr* CreateCFTapePushExpr(std::size_t value); | ||||||
|
||||||
public: | ||||||
/// Keeps track of all the created switch cases. It is required | ||||||
/// because we need to register all the switch cases later with the | ||||||
/// switch statement that will be used to manage the control flow in | ||||||
/// the reverse block. | ||||||
llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: member variable 'm_SwitchCases' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes] llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases;
^ |
||||||
|
||||||
ReverseModeVisitor& m_RMV; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: member variable 'm_RMV' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes] ReverseModeVisitor& m_RMV;
^ |
||||||
|
||||||
BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false) | ||||||
: m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {} | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -897,9 +897,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |||||
return StmtDiff(Forward, Reverse); | ||||||
}; | ||||||
|
||||||
llvm::SaveAndRestore<bool> SaveHasContStmt(hasContStmt); | ||||||
hasContStmt = false; | ||||||
StmtDiff thenDiff = VisitBranch(If->getThen()); | ||||||
llvm::SaveAndRestore<bool> SaveHasContStmtThen(hasContStmt); | ||||||
hasContStmt = false; | ||||||
StmtDiff elseDiff = VisitBranch(If->getElse()); | ||||||
|
||||||
|
||||||
// It is problematic to specify both condVarDecl and cond thorugh | ||||||
// Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor. | ||||||
Stmt* Forward = clad_compat::IfStmt_Create(m_Context, | ||||||
|
@@ -920,18 +925,48 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |||||
addToCurrentBlock(PushCond, direction::forward); | ||||||
reverseCond = PopCond; | ||||||
} | ||||||
Stmt* Reverse = clad_compat::IfStmt_Create(m_Context, | ||||||
noLoc, | ||||||
If->isConstexpr(), | ||||||
initResult.getStmt_dx(), | ||||||
condVarClone, | ||||||
reverseCond, | ||||||
noLoc, | ||||||
noLoc, | ||||||
thenDiff.getStmt_dx(), | ||||||
noLoc, | ||||||
elseDiff.getStmt_dx()); | ||||||
addToCurrentBlock(Reverse, direction::reverse); | ||||||
|
||||||
// if neither then nor else block contains a continue statement, | ||||||
// we can add the reverse block to the current block. | ||||||
if (!SaveHasContStmtThen.get() && !hasContStmt){ | ||||||
Stmt* Reverse = clad_compat::IfStmt_Create( | ||||||
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), | ||||||
condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, | ||||||
elseDiff.getStmt_dx()); | ||||||
|
||||||
addToCurrentBlock(Reverse, direction::reverse); | ||||||
} | ||||||
// if both then and else block contain a continue statement, | ||||||
// we need to add their cases to the current block. | ||||||
else if (SaveHasContStmtThen.get() && hasContStmt){ | ||||||
addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse); | ||||||
addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse); | ||||||
} | ||||||
// if only then block contains a continue statement, we need to add | ||||||
// the then block to the current block and create an if stmt for the else block | ||||||
// afterwards to ensure that in the reverse pass it will be included in the prior case | ||||||
else if (SaveHasContStmtThen.get()) { | ||||||
addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse); | ||||||
if (elseDiff.getStmt_dx()){ | ||||||
Stmt* Reverse = clad_compat::IfStmt_Create( | ||||||
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), | ||||||
condVarClone, reverseCond, noLoc, noLoc, | ||||||
m_Sema.ActOnNullStmt(noLoc).get(), noLoc, elseDiff.getStmt_dx()); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we invert the condition instead of using a null stmt?
instead of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about this approach before but the way I believe this would turn out is checking for each binary condition operator (<=, >=, <, >, ==, !=) and replacing it accordingly. I will look into it more though, thanks for the comment. |
||||||
addToCurrentBlock(Reverse, direction::reverse); | ||||||
} | ||||||
} | ||||||
// if only else block contains a continue statement, we need to add | ||||||
// the else block to the current block and create an if stmt for the then block | ||||||
// afterwards to ensure that in the reverse pass it will be included in the prior case | ||||||
else if (hasContStmt) { | ||||||
addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse); | ||||||
Stmt* Reverse = clad_compat::IfStmt_Create( | ||||||
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), | ||||||
condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, | ||||||
nullptr); | ||||||
addToCurrentBlock(Reverse, direction::reverse); | ||||||
} | ||||||
|
||||||
CompoundStmt* ForwardBlock = endBlock(direction::forward); | ||||||
CompoundStmt* ReverseBlock = endBlock(direction::reverse); | ||||||
endScope(); | ||||||
|
@@ -3487,6 +3522,88 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |||||
return {endBlock(direction::forward), endBlock(direction::reverse)}; | ||||||
} | ||||||
|
||||||
void ReverseModeVisitor::AppendCaseStmts(llvm::SmallVectorImpl<Stmt*>& curBlock, | ||||||
llvm::SmallVectorImpl<Stmt*>& cases, | ||||||
Stmt* S, bool& afterCase) { | ||||||
if (auto CS = dyn_cast_or_null<CompoundStmt>(S)) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: 'auto CS' can be declared as 'auto *CS' [llvm-qualified-auto]
Suggested change
|
||||||
// create a new list to store the nested stmts | ||||||
Stmts newBlock; | ||||||
// This stmts is a compound and not a case | ||||||
// so its nested stmts do not come immediately after a case. | ||||||
// The whole compound though may belong to a case stmt, | ||||||
// hence, we store the original flag's value | ||||||
SaveAndRestore<bool> SaveAfterCase(afterCase); | ||||||
afterCase = false; | ||||||
for (auto stmt : CS->body()) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: 'auto stmt' can be declared as 'auto *stmt' [llvm-qualified-auto]
Suggested change
|
||||||
AppendCaseStmts(newBlock, cases, stmt, afterCase); | ||||||
if (!newBlock.empty()){ | ||||||
auto Stmts_ref = clad_compat::makeArrayRef(newBlock.data(), newBlock.size()); | ||||||
auto newCS = clad_compat::CompoundStmt_Create( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: 'auto newCS' can be declared as 'auto *newCS' [llvm-qualified-auto]
Suggested change
|
||||||
m_Context, Stmts_ref /**/ | ||||||
CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2( | ||||||
FPOptionsOverride()), | ||||||
noLoc, noLoc); | ||||||
// if the compound belongs to a case, add it to the `cases` vector | ||||||
// else add it to the main body of the for loop | ||||||
if (SaveAfterCase.get()) | ||||||
cases.push_back(newCS); | ||||||
else{ | ||||||
curBlock.push_back(newCS); | ||||||
} | ||||||
} | ||||||
} else if (isa<CaseStmt>(S)) { | ||||||
afterCase = true; | ||||||
cases.push_back(S); | ||||||
} else if (auto If = dyn_cast_or_null<IfStmt>(S)) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: 'auto If' can be declared as 'auto *If' [llvm-qualified-auto]
Suggested change
|
||||||
if (auto IfThenCS = dyn_cast_or_null<CompoundStmt>(If->getThen())) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: 'auto IfThenCS' can be declared as 'auto *IfThenCS' [llvm-qualified-auto]
Suggested change
|
||||||
Stmts thenBlock; | ||||||
SaveAndRestore<bool> SaveAfterCase(afterCase); | ||||||
afterCase = false; | ||||||
for (auto stmt : IfThenCS->body()) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: 'auto stmt' can be declared as 'auto *stmt' [llvm-qualified-auto]
Suggested change
|
||||||
AppendCaseStmts(thenBlock, cases, stmt, afterCase); | ||||||
auto Stmts_ref = | ||||||
clad_compat::makeArrayRef(thenBlock.data(), | ||||||
thenBlock.size()); | ||||||
auto newThenCS = clad_compat::CompoundStmt_Create( | ||||||
m_Context, | ||||||
Stmts_ref /**/ | ||||||
CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2( | ||||||
FPOptionsOverride()), | ||||||
noLoc, noLoc); | ||||||
If->setThen(newThenCS); | ||||||
} | ||||||
if (auto IfElseCS = dyn_cast_or_null<CompoundStmt>(If->getElse())) { | ||||||
Stmts elseBlock; | ||||||
SaveAndRestore<bool> SaveAfterCase(afterCase); | ||||||
afterCase = false; | ||||||
for (auto stmt : IfElseCS->body()) | ||||||
AppendCaseStmts(elseBlock, cases, stmt, afterCase); | ||||||
auto Stmts_ref = | ||||||
clad_compat::makeArrayRef(elseBlock.data(), | ||||||
elseBlock.size()); | ||||||
auto newElseCS = clad_compat::CompoundStmt_Create( | ||||||
m_Context, | ||||||
Stmts_ref /**/ | ||||||
CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2( | ||||||
FPOptionsOverride()), | ||||||
noLoc, noLoc); | ||||||
If->setElse(newElseCS); | ||||||
} | ||||||
if (afterCase) | ||||||
cases.push_back(If); | ||||||
else | ||||||
curBlock.push_back(If); | ||||||
} else if (S) { | ||||||
if (afterCase) | ||||||
cases.push_back(S); | ||||||
else | ||||||
curBlock.push_back(S); | ||||||
} | ||||||
// No need to check fo other stmts that have a body, | ||||||
// as while and for loops as well as do stmts have their own switch. | ||||||
// Functions and class objects are also independent. | ||||||
} | ||||||
|
||||||
StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body, | ||||||
LoopCounter& loopCounter, | ||||||
Stmt* condVarDiff, | ||||||
|
@@ -3523,15 +3640,69 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |||||
// for forward-pass loop statement body | ||||||
endScope(); | ||||||
} | ||||||
|
||||||
Stmts revLoopBlock = m_LoopBlock.back(); | ||||||
utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx()); | ||||||
if (!revLoopBlock.empty()) | ||||||
|
||||||
if (!activeBreakContHandler->m_SwitchCases.empty()) { | ||||||
// Add case statement in the beginning of the reverse block | ||||||
// and corresponding push expression for this case statement | ||||||
// at the end of the forward block to cover the case when no | ||||||
// `break`/`continue` statements are hit. | ||||||
auto* lastSC = activeBreakContHandler->GetNextCFCaseStmt(); | ||||||
auto* pushExprToCurrentCase = | ||||||
activeBreakContHandler->CreateCFTapePushExprToCurrentCase(); | ||||||
|
||||||
Stmt* forwBlock = nullptr; | ||||||
forwBlock = utils::AppendAndCreateCompoundStmt( | ||||||
activeBreakContHandler->m_RMV.m_Context, bodyDiff.getStmt(), | ||||||
pushExprToCurrentCase); | ||||||
bodyDiff.updateStmt(forwBlock); | ||||||
|
||||||
bool afterCase = false; | ||||||
Stmts cases; | ||||||
AppendCaseStmts(revLoopBlock, cases, bodyDiff.getStmt_dx(), afterCase); | ||||||
revLoopBlock.append(cases.begin(), cases.end()); | ||||||
revLoopBlock.insert(revLoopBlock.begin(), lastSC); | ||||||
Stmts revLoopBlockIndexed; // stores the correctly indexed version of the loop's body | ||||||
bool betweenCase = false; | ||||||
Stmts curBlockStmts; | ||||||
|
||||||
// Add the Stmts between cases as SubStmt of the first CaseStmt | ||||||
if (!revLoopBlock.empty()) { | ||||||
CaseStmt* curCaseStmt = nullptr; | ||||||
for (auto revLoopStmt : revLoopBlock) { | ||||||
if (auto caseStmt = dyn_cast_or_null<CaseStmt>(revLoopStmt)) { | ||||||
if (!betweenCase) { | ||||||
betweenCase = true; | ||||||
} else { | ||||||
curBlockStmts.push_back(new (m_Context) | ||||||
BreakStmt(Stmt::EmptyShell())); // compatible with all clang versions | ||||||
curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts)); | ||||||
curBlockStmts.clear(); | ||||||
} | ||||||
curCaseStmt = caseStmt; | ||||||
revLoopBlockIndexed.push_back(caseStmt); | ||||||
} else { | ||||||
curBlockStmts.push_back(revLoopStmt); | ||||||
} | ||||||
} | ||||||
curBlockStmts.push_back(new (m_Context) BreakStmt(Stmt::EmptyShell())); | ||||||
curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts)); | ||||||
bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlockIndexed)); | ||||||
} | ||||||
} | ||||||
else{ | ||||||
utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx()); | ||||||
bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock)); | ||||||
} | ||||||
m_LoopBlock.pop_back(); | ||||||
|
||||||
// Increment statement in the for-loop is only executed if the iteration | ||||||
// did not end with a break/continue statement. Therefore, forLoopIncDiff | ||||||
// should be inside the last switch case in the reverse pass. | ||||||
activeBreakContHandler->EndCFSwitchStmtScope(); | ||||||
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); | ||||||
PopBreakContStmtHandler(); | ||||||
|
||||||
// Increment statement in the for-loop should be executed in the beginning for | ||||||
// every case, hence it should be added prior to the switch statement. | ||||||
if (forLoopIncDiff) { | ||||||
if (bodyDiff.getStmt_dx()) { | ||||||
bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( | ||||||
|
@@ -3541,10 +3712,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |||||
} | ||||||
} | ||||||
|
||||||
activeBreakContHandler->EndCFSwitchStmtScope(); | ||||||
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); | ||||||
PopBreakContStmtHandler(); | ||||||
|
||||||
Expr* counterDecrement = loopCounter.getCounterDecrement(); | ||||||
|
||||||
// Create reverse pass loop body statements by arranging various | ||||||
|
@@ -3568,6 +3735,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |||||
} | ||||||
|
||||||
StmtDiff ReverseModeVisitor::VisitContinueStmt(const ContinueStmt* CS) { | ||||||
hasContStmt = true; | ||||||
beginBlock(direction::forward); | ||||||
Stmt* newCS = m_Sema.ActOnContinueStmt(noLoc, getCurrentScope()).get(); | ||||||
auto* activeBreakContHandler = GetActiveBreakContStmtHandler(); | ||||||
|
@@ -3659,25 +3827,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, | |||||
if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt) | ||||||
return; | ||||||
|
||||||
// Add case statement in the beginning of the reverse block | ||||||
// and corresponding push expression for this case statement | ||||||
// at the end of the forward block to cover the case when no | ||||||
// `break`/`continue` statements are hit. | ||||||
auto* lastSC = GetNextCFCaseStmt(); | ||||||
auto* pushExprToCurrentCase = CreateCFTapePushExprToCurrentCase(); | ||||||
|
||||||
Stmt* forwBlock = nullptr; | ||||||
Stmt* revBlock = nullptr; | ||||||
|
||||||
forwBlock = utils::AppendAndCreateCompoundStmt(m_RMV.m_Context, | ||||||
bodyDiff.getStmt(), | ||||||
pushExprToCurrentCase); | ||||||
revBlock = utils::PrependAndCreateCompoundStmt(m_RMV.m_Context, | ||||||
bodyDiff.getStmt_dx(), | ||||||
lastSC); | ||||||
|
||||||
bodyDiff = {forwBlock, revBlock}; | ||||||
|
||||||
auto condResult = m_RMV.m_Sema.ActOnCondition(m_RMV.getCurrentScope(), | ||||||
noLoc, m_ControlFlowTape->Pop, | ||||||
Sema::ConditionKind::Switch); | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: member variable 'hasContStmt' has protected visibility [cppcoreguidelines-non-private-member-variables-in-classes]