From b38d8cfb65ecdaf8521d2225ba7b42e21920226d Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Tue, 4 Jun 2024 14:59:02 +0200 Subject: [PATCH] Improve derived variable init for const pointers --- lib/Differentiator/ReverseModeVisitor.cpp | 73 ++++++++++++++--------- test/Gradient/FunctionCalls.C | 34 +++++++++++ 2 files changed, 79 insertions(+), 28 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 162ca0369..195ccbf3e 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2432,6 +2432,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VarDecl* VDDerived = nullptr; bool isPointerType = VD->getType()->isPointerType(); bool isInitializedByNewExpr = false; + bool initializeDerivedVar = true; // Check if the variable is pointer type and initialized by new expression if (isPointerType && VD->getInit() && isa(VD->getInit())) isInitializedByNewExpr = true; @@ -2506,22 +2507,34 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } // if VD is a pointer type, then the initial value is set to the derived // expression of the corresponding pointer type. - else if (isPointerType && VD->getInit()) { - VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); - // If it's a pointer to a constant type, then remove the constness. - if (VD->getType()->getPointeeType().isConstQualified()) { - // first extract the pointee type - auto pointeeType = VD->getType()->getPointeeType(); - // then remove the constness - pointeeType.removeLocalConst(); - // then create a new pointer type with the new pointee type - VDDerivedType = m_Context.getPointerType(pointeeType); + else if (isPointerType) { + if (!isInitializedByNewExpr) + initDiff = Visit(VD->getInit()); + + // If the pointer is const and derived expression is not available, then + // we should not create a derived variable for it. This will be useful + // for reducing number of differentiation variables in pullbacks. + bool constPointer = VD->getType()->getPointeeType().isConstQualified(); + if (constPointer && !isInitializedByNewExpr && !initDiff.getExpr_dx()) + initializeDerivedVar = false; + else { + VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); + // If it's a pointer to a constant type, then remove the constness. + if (constPointer) { + // first extract the pointee type + auto pointeeType = VD->getType()->getPointeeType(); + // then remove the constness + pointeeType.removeLocalConst(); + // then create a new pointer type with the new pointee type + VDDerivedType = m_Context.getPointerType(pointeeType); + } + VDDerivedInit = getZeroInit(VDDerivedType); } - VDDerivedInit = getZeroInit(VDDerivedType); } - VDDerived = - BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), - VDDerivedInit, false, nullptr, VD->getInitStyle()); + if (initializeDerivedVar) + VDDerived = BuildGlobalVarDecl( + VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false, + nullptr, VD->getInitStyle()); } // If `VD` is a reference to a local variable, then it is already @@ -2562,11 +2575,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, assignToZero = GetCladZeroInit(declRef); addToCurrentBlock(assignToZero, direction::reverse); } - } else if (isPointerType && VD->getInit()) { - initDiff = Visit(VD->getInit()); } VarDecl* VDClone = nullptr; - Expr* derivedVDE = BuildDeclRef(VDDerived); + Expr* derivedVDE = nullptr; + if (VDDerived) + derivedVDE = BuildDeclRef(VDDerived); // FIXME: Add extra parantheses if derived variable pointer is pointing to a // class type object. @@ -2601,7 +2614,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(), initDiff.getExpr(), VD->isDirectInit(), nullptr, VD->getInitStyle()); - if (isPointerType) { + if (isPointerType && derivedVDE) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, initDiff.getExpr_dx()); addToCurrentBlock(assignDerivativeE, direction::forward); @@ -2615,7 +2628,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE); } } - m_Variables.emplace(VDClone, derivedVDE); + if (derivedVDE) + m_Variables.emplace(VDClone, derivedVDE); return DeclDiff(VDClone, VDDerived); } @@ -2642,11 +2656,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (auto* FD = dyn_cast(callExpr->getCalleeDecl())) if (utils::IsMemoryFunction(FD)) dxInForward = true; - - if (dxInForward) - addToCurrentBlock(stmtDx, direction::forward); - else - addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse); + if (stmtDx) { + if (dxInForward) + addToCurrentBlock(stmtDx, direction::forward); + else + addToCurrentBlock(stmtDx, direction::reverse); + } CompoundStmt* RCS = endBlock(direction::reverse); std::reverse(RCS->body_begin(), RCS->body_end()); Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS); @@ -2747,10 +2762,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } decls.push_back(VDDiff.getDecl()); - if (isa(VD->getType())) - localDeclsDiff.push_back(VDDiff.getDecl_dx()); - else - declsDiff.push_back(VDDiff.getDecl_dx()); + if (VDDiff.getDecl_dx()) { + if (isa(VD->getType())) + localDeclsDiff.push_back(VDDiff.getDecl_dx()); + else + declsDiff.push_back(VDDiff.getDecl_dx()); + } } else if (auto* SAD = dyn_cast(D)) { DeclDiff SADDiff = DifferentiateStaticAssertDecl(SAD); if (SADDiff.getDecl()) diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index d1ae1c1e0..b76c46dc3 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -670,6 +670,25 @@ double fn19(double x) { // CHECK-NEXT: } // CHECK-NEXT: } +double weighted_sum(double* x, const double* w) { + return w[0] * x[0] + w[1] * x[1]; +} + +// CHECK: void weighted_sum_pullback(double *x, const double *w, double _d_y, double *_d_x); + +double fn20(double* x, const double* w) { + const double* auxW = w + 1; + return weighted_sum(x, auxW); +} + +// CHECK: void fn20_grad_0(double *x, const double *w, double *_d_x) { +// CHECK-NEXT: const double *auxW = w + 1; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: weighted_sum_pullback(x, auxW, 1, _d_x); +// CHECK-NEXT: } + + template void reset(T* arr, int n) { for (int i=0; i