diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index e78149d3d..e5eea9cdc 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -366,6 +366,17 @@ namespace clad { /// \returns the DeclRefExpr for the given declaration. clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, const clang::CXXScopeSpec* SS = nullptr); + /// Builds a DeclRefExpr to a given Decl, adding proper nested name + /// qualifiers. \param[in] D The declaration to build a DeclRefExpr for. + /// \param[in] NNS The nested name specifier to use + /// \param[in] FoundD Found decl that can later be accessed from the + /// DeclRefExpr with the getFoundDecl() method \param[in] TemplateArgs + /// Template arguments, can be left nullptr \returns the DeclRefExpr for the + /// given declaration. + clang::DeclRefExpr* + BuildDeclRef(clang::DeclaratorDecl* D, clang::NestedNameSpecifier* NNS, + clang::NamedDecl* FoundD, + const clang::TemplateArgumentListInfo* TemplateArgs = nullptr); /// Stores the result of an expression in a temporary variable (of the same /// type as is the result of the expression) and returns a reference to it. diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 9635e19f8..67458a065 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1044,8 +1044,9 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { // Sema::BuildDeclRefExpr is responsible for adding captured fields // to the underlying struct of a lambda. if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { + NestedNameSpecifier* NNS = DRE->getQualifier(); auto referencedDecl = cast(clonedDRE->getDecl()); - clonedDRE = cast(BuildDeclRef(referencedDecl)); + clonedDRE = BuildDeclRef(referencedDecl, NNS, clonedDRE->getFoundDecl()); } } else clonedDRE = cast(Clone(DRE)); @@ -1060,7 +1061,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { if (auto dVarDRE = dyn_cast(dExpr)) { auto dVar = cast(dVarDRE->getDecl()); if (dVar->getDeclContext() != m_Sema.CurContext) - dExpr = BuildDeclRef(dVar); + dExpr = + BuildDeclRef(dVar, DRE->getQualifier(), dVarDRE->getFoundDecl()); } return StmtDiff(clonedDRE, dExpr); } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index c1e5442f2..3e589682d 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1504,8 +1504,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. // Sema::BuildDeclRefExpr is responsible for adding captured fields // to the underlying struct of a lambda. - if (VD->getDeclContext() != m_Sema.CurContext) - clonedDRE = cast(BuildDeclRef(VD)); + if (VD->getDeclContext() != m_Sema.CurContext) { + auto* ccDRE = dyn_cast(clonedDRE); + NestedNameSpecifier* NNS = DRE->getQualifier(); + auto referencedDecl = cast(ccDRE->getDecl()); + clonedDRE = BuildDeclRef(referencedDecl, NNS, ccDRE->getFoundDecl()); + } // This case happens when ref-type variables have to become function // global. Ref-type declarations cannot be moved to the function global // scope because they can't be separated from their inits. @@ -1812,9 +1816,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } Expr* OverloadedDerivedFn = nullptr; - // If the function has a single arg and does not returns a reference or take + // If the function has a single arg and does not return a reference or take // arg by reference, we look for a derivative w.r.t. to this arg using the - // forward mode(it is unlikely that we need gradient of a one-dimensional' + // forward mode(it is unlikely that we need gradient of a one-dimensional // function). bool asGrad = true; @@ -2109,8 +2113,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff argDiff = Visit(arg); CallArgs.push_back(argDiff.getExpr_dx()); } - if (baseDiff.getExpr()) { - Expr* baseE = baseDiff.getExpr(); + if (Expr* baseE = baseDiff.getExpr()) { call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), CallArgs, Loc); } else { @@ -2127,6 +2130,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); return StmtDiff(resValue, resAdjoint, resAdjoint); } // Recreate the original call expression. + + if (const auto* OCE = dyn_cast(CE)) { + CallArgs.insert(CallArgs.begin(), Clone(OCE->getArg(0))); + call = CXXOperatorCallExpr::Create( + m_Context, OCE->getOperator(), Clone(CE->getCallee()), CallArgs, + FD->getCallResultType(), VK_LValue, Loc, + CE->getFPFeaturesInEffect(LangOptions())); + return StmtDiff(call); + } + call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, CallArgs, Loc) diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index f739abd2f..16318e1ea 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -243,6 +243,43 @@ namespace clad { m_Sema.BuildDeclRefExpr(D, T, VK_LValue, D->getBeginLoc(), SS))); } + DeclRefExpr* VisitorBase::BuildDeclRef( + DeclaratorDecl* D, NestedNameSpecifier* NNS, NamedDecl* FoundD, + const TemplateArgumentListInfo* TemplateArgs /*=nullptr*/) { + QualType T = D->getType(); + T = T.getNonReferenceType(); + + std::vector NNChain; + CXXScopeSpec CSS; + while (NNS) { + NNChain.push_back(NNS); + NNS = NNS->getPrefix(); + } + + std::reverse(NNChain.begin(), NNChain.end()); + + for (size_t i = 0; i < NNChain.size(); ++i) { + NNS = NNChain[i]; + if (NNS->getKind() == NestedNameSpecifier::Namespace) { + NamespaceDecl* NS = NNS->getAsNamespace(); + CSS.Extend(m_Context, NS, noLoc, noLoc); + } else if (NNS->getKind() == NestedNameSpecifier::TypeSpec) { + const Type* T = NNS->getAsType(); + if (auto* RT = const_cast(T->getAs())) { + RecordDecl* RD = RT->getDecl(); + CSS.Extend(m_Context, RD->getIdentifier(), noLoc, noLoc); + } + } + } + + DeclarationNameInfo NameInfo(D->getDeclName(), D->getBeginLoc()); + auto NNLoc = (CSS.isNotEmpty() && CSS.isValid()) + ? CSS.getWithLocInContext(m_Context) + : NestedNameSpecifierLoc(); + return cast(clad_compat::GetResult( + m_Sema.BuildDeclRefExpr(D, T, VK_LValue, NameInfo, NNLoc, FoundD))); + } + IdentifierInfo* VisitorBase::CreateUniqueIdentifier(llvm::StringRef nameBase) { // For intermediate variables, use numbered names (_t0), for everything diff --git a/test/Gradient/Lambdas.C b/test/Gradient/Lambdas.C index b39e6dc4a..f5691e64b 100644 --- a/test/Gradient/Lambdas.C +++ b/test/Gradient/Lambdas.C @@ -40,7 +40,7 @@ double f2(double i, double j) { // CHECK-NEXT: return t + k; // CHECK-NEXT: }{{;?}} // CHECK: double _d_x = 0.; -// CHECK-NEXT: double x = operator()(i + j, i); +// CHECK-NEXT: double x = _f(i + j, i); // CHECK-NEXT: _d_x += 1; // CHECK-NEXT: { // CHECK-NEXT: double _r0 = 0.; diff --git a/test/ValidCodeGen/ValidCodeGen.C b/test/ValidCodeGen/ValidCodeGen.C new file mode 100644 index 000000000..b789a57a5 --- /dev/null +++ b/test/ValidCodeGen/ValidCodeGen.C @@ -0,0 +1,80 @@ +// XFAIL: asserts +// RUN: %cladclang -std=c++14 %s -I%S/../../include -oValidCodeGen.out 2>&1 | %filecheck %s +// RUN: ./ValidCodeGen.out | %filecheck_exec %s +// RUN: %cladclang -std=c++14 -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oValidCodeGenWithTBR.out +// RUN: ./ValidCodeGenWithTBR.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/STLBuiltins.h" +#include "../TestUtils.h" +#include "../PrintOverloads.h" + +namespace TN { + struct Test { + static int multiplier; + }; + int Test::multiplier = 3; + + template + struct Test2 { + T operator[](T x) { + return 4*x; + } + }; +} + +namespace clad { +namespace custom_derivatives { +namespace class_functions { + template + void operator_subscript_pullback(::TN::Test2* obj, T x, T d_u, ::TN::Test2* d_obj, T* d_x) { + (*d_x) += 4*d_u; + } +}}} + +double fn(double x) { + // fwd and rvs mode test + return x*TN::Test::multiplier; // in this test, it's important that this nested name is copied into the generated code properly in both modes +} + +double fn2(double x, double y) { + // rvs mode test + TN::Test2 t; + auto q = t[x]; // in this test, it's important that this operator call is copied into the generated code properly and that the pullback function is called with all the needed namespace prefixes + return q; +} + +int main() { + double dx, dy; + INIT_DIFFERENTIATE(fn, "x"); + INIT_GRADIENT(fn); + INIT_GRADIENT(fn2); + + TEST_GRADIENT(fn, /*numOfDerivativeArgs=*/1, 3, &dx); // CHECK-EXEC: {3.00} + TEST_GRADIENT(fn2, /*numOfDerivativeArgs=*/2, 3, 4, &dx, &dy); // CHECK-EXEC: {4.00, 0.00} + TEST_DIFFERENTIATE(fn, 3) // CHECK-EXEC: {3.00} +} + +//CHECK: double fn_darg0(double x) { +//CHECK-NEXT: double _d_x = 1; +//CHECK-NEXT: return _d_x * TN::Test::multiplier + x * 0; +//CHECK-NEXT: } + +//CHECK: void fn_grad(double x, double *_d_x) { +//CHECK-NEXT: *_d_x += 1 * TN::Test::multiplier; +//CHECK-NEXT: } + +//CHECK: void fn2_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: TN::Test2 _d_t({}); +//CHECK-NEXT: TN::Test2 t; +//CHECK-NEXT: TN::Test2 _t0 = t; +//CHECK-NEXT: double _d_q = 0.; +//CHECK-NEXT: double q = t[x]; +//CHECK-NEXT: _d_q += 1; +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&_t0, x, _d_q, &_d_t, &_r0); +//CHECK-NEXT: *_d_x += _r0; +//CHECK-NEXT: } +//CHECK-NEXT: }