Skip to content

Commit

Permalink
Add declarations to the reverse sweep.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Jan 22, 2024
1 parent d8ef790 commit 1661a22
Show file tree
Hide file tree
Showing 46 changed files with 2,228 additions and 1,413 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/ErrorEstimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class ErrorEstimationHandler : public ExternalRMVSource {
void ActOnEndOfDerivedFnBody() override;
void ActBeforeDifferentiatingStmtInVisitCompoundStmt() override;
void ActAfterProcessingStmtInVisitCompoundStmt() override;
void ActBeforeDifferentiatingLoopCondition() override;
void ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt() override;
void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() override;
void ActBeforeDifferentiatingLoopInitStmt() override;
Expand Down
3 changes: 3 additions & 0 deletions include/clad/Differentiator/ExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class ExternalRMVSource {
/// branch in `VisitBranch` lambda in
virtual void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() {}

/// This is called just before differentiating loop conditions.
virtual void ActBeforeDifferentiatingLoopCondition() {}

/// This is called just before differentiating init statement of loops.
virtual void ActBeforeDifferentiatingLoopInitStmt() {}

Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/MultiplexExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MultiplexExternalRMVSource : public ExternalRMVSource {
void ActAfterProcessingStmtInVisitCompoundStmt() override;
void ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt() override;
void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() override;
void ActBeforeDifferentiatingLoopCondition() override;
void ActBeforeDifferentiatingLoopInitStmt() override;
void ActBeforeDifferentiatingSingleStmtLoopBody() override;
void ActAfterProcessingSingleStmtBodyInVisitForLoop() override;
Expand Down
82 changes: 59 additions & 23 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ namespace clad {
/// the reverse mode we also accumulate Stmts for the reverse pass which
/// will be executed on return.
std::vector<Stmts> m_Reverse;
/// Accumulates local variables for all visited blocks.
std::vector<std::map<clang::VarDecl*, clang::VarDecl*>> m_Locals;
/// Stores all expressions used as placeholders which have to be
/// reset later.
std::set<const clang::Expr*> m_Placeholders;
/// Stack is used to pass the arguments (dfdx) to further nodes
/// in the Visit method.
std::stack<clang::Expr*> m_Stack;
Expand All @@ -51,7 +56,7 @@ namespace clad {
//// A reference to the output parameter of the gradient function.
clang::Expr* m_Result;
/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// tells isToBeRecorded whether a variable with a given
/// SourceLocation has to be stored before being changed or not.
std::set<clang::SourceLocation> m_ToBeRecorded;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
Expand Down Expand Up @@ -142,15 +147,18 @@ namespace clad {
}
/// Create new block.
Stmts& beginBlock(direction d = direction::forward) {
if (d == direction::forward)
if (d == direction::forward) {
m_Blocks.emplace_back();
else
m_Locals.emplace_back();
} else {
m_Reverse.emplace_back();
}
return getCurrentBlock(d);
}
/// Remove the block from the stack, wrap it in CompoundStmt and return it.
clang::CompoundStmt* endBlock(direction d = direction::forward) {
if (d == direction::forward) {
EmitReverseSweepDeclarations();
auto* CS = MakeCompoundStmt(getCurrentBlock(direction::forward));
m_Blocks.pop_back();
return CS;
Expand All @@ -164,12 +172,31 @@ namespace clad {

Stmts EndBlockWithoutCreatingCS(direction d = direction::forward) {
auto blk = getCurrentBlock(d);
if (d == direction::forward)
if (d == direction::forward) {
m_Blocks.pop_back();
else
EmitReverseSweepDeclarations();
} else
m_Reverse.pop_back();
return blk;
}

clang::Expr* Clone(const clang::Expr* E) {
// Placeholders should not be cloned since otherwise we will not be able
// to find and replace them later.
if (m_Placeholders.find(E) != m_Placeholders.end())
return const_cast<clang::Expr*>(E);
return VisitorBase::Clone(E);
}

clang::Stmt* Clone(const clang::Stmt* E) { return VisitorBase::Clone(E); }

/// Emits declarations in the reverse sweep of variables that were declared
/// locally inside some blocks in the forward sweep. Needs to be called
/// after all of the usages of the variable in the reverse sweep since
/// the declaration should come before the usages when the statement order
/// is reversed.
void EmitReverseSweepDeclarations(bool endCurBlock = true);

/// Output a statement to the current block. If Stmt is null or is an unused
/// expression, it is not output and false is returned.
bool addToCurrentBlock(clang::Stmt* S, direction d = direction::forward) {
Expand Down Expand Up @@ -222,6 +249,10 @@ namespace clad {
/// This is the central point for checkpointing.
bool ShouldRecompute(const clang::Expr* E);

/// For a given location, tells if the corresponding expression should
/// be stored in the forward pass and then restored in the reverse pass.
bool isToBeRecorded(clang::SourceLocation loc);

/// Builds a variable declaration and stores it in the function
/// global scope.
///
Expand Down Expand Up @@ -253,8 +284,11 @@ namespace clad {
bool force = false);
StmtDiff BuildPushPop(clang::Expr* E, clang::QualType Type,
llvm::StringRef prefix = "_t", bool force = false);
StmtDiff StoreAndRestore(clang::Expr* E, llvm::StringRef prefix = "_t",
bool force = false);
StmtDiff StoreAndRestore(StmtDiff SD, llvm::StringRef prefix = "_t",
clang::SourceLocation loc = {});
StmtDiff StoreAndRestore(StmtDiff SD, clang::QualType Type,
llvm::StringRef prefix = "_t",
clang::SourceLocation loc = {});

//// A type returned by DelayedGlobalStoreAndRef
/// .Result is a reference to the created (yet uninitialized) global
Expand All @@ -266,15 +300,15 @@ namespace clad {
struct DelayedStoreResult {
ReverseModeVisitor& V;
StmtDiff Result;
bool isConstant;
bool isInsideLoop;
bool needsUpdate;
clang::Expr* Placeholder;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
bool pIsConstant, bool pIsInsideLoop,
bool pNeedsUpdate = false)
: V(pV), Result(pResult), isConstant(pIsConstant),
isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {}
void Finalize(clang::Expr* New);
bool pIsInsideLoop, bool pNeedsUpdate = false,
clang::Expr* pPlaceholder = nullptr)
: V(pV), Result(pResult), isInsideLoop(pIsInsideLoop),
needsUpdate(pNeedsUpdate), Placeholder(pPlaceholder) {}
void Finalize(StmtDiff New);
};

/// Sometimes (e.g. when visiting multiplication/division operator), we
Expand All @@ -286,7 +320,8 @@ namespace clad {
/// This is what DelayedGlobalStoreAndRef does. E is expected to be the
/// original (uncloned) expression.
DelayedStoreResult DelayedGlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t");
llvm::StringRef prefix = "_t",
bool forceNoRecompute = false);

struct CladTapeResult {
ReverseModeVisitor& V;
Expand All @@ -313,6 +348,9 @@ namespace clad {
CladTapeResult MakeCladTapeFor(clang::Expr* E,
llvm::StringRef prefix = "_t");

CladTapeResult MakeCladTapeFor(clang::Expr* E, clang::QualType type,
llvm::StringRef prefix = "_t");

public:
ReverseModeVisitor(DerivativeBuilder& builder);
virtual ~ReverseModeVisitor();
Expand Down Expand Up @@ -468,20 +506,18 @@ namespace clad {
/// Helper function to differentiate a loop body.
///
///\param[in] body body of the loop
///\param[in] cond condition of the loop
///\param[in] loopCounter associated `LoopCounter` object of the loop.
///\param[in] condVarDiff derived statements of the condition
/// variable, if any.
///\param[in] forLoopIncDiff derived statements of the `for` loop
/// increment statement, if any.
///\param[in] isForLoop should be true if we are differentiating a `for`
/// loop body; otherwise false.
///\returns {forward pass statements, reverse pass statements} for the loop
/// body.
StmtDiff DifferentiateLoopBody(const clang::Stmt* body,
LoopCounter& loopCounter,
clang::Stmt* condVarDifff = nullptr,
clang::Stmt* forLoopIncDiff = nullptr,
bool isForLoop = false);
///\returns {differentiated loop body, differentiated loop condition}
std::pair<StmtDiff, StmtDiff>
DifferentiateLoopBody(const clang::Stmt* body, const clang::Stmt* cond,
LoopCounter& loopCounter,
clang::Stmt* forLoopIncDiff = nullptr,
bool isForLoop = false);

/// This class modifies forward and reverse blocks of the loop
/// body so that `break` and `continue` statements are correctly
Expand Down
14 changes: 8 additions & 6 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,14 @@ namespace clad {
/// C style initalization.
/// \param[in] TSI The type source information of the variable declaration.
/// \returns The newly built variable declaration.
clang::VarDecl*
BuildVarDecl(clang::QualType Type, clang::IdentifierInfo* Identifier,
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
clang::VarDecl* BuildVarDecl(clang::QualType Type,
clang::IdentifierInfo* Identifier,
clang::Expr* Init = nullptr,
bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit,
bool pushOnScopeChains = true);
/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ namespace clad {
CompoundStmt* AppendAndCreateCompoundStmt(ASTContext& C, Stmt* initial,
Stmt* S) {
llvm::SmallVector<Stmt*, 16> block;
assert(isa<CompoundStmt>(initial) &&
"initial should be of type `clang::CompoundStmt`");
CompoundStmt* CS = dyn_cast<CompoundStmt>(initial);
if (CS)
block.append(CS->body_begin(), CS->body_end());
else
block.push_back(initial);
block.push_back(S);
auto stmtsRef = clad_compat::makeArrayRef(block.begin(), block.end());
return clad_compat::CompoundStmt_Create(C, stmtsRef /**/ CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam1(CS), noLoc, noLoc);
Expand Down
4 changes: 4 additions & 0 deletions lib/Differentiator/ErrorEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,10 @@ void ErrorEstimationHandler::ActAfterProcessingStmtInVisitCompoundStmt() {
EmitErrorEstimationStmts(direction::reverse);
}

void ErrorEstimationHandler::ActBeforeDifferentiatingLoopCondition() {
m_ShouldEmit.push(true);
}

void ErrorEstimationHandler::
ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt() {
m_ShouldEmit.push(true);
Expand Down
5 changes: 5 additions & 0 deletions lib/Differentiator/MultiplexExternalRMVSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ void MultiplexExternalRMVSource::
}
}

void MultiplexExternalRMVSource::ActBeforeDifferentiatingLoopCondition() {
for (auto* source : m_Sources)
source->ActBeforeDifferentiatingLoopCondition();
}

void MultiplexExternalRMVSource::ActBeforeDifferentiatingLoopInitStmt() {
for (auto source : m_Sources) {
source->ActBeforeDifferentiatingLoopInitStmt();
Expand Down
Loading

0 comments on commit 1661a22

Please sign in to comment.