From ad6809bfdd9835bf496c2048ac659116373515c0 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Tue, 29 Oct 2024 22:40:30 +0100 Subject: [PATCH] tmp + dbg --- include/clad/Differentiator/VisitorBase.h | 6 +- lib/Differentiator/ReverseModeVisitor.cpp | 72 +++++++++++++++++------ lib/Differentiator/VisitorBase.cpp | 14 ++--- 3 files changed, 64 insertions(+), 28 deletions(-) diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 210f82112..7454a5e06 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -293,7 +293,7 @@ namespace clad { clang::Scope* scope, clang::Expr* Init = nullptr, bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr, clang::VarDecl::InitializationStyle IS = - clang::VarDecl::InitializationStyle::CInit); + clang::VarDecl::InitializationStyle::CInit, bool pushCodeSynthCtxt = false); /// Builds variable declaration to be used inside the derivative /// body. /// \param[in] Type The type of variable declaration to build. @@ -310,7 +310,7 @@ namespace clad { clang::Expr* Init = nullptr, bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr, clang::VarDecl::InitializationStyle IS = - clang::VarDecl::InitializationStyle::CInit); + clang::VarDecl::InitializationStyle::CInit, bool pushCodeSynthCtxt = false); /// Builds variable declaration to be used inside the derivative /// body. /// \param[in] Type The type of variable declaration to build. @@ -326,7 +326,7 @@ namespace clad { clang::Expr* Init = nullptr, bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr, clang::VarDecl::InitializationStyle IS = - clang::VarDecl::InitializationStyle::CInit); + clang::VarDecl::InitializationStyle::CInit, bool pushCodeSynthCtxt = false); /// Builds variable declaration to be used inside the derivative /// body in the derivative function global scope. clang::VarDecl* diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 71b53d636..974a9ec6b 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2133,6 +2133,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::size_t insertionPoint = getCurrentBlock(direction::reverse).size(); const auto* MD = dyn_cast(FD); + bool isLambda = (MD ? isLambdaCallOperator(MD) : false); // Method operators have a base like methods do but it's included in the // call arguments so we have to shift the indexing of call arguments. bool isMethodOperatorCall = MD && isa(CE); @@ -2140,6 +2141,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (std::size_t i = static_cast(isMethodOperatorCall), e = CE->getNumArgs(); i != e; ++i) { + llvm::errs() << "i: " << i << '\n'; const Expr* arg = CE->getArg(i); const auto* PVD = FD->getParamDecl( i - static_cast(isMethodOperatorCall)); @@ -2159,7 +2161,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // same as the call expression as it is the type used to declare the // _gradX array QualType dArgTy = getNonConstType(arg->getType(), m_Context, m_Sema); - VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy)); + VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy), false, nullptr, clang::VarDecl::InitializationStyle::CInit, true); PreCallStmts.push_back(BuildDeclStmt(dArgDecl)); CallArgDx.push_back(BuildDeclRef(dArgDecl)); // Visit using uninitialized reference. @@ -2261,13 +2263,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, /// Add base derivative expression in the derived call output args list if /// `CE` is a call to an instance member function. if (MD) { - if (isLambdaCallOperator(MD)) { - QualType ptrType = m_Context.getPointerType(m_Context.getRecordType( - FD->getDeclContext()->getOuterLexicalRecordContext())); - baseDiff = - StmtDiff(Clone(dyn_cast(CE)->getArg(0)), - new (m_Context) CXXNullPtrLiteralExpr(ptrType, Loc)); - } else if (MD->isInstance()) { + if (MD->isInstance()) { const Expr* baseOriginalE = nullptr; if (const auto* MCE = dyn_cast(CE)) baseOriginalE = MCE->getImplicitObjectArgument(); @@ -2276,18 +2272,30 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, baseDiff = Visit(baseOriginalE); baseExpr = baseDiff.getExpr(); - Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); - baseDiff.updateStmt(baseDiffStore); + if (!isLambda) { + Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); + baseDiff.updateStmt(baseDiffStore); + } + + llvm::errs() << "diff base: "; + baseExpr->dumpPretty(m_Context); + llvm::errs() << " "; + baseDiff.getExpr_dx()->dumpPretty(m_Context); + llvm::errs() << " "; + llvm::errs() << "\n"; + Expr* baseDerivative = baseDiff.getExpr_dx(); if (!baseDerivative->getType()->isPointerType()) baseDerivative = BuildOp(UnaryOperatorKind::UO_AddrOf, baseDerivative); - DerivedCallOutputArgs.push_back(baseDerivative); + if (!isLambda) + DerivedCallOutputArgs.push_back(baseDerivative); } } for (auto* argDerivative : CallArgDx) { Expr* gradArgExpr = nullptr; + llvm::errs() << "i: " << idx << '\n'; QualType paramTy = FD->getParamDecl(idx)->getType(); if (!argDerivative || utils::isArrayOrPointerType(paramTy) || isCladArrayType(argDerivative->getType())) @@ -2358,12 +2366,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative) .get(); - OverloadedDerivedFn = - m_Sema - .ActOnCallExpr(getCurrentScope(), selfRef, Loc, - pullbackCallArgs, Loc, CUDAExecConfig) - .get(); - } else { + OverloadedDerivedFn = m_Sema + .ActOnCallExpr(getCurrentScope(), selfRef, + Loc, pullbackCallArgs, Loc) + .get(); + } else if (!isLambda) { if (m_ExternalSource) m_ExternalSource->ActBeforeDifferentiatingCallExpr( pullbackCallArgs, PreCallStmts, dfdx()); @@ -2446,7 +2453,36 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - if (OverloadedDerivedFn) { + if (isLambda) { + Stmts& block = getCurrentBlock(direction::reverse); + Stmts::iterator it = std::begin(block) + insertionPoint; + // Insert PreCallStmts + it = block.insert(it, PreCallStmts.begin(), PreCallStmts.end()); + it += PreCallStmts.size(); + // Insert the call + Expr* baseEdx = baseDiff.getExpr_dx(); // The pullback lambda + const CXXRecordDecl* EdxRD = baseEdx->getType()->getAsCXXRecordDecl(); + auto* CMD = const_cast(EdxRD->getLambdaCallOperator()); + NestedNameSpecifierLoc NNS(CMD->getQualifier(), + /*Data=*/nullptr); + auto DAP = DeclAccessPair::make(CMD, CMD->getAccess()); + auto* memberExpr = MemberExpr::Create( + m_Context, Clone(baseEdx), /*isArrow=*/false, Loc, NNS, noLoc, + CMD, DAP, CMD->getNameInfo(), + /*TemplateArgs=*/nullptr, m_Context.BoundMemberTy, + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + ExprObjectKind::OK_Ordinary CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams( + NOUR_None)); + OverloadedDerivedFn = m_Sema + .BuildCallToMemberFunction(getCurrentScope(), memberExpr, Loc, + pullbackCallArgs, Loc) + .get(); + + // OverloadedDerivedFn = BuildCallExprToMemFn( + // baseEdx, FD->getName(), pullbackCallArgs, Loc); + it = block.insert(it, OverloadedDerivedFn); + it++; + } else if (OverloadedDerivedFn) { // Derivative was found. FunctionDecl* fnDecl = dyn_cast(OverloadedDerivedFn) ->getDirectCallee(); diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 326973d9b..5b2fe0577 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -106,14 +106,14 @@ namespace clad { VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier, Expr* Init, bool DirectInit, TypeSourceInfo* TSI, - VarDecl::InitializationStyle IS) { + VarDecl::InitializationStyle IS, bool pushCodeSynthCtxt) { return BuildVarDecl(Type, Identifier, getCurrentScope(), Init, DirectInit, - TSI, IS); + TSI, IS, pushCodeSynthCtxt); } VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier, Scope* Scope, Expr* Init, bool DirectInit, TypeSourceInfo* TSI, - VarDecl::InitializationStyle IS) { + VarDecl::InitializationStyle IS, bool pushCodeSynthCtxt) { // add namespace specifier in variable declaration if needed. Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type); auto* VD = VarDecl::Create( @@ -126,7 +126,7 @@ namespace clad { if (const CXXRecordDecl* RD = dyn_cast(RT->getDecl())) isLambda = RD->isLambda(); } - if (isLambda) { + if (isLambda || pushCodeSynthCtxt) { clang::Sema::CodeSynthesisContext csc; m_Sema.pushCodeSynthesisContext(csc); } @@ -138,7 +138,7 @@ namespace clad { m_Sema.FinalizeDeclaration(VD); // Add the identifier to the scope and IdResolver m_Sema.PushOnScopeChains(VD, Scope, /*AddToContext*/ false); - if (isLambda) + if (Init && (isLambda || pushCodeSynthCtxt)) m_Sema.popCodeSynthesisContext(); return VD; } @@ -152,9 +152,9 @@ namespace clad { VarDecl* VisitorBase::BuildVarDecl(QualType Type, llvm::StringRef prefix, Expr* Init, bool DirectInit, TypeSourceInfo* TSI, - VarDecl::InitializationStyle IS) { + VarDecl::InitializationStyle IS, bool pushCodeSynthCtxt) { return BuildVarDecl(Type, CreateUniqueIdentifier(prefix), Init, DirectInit, - TSI, IS); + TSI, IS, pushCodeSynthCtxt); } VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type,