Skip to content

Commit

Permalink
Add support for non-differentiable attribute in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
MihailMihov committed Jul 6, 2024
1 parent 98abc84 commit 34c325b
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 125 deletions.
53 changes: 40 additions & 13 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Clone(CE));
}

SourceLocation validLoc{CE->getBeginLoc()};

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
llvm::SmallVector<Expr*, 4> ClonedArgs;
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i)
ClonedArgs.push_back(Clone(CE->getArg(i)));

Expr* Call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, ClonedArgs, validLoc)
.get();
// Creating a zero derivative
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);

// Returning the function call and zero derivative
return StmtDiff(Call, zero);
}

auto NArgs = FD->getNumParams();
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
Expand Down Expand Up @@ -2755,19 +2776,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (declsBegin != DS->decls().end() && isa<VarDecl>(*declsBegin)) {
auto* VD = dyn_cast<VarDecl>(*declsBegin);
QualType QT = VD->getType();
if (!QT->isPointerType()) {
auto* typeDecl = QT->getAsCXXRecordDecl();
// We should also simply copy the original lambda. The differentiation
// of lambdas is happening in the `VisitCallExpr`. For now, only the
// declarations with lambda expressions without captures are supported.
isLambda = typeDecl && typeDecl->isLambda();
if (isLambda) {
for (auto* D : DS->decls())
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
Stmt* DSClone = BuildDeclStmt(decls);
return StmtDiff(DSClone, nullptr);
}
if (QT->isPointerType())
QT = QT->getPointeeType();

auto* typeDecl = QT->getAsCXXRecordDecl();
// We should also simply copy the original lambda. The differentiation
// of lambdas is happening in the `VisitCallExpr`. For now, only the
// declarations with lambda expressions without captures are supported.
isLambda = typeDecl && typeDecl->isLambda();
if (isLambda ||
(typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl))) {
for (auto* D : DS->decls())
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
Stmt* DSClone = BuildDeclStmt(decls);
return StmtDiff(DSClone, nullptr);
}
}

Expand Down Expand Up @@ -2954,6 +2977,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
"CXXMethodDecl nodes not supported yet!");
MemberExpr* clonedME = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName());
auto zero =
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0);
if (clad::utils::hasNonDifferentiableAttribute(ME))
return {clonedME, zero};
if (!baseDiff.getExpr_dx())
return {clonedME, nullptr};
MemberExpr* derivedME = utils::BuildMemberExpr(
Expand Down
209 changes: 97 additions & 112 deletions test/Gradient/NonDifferentiable.C
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ double fn_s2_operator(double i, double j) {

#define TEST_CLASS(classname, name, i, j) \
auto d_##name = clad::gradient(&classname::name); \
double result_##name[2] = {}; \
double result_##name[2]; \
d_##name.execute(expr_1, i, j, &result_##name[0], &result_##name[1]); \
printf("%.2f %.2f\n", result_##name[0], result_##name[1]);
printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]);

#define TEST_FUNC(name, i, j) \
auto d_##name = clad::gradient(&name); \
double result_##name[2] = {}; \
double result_##name[2]; \
d_##name.execute(i, j, &result_##name[0], &result_##name[1]); \
printf("%.2f\n", result_##name[0], result_##name[1]);
printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]);

int main() {
INIT_EXPR(SimpleFunctions1);
Expand Down Expand Up @@ -111,113 +111,98 @@ int main() {

/*TEST_FUNC(fn_s2_operator, 3, 5) // CHECK-EXEC: 0.00*/

// CHECK: void mem_fn_1_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: {
// CHECK-NEXT: (*_d_this).x += 1 * i;
// CHECK-NEXT: (*_d_this).y += 1 * i;
// CHECK-NEXT: *_d_i += (this->x + this->y) * 1;
// CHECK-NEXT: *_d_i += 1 * j * j;
// CHECK-NEXT: *_d_j += i * 1 * j;
// CHECK-NEXT: *_d_j += i * j * 1;
// CHECK-NEXT: }

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j);

// CHECK: void mem_fn_3_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 *_t0;
// CHECK-NEXT: _t0 = this;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t0->mem_fn_1_pullback(i, j, 1, &(*_d_this), &_r0, &_r1);
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void mem_fn_4_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 *_t0;
// CHECK-NEXT: _t0 = this;
// CHECK-NEXT: {
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void mem_fn_5_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: SimpleFunctions1 *_t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: SimpleFunctions1 *_t3;
// CHECK-NEXT: _t1 = this;
// CHECK-NEXT: _t2 = this->mem_fn_2(i, j);
// CHECK-NEXT: _t3 = this;
// CHECK-NEXT: _t0 = this->mem_fn_1(i, j);
// CHECK-NEXT: {
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: double _r2 = 0;
// CHECK-NEXT: double _r3 = 0;
// CHECK-NEXT: _t3->mem_fn_1_pullback(i, j, _t2 * 1 * i, &(*_d_this), &_r2, &_r3);
// CHECK-NEXT: *_d_i += _r2;
// CHECK-NEXT: *_d_j += _r3;
// CHECK-NEXT: *_d_i += _t2 * _t0 * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 _t0;
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: _t0 = obj;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1);
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: {
// CHECK-NEXT: _d_obj.x += 1 * obj.y;
// CHECK-NEXT: _d_obj.y += obj.x * 1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions2 _d_obj({});
// CHECK-NEXT: SimpleFunctions2 _t0;
// CHECK-NEXT: SimpleFunctions2 obj(2, 3);
// CHECK-NEXT: _t0 = obj;
// CHECK-NEXT: {
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: {
// CHECK-NEXT: (*_d_this).x += _d_y * i;
// CHECK-NEXT: (*_d_this).y += _d_y * i;
// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y;
// CHECK-NEXT: *_d_i += _d_y * j * j;
// CHECK-NEXT: *_d_j += i * _d_y * j;
// CHECK-NEXT: *_d_j += i * j * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK: void mem_fn_1_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: {
// CHECK-NEXT: (*_d_this).x += 1 * i;
// CHECK-NEXT: *_d_i += (this->x + this->y) * 1;
// CHECK-NEXT: *_d_i += 1 * j * j;
// CHECK-NEXT: *_d_j += i * 1 * j;
// CHECK-NEXT: *_d_j += i * j * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j);

// CHECK: void mem_fn_3_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 *_t0;
// CHECK-NEXT: _t0 = this;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t0->mem_fn_1_pullback(i, j, 1, &(*_d_this), &_r0, &_r1);
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void mem_fn_4_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: {
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void mem_fn_5_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: SimpleFunctions1 *_t2;
// CHECK-NEXT: _t1 = this->mem_fn_2(i, j);
// CHECK-NEXT: _t2 = this;
// CHECK-NEXT: _t0 = this->mem_fn_1(i, j);
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t2->mem_fn_1_pullback(i, j, _t1 * 1 * i, &(*_d_this), &_r0, &_r1);
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += _t1 * _t0 * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 _t0;
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: _t0 = obj;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1);
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: {
// CHECK-NEXT: _d_obj.x += 1 * obj.y;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions2 obj(2, 3);
// CHECK-NEXT: {
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: {
// CHECK-NEXT: (*_d_this).x += _d_y * i;
// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y;
// CHECK-NEXT: *_d_i += _d_y * j * j;
// CHECK-NEXT: *_d_j += i * _d_y * j;
// CHECK-NEXT: *_d_j += i * j * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }

}

0 comments on commit 34c325b

Please sign in to comment.