From 111dff306a5325126dd1884251a32d6b1616c8ce Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 16 Oct 2024 10:32:14 +0200 Subject: [PATCH] Proper differentiation of lambdas (no captures) & code cleanup --- include/clad/Differentiator/Compatibility.h | 10 + .../clad/Differentiator/ReverseModeVisitor.h | 6 + lib/Differentiator/ReverseModeVisitor.cpp | 293 +++++++++++++++++- lib/Differentiator/VisitorBase.cpp | 11 + 4 files changed, 305 insertions(+), 15 deletions(-) diff --git a/include/clad/Differentiator/Compatibility.h b/include/clad/Differentiator/Compatibility.h index efd3d629c..1fb98efba 100644 --- a/include/clad/Differentiator/Compatibility.h +++ b/include/clad/Differentiator/Compatibility.h @@ -403,6 +403,16 @@ getConstantArrayType(const ASTContext& Ctx, QualType EltTy, #define CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam clang::ParsedAttributesView::none(), #endif +#if CLANG_VERSION_MAJOR > 12 +#define CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind( \ + LAMBDACXXRECORDDECL) \ + LAMBDACXXRECORDDECL->getLambdaDependencyKind() +#else +#define CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind( \ + LAMBDACXXRECORDDECL) \ + LAMBDACXXRECORDDECL->isDependentLambda() +#endif + // Clang 12 add one extra param (FPO) that we get from Node in Create method of: // ImplicitCastExpr, CStyleCastExpr, CXXStaticCastExpr and CXXFunctionalCastExpr diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index a73513f94..385d60ba3 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -357,6 +357,11 @@ namespace clad { llvm::SmallVectorImpl& outputArgs, clang::Expr* CUDAExecConfig = nullptr); + clang::CXXRecordDecl* + diffLambdaCXXRecordDecl(const clang::CXXRecordDecl* Original); + clang::CXXMethodDecl* + DifferentiateCallOperatorIfLambda(const clang::CXXRecordDecl* RD); + public: ReverseModeVisitor(DerivativeBuilder& builder, const DiffRequest& request); virtual ~ReverseModeVisitor(); @@ -383,6 +388,7 @@ namespace clad { DerivativeAndOverload DerivePullback(); StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); + StmtDiff VisitLambdaExpr(const clang::LambdaExpr* LE); StmtDiff VisitCallExpr(const clang::CallExpr* CE); virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 026164498..71b53d636 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -26,6 +26,7 @@ #include "clang/Sema/Lookup.h" #include "clang/Sema/Overload.h" #include "clang/Sema/Scope.h" +#include "clang/Sema/ScopeInfo.h" #include "clang/Sema/Sema.h" #include "clang/Sema/SemaInternal.h" #include "clang/Sema/Template.h" @@ -1667,6 +1668,255 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Clone(FL)); } + CXXMethodDecl* ReverseModeVisitor::DifferentiateCallOperatorIfLambda( + const clang::CXXRecordDecl* RD) { + if (RD) { + CXXRecordDecl* constructedType = RD->getDefinition(); + bool isLambda = constructedType->isLambda(); + if (isLambda) { + for (const auto* method : constructedType->methods()) { + if (const auto* cxxMethod = dyn_cast(method)) { + if (cxxMethod->isOverloadedOperator() && + cxxMethod->getOverloadedOperator() == OO_Call) { + + DiffRequest req; + req.Function = cxxMethod; + req.Mode = DiffMode::experimental_pullback; + req.BaseFunctionName = utils::ComputeEffectiveFnName(cxxMethod); + // Silence diag outputs in nested derivation process. + req.VerboseDiags = false; + + return dyn_cast(m_Builder.Derive(req).derivative); + } + } + } + } + } + return nullptr; + } + + CXXRecordDecl* + ReverseModeVisitor::diffLambdaCXXRecordDecl(const CXXRecordDecl* Original) { + // Create a new Lambda CXXRecordDecl that is going to represent a pullback + CXXRecordDecl* Cloned = CXXRecordDecl::CreateLambda( + m_Context, const_cast(Original->getDeclContext()), + Original->getLambdaTypeInfo(), Original->getBeginLoc(), + CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind(Original), + Original->isGenericLambda(), Original->getLambdaCaptureDefault()); + + // Copy the fields if any (FieldDecl) + for (auto* Field : Original->fields()) { + FieldDecl* NewField = FieldDecl::Create( + m_Context, // AST context + Cloned, // Owning class (Cloned CXXRecordDecl) + Field->getBeginLoc(), // Start location of field + Field->getLocation(), // End location of field + Field->getIdentifier(), // Field's name + Field->getType(), // Field's type + Field->getTypeSourceInfo(), // Type source info + Field->getBitWidth(), // Bit width (Expr*), nullptr if not a bitfield + Field->isMutable(), // Is the field mutable? + Field->getInClassInitStyle() // In-class initialization style + ); + + NewField->setAccess( + Field + ->getAccess()); // Set access specifier (public/private/protected) + Cloned->addDecl( + NewField); // Add the new field to the cloned CXXRecordDecl + } + + // Create operator() as a pullback + for (auto* Method : Original->methods()) { + if (CXXMethodDecl* OriginalOpCall = dyn_cast(Method)) { + if (OriginalOpCall->getOverloadedOperator() == OO_Call) { + auto* diffedOpCall = DifferentiateCallOperatorIfLambda(Original); + if (diffedOpCall) { + diffedOpCall->setAccess(OriginalOpCall->getAccess()); + // Cloned->addDecl(diffedOpCall); + + CXXMethodDecl* ClonedOpCall = CXXMethodDecl::Create( + m_Context, Cloned, diffedOpCall->getBeginLoc(), + OriginalOpCall->getNameInfo(), + diffedOpCall + ->getType(), // Function type (return type + parameters) + diffedOpCall->getTypeSourceInfo(), + diffedOpCall->getStorageClass() + CLAD_COMPAT_FunctionDecl_UsesFPIntrin_Param(diffedOpCall), + diffedOpCall->isInlineSpecified(), // Inline specifier + clad_compat::Function_GetConstexprKind( + diffedOpCall), // Constexpr specifier + diffedOpCall->getEndLoc() //, + // diffedOpCall->getTrailingRequiresClause() + ); + + llvm::SmallVector params; + for (unsigned i = 0; i < diffedOpCall->param_size(); ++i) { + ParmVarDecl* p = diffedOpCall->getParamDecl(i); + ParmVarDecl* NewParam = ParmVarDecl::Create( + m_Context, ClonedOpCall, p->getBeginLoc(), p->getLocation(), + p->getIdentifier(), p->getType(), p->getTypeSourceInfo(), + p->getStorageClass(), p->getDefaultArg()); + params.push_back(NewParam); + } + ClonedOpCall->setParams(params); + + // Copy the method body if it exists + if (diffedOpCall->hasBody()) { + Stmt* body = diffedOpCall->getBody(); + Stmt* ClonedBody = Clone(body); + ClonedOpCall->setBody(ClonedBody); + } + + ClonedOpCall->setAccess(OriginalOpCall->getAccess()); + Cloned->addDecl(ClonedOpCall); + + break; // we might get into an infinite loop otherwise + } + } + } + } + + // Step 4: Finish defining the class + Cloned->completeDefinition(); + + return Cloned; + } + + StmtDiff ReverseModeVisitor::VisitLambdaExpr(const clang::LambdaExpr* LE) { + auto children_iterator_range = LE->children(); + + std::vector children_Exp; + std::vector children_Exp_dx; + + for (auto children : children_iterator_range) { + auto children_expr = const_cast(dyn_cast(children)); + if (children_expr) { + children_Exp.push_back(children_expr); + + children_Exp_dx.push_back(children_expr); + + if(isa(children_expr)) { + std::string constructedTypeName = QualType::getAsString(dyn_cast(children_expr)->getType().split(), PrintingPolicy{ {} }); + // if (!utils::IsKokkosTeamPolicy(constructedTypeName) && !utils::IsKokkosRange(constructedTypeName) && !utils::IsKokkosMember(constructedTypeName)) { + auto children_exprV = Visit(children_expr); + auto children_expr_copy = dyn_cast(Clone(children_expr)); + children_expr_copy->setArg(0, children_exprV.getExpr_dx()); + children_Exp_dx.push_back(children_expr_copy); + // } + } + else if(isa(children_expr)) { + + } + else { + auto children_exprV = Visit(children_expr); + if (children_exprV.getExpr_dx()) { + children_Exp_dx.push_back(children_exprV.getExpr_dx()); + } + } + } + } + + llvm::ArrayRef childrenRef_Exp = + clad_compat::makeArrayRef(children_Exp.data(), children_Exp.size()); + + llvm::ArrayRef childrenRef_Exp_dx = + clad_compat::makeArrayRef(children_Exp_dx.data(), children_Exp_dx.size()); + + auto forwardLambdaClass = LE->getLambdaClass(); + + auto forwardLE = LambdaExpr::Create(m_Context, + forwardLambdaClass, + LE->getIntroducerRange(), + LE->getCaptureDefault(), + LE->getCaptureDefaultLoc(), + LE->hasExplicitParameters(), + LE->hasExplicitResultType(), + childrenRef_Exp, + LE->getEndLoc(), + false); + + clang::LambdaExpr* reverseLE = nullptr; + auto* ClonedCXXRec = diffLambdaCXXRecordDecl(forwardLambdaClass); + + clang::LambdaIntroducer Intro; + Intro.Default = forwardLambdaClass->getLambdaCaptureDefault(); + Intro.Range.setBegin(LE->getBeginLoc()); + Intro.Range.setEnd(LE->getEndLoc()); + + clang::AttributeFactory AttrFactory; + const clang::DeclSpec DS(AttrFactory); + clang::Declarator D( + DS, CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam + CLAD_COMPAT_CLANG12_Declarator_LambdaExpr); + clang::sema::LambdaScopeInfo* LSI = m_Sema.PushLambdaScope(); + beginScope(clang::Scope::BlockScope | clang::Scope::FnScope | + clang::Scope::DeclScope); + m_Sema.ActOnStartOfLambdaDefinition( + Intro, D, + clad_compat::Sema_ActOnStartOfLambdaDefinition_ScopeOrDeclSpec( + getCurrentScope(), DS)); + + for (auto* Method : ClonedCXXRec->methods()) { + if (CXXMethodDecl* cpb = dyn_cast(Method)) { + if (cpb->getOverloadedOperator() == OO_Call) + LSI->CallOperator = cpb; + } + } + + std::vector children_LC_Exp_dx; + + for (auto children_expr : children_Exp_dx) { + if(isa(children_expr)) { + + auto tmp = dyn_cast(children_expr)->getArg(0)->IgnoreImpCasts(); + + if (isa(tmp)) { + auto VD = dyn_cast(dyn_cast(tmp)->getDecl()); + children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD)); + } + if(isa(tmp)) { + auto PE = dyn_cast(tmp); + auto OCE = dyn_cast(PE->getSubExpr()); + + auto VD = dyn_cast(dyn_cast(OCE->getArg(0))->getDecl()); + children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD)); + } + } + if (isa(children_expr)) { + auto VD = dyn_cast(dyn_cast(children_expr)->getDecl()); + children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD)); + } + } + assert(children_Exp_dx.size() == children_LC_Exp_dx.size() && "Wrong number of captures"); + + llvm::ArrayRef childrenRef_LC_Exp_dx = + clad_compat::makeArrayRef(children_LC_Exp_dx.data(), children_LC_Exp_dx.size()); + + // Initialize and attach LambdaDefinitionData to mark this as a lambda. + ClonedCXXRec->setCaptures(m_Context, childrenRef_LC_Exp_dx); + + m_Sema.buildLambdaScope(LSI, + //bodyV.getStmt_dx(), + LSI->CallOperator, + LE->getIntroducerRange(), + LE->getCaptureDefault(), + LE->getCaptureDefaultLoc(), + LE->hasExplicitParameters(), + LE->hasExplicitResultType(), + LE->isMutable()); + + reverseLE = LambdaExpr::Create( + m_Context, ClonedCXXRec, LE->getIntroducerRange(), + LE->getCaptureDefault(), LE->getCaptureDefaultLoc(), + LE->hasExplicitParameters(), LE->hasExplicitResultType(), + childrenRef_Exp_dx, LE->getEndLoc(), false); + + endScope(); + + return {forwardLE, reverseLE}; + } + StmtDiff ReverseModeVisitor::VisitCallExpr(const CallExpr* CE) { const FunctionDecl* FD = CE->getDirectCallee(); if (!FD) { @@ -2991,6 +3241,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } + // Lambda function declaractions should be of auto type; + bool isLambda = false; + if (const RecordType* RT = VDType->getAs()) { + if (const CXXRecordDecl* RD = dyn_cast(RT->getDecl())) + isLambda = RD->isLambda(); + } + if (isLambda) + VDCloneType = VDDerivedType = m_Context.getAutoDeductType(); + // Check if the variable is pointer type and initialized by new expression if (isPointerType && VD->getInit() && isa(VD->getInit())) isInitializedByNewExpr = true; @@ -3013,7 +3272,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } VDDerived = BuildGlobalVarDecl( VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false, - nullptr, VarDecl::InitializationStyle::CInit); + m_Context.getTrivialTypeSourceInfo(VDDerivedType), + VarDecl::InitializationStyle::CInit); } else { // If VD is a reference to a local variable, then the initial value is set // to the derived variable of the corresponding local variable. @@ -3023,8 +3283,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // `VDDerivedType` is the corresponding non-reference type and the initial // value is set to 0. // Otherwise, for non-reference types, the initial value is set to 0. - if (!VDDerivedInit) + if (!(VDDerivedInit || isLambda)) { VDDerivedInit = getZeroInit(VDType); + } else if (isLambda) { + if (const Expr* init = VD->getInit()) { + initDiff = Visit(init); + VDDerivedInit = initDiff.getExpr_dx(); + } /* else ==> invalid lambda */ + } // `specialThisDiffCase` is only required for correctly differentiating // the following code: @@ -3104,10 +3370,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VDDerivedInit = getZeroInit(VDDerivedType); } } - if (initializeDerivedVar) + if (initializeDerivedVar) { VDDerived = BuildGlobalVarDecl( VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false, - nullptr, VD->getInitStyle()); + m_Context.getTrivialTypeSourceInfo(VDDerivedType), + VD->getInitStyle()); + } } if (!m_DiffReq.shouldHaveAdjoint((VD))) @@ -3194,9 +3462,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()), VD->isDirectInit()); else - VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(), - initDiff.getExpr(), VD->isDirectInit(), - nullptr, VD->getInitStyle()); + VDClone = BuildGlobalVarDecl( + VDCloneType, VD->getNameAsString(), initDiff.getExpr(), + VD->isDirectInit(), m_Context.getTrivialTypeSourceInfo(VDCloneType), + VD->getInitStyle()); if (isPointerType && derivedVDE) { if (promoteToFnScope) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, @@ -3316,8 +3585,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass; // If the DeclStmt is not empty, check the first declaration in case it is a - // lambda function. This case it is treated separately for now and we don't - // create a variable for its derivative. + // lambda function. This case it is treated differently. bool isLambda = false; const auto* declsBegin = DS->decls().begin(); if (declsBegin != DS->decls().end() && isa(*declsBegin)) { @@ -3327,12 +3595,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, QT = QT->getPointeeType(); auto* typeDecl = QT->getAsCXXRecordDecl(); - // We should also simply copy the original lambda. The differentiation - // of lambdas is happening in the `VisitCallExpr`. For now, only the - // declarations with lambda expressions without captures are supported. isLambda = typeDecl && typeDecl->isLambda(); - if (isLambda || - (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl))) { + if (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl)) { for (auto* D : DS->decls()) if (auto* VD = dyn_cast(D)) decls.push_back(VD); @@ -3350,7 +3614,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (auto* VD = dyn_cast(D)) { DeclDiff VDDiff; - if (!isLambda) VDDiff = DifferentiateVarDecl(VD); // Here, we move the declaration to the function global scope. diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index e8fce3628..326973d9b 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -120,7 +120,16 @@ namespace clad { m_Context, m_Sema.CurContext, m_DiffReq->getLocation(), m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None); + bool isLambda = false; if (Init) { + if (const RecordType* RT = Init->getType()->getAs()) { + if (const CXXRecordDecl* RD = dyn_cast(RT->getDecl())) + isLambda = RD->isLambda(); + } + if (isLambda) { + clang::Sema::CodeSynthesisContext csc; + m_Sema.pushCodeSynthesisContext(csc); + } m_Sema.AddInitializerToDecl(VD, Init, DirectInit); VD->setInitStyle(IS); } else { @@ -129,6 +138,8 @@ namespace clad { m_Sema.FinalizeDeclaration(VD); // Add the identifier to the scope and IdResolver m_Sema.PushOnScopeChains(VD, Scope, /*AddToContext*/ false); + if (isLambda) + m_Sema.popCodeSynthesisContext(); return VD; }