From cd177a40ad70f0608e45ddcacac63015ddc118fa Mon Sep 17 00:00:00 2001 From: Mihail Mihov Date: Wed, 17 Jul 2024 12:55:14 +0300 Subject: [PATCH] Add support for non-differentiable attribute in reverse mode fixes #717 --- lib/Differentiator/ReverseModeVisitor.cpp | 73 ++++++++++++++++++----- 1 file changed, 57 insertions(+), 16 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ad2a85df2..cc3bc8db7 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1432,6 +1432,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Clone(CE)); } + // If the function is non_differentiable, return zero derivative. + if (clad::utils::hasNonDifferentiableAttribute(CE)) { + // Calling the function without computing derivatives + llvm::SmallVector ClonedArgs; + for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) + ClonedArgs.push_back(Clone(CE->getArg(i))); + + SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema); + Expr* Call = m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), + validLoc, ClonedArgs, validLoc) + .get(); + // Creating a zero derivative + auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, + /*val=*/0); + + // Returning the function call and zero derivative + return StmtDiff(Call, zero); + } + auto NArgs = FD->getNumParams(); // If the function has no args and is not a member function call then we // assume that it is not related to independent variables and does not @@ -2073,6 +2093,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else if (opCode == UnaryOperatorKind::UO_Deref) { diff = Visit(E); Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr()); + + // If we have a pointer to a member expression, which is + // non-differentiable, we just return a clone of the original expression. + if (auto* ME = dyn_cast(diff.getExpr())) + if (clad::utils::hasNonDifferentiableAttribute(ME->getMemberDecl())) + return {cloneE}; + Expr* diff_dx = diff.getExpr_dx(); bool specialDThisCase = false; Expr* derivedE = nullptr; @@ -2662,9 +2689,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If `VD` is a reference to a non-local variable then also there's no // need to call `Visit` since non-local variables are not differentiated. if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) { - Expr* derivedE = BuildDeclRef(VDDerived); - if (isInitializedByNewExpr) - derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE); + Expr* derivedE = nullptr; + + if (!clad::utils::hasNonDifferentiableAttribute(VD)) { + derivedE = BuildDeclRef(VDDerived); + if (isInitializedByNewExpr) + derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE); + } + if (VD->getInit()) { if (isa(VD->getInit())) initDiff = Visit(VD->getInit()); @@ -2696,6 +2728,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(assignToZero, direction::reverse); } } + VarDecl* VDClone = nullptr; Expr* derivedVDE = nullptr; if (VDDerived) @@ -2822,19 +2855,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (declsBegin != DS->decls().end() && isa(*declsBegin)) { auto* VD = dyn_cast(*declsBegin); QualType QT = VD->getType(); - if (!QT->isPointerType()) { - auto* typeDecl = QT->getAsCXXRecordDecl(); - // We should also simply copy the original lambda. The differentiation - // of lambdas is happening in the `VisitCallExpr`. For now, only the - // declarations with lambda expressions without captures are supported. - isLambda = typeDecl && typeDecl->isLambda(); - if (isLambda) { - for (auto* D : DS->decls()) - if (auto* VD = dyn_cast(D)) - decls.push_back(VD); - Stmt* DSClone = BuildDeclStmt(decls); - return StmtDiff(DSClone, nullptr); - } + if (QT->isPointerType()) + QT = QT->getPointeeType(); + + auto* typeDecl = QT->getAsCXXRecordDecl(); + // We should also simply copy the original lambda. The differentiation + // of lambdas is happening in the `VisitCallExpr`. For now, only the + // declarations with lambda expressions without captures are supported. + isLambda = typeDecl && typeDecl->isLambda(); + if (isLambda || + (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl))) { + for (auto* D : DS->decls()) + if (auto* VD = dyn_cast(D)) + decls.push_back(VD); + Stmt* DSClone = BuildDeclStmt(decls); + return StmtDiff(DSClone, nullptr); } } @@ -2846,6 +2881,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (auto* D : DS->decls()) { if (auto* VD = dyn_cast(D)) { DeclDiff VDDiff; + if (!isLambda) VDDiff = DifferentiateVarDecl(VD); @@ -3021,6 +3057,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, "CXXMethodDecl nodes not supported yet!"); MemberExpr* clonedME = utils::BuildMemberExpr( m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); + if (clad::utils::hasNonDifferentiableAttribute(ME)) { + auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, + /*val=*/0); + return {clonedME, zero}; + } if (!baseDiff.getExpr_dx()) return {clonedME, nullptr}; MemberExpr* derivedME = utils::BuildMemberExpr(