Skip to content

Commit

Permalink
Update comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Jan 15, 2024
1 parent a726051 commit 44c754b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
12 changes: 7 additions & 5 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ namespace clad {
return VisitorBase::Clone(E);
}

/// Add more comments.
/// Emits declarations in the reverse sweep of variables that were declared
/// locally inside some blocks in the forward sweep. Needs to be called
/// after all of the usages of the variable in the reverse sweep since
/// the declaration should come before the usages when the statement order
/// is reversed.
void EmitRevSweepDecls();

/// Output a statement to the current block. If Stmt is null or is an unused
Expand Down Expand Up @@ -504,15 +508,13 @@ namespace clad {
/// Helper function to differentiate a loop body.
///
///\param[in] body body of the loop
///\param[in] cond condition of the loop
///\param[in] loopCounter associated `LoopCounter` object of the loop.
///\param[in] condVarDiff derived statements of the condition
/// variable, if any.
///\param[in] forLoopIncDiff derived statements of the `for` loop
/// increment statement, if any.
///\param[in] isForLoop should be true if we are differentiating a `for`
/// loop body; otherwise false.
///\returns {forward pass statements, reverse pass statements} for the loop
/// body.
///\returns {differentiated loop body, differentiated loop condition}
std::pair<StmtDiff, StmtDiff> DifferentiateLoopBody(
const clang::Stmt* body,
const clang::Stmt* cond,
Expand Down
26 changes: 12 additions & 14 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff SDiff = DifferentiateSingleStmt(S);
addToCurrentBlock(SDiff.getStmt(), direction::forward);

/// FIXME: Explain this.
/// FIXME: reverse sweep declarations are sometimes emitted when
/// differentiating return/break/continue statements and so wrapping
/// the result in the reverse sweep in a compound statement will
/// limit visibility of the emitted declarations. Therefore, we need
/// to unwrap it.
if ((isa<ReturnStmt>(S) || isa<BreakStmt>(S) || isa<ContinueStmt>(S))
&& isa<CompoundStmt>(SDiff.getStmt_dx())) {
auto* returnCS = cast<CompoundStmt>(SDiff.getStmt_dx());
Expand Down Expand Up @@ -2729,7 +2733,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}
m_Variables.emplace(VDClone, derivedVDE);
/// FIXME: Add more comments.
/// If the declaration is not located in the function global scope,
/// we have to redeclare it in the reverse sweep since it won't be
/// visible otherwise.
/// FIXME: Is there a better way to determine non-function-global scope
/// then relying on the size of m_Blocks.
if (m_Blocks.size()>2) {
clang::QualType type = CloneType(VDClone->getType());
if (type->isArrayType()) {
Expand All @@ -2743,6 +2751,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
clang::VarDecl* VDRevSweep = BuildVarDecl(type, VD->getIdentifier(), init, false, nullptr,
clang::VarDecl::InitializationStyle::CInit,
/*pushOnScopeChains=*/false);
/// Insert the reverse sweep declarations into m_Locals to emit it
/// to the reverse pass later.
m_Locals.back().emplace(VDClone, VDRevSweep);
}
return VarDeclDiff(VDClone, VDDerived);
Expand Down Expand Up @@ -3338,18 +3348,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock));
m_LoopBlock.pop_back();

// Increment statement in the for-loop is only executed if the iteration
// did not end with a break/continue statement. Therefore, forLoopIncDiff
// should be inside the last switch case in the reverse pass.
// if (forLoopIncDiff) {
// if (bodyDiff.getStmt_dx()) {
// bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
// m_Context, bodyDiff.getStmt_dx(), forLoopIncDiff));
// } else {
// bodyDiff.updateStmtDx(forLoopIncDiff);
// }
// }

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();
Expand Down

0 comments on commit 44c754b

Please sign in to comment.