Skip to content

Commit

Permalink
Use different break cond for each loop and remove the now unnecessary…
Browse files Browse the repository at this point in the history
… tape size check
  • Loading branch information
kchristin22 authored and vgvassilev committed Oct 29, 2024
1 parent 9ffefc1 commit 1bae2cd
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 66 deletions.
5 changes: 0 additions & 5 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
return of.back();
}

/// Return the size of the tape.
template <typename T> CUDA_HOST_DEVICE std::size_t size(tape<T>& of) {
return of.size();
}

/// The purpose of this function is to initialize adjoints
/// (or all of its differentiable fields) with 0.
// FIXME: Add support for objects.
Expand Down
6 changes: 0 additions & 6 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,6 @@ namespace clad {
/// (clad::back(Ref)). Since it is required only rarely, it is built on
/// demand in the method.
clang::Expr* Last();
/// A request to get the size of the tape (clad::size(Ref)).
clang::Expr* Size();
};

/// Make a clad::tape to store variables.
Expand Down Expand Up @@ -666,10 +664,6 @@ namespace clad {
/// by their actual values respectively
clang::Expr* CreateCFTapeBackExprForCurrentCase();

/// Builds and returns `clad::size(TapeRef) != 0` expression,
/// where `TapeRef` is replaced by its actual value
clang::Expr* CreateCFTapeSizeExprForCurrentCase();

/// Does final modifications on forward and reverse blocks
/// so that `break` and `continue` statements are handled
/// accurately.
Expand Down
3 changes: 1 addition & 2 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,9 @@ namespace clad {
clang::TemplateDecl* GetCladTapeDecl();
/// Perform a lookup into clad namespace for an entity with given name.
clang::LookupResult LookupCladTapeMethod(llvm::StringRef name);
/// Perform lookup into clad namespace for push/pop/back/size. Returns
/// Perform lookup into clad namespace for push/pop/back. Returns
/// LookupResult, which is will be resolved later (which is handy since they
/// are templates).
clang::LookupResult& GetCladTapeSize();
clang::LookupResult& GetCladTapePush();
clang::LookupResult& GetCladTapePop();
clang::LookupResult& GetCladTapeBack();
Expand Down
34 changes: 7 additions & 27 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return Call;
}

Expr* ReverseModeVisitor::CladTapeResult::Size() {
LookupResult& TapeSize = V.GetCladTapeSize();
CXXScopeSpec CSS;
CSS.Extend(V.m_Context, V.GetCladNamespace(), noLoc, noLoc);
Expr* SizeDRE = V.m_Sema
.BuildDeclarationNameExpr(CSS, TapeSize,
/*AcceptInvalidDecl=*/false)
.get();
Expr* Call =
V.m_Sema.ActOnCallExpr(V.getCurrentScope(), SizeDRE, noLoc, Ref, noLoc)
.get();
return Call;
}

ReverseModeVisitor::CladTapeResult
ReverseModeVisitor::MakeCladTapeFor(Expr* E, llvm::StringRef prefix) {
assert(E && "must be provided");
Expand Down Expand Up @@ -3808,6 +3794,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop);
isInsideLoop = true;
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;

Expr* condClone = (WS->getCond() ? Clone(WS->getCond()) : nullptr);
const VarDecl* condVarDecl = WS->getConditionVariable();
Expand Down Expand Up @@ -3866,6 +3855,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

llvm::SaveAndRestore<bool> SaveIsInsideLoop(isInsideLoop);
isInsideLoop = true;
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;

Expr* clonedCond = (DS->getCond() ? Clone(DS->getCond()) : nullptr);

Expand Down Expand Up @@ -4203,11 +4195,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildOp(BinaryOperatorKind::BO_LAnd, m_CurrentBreakFlagExpr,
tapeBackExprForCurrentCase);
} else {
Expr* tapeSizeExprForCurrentCase =
activeBreakContHandler->CreateCFTapeSizeExprForCurrentCase();
m_CurrentBreakFlagExpr =
BuildOp(BinaryOperatorKind::BO_LAnd, tapeSizeExprForCurrentCase,
tapeBackExprForCurrentCase);
m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase;
}
}
addToCurrentBlock(pushExprToCurrentCase);
Expand Down Expand Up @@ -4286,14 +4274,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CreateCFTapePushExpr(m_CaseCounter);
}

Expr* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapeSizeExprForCurrentCase() {
return m_RMV.BuildOp(
BinaryOperatorKind::BO_NE, m_ControlFlowTape->Size(),
ConstantFolder::synthesizeLiteral(m_RMV.m_Context.IntTy,
m_RMV.m_Context, /*val=*/0));
}

void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks(
StmtDiff& bodyDiff) {
if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt)
Expand Down
7 changes: 0 additions & 7 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,13 +580,6 @@ namespace clad {
return clad_compat::llvm_Optional_GetValue(Result);
}

LookupResult& VisitorBase::GetCladTapeSize() {
static clad_compat::llvm_Optional<LookupResult> Result{};
if (!Result)
Result = LookupCladTapeMethod("size");
return clad_compat::llvm_Optional_GetValue(Result);
}

QualType VisitorBase::GetCladTapeOfType(QualType T) {
return InstantiateTemplate(GetCladTapeDecl(), {T});
}
Expand Down
37 changes: 18 additions & 19 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,7 @@ double fn16(double i, double j) {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))
// CHECK-NEXT: --ii;
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case {{3U|3UL|3ULL}}:
Expand Down Expand Up @@ -1439,13 +1439,12 @@ double fn17(double i, double j) {
// CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}});
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) {
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t5) != 0 && clad::back(_t5) != 1))
// CHECK-NEXT: --ii;
// CHECK-NEXT: --ii;
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case {{2U|2UL|2ULL}}:
// CHECK-NEXT: ;
Expand Down Expand Up @@ -1561,7 +1560,7 @@ double fn18(double i, double j) {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 2))
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 2))
// CHECK-NEXT: --counter;
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case {{3U|3UL|3ULL}}:
Expand Down Expand Up @@ -1891,7 +1890,7 @@ double fn23(double i, double j) {
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))) {
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: _d_res = 0.;
Expand All @@ -1901,7 +1900,7 @@ double fn23(double i, double j) {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))
// CHECK-NEXT: --c;
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand Down Expand Up @@ -2003,7 +2002,7 @@ double fn25(double i, double j) {
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t3) != 1))) {
// CHECK-NEXT: _d_res += 0;
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
Expand All @@ -2013,7 +2012,7 @@ double fn25(double i, double j) {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t3) != 1))
// CHECK-NEXT: --c;
// CHECK-NEXT: switch (clad::pop(_t3)) {
// CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand Down Expand Up @@ -2076,7 +2075,7 @@ double fn26(double i, double j) {
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t3) != 1))) {
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: *_d_i += _r_d0 * j;
Expand Down Expand Up @@ -2151,7 +2150,7 @@ double fn27(double i, double j) {
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))) {
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: *_d_i += _r_d0 * j;
Expand All @@ -2160,7 +2159,7 @@ double fn27(double i, double j) {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))
// CHECK-NEXT: --c;
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand Down Expand Up @@ -2477,7 +2476,7 @@ double fn32(double i, double j) {
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1))) {
//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations1 || (clad::back(_t8) != 1))) {
//CHECK-NEXT: res = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: *_d_i += _r_d0 * j;
Expand All @@ -2486,7 +2485,7 @@ double fn32(double i, double j) {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1))
//CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::back(_t8) != 1))
//CHECK-NEXT: --c;
//CHECK-NEXT: switch (clad::pop(_t8)) {
//CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand All @@ -2507,7 +2506,7 @@ double fn32(double i, double j) {
//CHECK-NEXT: {
//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = clad::back(_t2); ; clad::back(_t2)--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1))) {
//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t2) != _numRevIterations0 || (clad::back(_t6) != 1))) {
//CHECK-NEXT: res = clad::pop(_t4);
//CHECK-NEXT: double _r_d1 = _d_res;
//CHECK-NEXT: *_d_i += _r_d1 * j;
Expand All @@ -2516,7 +2515,7 @@ double fn32(double i, double j) {
//CHECK-NEXT: if (!clad::back(_t2))
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: if (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1))
//CHECK-NEXT: if (clad::back(_t2) != _numRevIterations0 || (clad::back(_t6) != 1))
//CHECK-NEXT: --d;
//CHECK-NEXT: switch (clad::pop(_t6)) {
//CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand Down Expand Up @@ -2629,7 +2628,7 @@ double fn33(double i, double j) {
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2))) {
//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t4) != 1 && clad::back(_t4) != 2))) {
//CHECK-NEXT: res = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_res = 0.;
Expand All @@ -2639,7 +2638,7 @@ double fn33(double i, double j) {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2))
//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t4) != 1 && clad::back(_t4) != 2))
//CHECK-NEXT: --c;
//CHECK-NEXT: switch (clad::pop(_t4)) {
//CHECK-NEXT: case {{3U|3UL|3ULL}}:
Expand Down Expand Up @@ -3238,7 +3237,7 @@ double fn41(double u, double v) {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))
//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))
//CHECK-NEXT: i--;
//CHECK-NEXT: switch (clad::pop(_t2)) {
//CHECK-NEXT: case {{2U|2UL}}:
Expand Down

0 comments on commit 1bae2cd

Please sign in to comment.