From bba8991ee3e3ab013ab6970912353ae6ae0b1ec6 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Tue, 13 Aug 2024 19:14:25 +0200 Subject: [PATCH] Enable some cases of functor calls in custom pushforwards Previously, if a user wanted to provide a custom pushforward for a function that uses functors in it, it was impossible to use generated pushforwards for that functors' call operators. This commit aims to fix this for basic functors that don't have multiple call operator overloads. Fixes: #1023 --- .../Differentiator/BaseForwardModeVisitor.h | 2 + lib/Differentiator/BaseForwardModeVisitor.cpp | 60 +++++++++++++++++++ test/ForwardMode/Functors.C | 47 +++++++++++++++ test/ForwardMode/ReferenceArguments.C | 31 +++++++++- test/ForwardMode/UserDefinedTypes.C | 24 ++++---- test/Functors/Simple.C | 47 +++++++++++++++ 6 files changed, 197 insertions(+), 14 deletions(-) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index fa7c09ad6..6e5deeea9 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -47,6 +47,8 @@ class BaseForwardModeVisitor virtual void ExecuteInsidePushforwardFunctionBlock(); + virtual void DifferentiateCallOperatorIfFunctor(clang::QualType QT); + static bool IsDifferentiableType(clang::QualType T); virtual StmtDiff diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 7a70f3bcb..d9c5bb504 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1029,6 +1029,8 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) { } StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { + DifferentiateCallOperatorIfFunctor(DRE->getType()); + DeclRefExpr* clonedDRE = nullptr; // Check if referenced Decl was "replaced" with another identifier inside // the derivative @@ -1594,6 +1596,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { // If the DeclStmt is not empty, check the first declaration. if (declsBegin != declsEnd && isa(*declsBegin)) { auto* VD = dyn_cast(*declsBegin); + DifferentiateCallOperatorIfFunctor(VD->getType()); // Check for non-differentiable types. QualType QT = VD->getType(); if (QT->isPointerType()) @@ -2057,8 +2060,65 @@ StmtDiff BaseForwardModeVisitor::VisitBreakStmt(const BreakStmt* stmt) { return StmtDiff(Clone(stmt)); } +void BaseForwardModeVisitor::DifferentiateCallOperatorIfFunctor( + clang::QualType QT) { + // Identify if the constructed type is a functor. For functors, we need to + // differentiate their call operator once an object has been constructed, to + // allow user calls to pushforwards inside user-provided custom derivatives. + // FIXME: A much more scalable solution would be to create pushforwards once + // they're called from user-provided custom derivatives. This could then be + // applied to other operators besides operator() to avoid compilation errors + // in such cases. + if (auto* RD = QT->getAsCXXRecordDecl()) { + CXXRecordDecl* constructedType = RD->getDefinition(); + bool isFunctor = constructedType && !constructedType->isLambda(); + std::vector callMethods; + if (isFunctor) { + for (const auto* method : constructedType->methods()) { + if (const auto* cxxMethod = dyn_cast(method)) { + if (cxxMethod->isOverloadedOperator() && + cxxMethod->getOverloadedOperator() == OO_Call) { + callMethods.push_back(cxxMethod); + } + } + } + isFunctor = isFunctor && !callMethods.empty(); + } + + if (isFunctor) { + for (const auto* FD : callMethods) { + CXXScopeSpec SS; + bool hasCustomDerivative = + !m_Builder + .LookupCustomDerivativeOrNumericalDiff( + clad::utils::ComputeEffectiveFnName(FD) + + GetPushForwardFunctionSuffix(), + const_cast(FD->getDeclContext()), SS) + .empty(); + + if (!hasCustomDerivative) { + // Request Clad to diff it. + DiffRequest pushforwardFnRequest; + pushforwardFnRequest.Function = FD; + pushforwardFnRequest.Mode = GetPushForwardMode(); + pushforwardFnRequest.BaseFunctionName = + utils::ComputeEffectiveFnName(FD); + // Silence diag outputs in nested derivation process. + pushforwardFnRequest.VerboseDiags = false; + + // Check if request already derived in DerivedFunctions. + m_Builder.HandleNestedDiffRequest(pushforwardFnRequest); + } + } + } + } +} + StmtDiff BaseForwardModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) { + DifferentiateCallOperatorIfFunctor(CE->getType()); + + // Now continue differentiating the constructor itself: llvm::SmallVector clonedArgs, derivedArgs; for (auto arg : CE->arguments()) { auto argDiff = Visit(arg); diff --git a/test/ForwardMode/Functors.C b/test/ForwardMode/Functors.C index fdfcab92b..05d45e7f2 100644 --- a/test/ForwardMode/Functors.C +++ b/test/ForwardMode/Functors.C @@ -386,6 +386,36 @@ struct WidgetPointer { } }; +namespace clad { +namespace custom_derivatives { + template + void use_functor_pushforward(double x, F& f, double d_x, F& d_f) { + f.operator_call_pushforward(x, &d_f, d_x); + } +} +} +template +void use_functor(double x, F& f) { + f(x); +} + +struct Foo { + double &y; + Foo(double &y): y(y) {} + + double operator()(double x) { + y = 2*x; + + return x; + } +}; + +double fn0(double x) { + Foo func = Foo{x}; + use_functor(x, func); + return x; +} + #define INIT(E, ARG)\ auto d_##E = clad::differentiate(&E, ARG);\ auto d_##E##Ref = clad::differentiate(E, ARG); @@ -504,4 +534,21 @@ int main() { TEST_2(W_Arr_5, 6, 5); // CHECK-EXEC: 6.00 6.00 TEST_2(W_Pointer_3, 6, 5); // CHECK-EXEC: 37.00 37.00 TEST_2(W_Pointer_5, 6, 5); // CHECK-EXEC: 51.00 51.00 + + auto dfn0 = clad::differentiate(fn0, "x"); + printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2 } + +// CHECK: clad::ValueAndPushforward operator_call_pushforward(double x, Foo *_d_this, double _d_x); +// CHECK: double fn0_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: Foo _d_func = Foo{{[{]*_d_x[}]*}}; +// CHECK-NEXT: Foo func = Foo{{[{]*x[}]*}}; +// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func); +// CHECK-NEXT: return _d_x; +// CHECK-NEXT:} +// CHECK: clad::ValueAndPushforward operator_call_pushforward(double x, Foo *_d_this, double _d_x) { +// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x; +// CHECK-NEXT: this->y = 2 * x; +// CHECK-NEXT: return {x, _d_x}; +// CHECK-NEXT:} \ No newline at end of file diff --git a/test/ForwardMode/ReferenceArguments.C b/test/ForwardMode/ReferenceArguments.C index 0800bdacb..4d2be654c 100644 --- a/test/ForwardMode/ReferenceArguments.C +++ b/test/ForwardMode/ReferenceArguments.C @@ -4,6 +4,31 @@ #include "clad/Differentiator/Differentiator.h" +namespace clad { +namespace custom_derivatives { + template + void use_functor_pushforward(double &x, F& f, double &d_x, F& d_f) { + f.operator_call_pushforward(x, &d_f, d_x); + } +} +} +template +void use_functor(double &x, F& f) { + f(x); +} + +struct Foo { + double operator()(double& x) { + x = 2*x*x; + return x; + } +}; + +double fn0(double x, Foo& func) { + use_functor(x, func); + return x; +} + double fn1(double& i, double& j) { double res = i * i * j; return res; @@ -21,12 +46,14 @@ double fn1(double& i, double& j) { #define INIT(fn, ...) auto d_##fn = clad::differentiate(fn, __VA_ARGS__); #define TEST(fn, ...) \ - auto res = d_##fn.execute(__VA_ARGS__); \ - printf("{%.2f}\n", res) + printf("{%.2f}\n", d_##fn.execute(__VA_ARGS__)) int main() { + INIT(fn0, "x"); INIT(fn1, "i"); double i = 3, j = 5; TEST(fn1, i, j); // CHECK-EXEC: {30.00} + Foo fff; + TEST(fn0, i, fff); // CHECK-EXEC: {12.00} } diff --git a/test/ForwardMode/UserDefinedTypes.C b/test/ForwardMode/UserDefinedTypes.C index 82ef5f54f..70340577d 100644 --- a/test/ForwardMode/UserDefinedTypes.C +++ b/test/ForwardMode/UserDefinedTypes.C @@ -420,6 +420,8 @@ Tensor fn5(double i, double j) { return T; } +// CHECK: void operator_call_pushforward(double val, Tensor *_d_this, double _d_val); + // CHECK: Tensor fn5_darg0(double i, double j) { // CHECK-NEXT: double _d_i = 1; // CHECK-NEXT: double _d_j = 0; @@ -593,8 +595,6 @@ TensorD5 fn11(double i, double j) { return res1; } -// CHECK: void operator_call_pushforward(double val, Tensor *_d_this, double _d_val); - // CHECK: clad::ValueAndPushforward operator_subscript_pushforward(std::size_t idx, Tensor *_d_this, std::size_t _d_idx); // CHECK: clad::ValueAndPushforward, Tensor > operator_plus_pushforward(const Tensor &a, const Tensor &b, const Tensor &_d_a, const Tensor &_d_b); @@ -965,6 +965,16 @@ double fn18(double i, double j) { // CHECK-NEXT: return _d_v[0].mem; // CHECK-NEXT: } +// CHECK: void operator_call_pushforward(double val, Tensor *_d_this, double _d_val) { +// CHECK-NEXT: { +// CHECK-NEXT: unsigned int _d_i = 0; +// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) { +// CHECK-NEXT: _d_this->data[i] = _d_val; +// CHECK-NEXT: this->data[i] = val; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + template void print(const Tensor& t) { for (int i=0; i[[_M_value:[a-zA-Z_]+]],{{( __imag)?}} _d_this->[[_M_value:[a-zA-Z_]+]]}; // CHECK-NEXT: } -// CHECK: void operator_call_pushforward(double val, Tensor *_d_this, double _d_val) { -// CHECK-NEXT: { -// CHECK-NEXT: unsigned int _d_i = 0; -// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) { -// CHECK-NEXT: _d_this->data[i] = _d_val; -// CHECK-NEXT: this->data[i] = val; -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } - // CHECK: clad::ValueAndPushforward operator_subscript_pushforward(std::size_t idx, Tensor *_d_this, std::size_t _d_idx) { // CHECK-NEXT: return {(double &)this->data[idx], (double &)_d_this->data[idx]}; // CHECK-NEXT: } diff --git a/test/Functors/Simple.C b/test/Functors/Simple.C index d24dd7d17..21e455d31 100644 --- a/test/Functors/Simple.C +++ b/test/Functors/Simple.C @@ -58,6 +58,50 @@ float f(float x) { return x; } +namespace clad { +namespace custom_derivatives { + template + void use_functor_pushforward(double x, F& f, double d_x, F& d_f) { + f.operator_call_pushforward(x, &d_f, d_x); + } +} +} +template +void use_functor(double x, F& f) { + f(x); +} + +struct Foo { + double &y; + Foo(double &y): y(y) {} + + double operator()(double x) { + y = 2*x; + + return x; + } +}; + +double fn0(double x) { + Foo func = Foo({x}); + use_functor(x, func); + return x; +} + +// CHECK: clad::ValueAndPushforward operator_call_pushforward(double x, Foo *_d_this, double _d_x); +// CHECK: double fn0_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: Foo _d_func = Foo({_d_x}); +// CHECK-NEXT: Foo func = Foo({x}); +// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func); +// CHECK-NEXT: return _d_x; +// CHECK-NEXT:} +// CHECK: clad::ValueAndPushforward operator_call_pushforward(double x, Foo *_d_this, double _d_x) { +// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x; +// CHECK-NEXT: this->y = 2 * x; +// CHECK-NEXT: return {x, _d_x}; +// CHECK-NEXT:} + int main() { AFunctor doubler; int x = doubler(5); @@ -73,5 +117,8 @@ int main() { auto f1_darg1 = clad::differentiate(&SimpleExpression::operator(), 1); printf("Result is = %f\n", f1_darg1.execute(expr, 3.5, 4.5)); // CHECK-EXEC: Result is = 9 + auto dfn0 = clad::differentiate(fn0, "x"); + printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2 + return 0; }