Skip to content

Commit

Permalink
Do not store the RHS of multiplication with no side effects.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Nov 29, 2023
1 parent a3415b8 commit 78cb225
Show file tree
Hide file tree
Showing 39 changed files with 914 additions and 1,700 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ namespace clad {
/// need to decide what needs to be stored on tape in reverse mode.
void GetInnermostReturnExpr(const clang::Expr* E,
llvm::SmallVectorImpl<clang::Expr*>& Exprs);

bool ContainsFunctionCalls(const clang::Stmt* E);
} // namespace utils
}

Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/StmtClone.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "clang/Sema/Scope.h"

#include "llvm/ADT/DenseMap.h"
#include <unordered_map>

namespace clang {
class Stmt;
Expand Down
16 changes: 16 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/Lookup.h"
#include "clad/Differentiator/Compatibility.h"
Expand Down Expand Up @@ -617,5 +618,20 @@ namespace clad {

return false;
}

bool ContainsFunctionCalls(const clang::Stmt* S) {
class CallExprFinder : public RecursiveASTVisitor<CallExprFinder> {
public:
bool hasCallExpr = false;

bool VisitCallExpr(CallExpr *CE) {
hasCallExpr = true;
return false;
}
};
CallExprFinder finder;
finder.TraverseStmt(const_cast<Stmt*>(S));
return finder.hasCallExpr;
}
} // namespace utils
} // namespace clad
27 changes: 18 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2184,29 +2184,38 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to reduce cloning complexity and only clones once. Storing it in a
// global variable allows to save current result and make it accessible
// in the reverse pass.
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
std::unique_ptr<DelayedStoreResult> RDelayed;
StmtDiff RResult;
// If R has no side effects, it can be just cloned
// (no need to store it).
if (utils::ContainsFunctionCalls(R) || R->HasSideEffects(m_Context)) {
RDelayed = std::unique_ptr<DelayedStoreResult>(new DelayedStoreResult(DelayedGlobalStoreAndRef(R)));
RResult = RDelayed->Result;
} else {
RResult = StmtDiff(Clone(R));
}

Expr* dl = nullptr;
if (dfdx()) {
dl = BuildOp(BO_Mul, dfdx(), RResult.getExpr_dx());
dl = BuildOp(BO_Mul, dfdx(), RResult.getRevSweepAsExpr());
dl = StoreAndRef(dl, direction::reverse);
}
Ldiff = Visit(L, dl);
// dxi/xr = xl
// df/dxr += df/dxi * dxi/xr = df/dxi * xl
// Store left multiplier and assign it with L.
Expr* LStored = Ldiff.getExpr();
// RDelayed.isConstant == true implies that R is a constant expression,
// therefore we can skip visiting it.
if (!RDelayed.isConstant) {
Expr::EvalResult dummy;
if (RDelayed || !clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
Expr* dr = nullptr;
if (dfdx()) {
dr = BuildOp(BO_Mul, Ldiff.getRevSweepAsExpr(), dfdx());
dr = StoreAndRef(dr, direction::reverse);
}
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
RDelayed.Finalize(Rdiff.getExpr());
if (RDelayed)
RDelayed->Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult.getExpr());
} else if (opCode == BO_Div) {
Expand Down Expand Up @@ -2239,7 +2248,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) = std::make_pair(Ldiff.getRevSweepAsExpr(), RResult.getRevSweepAsExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(Ldiff.getExpr(), RResult.getExpr());
} else if (BinOp->isAssignmentOp()) {
if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning,
Expand Down Expand Up @@ -2966,7 +2975,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
return DelayedStoreResult{*this,
StmtDiff{Push, Pop},
StmtDiff{Push, Pop, nullptr, Pop},
/*isConstant*/ false,
/*isInsideLoop*/ true, /*pNeedsUpdate=*/ true};
}
Expand Down
45 changes: 18 additions & 27 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,10 @@ float helper(float x) {
}

// CHECK: void helper_pullback(float x, float _d_y, clad::array_ref<float> _d_x) {
// CHECK-NEXT: float _t0;
// CHECK-NEXT: _t0 = x;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = _d_y * _t0;
// CHECK-NEXT: float _r0 = _d_y * x;
// CHECK-NEXT: float _r1 = 2 * _d_y;
// CHECK-NEXT: * _d_x += _r1;
// CHECK-NEXT: }
Expand Down Expand Up @@ -208,30 +206,26 @@ double func4(double x) {
}

//CHECK: void func4_grad(double x, clad::array_ref<double> _d_x) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: clad::array<double> _d_arr(3UL);
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t2;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: double arr[3] = {x, 2 * _t0, x * _t1};
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double arr[3] = {x, 2 * x, x * x};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t2 = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _t2++;
//CHECK-NEXT: clad::push(_t3, sum);
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t2; _t2--) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: int _grad1 = 0;
Expand All @@ -243,10 +237,10 @@ double func4(double x) {
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += _d_arr[0];
//CHECK-NEXT: double _r0 = _d_arr[1] * _t0;
//CHECK-NEXT: double _r0 = _d_arr[1] * x;
//CHECK-NEXT: double _r1 = 2 * _d_arr[1];
//CHECK-NEXT: * _d_x += _r1;
//CHECK-NEXT: double _r2 = _d_arr[2] * _t1;
//CHECK-NEXT: double _r2 = _d_arr[2] * x;
//CHECK-NEXT: * _d_x += _r2;
//CHECK-NEXT: double _r3 = x * _d_arr[2];
//CHECK-NEXT: * _d_x += _r3;
Expand Down Expand Up @@ -334,15 +328,14 @@ double func6(double seed) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<int> _t1 = {};
//CHECK-NEXT: clad::array<double> _d_arr(3UL);
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: double arr[3] = {seed, seed * clad::push(_t1, i), seed + i};
//CHECK-NEXT: clad::push(_t2, sum);
//CHECK-NEXT: double arr[3] = {seed, seed * i, seed + i};
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand All @@ -351,7 +344,7 @@ double func6(double seed) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t2);
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: int _grad1 = 0;
Expand All @@ -362,7 +355,7 @@ double func6(double seed) {
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_seed += _d_arr[0];
//CHECK-NEXT: double _r0 = _d_arr[1] * clad::pop(_t1);
//CHECK-NEXT: double _r0 = _d_arr[1] * i;
//CHECK-NEXT: * _d_seed += _r0;
//CHECK-NEXT: double _r1 = seed * _d_arr[1];
//CHECK-NEXT: _d_i += _r1;
Expand All @@ -379,15 +372,13 @@ double inv_square(double *params) {

//CHECK: void inv_square_pullback(double *params, double _d_y, clad::array_ref<double> _d_params) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: _t1 = params[0];
//CHECK-NEXT: _t0 = (params[0] * _t1);
//CHECK-NEXT: _t0 = (params[0] * params[0]);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y / _t0;
//CHECK-NEXT: double _r1 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: double _r2 = _r1 * _t1;
//CHECK-NEXT: double _r2 = _r1 * params[0];
//CHECK-NEXT: _d_params[0] += _r2;
//CHECK-NEXT: double _r3 = params[0] * _r1;
//CHECK-NEXT: _d_params[0] += _r3;
Expand Down
12 changes: 3 additions & 9 deletions test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,20 @@ double const_dot_product(double x, double y, double z) {
//CHECK: void const_dot_product_grad(double x, double y, double z, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y, clad::array_ref<double> _d_z) {
//CHECK-NEXT: clad::array<double> _d_vars(3UL);
//CHECK-NEXT: clad::array<double> _d_consts(3UL);
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double _t2;
//CHECK-NEXT: double vars[3] = {x, y, z};
//CHECK-NEXT: double consts[3] = {1, 2, 3};
//CHECK-NEXT: _t0 = consts[0];
//CHECK-NEXT: _t1 = consts[1];
//CHECK-NEXT: _t2 = consts[2];
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 1 * _t0;
//CHECK-NEXT: double _r0 = 1 * consts[0];
//CHECK-NEXT: _d_vars[0] += _r0;
//CHECK-NEXT: double _r1 = vars[0] * 1;
//CHECK-NEXT: _d_consts[0] += _r1;
//CHECK-NEXT: double _r2 = 1 * _t1;
//CHECK-NEXT: double _r2 = 1 * consts[1];
//CHECK-NEXT: _d_vars[1] += _r2;
//CHECK-NEXT: double _r3 = vars[1] * 1;
//CHECK-NEXT: _d_consts[1] += _r3;
//CHECK-NEXT: double _r4 = 1 * _t2;
//CHECK-NEXT: double _r4 = 1 * consts[2];
//CHECK-NEXT: _d_vars[2] += _r4;
//CHECK-NEXT: double _r5 = vars[2] * 1;
//CHECK-NEXT: _d_consts[2] += _r5;
Expand Down
76 changes: 35 additions & 41 deletions test/CUDA/GradientCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,79 +35,73 @@ auto gauss_g = clad::gradient(gauss, "p");
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double _t2;
//CHECK-NEXT: double _t3;
//CHECK-NEXT: double _t4;
//CHECK-NEXT: double _t5;
//CHECK-NEXT: double _t6;
//CHECK-NEXT: double _t7;
//CHECK-NEXT: double _t8;
//CHECK-NEXT: double _t9;
//CHECK-NEXT: double _t10;
//CHECK-NEXT: double t = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < dim; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, t);
//CHECK-NEXT: t += (x[i] - p[i]) * clad::push(_t2, (x[i] - p[i]));
//CHECK-NEXT: t += (x[i] - p[i]) * (x[i] - p[i]);
//CHECK-NEXT: }
//CHECK-NEXT: _t3 = t;
//CHECK-NEXT: _t6 = sigma;
//CHECK-NEXT: _t5 = sigma;
//CHECK-NEXT: _t4 = (2 * _t6 * _t5);
//CHECK-NEXT: t = -t / _t4;
//CHECK-NEXT: _t9 = 3.1415926535897931;
//CHECK-NEXT: _t10 = 2.;
//CHECK-NEXT: _t8 = std::pow(sigma, -0.5);
//CHECK-NEXT: _t7 = std::exp(t);
//CHECK-NEXT: _t2 = t;
//CHECK-NEXT: _t3 = (2 * sigma * sigma);
//CHECK-NEXT: t = -t / _t3;
//CHECK-NEXT: _t6 = 2.;
//CHECK-NEXT: _t5 = std::pow(sigma, -0.5);
//CHECK-NEXT: _t4 = std::exp(t);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r8 = 1 * _t7;
//CHECK-NEXT: double _r9 = _r8 * _t8;
//CHECK-NEXT: double _r8 = 1 * _t4;
//CHECK-NEXT: double _r9 = _r8 * _t5;
//CHECK-NEXT: double _grad0 = 0.;
//CHECK-NEXT: double _grad1 = 0.;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(2 * _t9, -dim / _t10, _r9, &_grad0, &_grad1);
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(2 * 3.1415926535897931, -dim / _t6, _r9, &_grad0, &_grad1);
//CHECK-NEXT: double _r10 = _grad0;
//CHECK-NEXT: double _r11 = _r10 * _t9;
//CHECK-NEXT: double _r12 = 2 * _r10;
//CHECK-NEXT: double _r13 = _grad1;
//CHECK-NEXT: double _r14 = _r13 / _t10;
//CHECK-NEXT: _d_dim += -_r14;
//CHECK-NEXT: double _r15 = _r13 * --dim / (_t10 * _t10);
//CHECK-NEXT: double _r16 = std::pow(2 * _t9, -dim / _t10) * _r8;
//CHECK-NEXT: double _r11 = _r10 * 3.1415926535897931;
//CHECK-NEXT: double _r12 = _grad1;
//CHECK-NEXT: double _r13 = _r12 / _t6;
//CHECK-NEXT: * _d_dim += -_r13;
//CHECK-NEXT: double _r14 = _r12 * --dim / (_t6 * _t6);
//CHECK-NEXT: double _r15 = std::pow(2 * 3.1415926535897931, -dim / _t6) * _r8;
//CHECK-NEXT: double _grad2 = 0.;
//CHECK-NEXT: double _grad3 = 0.;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(sigma, -0.5, _r16, &_grad2, &_grad3);
//CHECK-NEXT: double _r17 = _grad2;
//CHECK-NEXT: _d_sigma += _r17;
//CHECK-NEXT: double _r18 = _grad3;
//CHECK-NEXT: double _r19 = std::pow(2 * _t9, -dim / _t10) * _t8 * 1;
//CHECK-NEXT: double _r20 = _r19 * clad::custom_derivatives::exp_pushforward(t, 1.).pushforward;
//CHECK-NEXT: _d_t += _r20;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(sigma, -0.5, _r15, &_grad2, &_grad3);
//CHECK-NEXT: double _r16 = _grad2;
//CHECK-NEXT: * _d_sigma += _r16;
//CHECK-NEXT: double _r17 = _grad3;
//CHECK-NEXT: double _r18 = std::pow(2 * 3.1415926535897931, -dim / _t6) * _t5 * 1;
//CHECK-NEXT: double _r19 = _r18 * clad::custom_derivatives::exp_pushforward(t, 1.).pushforward;
//CHECK-NEXT: _d_t += _r19;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: t = _t3;
//CHECK-NEXT: t = _t2;
//CHECK-NEXT: double _r_d1 = _d_t;
//CHECK-NEXT: double _r2 = _r_d1 / _t4;
//CHECK-NEXT: double _r2 = _r_d1 / _t3;
//CHECK-NEXT: _d_t += -_r2;
//CHECK-NEXT: double _r3 = _r_d1 * --t / (_t4 * _t4);
//CHECK-NEXT: double _r4 = _r3 * _t5;
//CHECK-NEXT: double _r5 = _r4 * _t6;
//CHECK-NEXT: double _r3 = _r_d1 * --t / (_t3 * _t3);
//CHECK-NEXT: double _r4 = _r3 * sigma;
//CHECK-NEXT: double _r5 = _r4 * sigma;
//CHECK-NEXT: double _r6 = 2 * _r4;
//CHECK-NEXT: _d_sigma += _r6;
//CHECK-NEXT: double _r7 = 2 * _t6 * _r3;
//CHECK-NEXT: _d_sigma += _r7;
//CHECK-NEXT: * _d_sigma += _r6;
//CHECK-NEXT: double _r7 = 2 * sigma * _r3;
//CHECK-NEXT: * _d_sigma += _r7;
//CHECK-NEXT: _d_t -= _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: t = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_t;
//CHECK-NEXT: _d_t += _r_d0;
//CHECK-NEXT: double _r0 = _r_d0 * clad::pop(_t2);
//CHECK-NEXT: double _r0 = _r_d0 * (x[i] - p[i]);
//CHECK-NEXT: _d_x[i] += _r0;
//CHECK-NEXT: _d_p[i] += -_r0;
//CHECK-NEXT: double _r1 = (x[i] - p[i]) * _r_d0;
//CHECK-NEXT: _d_x[i] += _r1;
//CHECK-NEXT: _d_p[i] += -_r1;
//CHECK-NEXT: _d_t -= _r_d0;
//CHECK-NEXT: }
Expand Down
4 changes: 1 addition & 3 deletions test/Enzyme/DifferentCladEnzymeDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ double foo(double x, double y){
}

// CHECK: void foo_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: _t0 = y;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 1 * _t0;
// CHECK-NEXT: double _r0 = 1 * y;
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: double _r1 = x * 1;
// CHECK-NEXT: * _d_y += _r1;
Expand Down
Loading

0 comments on commit 78cb225

Please sign in to comment.