diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 9f858ba35..0aa613f38 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -664,8 +664,9 @@ namespace clad { /// null otherwise. clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff( const std::string& Name, llvm::SmallVectorImpl& CallArgs, - clang::Scope* S, clang::DeclContext* originalFnDC, - bool forCustomDerv = true, bool namespaceShouldExist = true); + clang::Scope* S, const clang::FunctionDecl* originalFD, + bool forCustomDerv = true, bool namespaceShouldExist = true, + llvm::SmallVectorImpl* block = nullptr); bool noOverloadExists(clang::Expr* UnresolvedLookup, llvm::MutableArrayRef ARargs); }; diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index b67ccf7d3..d0e53581a 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1095,8 +1095,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { std::string customPushforward = clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix(); Expr* callDiff = BuildCallToCustomDerivativeOrNumericalDiff( - customPushforward, customDerivativeArgs, getCurrentScope(), - const_cast(FD->getDeclContext())); + customPushforward, customDerivativeArgs, getCurrentScope(), FD); // Check if it is a recursive call. if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index b54592abb..5e1dac638 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1670,8 +1670,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pushforwardCallArgs.push_back(ConstantFolder::synthesizeLiteral( DerivedCallArgs.front()->getType(), m_Context, 1)); OverloadedDerivedFn = BuildCallToCustomDerivativeOrNumericalDiff( - customPushforward, pushforwardCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext())); + customPushforward, pushforwardCallArgs, getCurrentScope(), FD); if (OverloadedDerivedFn) asGrad = false; } @@ -1766,8 +1765,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::string customPullback = clad::utils::ComputeEffectiveFnName(FD) + "_pullback"; OverloadedDerivedFn = BuildCallToCustomDerivativeOrNumericalDiff( - customPullback, pullbackCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext())); + customPullback, pullbackCallArgs, getCurrentScope(), FD, + /*forCustomDerv=*/true, + /*namespaceShouldExist=*/true, + /*block=*/&PreCallStmts); if (baseDiff.getExpr()) pullbackCallArgs.erase(pullbackCallArgs.begin()); } @@ -2042,8 +2043,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, NumDiffArgs.push_back(args[i]); } std::string Name = "central_difference"; + const FunctionDecl* FD = nullptr; + if (auto* DRE = dyn_cast(targetFuncCall->IgnoreImplicit())) + FD = dyn_cast(DRE->getDecl()); return BuildCallToCustomDerivativeOrNumericalDiff( - Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr, + Name, NumDiffArgs, getCurrentScope(), /*OriginalFD=*/FD, /*forCustomDerv=*/false, /*namespaceShouldExist=*/false); } diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index cc6d84e4a..5b1f5d625 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -734,8 +734,11 @@ namespace clad { NumDiffArgs.insert(NumDiffArgs.end(), args.begin(), args.begin() + numArgs); // Return the found overload. std::string Name = "forward_central_difference"; + const FunctionDecl* FD = nullptr; + if (auto* DRE = dyn_cast(targetFuncCall->IgnoreImplicit())) + FD = dyn_cast(DRE->getDecl()); return BuildCallToCustomDerivativeOrNumericalDiff( - Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr, + Name, NumDiffArgs, getCurrentScope(), /*OriginalFD=*/FD, /*forCustomDerv=*/false, /*namespaceShouldExist=*/false); } @@ -818,8 +821,12 @@ namespace clad { Expr* VisitorBase::BuildCallToCustomDerivativeOrNumericalDiff( const std::string& Name, llvm::SmallVectorImpl& CallArgs, - clang::Scope* S, clang::DeclContext* originalFnDC, - bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) { + clang::Scope* S, const clang::FunctionDecl* originalFD, + bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/, + llvm::SmallVectorImpl* block /*=nullptr*/) { + DeclContext* originalFnDC = nullptr; + if (originalFD) + originalFnDC = const_cast(originalFD->getDeclContext()); NamespaceDecl* NSD = nullptr; std::string namespaceID; if (forCustomDerv) { @@ -841,7 +848,7 @@ namespace clad { NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist); if (!forCustomDerv && !NSD) { diag(DiagnosticsEngine::Warning, noLoc, - "Numerical differentiation is diabled using the " + "Numerical differentiation is disabled using the " "-DCLAD_NO_NUM_DIFF " "flag, this means that every try to numerically differentiate a " "function will fail! Remove the flag to revert to default " @@ -889,15 +896,40 @@ namespace clad { Expr* UnresolvedLookup = m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get(); - auto MARargs = llvm::MutableArrayRef(CallArgs); - - SourceLocation Loc; + llvm::SmallVector ExtendedCallArgs(CallArgs.begin(), + CallArgs.end()); + llvm::SmallVector DeclStmts; + // FIXME: for now, integer types are considered differentiable in the + // forward mode. + if (m_Mode != DiffMode::forward && + m_Mode != DiffMode::vector_forward_mode && + m_Mode != DiffMode::experimental_pushforward) + for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) { + QualType paramTy = originalFD->getParamDecl(i)->getType(); + if (!utils::IsDifferentiableType(paramTy)) { + QualType argTy = utils::getNonConstType(paramTy, m_Context, m_Sema); + VarDecl* argDecl = BuildVarDecl(argTy, "_r", getZeroInit(argTy)); + Expr* arg = BuildDeclRef(argDecl); + if (!utils::isArrayOrPointerType(argTy)) + arg = BuildOp(UO_AddrOf, arg); + ExtendedCallArgs.insert(ExtendedCallArgs.begin() + e + i + 1, arg); + DeclStmts.push_back(BuildDeclStmt(argDecl)); + } + } + auto MARargs = llvm::MutableArrayRef(ExtendedCallArgs); if (noOverloadExists(UnresolvedLookup, MARargs)) return nullptr; OverloadedFn = - m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get(); + m_Sema.ActOnCallExpr(S, UnresolvedLookup, noLoc, MARargs, noLoc) + .get(); + if (!DeclStmts.empty()) { + if (!block) + block = &getCurrentBlock(); + for (Stmt* decl : DeclStmts) + block->push_back(decl); + } } return OverloadedFn; } diff --git a/test/FirstDerivative/BuiltinDerivatives.C b/test/FirstDerivative/BuiltinDerivatives.C index f291253ce..c154d0ce3 100644 --- a/test/FirstDerivative/BuiltinDerivatives.C +++ b/test/FirstDerivative/BuiltinDerivatives.C @@ -171,9 +171,9 @@ double f10(float x, int y) { // CHECK-NEXT: return _t0.pushforward; // CHECK-NEXT: } -void f10_grad(float x, int y, float *_d_x, int *_d_y); +void f10_grad_0(float x, int y, float *_d_x); -// CHECK: void f10_grad(float x, int y, float *_d_x, int *_d_y) { +// CHECK: void f10_grad_0(float x, int y, float *_d_x) { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { @@ -181,7 +181,6 @@ void f10_grad(float x, int y, float *_d_x, int *_d_y); // CHECK-NEXT: int _r1 = 0; // CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &_r0, &_r1); // CHECK-NEXT: *_d_x += _r0; -// CHECK-NEXT: *_d_y += _r1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -222,7 +221,6 @@ double f12(double a, double b) { return std::fma(a, b, b); } int main () { //expected-no-diagnostics float f_result[2]; double d_result[2]; - int i_result[1]; auto f1_darg0 = clad::differentiate(f1, 0); printf("Result is = %f\n", f1_darg0.execute(60)); // CHECK-EXEC: Result is = -0.952413 @@ -276,10 +274,9 @@ int main () { //expected-no-diagnostics printf("Result is = %f\n", f10_darg0.execute(3, 4)); //CHECK-EXEC: Result is = 108.000000 f_result[0] = f_result[1] = 0; - i_result[0] = 0; clad::gradient(f10); - f10_grad(3, 4, &f_result[0], &i_result[0]); - printf("Result is = {%f, %d}\n", f_result[0], i_result[0]); //CHECK-EXEC: Result is = {108.000000, 88} + f10_grad_0(3, 4, &f_result[0]); + printf("Result is = {%f}\n", f_result[0]); //CHECK-EXEC: Result is = {108.000000} INIT_GRADIENT(f11); diff --git a/test/NumericalDiff/NoNumDiff.C b/test/NumericalDiff/NoNumDiff.C index d958c23e1..cc83baecc 100644 --- a/test/NumericalDiff/NoNumDiff.C +++ b/test/NumericalDiff/NoNumDiff.C @@ -8,8 +8,8 @@ double func(double x) { return std::tanh(x); } -//CHECK: warning: Numerical differentiation is diabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour. -//CHECK: warning: Numerical differentiation is diabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour. +//CHECK: warning: Numerical differentiation is disabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour. +//CHECK: warning: Numerical differentiation is disabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour. //CHECK: double func_darg0(double x) { //CHECK-NEXT: double _d_x = 1; //CHECK-NEXT: return 0;