From 530a10d3ba9c9cd188b42b8ae3fd230f9e30d13b Mon Sep 17 00:00:00 2001 From: Mihail Mihov Date: Fri, 5 Jul 2024 16:19:53 +0300 Subject: [PATCH] Add support for non-differentiable attribute in reverse mode --- lib/Differentiator/ReverseModeVisitor.cpp | 58 ++++-- test/Gradient/NonDifferentiable.C | 209 ++++++++++------------ 2 files changed, 139 insertions(+), 128 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6394ee9dd..1f363cac2 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1369,6 +1369,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Clone(CE)); } + SourceLocation validLoc{CE->getBeginLoc()}; + + // If the function is non_differentiable, return zero derivative. + if (clad::utils::hasNonDifferentiableAttribute(CE)) { + // Calling the function without computing derivatives + llvm::SmallVector ClonedArgs; + for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) + ClonedArgs.push_back(Clone(CE->getArg(i))); + + Expr* Call = m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), + validLoc, ClonedArgs, validLoc) + .get(); + // Creating a zero derivative + auto* zero = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); + + // Returning the function call and zero derivative + return StmtDiff(Call, zero); + } + auto NArgs = FD->getNumParams(); // If the function has no args and is not a member function call then we // assume that it is not related to independent variables and does not @@ -2753,22 +2774,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bool isLambda = false; const auto* declsBegin = DS->decls().begin(); if (declsBegin != DS->decls().end() && isa(*declsBegin)) { - auto* VD = dyn_cast(*declsBegin); - QualType QT = VD->getType(); - if (!QT->isPointerType()) { - auto* typeDecl = QT->getAsCXXRecordDecl(); - // We should also simply copy the original lambda. The differentiation - // of lambdas is happening in the `VisitCallExpr`. For now, only the - // declarations with lambda expressions without captures are supported. - isLambda = typeDecl && typeDecl->isLambda(); - if (isLambda) { - for (auto* D : DS->decls()) - if (auto* VD = dyn_cast(D)) - decls.push_back(VD); - Stmt* DSClone = BuildDeclStmt(decls); - return StmtDiff(DSClone, nullptr); - } - } + auto* VD = dyn_cast(*declsBegin); + QualType QT = VD->getType(); + if(QT->isPointerType()) + QT = QT->getPointeeType(); + + auto* typeDecl = QT->getAsCXXRecordDecl(); + // We should also simply copy the original lambda. The differentiation + // of lambdas is happening in the `VisitCallExpr`. For now, only the + // declarations with lambda expressions without captures are supported. + isLambda = typeDecl && typeDecl->isLambda(); + if (isLambda || (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl))) { + for (auto* D : DS->decls()) + if (auto* VD = dyn_cast(D)) + decls.push_back(VD); + Stmt* DSClone = BuildDeclStmt(decls); + return StmtDiff(DSClone, nullptr); + } } // For each variable declaration v, create another declaration _d_v to @@ -2954,6 +2976,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, "CXXMethodDecl nodes not supported yet!"); MemberExpr* clonedME = utils::BuildMemberExpr( m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); + auto zero = + ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0); + if (clad::utils::hasNonDifferentiableAttribute(ME)) + return {clonedME, zero}; if (!baseDiff.getExpr_dx()) return {clonedME, nullptr}; MemberExpr* derivedME = utils::BuildMemberExpr( diff --git a/test/Gradient/NonDifferentiable.C b/test/Gradient/NonDifferentiable.C index 7c889a070..077eddf58 100644 --- a/test/Gradient/NonDifferentiable.C +++ b/test/Gradient/NonDifferentiable.C @@ -74,15 +74,15 @@ double fn_s2_operator(double i, double j) { #define TEST_CLASS(classname, name, i, j) \ auto d_##name = clad::gradient(&classname::name); \ - double result_##name[2] = {}; \ + double result_##name[2]; \ d_##name.execute(expr_1, i, j, &result_##name[0], &result_##name[1]); \ - printf("%.2f %.2f\n", result_##name[0], result_##name[1]); + printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]); #define TEST_FUNC(name, i, j) \ auto d_##name = clad::gradient(&name); \ - double result_##name[2] = {}; \ + double result_##name[2]; \ d_##name.execute(i, j, &result_##name[0], &result_##name[1]); \ - printf("%.2f\n", result_##name[0], result_##name[1]); + printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]); int main() { INIT_EXPR(SimpleFunctions1); @@ -111,113 +111,98 @@ int main() { /*TEST_FUNC(fn_s2_operator, 3, 5) // CHECK-EXEC: 0.00*/ - // CHECK: void mem_fn_1_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { - // CHECK-NEXT: { - // CHECK-NEXT: (*_d_this).x += 1 * i; - // CHECK-NEXT: (*_d_this).y += 1 * i; - // CHECK-NEXT: *_d_i += (this->x + this->y) * 1; - // CHECK-NEXT: *_d_i += 1 * j * j; - // CHECK-NEXT: *_d_j += i * 1 * j; - // CHECK-NEXT: *_d_j += i * j * 1; - // CHECK-NEXT: } - - // CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j); - - // CHECK: void mem_fn_3_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { - // CHECK-NEXT: SimpleFunctions1 *_t0; - // CHECK-NEXT: _t0 = this; - // CHECK-NEXT: { - // CHECK-NEXT: double _r0 = 0; - // CHECK-NEXT: double _r1 = 0; - // CHECK-NEXT: _t0->mem_fn_1_pullback(i, j, 1, &(*_d_this), &_r0, &_r1); - // CHECK-NEXT: *_d_i += _r0; - // CHECK-NEXT: *_d_j += _r1; - // CHECK-NEXT: *_d_i += 1 * j; - // CHECK-NEXT: *_d_j += i * 1; - // CHECK-NEXT: } - // CHECK-NEXT: } - - // CHECK: void mem_fn_4_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { - // CHECK-NEXT: SimpleFunctions1 *_t0; - // CHECK-NEXT: _t0 = this; - // CHECK-NEXT: { - // CHECK-NEXT: *_d_i += _r0; - // CHECK-NEXT: *_d_j += _r1; - // CHECK-NEXT: *_d_i += 1 * j; - // CHECK-NEXT: *_d_j += i * 1; - // CHECK-NEXT: } - // CHECK-NEXT: } - - // CHECK: void mem_fn_5_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { - // CHECK-NEXT: double _t0; - // CHECK-NEXT: SimpleFunctions1 *_t1; - // CHECK-NEXT: double _t2; - // CHECK-NEXT: SimpleFunctions1 *_t3; - // CHECK-NEXT: _t1 = this; - // CHECK-NEXT: _t2 = this->mem_fn_2(i, j); - // CHECK-NEXT: _t3 = this; - // CHECK-NEXT: _t0 = this->mem_fn_1(i, j); - // CHECK-NEXT: { - // CHECK-NEXT: *_d_i += _r0; - // CHECK-NEXT: *_d_j += _r1; - // CHECK-NEXT: double _r2 = 0; - // CHECK-NEXT: double _r3 = 0; - // CHECK-NEXT: _t3->mem_fn_1_pullback(i, j, _t2 * 1 * i, &(*_d_this), &_r2, &_r3); - // CHECK-NEXT: *_d_i += _r2; - // CHECK-NEXT: *_d_j += _r3; - // CHECK-NEXT: *_d_i += _t2 * _t0 * 1; - // CHECK-NEXT: } - // CHECK-NEXT: } - - // CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { - // CHECK-NEXT: SimpleFunctions1 _d_obj({}); - // CHECK-NEXT: SimpleFunctions1 _t0; - // CHECK-NEXT: SimpleFunctions1 obj(2, 3); - // CHECK-NEXT: _t0 = obj; - // CHECK-NEXT: { - // CHECK-NEXT: double _r0 = 0; - // CHECK-NEXT: double _r1 = 0; - // CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1); - // CHECK-NEXT: *_d_i += _r0; - // CHECK-NEXT: *_d_j += _r1; - // CHECK-NEXT: *_d_i += 1 * j; - // CHECK-NEXT: *_d_j += i * 1; - // CHECK-NEXT: } - // CHECK-NEXT: } - - // CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) { - // CHECK-NEXT: SimpleFunctions1 _d_obj({}); - // CHECK-NEXT: SimpleFunctions1 obj(2, 3); - // CHECK-NEXT: { - // CHECK-NEXT: _d_obj.x += 1 * obj.y; - // CHECK-NEXT: _d_obj.y += obj.x * 1; - // CHECK-NEXT: *_d_i += 1 * j; - // CHECK-NEXT: *_d_j += i * 1; - // CHECK-NEXT: } - // CHECK-NEXT: } - - // CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { - // CHECK-NEXT: SimpleFunctions2 _d_obj({}); - // CHECK-NEXT: SimpleFunctions2 _t0; - // CHECK-NEXT: SimpleFunctions2 obj(2, 3); - // CHECK-NEXT: _t0 = obj; - // CHECK-NEXT: { - // CHECK-NEXT: *_d_i += _r0; - // CHECK-NEXT: *_d_j += _r1; - // CHECK-NEXT: *_d_i += 1 * j; - // CHECK-NEXT: *_d_j += i * 1; - // CHECK-NEXT: } - // CHECK-NEXT: } - - // CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { - // CHECK-NEXT: { - // CHECK-NEXT: (*_d_this).x += _d_y * i; - // CHECK-NEXT: (*_d_this).y += _d_y * i; - // CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y; - // CHECK-NEXT: *_d_i += _d_y * j * j; - // CHECK-NEXT: *_d_j += i * _d_y * j; - // CHECK-NEXT: *_d_j += i * j * _d_y; - // CHECK-NEXT: } - // CHECK-NEXT: } + // CHECK: void mem_fn_1_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { + // CHECK-NEXT: { + // CHECK-NEXT: (*_d_this).x += 1 * i; + // CHECK-NEXT: *_d_i += (this->x + this->y) * 1; + // CHECK-NEXT: *_d_i += 1 * j * j; + // CHECK-NEXT: *_d_j += i * 1 * j; + // CHECK-NEXT: *_d_j += i * j * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j); + + // CHECK: void mem_fn_3_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { + // CHECK-NEXT: SimpleFunctions1 *_t0; + // CHECK-NEXT: _t0 = this; + // CHECK-NEXT: { + // CHECK-NEXT: double _r0 = 0; + // CHECK-NEXT: double _r1 = 0; + // CHECK-NEXT: _t0->mem_fn_1_pullback(i, j, 1, &(*_d_this), &_r0, &_r1); + // CHECK-NEXT: *_d_i += _r0; + // CHECK-NEXT: *_d_j += _r1; + // CHECK-NEXT: *_d_i += 1 * j; + // CHECK-NEXT: *_d_j += i * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void mem_fn_4_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { + // CHECK-NEXT: { + // CHECK-NEXT: *_d_i += 1 * j; + // CHECK-NEXT: *_d_j += i * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void mem_fn_5_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { + // CHECK-NEXT: double _t0; + // CHECK-NEXT: double _t1; + // CHECK-NEXT: SimpleFunctions1 *_t2; + // CHECK-NEXT: _t1 = this->mem_fn_2(i, j); + // CHECK-NEXT: _t2 = this; + // CHECK-NEXT: _t0 = this->mem_fn_1(i, j); + // CHECK-NEXT: { + // CHECK-NEXT: double _r0 = 0; + // CHECK-NEXT: double _r1 = 0; + // CHECK-NEXT: _t2->mem_fn_1_pullback(i, j, _t1 * 1 * i, &(*_d_this), &_r0, &_r1); + // CHECK-NEXT: *_d_i += _r0; + // CHECK-NEXT: *_d_j += _r1; + // CHECK-NEXT: *_d_i += _t1 * _t0 * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { + // CHECK-NEXT: SimpleFunctions1 _d_obj({}); + // CHECK-NEXT: SimpleFunctions1 _t0; + // CHECK-NEXT: SimpleFunctions1 obj(2, 3); + // CHECK-NEXT: _t0 = obj; + // CHECK-NEXT: { + // CHECK-NEXT: double _r0 = 0; + // CHECK-NEXT: double _r1 = 0; + // CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1); + // CHECK-NEXT: *_d_i += _r0; + // CHECK-NEXT: *_d_j += _r1; + // CHECK-NEXT: *_d_i += 1 * j; + // CHECK-NEXT: *_d_j += i * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) { + // CHECK-NEXT: SimpleFunctions1 _d_obj({}); + // CHECK-NEXT: SimpleFunctions1 obj(2, 3); + // CHECK-NEXT: { + // CHECK-NEXT: _d_obj.x += 1 * obj.y; + // CHECK-NEXT: *_d_i += 1 * j; + // CHECK-NEXT: *_d_j += i * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { + // CHECK-NEXT: SimpleFunctions2 obj(2, 3); + // CHECK-NEXT: { + // CHECK-NEXT: *_d_i += 1 * j; + // CHECK-NEXT: *_d_j += i * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { + // CHECK-NEXT: { + // CHECK-NEXT: (*_d_this).x += _d_y * i; + // CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y; + // CHECK-NEXT: *_d_i += _d_y * j * j; + // CHECK-NEXT: *_d_j += i * _d_y * j; + // CHECK-NEXT: *_d_j += i * j * _d_y; + // CHECK-NEXT: } + // CHECK-NEXT: } }