From b0dd5060c2e15dd383f4b4d38d968a89593e6ff4 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 13 Mar 2024 00:35:10 +0200 Subject: [PATCH 01/10] Simplify pullback calls by replacing _grad/_r pairs with single _r variables and not storing pointer args --- include/clad/Differentiator/ErrorEstimator.h | 15 +- .../clad/Differentiator/ExternalRMVSource.h | 4 +- .../MultiplexExternalRMVSource.h | 6 +- include/clad/Differentiator/VisitorBase.h | 10 +- lib/Differentiator/ErrorEstimator.cpp | 10 +- .../MultiplexExternalRMVSource.cpp | 8 +- lib/Differentiator/ReverseModeVisitor.cpp | 254 ++++-------------- lib/Differentiator/VisitorBase.cpp | 32 ++- test/Arrays/ArrayInputsReverseMode.C | 84 ++---- test/CUDA/GradientCuda.cu | 19 +- test/ErrorEstimation/BasicOps.C | 39 ++- test/FirstDerivative/BuiltinDerivatives.C | 56 ++-- test/Gradient/Assignments.C | 11 +- test/Gradient/FunctionCalls.C | 130 +++------ test/Gradient/Functors.C | 35 +-- test/Gradient/Gradients.C | 28 +- test/Gradient/Loops.C | 32 ++- test/Gradient/MemberFunctions.C | 18 +- test/Gradient/UserDefinedTypes.C | 31 +-- test/Hessian/BuiltinDerivatives.C | 136 ++++------ test/Hessian/NestedFunctionCalls.C | 28 +- test/Jacobian/FunctionCalls.C | 16 +- test/Jacobian/Jacobian.C | 33 ++- test/Misc/RunDemos.C | 11 +- test/NestedCalls/NestedCalls.C | 21 +- test/NumericalDiff/GradientMultiArg.C | 10 +- test/NumericalDiff/NoNumDiff.C | 3 +- test/NumericalDiff/NumDiff.C | 13 +- test/NumericalDiff/PrintErrorNumDiff.C | 3 +- test/ROOT/TFormula.C | 6 +- 30 files changed, 392 insertions(+), 710 deletions(-) diff --git a/include/clad/Differentiator/ErrorEstimator.h b/include/clad/Differentiator/ErrorEstimator.h index 9dc2a5320..79a51c4ba 100644 --- a/include/clad/Differentiator/ErrorEstimator.h +++ b/include/clad/Differentiator/ErrorEstimator.h @@ -97,10 +97,11 @@ class ErrorEstimationHandler : public ExternalRMVSource { /// \param[in] CallArgs The orignal call arguments of the function call. /// \param[in] ArgResultDecls The differentiated call arguments. /// \param[in] numArgs The number of call args. - void EmitNestedFunctionParamError( - clang::FunctionDecl* fnDecl, - llvm::SmallVectorImpl& CallArgs, - llvm::SmallVectorImpl& ArgResultDecls, size_t numArgs); + void + EmitNestedFunctionParamError(clang::FunctionDecl* fnDecl, + llvm::SmallVectorImpl& CallArgs, + llvm::SmallVectorImpl& ArgResult, + size_t numArgs); /// Checks if a variable should be considered in error estimation. /// @@ -181,16 +182,14 @@ class ErrorEstimationHandler : public ExternalRMVSource { void ActBeforeFinalizingVisitCallExpr( const clang::CallExpr*& CE, clang::Expr*& fnDecl, llvm::SmallVectorImpl& derivedCallArgs, - llvm::SmallVectorImpl& ArgResultDecls, - bool asGrad) override; + llvm::SmallVectorImpl& ArgResult, bool asGrad) override; void ActBeforeFinalizingAssignOp(clang::Expr*&, clang::Expr*&, clang::Expr*&, clang::BinaryOperator::Opcode&) override; void ActBeforeFinalizingDifferentiateSingleStmt(const direction& d) override; void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override; void ActBeforeDifferentiatingCallExpr( llvm::SmallVectorImpl& pullbackArgs, - llvm::SmallVectorImpl& ArgDecls, - bool hasAssignee) override; + llvm::SmallVectorImpl& ArgDecls, bool hasAssignee) override; void ActBeforeFinalizingVisitDeclStmt( llvm::SmallVectorImpl& decls, llvm::SmallVectorImpl& declsDiff) override; diff --git a/include/clad/Differentiator/ExternalRMVSource.h b/include/clad/Differentiator/ExternalRMVSource.h index 4879cc18a..396c6e249 100644 --- a/include/clad/Differentiator/ExternalRMVSource.h +++ b/include/clad/Differentiator/ExternalRMVSource.h @@ -127,7 +127,7 @@ class ExternalRMVSource { virtual void ActBeforeFinalizingVisitCallExpr( const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn, llvm::SmallVectorImpl& derivedCallArgs, - llvm::SmallVectorImpl& ArgResultDecls, bool asGrad) {} + llvm::SmallVectorImpl& ArgResult, bool asGrad) {} /// This is called just before finalising processing of post and pre /// increment and decrement operations. @@ -157,7 +157,7 @@ class ExternalRMVSource { virtual void ActBeforeDifferentiatingCallExpr( llvm::SmallVectorImpl& pullbackArgs, - llvm::SmallVectorImpl& ArgDecls, bool hasAssignee) {} + llvm::SmallVectorImpl& ArgDecls, bool hasAssignee) {} virtual void ActBeforeFinalizingVisitDeclStmt( llvm::SmallVectorImpl& decls, diff --git a/include/clad/Differentiator/MultiplexExternalRMVSource.h b/include/clad/Differentiator/MultiplexExternalRMVSource.h index 2e1f35f8b..0c0c2a07e 100644 --- a/include/clad/Differentiator/MultiplexExternalRMVSource.h +++ b/include/clad/Differentiator/MultiplexExternalRMVSource.h @@ -48,8 +48,7 @@ class MultiplexExternalRMVSource : public ExternalRMVSource { void ActBeforeFinalizingVisitCallExpr( const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn, llvm::SmallVectorImpl& derivedCallArgs, - llvm::SmallVectorImpl& ArgResultDecls, - bool asGrad) override; + llvm::SmallVectorImpl& ArgResult, bool asGrad) override; void ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) override; void ActAfterCloningLHSOfAssignOp(clang::Expr*&, clang::Expr*&, clang::BinaryOperatorKind& opCode) override; @@ -60,8 +59,7 @@ class MultiplexExternalRMVSource : public ExternalRMVSource { void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override; void ActBeforeDifferentiatingCallExpr( llvm::SmallVectorImpl& pullbackArgs, - llvm::SmallVectorImpl& ArgDecls, - bool hasAssignee) override; + llvm::SmallVectorImpl& ArgDecls, bool hasAssignee) override; void ActBeforeFinalizingVisitDeclStmt( llvm::SmallVectorImpl& decls, llvm::SmallVectorImpl& declsDiff) override; diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index eabcac02f..a588b2b97 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -590,16 +590,20 @@ namespace clad { /// /// \param[in] targetFuncCall The function to get the derivative for. /// \param[in] retType The return type of the target call expression. + /// \param[in] dfdx The dfdx corresponding to this call expression. /// \param[in] numArgs The total number of 'args'. - /// \param[in] NumericalDiffMultiArg The built statements to add to block - /// later. + /// \param[in] PreCallStmts The built statements to add to block + /// before the call to the derived function. + /// \param[in] PostCallStmts The built statements to add to block + /// after the call to the derived function. /// \param[in] args All the arguments to the target function. /// \param[in] outputArgs The output gradient arguments. /// /// \returns The derivative function call. clang::Expr* GetMultiArgCentralDiffCall( clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs, - llvm::SmallVectorImpl& NumericalDiffMultiArg, + clang::Expr* dfdx, llvm::SmallVectorImpl& PreCallStmts, + llvm::SmallVectorImpl& PostCallStmts, llvm::SmallVectorImpl& args, llvm::SmallVectorImpl& outputArgs); /// Emits diagnostic messages on differentiation (or lack thereof) for diff --git a/lib/Differentiator/ErrorEstimator.cpp b/lib/Differentiator/ErrorEstimator.cpp index 358e36cfb..214a3ec6b 100644 --- a/lib/Differentiator/ErrorEstimator.cpp +++ b/lib/Differentiator/ErrorEstimator.cpp @@ -98,7 +98,7 @@ void ErrorEstimationHandler::SaveReturnExpr(Expr* retExpr) { void ErrorEstimationHandler::EmitNestedFunctionParamError( FunctionDecl* fnDecl, llvm::SmallVectorImpl& derivedCallArgs, - llvm::SmallVectorImpl& ArgResultDecls, size_t numArgs) { + llvm::SmallVectorImpl& ArgResult, size_t numArgs) { assert(fnDecl && "Must have a value"); for (size_t i = 0; i < numArgs; i++) { if (!fnDecl->getParamDecl(0)->getType()->isLValueReferenceType()) @@ -109,7 +109,7 @@ void ErrorEstimationHandler::EmitNestedFunctionParamError( // if (utils::IsReferenceOrPointerType(fnDecl->getParamDecl(i)->getType())) // continue; Expr* errorExpr = m_EstModel->AssignError( - {derivedCallArgs[i], m_RMV->BuildDeclRef(ArgResultDecls[i])}, + {derivedCallArgs[i], m_RMV->Clone(ArgResult[i])}, fnDecl->getNameInfo().getAsString() + "_param_" + std::to_string(i)); Expr* errorStmt = m_RMV->BuildOp(BO_AddAssign, m_FinalError, errorExpr); m_ReverseErrorStmts.push_back(errorStmt); @@ -372,7 +372,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) { void ErrorEstimationHandler::ActBeforeFinalizingVisitCallExpr( const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn, llvm::SmallVectorImpl& derivedCallArgs, - llvm::SmallVectorImpl& ArgResultDecls, bool asGrad) { + llvm::SmallVectorImpl& ArgResult, bool asGrad) { if (OverloadedDerivedFn && asGrad) { // Derivative was found. FunctionDecl* fnDecl = @@ -382,7 +382,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingVisitCallExpr( // in the input prameters (if of reference type) to call and save to // emit them later. - EmitNestedFunctionParamError(fnDecl, derivedCallArgs, ArgResultDecls, + EmitNestedFunctionParamError(fnDecl, derivedCallArgs, ArgResult, CE->getNumArgs()); } } @@ -416,7 +416,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingDifferentiateSingleExpr( void ErrorEstimationHandler::ActBeforeDifferentiatingCallExpr( llvm::SmallVectorImpl& pullbackArgs, - llvm::SmallVectorImpl& ArgDecls, bool hasAssignee) { + llvm::SmallVectorImpl& ArgDecls, bool hasAssignee) { auto errorRef = m_RMV->BuildVarDecl(m_RMV->m_Context.DoubleTy, "_t", m_RMV->getZeroInit(m_RMV->m_Context.DoubleTy)); diff --git a/lib/Differentiator/MultiplexExternalRMVSource.cpp b/lib/Differentiator/MultiplexExternalRMVSource.cpp index 837990656..f81f41d1b 100644 --- a/lib/Differentiator/MultiplexExternalRMVSource.cpp +++ b/lib/Differentiator/MultiplexExternalRMVSource.cpp @@ -149,10 +149,10 @@ void MultiplexExternalRMVSource::ActBeforeFinalizingVisitReturnStmt( void MultiplexExternalRMVSource::ActBeforeFinalizingVisitCallExpr( const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn, llvm::SmallVectorImpl& derivedCallArgs, - llvm::SmallVectorImpl& ArgResultDecls, bool asGrad) { + llvm::SmallVectorImpl& ArgResult, bool asGrad) { for (auto source : m_Sources) { - source->ActBeforeFinalizingVisitCallExpr(CE, OverloadedDerivedFn, derivedCallArgs, - ArgResultDecls, asGrad); + source->ActBeforeFinalizingVisitCallExpr( + CE, OverloadedDerivedFn, derivedCallArgs, ArgResult, asGrad); } } @@ -199,7 +199,7 @@ void MultiplexExternalRMVSource::ActBeforeFinalizingDifferentiateSingleExpr( void MultiplexExternalRMVSource::ActBeforeDifferentiatingCallExpr( llvm::SmallVectorImpl& pullbackArgs, - llvm::SmallVectorImpl& ArgDecls, bool hasAssignee) { + llvm::SmallVectorImpl& ArgDecls, bool hasAssignee) { for (auto source : m_Sources) source->ActBeforeDifferentiatingCallExpr(pullbackArgs, ArgDecls, hasAssignee); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 2bcd1a7f7..acae93e56 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1362,6 +1362,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Add it to the body statements. addToCurrentBlock(add_assign, direction::reverse); } + return StmtDiff(clonedDRE, it->second, it->second); } else { // Check DeclRefExpr is a reference to an independent variable. auto it = m_Variables.find(VD); @@ -1440,7 +1441,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVector DerivedCallArgs{}; // Stores tape decl and pushes for multiarg numerically differentiated // calls. - llvm::SmallVector NumericalDiffMultiArg{}; + llvm::SmallVector PostCallStmts{}; // For calls to C-style memory allocation functions, we do not need to // differentiate the call. We just need to visit the arguments to the @@ -1516,8 +1517,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return call; } - llvm::SmallVector ArgResultDecls{}; - llvm::SmallVector ArgDeclStmts{}; + llvm::SmallVector PreCallStmts{}; // Save current index in the current block, to potentially put some // statements there later. std::size_t insertionPoint = getCurrentBlock(direction::reverse).size(); @@ -1534,101 +1534,38 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const auto* PVD = FD->getParamDecl(i - static_cast(isCXXOperatorCall)); StmtDiff argDiff{}; - bool passByRef = utils::IsReferenceOrPointerType(PVD->getType()); - if (passByRef && isa(arg)) { - // If the argument is a temporary variable, this means that param type - // is a reference to a const type and we are passing a temporary - // variable to it. In this case, we should not pass the derivative - // argument by reference. - passByRef = false; - } - QualType argDiffType; // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. - if (passByRef) { + if (utils::IsReferenceOrPointerType(PVD->getType()) && + !isa(arg)) { argDiff = Visit(arg); - Expr* dArg = nullptr; - argDiffType = argDiff.getExpr()->getType(); - QualType argResultValueType = - utils::GetValueType(argDiffType).getNonReferenceType(); - // Create ArgResult variable for each reference argument because it is - // required by error estimator. For automatic differentiation, we do not need - // to create ArgResult variable for arguments passed by reference. - // ``` - // _r0 = _d_a; - // ``` - if (argDiff.getExpr_dx() && utils::IsLiteral(argDiff.getExpr_dx())) { - dArg = StoreAndRef(argDiff.getExpr_dx(), arg->getType(), - direction::reverse, "_r", - /*forceDeclCreation=*/true); - } else if (argDiffType->isArrayType()) { - Expr* init = argDiff.getExpr_dx(); - if (isa(argDiff.getExpr_dx()->getType())) - init = utils::BuildCladArrayInitByConstArray(m_Sema, - argDiff.getExpr_dx()); - - dArg = StoreAndRef(init, GetCladArrayOfType(argResultValueType), - direction::reverse, "_r", - /*forceDeclCreation=*/true, - VarDecl::InitializationStyle::CallInit); - } else if (argDiffType->isPointerType()) { - dArg = StoreAndRef(argDiff.getExpr_dx(), argDiffType, - direction::reverse, "_r", - /*forceDeclCreation=*/true); - } else { - dArg = StoreAndRef(argDiff.getExpr_dx(), argResultValueType, - direction::reverse, "_r", - /*forceDeclCreation=*/true); - } - ArgResultDecls.push_back( - cast(cast(dArg)->getDecl())); + CallArgDx.push_back(argDiff.getExpr_dx()); } else { - assert(!utils::isArrayOrPointerType(arg->getType()) && - "Arguments passed by pointers should be covered in pass by " - "reference calls"); // Create temporary variables corresponding to derivative of each // argument, so that they can be referred to when arguments is visited. // Variables will be initialized later after arguments is visited. This // is done to reduce cloning complexity and only clone once. The type is // same as the call expression as it is the type used to declare the // _gradX array - Expr* dArg = nullptr; - QualType argType = utils::GetValueType(arg->getType()); - dArg = StoreAndRef(/*E=*/nullptr, argType, direction::reverse, "_r", - /*forceDeclCreation=*/true); - ArgResultDecls.push_back( - cast(cast(dArg)->getDecl())); + QualType dArgTy = getNonConstType(arg->getType(), m_Context, m_Sema); + VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy)); + PreCallStmts.push_back(BuildDeclStmt(dArgDecl)); + CallArgDx.push_back(BuildDeclRef(dArgDecl)); // Visit using uninitialized reference. - argDiff = Visit(arg, dArg); - argDiffType = argDiff.getExpr()->getType(); + argDiff = Visit(arg, BuildDeclRef(dArgDecl)); } - // FIXME: We may use same argDiff.getExpr_dx at two places. This can - // lead to inconsistent pushes and pops. If `isInsideLoop` is true and - // actual argument is something like "a[i]", then argDiff.getExpr() and - // argDiff.getExpr_dx() will respectively be: - // ``` - // a[clad::push(_t0, i)]; - // a[clad::pop(_t0)]; - // ``` - // The expression `a[clad::pop(_t0)]` might already be used in the AST if - // visit was called with a dfdx() present. - // And thus using this expression in the AST explicitly may lead to size - // assertion failed. - // - // We should modify the design so that the behaviour of obtained StmtDiff - // expression is consistent both inside and outside loops. - CallArgDx.push_back(argDiff.getExpr_dx()); // Save cloned arg in a "global" variable, so that it is accessible from // the reverse pass. // FIXME: At this point, we assume all the variables passed by reference // may be changed since we have no way to determine otherwise. // FIXME: We cannot use GlobalStoreAndRef to store a whole array so now // arrays are not stored. + bool passByRef = PVD->getType()->isReferenceType() && + !isa(arg); StmtDiff argDiffStore; - if (passByRef && !argDiffType->isArrayType() && - !argDiff.getExpr()->isEvaluatable(m_Context)) + if (passByRef && !argDiff.getExpr()->isEvaluatable(m_Context)) argDiffStore = GlobalStoreAndRef(argDiff.getExpr(), "_t", /*force=*/true); else @@ -1659,7 +1596,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // ``` // FIXME: We cannot use GlobalStoreAndRef to store a whole array so now // arrays are not stored. - if (passByRef && !argDiffType->isArrayType()) { + if (passByRef) { if (isInsideLoop) { // Add tape push expression. We need to explicitly add it here because // we cannot add it as call expression argument -- we need to pass the @@ -1715,10 +1652,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivedCallArgs.push_back(argDiffStore.getExpr_dx()); } - VarDecl* gradVarDecl = nullptr; - Expr* gradVarExpr = nullptr; - Expr* gradArgExpr = nullptr; - IdentifierInfo* gradVarII = nullptr; Expr* OverloadedDerivedFn = nullptr; // If the function has a single arg and does not returns a reference or take // arg by reference, we look for a derivative w.r.t. to this arg using the @@ -1745,10 +1678,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // requires (slightly) different arguments. llvm::SmallVector pullbackCallArgs{}; - // Stores a list of arg result variable declaration (_r0) with the - // corresponding grad variable expression (_grad0). - llvm::SmallVector, 4> argResultsAndGrads; - // Stores differentiation result of implicit `this` object, if any. StmtDiff baseDiff; // If it has more args or f_darg0 was not found, we look for its pullback @@ -1792,66 +1721,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } for (auto* argDerivative : CallArgDx) { - gradVarDecl = nullptr; - gradVarExpr = nullptr; - gradArgExpr = nullptr; - gradVarII = CreateUniqueIdentifier(funcPostfix()); - + Expr* gradArgExpr = nullptr; + const Expr* arg = CE->getArg(idx); const auto* PVD = FD->getParamDecl(idx); - bool passByRef = utils::IsReferenceOrPointerType(PVD->getType()); - if (passByRef && isa(CE->getArg(idx))) { - // If the argument is a temporary variable, this means that param type - // is a reference to a const type and we are passing a temporary - // variable to it. In this case, we should not pass the derivative - // argument by reference. - passByRef = false; - } - if (passByRef) { - // If derivative type is constant array type instead of - // `clad::array_ref` or `clad::array` type, then create an - // `clad::array_ref` variable that references this constant array. It - // is required because the pullback function expects `clad::array_ref` - // type for representing array derivatives. Currently, only constant - // array data members have derivatives of constant array types. - if ((argDerivative != nullptr) && - isa(argDerivative->getType())) { - Expr* init = - utils::BuildCladArrayInitByConstArray(m_Sema, argDerivative); - auto* derivativeArrayRefVD = BuildVarDecl( - GetCladArrayRefOfType(argDerivative->getType() - ->getPointeeOrArrayElementType() - ->getCanonicalTypeInternal()), - "_t", init); - ArgDeclStmts.push_back(BuildDeclStmt(derivativeArrayRefVD)); - argDerivative = BuildDeclRef(derivativeArrayRefVD); + if (utils::IsReferenceOrPointerType(PVD->getType()) && + !isa(arg)) { + if (argDerivative) { + if (utils::isArrayOrPointerType(argDerivative->getType()) || + isCladArrayType(argDerivative->getType()) || + !argDerivative->isLValue()) + gradArgExpr = argDerivative; + else + gradArgExpr = + BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative); } - if ((argDerivative != nullptr) && - (isCladArrayType(argDerivative->getType()) || - argDerivative->getType()->isPointerType() || - !argDerivative->isLValue())) - gradArgExpr = argDerivative; - else - gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative); } else { - // Declare: diffArgType _grad; - Expr* initVal = nullptr; - QualType gradVarType = utils::GetValueType(PVD->getType()); - if (!gradVarType->isRecordType()) { - // If the argument is not a class type, then initialize the grad - // variable with 0. - initVal = - ConstantFolder::synthesizeLiteral(gradVarType, m_Context, 0); - } - gradVarDecl = BuildVarDecl(gradVarType, gradVarII, initVal); - // Pass the address of the declared variable - gradVarExpr = BuildDeclRef(gradVarDecl); + Expr* gradVarExpr = CallArgDx[idx]; gradArgExpr = BuildOp(UO_AddrOf, gradVarExpr, m_Function->getLocation()); - argResultsAndGrads.push_back({ArgResultDecls[idx], gradVarExpr}); } DerivedCallOutputArgs.push_back(gradArgExpr); - if (gradVarDecl) - ArgDeclStmts.push_back(BuildDeclStmt(gradVarDecl)); idx++; } Expr* pullback = dfdx(); @@ -1897,9 +1786,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackCallArgs.erase(pullbackCallArgs.begin()); } - // should be true if we are using numerical differentiation to differentiate - // the callee function. - bool usingNumericalDiff = false; // Derivative was not found, check if it is a recursive call if (!OverloadedDerivedFn) { if (FD == m_Function && m_Mode == DiffMode::experimental_pullback) { @@ -1919,7 +1805,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else { if (m_ExternalSource) m_ExternalSource->ActBeforeDifferentiatingCallExpr( - pullbackCallArgs, ArgDeclStmts, dfdx()); + pullbackCallArgs, PreCallStmts, dfdx()); // Overloaded derivative was not found, request the CladPlugin to // derive the called function. DiffRequest pullbackRequest{}; @@ -1938,39 +1824,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!pullbackFD && !utils::HasAnyReferenceOrPointerArgument(FD) && !isa(FD)) { // Try numerically deriving it. - // Build a clone call expression so that we can correctly - // scope the function to be differentiated. - Expr* call = m_Sema - .ActOnCallExpr(getCurrentScope(), - Clone(CE->getCallee()), - noLoc, - llvm::MutableArrayRef(CallArgs), - noLoc) - .get(); - Expr* fnCallee = cast(call)->getCallee(); if (NArgs == 1) { - OverloadedDerivedFn = GetSingleArgCentralDiffCall(fnCallee, - DerivedCallArgs - [0], - /*targetPos=*/0, - /*numArgs=*/1, - DerivedCallArgs); + OverloadedDerivedFn = GetSingleArgCentralDiffCall( + Clone(CE->getCallee()), DerivedCallArgs[0], + /*targetPos=*/0, + /*numArgs=*/1, DerivedCallArgs); asGrad = !OverloadedDerivedFn; } else { auto CEType = getNonConstType(CE->getType(), m_Context, m_Sema); OverloadedDerivedFn = GetMultiArgCentralDiffCall( - fnCallee, CEType.getCanonicalType(), CE->getNumArgs(), - NumericalDiffMultiArg, DerivedCallArgs, DerivedCallOutputArgs); + Clone(CE->getCallee()), CEType.getCanonicalType(), + CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts, + DerivedCallArgs, CallArgDx); } CallExprDiffDiagnostics(FD->getNameAsString(), CE->getBeginLoc(), OverloadedDerivedFn); if (!OverloadedDerivedFn) { auto& block = getCurrentBlock(direction::reverse); - block.insert(block.begin(), ArgDeclStmts.begin(), - ArgDeclStmts.end()); + block.insert(block.begin(), PreCallStmts.begin(), + PreCallStmts.end()); return StmtDiff(Clone(CE)); } - usingNumericalDiff = true; } else if (pullbackFD) { if (baseDiff.getExpr()) { Expr* baseE = baseDiff.getExpr(); @@ -1991,55 +1865,33 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Derivative was found. FunctionDecl* fnDecl = dyn_cast(OverloadedDerivedFn) ->getDirectCallee(); + // Put Result array declaration in the function body. + // Call the gradient, passing Result as the last Arg. + Stmts& block = getCurrentBlock(direction::reverse); + Stmt** it = std::begin(block) + insertionPoint; + // Insert PreCallStmts + it = block.insert(it, PreCallStmts.begin(), PreCallStmts.end()); + it += PreCallStmts.size(); if (!asGrad) { if (utils::IsCladValueAndPushforwardType(fnDecl->getReturnType())) OverloadedDerivedFn = utils::BuildMemberExpr( m_Sema, getCurrentScope(), OverloadedDerivedFn, "pushforward"); // If the derivative is called through _darg0 instead of _grad. Expr* d = BuildOp(BO_Mul, dfdx(), OverloadedDerivedFn); - - PerformImplicitConversionAndAssign(ArgResultDecls[0], d); + Expr* addGrad = BuildOp(BO_AddAssign, Clone(CallArgDx[0]), d); + it = block.insert(it, addGrad); + it++; } else { - // Put Result array declaration in the function body. - // Call the gradient, passing Result as the last Arg. - auto& block = getCurrentBlock(direction::reverse); - auto* it = std::begin(block) + insertionPoint; - - // Insert the _gradX declaration statements - it = block.insert(it, ArgDeclStmts.begin(), ArgDeclStmts.end()); - it += ArgDeclStmts.size(); - it = block.insert(it, NumericalDiffMultiArg.begin(), - NumericalDiffMultiArg.end()); - it += NumericalDiffMultiArg.size(); // Insert the CallExpr to the derived function - block.insert(it, OverloadedDerivedFn); - - if (usingNumericalDiff) { - for (auto resAndGrad : argResultsAndGrads) { - VarDecl* argRes = resAndGrad.first; - Expr* grad = resAndGrad.second; - if (isCladArrayType(grad->getType())) { - Expr* E = BuildOp(BO_MulAssign, grad, dfdx()); - // Visit each arg with df/dargi = df/dxi * Result. - PerformImplicitConversionAndAssign(argRes, E); - } else { - // Visit each arg with df/dargi = df/dxi * Result. - PerformImplicitConversionAndAssign(argRes, - BuildOp(BO_Mul, dfdx(), grad)); - } - } - } else { - for (auto resAndGrad : argResultsAndGrads) { - VarDecl* argRes = resAndGrad.first; - Expr* grad = resAndGrad.second; - PerformImplicitConversionAndAssign(argRes, grad); - } - } + it = block.insert(it, OverloadedDerivedFn); + it++; } + // Insert PostCallStmts + it = block.insert(it, PostCallStmts.begin(), PostCallStmts.end()); } if (m_ExternalSource) m_ExternalSource->ActBeforeFinalizingVisitCallExpr( - CE, OverloadedDerivedFn, DerivedCallArgs, ArgResultDecls, asGrad); + CE, OverloadedDerivedFn, DerivedCallArgs, CallArgDx, asGrad); Expr* call = nullptr; diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index eef3e2353..32a5e99a2 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -772,8 +772,9 @@ namespace clad { } Expr* VisitorBase::GetMultiArgCentralDiffCall( - Expr* targetFuncCall, QualType retType, unsigned numArgs, - llvm::SmallVectorImpl& NumericalDiffMultiArg, + Expr* targetFuncCall, QualType retType, unsigned numArgs, Expr* dfdx, + llvm::SmallVectorImpl& PreCallStmts, + llvm::SmallVectorImpl& PostCallStmts, llvm::SmallVectorImpl& args, llvm::SmallVectorImpl& outputArgs) { int printErrorInf = m_Builder.shouldPrintNumDiffErrs(); @@ -785,7 +786,7 @@ namespace clad { auto VD = BuildVarDecl( TapeType, "_t", getZeroInit(TapeType), /*DirectInit=*/false, /*TSI=*/nullptr, VarDecl::InitializationStyle::CInit); - NumericalDiffMultiArg.push_back(BuildDeclStmt(VD)); + PreCallStmts.push_back(BuildDeclStmt(VD)); Expr* TapeRef = BuildDeclRef(VD); NumDiffArgs.push_back(TapeRef); NumDiffArgs.push_back(ConstantFolder::synthesizeLiteral(m_Context.IntTy, @@ -798,16 +799,23 @@ namespace clad { CXXScopeSpec CSS; CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc); LookupResult& Push = GetCladTapePush(); - auto PushDRE = m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false) - .get(); - Expr* PushExpr; + auto PushDRE = + m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false).get(); for (unsigned i = 0, e = numArgs; i < e; i++) { - Expr* callArgs[] = {TapeRef, outputArgs[i]}; - PushExpr = m_Sema - .ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, - noLoc) - .get(); - NumericalDiffMultiArg.push_back(PushExpr); + QualType argTy = args[i]->getType(); + VarDecl* gradVar = BuildVarDecl(argTy, "_grad", getZeroInit(argTy)); + PreCallStmts.push_back(BuildDeclStmt(gradVar)); + Expr* PushExpr = BuildDeclRef(gradVar); + if (!isCladArrayType(argTy)) + PushExpr = BuildOp(UO_AddrOf, PushExpr); + Expr* callArgs[] = {TapeRef, PushExpr}; + Stmt* PushStmt = + m_Sema + .ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, noLoc) + .get(); + PreCallStmts.push_back(PushStmt); + Expr* gradExpr = BuildOp(BO_Mul, dfdx, BuildDeclRef(gradVar)); + PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr)); NumDiffArgs.push_back(args[i]); } std::string Name = "central_difference"; diff --git a/test/Arrays/ArrayInputsReverseMode.C b/test/Arrays/ArrayInputsReverseMode.C index d4bad52cc..50ca4a889 100644 --- a/test/Arrays/ArrayInputsReverseMode.C +++ b/test/Arrays/ArrayInputsReverseMode.C @@ -47,16 +47,11 @@ double f(double *arr) { } //CHECK: void f_grad(double *arr, clad::array_ref _d_arr) { -//CHECK-NEXT: double *_t0; -//CHECK-NEXT: _t0 = arr; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: arr = _t0; -//CHECK-NEXT: int _grad1 = 0; -//CHECK-NEXT: addArr_pullback(_t0, 3, 1, _d_arr, &_grad1); -//CHECK-NEXT: double *_r0 = _d_arr; -//CHECK-NEXT: int _r1 = _grad1; +//CHECK-NEXT: int _r0 = 0; +//CHECK-NEXT: addArr_pullback(arr, 3, 1, _d_arr, &_r0); //CHECK-NEXT: } //CHECK-NEXT: } @@ -142,9 +137,8 @@ float func2(float* a) { //CHECK-NEXT: i--; //CHECK-NEXT: sum = clad::pop(_t1); //CHECK-NEXT: float _r_d0 = _d_sum; -//CHECK-NEXT: float _grad0 = 0.F; -//CHECK-NEXT: helper_pullback(a[i], _r_d0, &_grad0); -//CHECK-NEXT: float _r0 = _grad0; +//CHECK-NEXT: float _r0 = 0; +//CHECK-NEXT: helper_pullback(a[i], _r_d0, &_r0); //CHECK-NEXT: _d_a[i] += _r0; //CHECK-NEXT: } //CHECK-NEXT: } @@ -217,10 +211,8 @@ double func4(double x) { //CHECK-NEXT: { //CHECK-NEXT: sum = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_sum; -//CHECK-NEXT: int _grad1 = 0; -//CHECK-NEXT: addArr_pullback(arr, 3, _r_d0, _d_arr, &_grad1); -//CHECK-NEXT: clad::array _r0(_d_arr); -//CHECK-NEXT: int _r1 = _grad1; +//CHECK-NEXT: int _r0 = 0; +//CHECK-NEXT: addArr_pullback(arr, 3, _r_d0, _d_arr, &_r0); //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: { @@ -279,11 +271,9 @@ double func5(int k) { //CHECK-NEXT: { //CHECK-NEXT: sum = clad::pop(_t3); //CHECK-NEXT: double _r_d1 = _d_sum; -//CHECK-NEXT: int _grad1 = 0; -//CHECK-NEXT: addArr_pullback(arr, n, _r_d1, _d_arr, &_grad1); -//CHECK-NEXT: clad::array _r0(_d_arr); -//CHECK-NEXT: int _r1 = _grad1; -//CHECK-NEXT: _d_n += _r1; +//CHECK-NEXT: int _r0 = 0; +//CHECK-NEXT: addArr_pullback(arr, n, _r_d1, _d_arr, &_r0); +//CHECK-NEXT: _d_n += _r0; //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: for (; _t0; _t0--) { @@ -316,14 +306,12 @@ double func6(double seed) { //CHECK-NEXT: clad::array _d_arr(3UL); //CHECK-NEXT: clad::array arr(3UL); //CHECK-NEXT: clad::tape _t2 = {}; -//CHECK-NEXT: clad::tape > _t3 = {}; //CHECK-NEXT: double sum = 0; //CHECK-NEXT: _t0 = 0; //CHECK-NEXT: for (i = 0; i < 3; i++) { //CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, arr) , arr = {seed, seed * i, seed + i}; //CHECK-NEXT: clad::push(_t2, sum); -//CHECK-NEXT: clad::push(_t3, arr); //CHECK-NEXT: sum += addArr(arr, 3); //CHECK-NEXT: } //CHECK-NEXT: goto _label0; @@ -334,12 +322,8 @@ double func6(double seed) { //CHECK-NEXT: { //CHECK-NEXT: sum = clad::pop(_t2); //CHECK-NEXT: double _r_d0 = _d_sum; -//CHECK-NEXT: clad::array _r1 = clad::pop(_t3); -//CHECK-NEXT: arr = _r1; -//CHECK-NEXT: int _grad1 = 0; -//CHECK-NEXT: addArr_pullback(_r1, 3, _r_d0, _d_arr, &_grad1); -//CHECK-NEXT: clad::array _r0 = _d_arr; -//CHECK-NEXT: int _r2 = _grad1; +//CHECK-NEXT: int _r0 = 0; +//CHECK-NEXT: addArr_pullback(arr, 3, _r_d0, _d_arr, &_r0); //CHECK-NEXT: } //CHECK-NEXT: { //CHECK-NEXT: * _d_seed += _d_arr[0]; @@ -387,14 +371,12 @@ double func7(double *params) { //CHECK-NEXT: clad::array _d_paramsPrime(1UL); //CHECK-NEXT: clad::array paramsPrime(1UL); //CHECK-NEXT: clad::tape _t2 = {}; -//CHECK-NEXT: clad::tape > _t3 = {}; //CHECK-NEXT: double out = 0.; //CHECK-NEXT: _t0 = 0; //CHECK-NEXT: for (i = 0; i < 1; ++i) { //CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, paramsPrime) , paramsPrime = {params[0]}; //CHECK-NEXT: clad::push(_t2, out); -//CHECK-NEXT: clad::push(_t3, paramsPrime); //CHECK-NEXT: out = out + inv_square(paramsPrime); //CHECK-NEXT: } //CHECK-NEXT: goto _label0; @@ -407,10 +389,7 @@ double func7(double *params) { //CHECK-NEXT: double _r_d0 = _d_out; //CHECK-NEXT: _d_out -= _r_d0; //CHECK-NEXT: _d_out += _r_d0; -//CHECK-NEXT: clad::array _r1 = clad::pop(_t3); -//CHECK-NEXT: paramsPrime = _r1; -//CHECK-NEXT: inv_square_pullback(_r1, _r_d0, _d_paramsPrime); -//CHECK-NEXT: clad::array _r0 = _d_paramsPrime; +//CHECK-NEXT: inv_square_pullback(paramsPrime, _r_d0, _d_paramsPrime); //CHECK-NEXT: } //CHECK-NEXT: { //CHECK-NEXT: _d_params[0] += _d_paramsPrime[0]; @@ -445,21 +424,19 @@ double func8(double i, double *arr, int n) { //CHECK-NEXT: double _d_res = 0; //CHECK-NEXT: double _t0; //CHECK-NEXT: double _t1; -//CHECK-NEXT: double *_t2; -//CHECK-NEXT: double _t3; +//CHECK-NEXT: double _t2; //CHECK-NEXT: double res = 0; //CHECK-NEXT: _t0 = arr[0]; //CHECK-NEXT: arr[0] = 1; //CHECK-NEXT: _t1 = res; -//CHECK-NEXT: _t2 = arr; //CHECK-NEXT: res = helper2(i, arr, n); -//CHECK-NEXT: _t3 = arr[0]; +//CHECK-NEXT: _t2 = arr[0]; //CHECK-NEXT: arr[0] = 5; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: _d_res += 1; //CHECK-NEXT: { -//CHECK-NEXT: arr[0] = _t3; +//CHECK-NEXT: arr[0] = _t2; //CHECK-NEXT: double _r_d2 = _d_arr[0]; //CHECK-NEXT: _d_arr[0] -= _r_d2; //CHECK-NEXT: } @@ -467,15 +444,11 @@ double func8(double i, double *arr, int n) { //CHECK-NEXT: res = _t1; //CHECK-NEXT: double _r_d1 = _d_res; //CHECK-NEXT: _d_res -= _r_d1; -//CHECK-NEXT: arr = _t2; -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: int _grad2 = 0; -//CHECK-NEXT: helper2_pullback(i, _t2, n, _r_d1, &_grad0, _d_arr, &_grad2); -//CHECK-NEXT: double _r0 = _grad0; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: int _r1 = 0; +//CHECK-NEXT: helper2_pullback(i, arr, n, _r_d1, &_r0, _d_arr, &_r1); //CHECK-NEXT: * _d_i += _r0; -//CHECK-NEXT: double *_r1 = _d_arr; -//CHECK-NEXT: int _r2 = _grad2; -//CHECK-NEXT: * _d_n += _r2; +//CHECK-NEXT: * _d_n += _r1; //CHECK-NEXT: } //CHECK-NEXT: { //CHECK-NEXT: arr[0] = _t0; @@ -534,13 +507,11 @@ double func9(double i, double j) { //CHECK-NEXT: for (; _t0; _t0--) { //CHECK-NEXT: --idx; //CHECK-NEXT: { -//CHECK-NEXT: double _r1 = clad::pop(_t1); -//CHECK-NEXT: arr[idx] = _r1; -//CHECK-NEXT: double _grad1 = 0.; -//CHECK-NEXT: modify_pullback(_r1, i, &_d_arr[idx], &_grad1); -//CHECK-NEXT: double _r0 = _d_arr[idx]; -//CHECK-NEXT: double _r2 = _grad1; -//CHECK-NEXT: * _d_i += _r2; +//CHECK-NEXT: double _r0 = clad::pop(_t1); +//CHECK-NEXT: arr[idx] = _r0; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: modify_pullback(_r0, i, &_d_arr[idx], &_r1); +//CHECK-NEXT: * _d_i += _r1; //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: } @@ -598,10 +569,9 @@ double func10(double *arr, int n) { //CHECK-NEXT: { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; -//CHECK-NEXT: double _r1 = clad::pop(_t2); -//CHECK-NEXT: arr[i] = _r1; -//CHECK-NEXT: sq_pullback(_r1, _r_d0, &_d_arr[i]); -//CHECK-NEXT: double _r0 = _d_arr[i]; +//CHECK-NEXT: double _r0 = clad::pop(_t2); +//CHECK-NEXT: arr[i] = _r0; +//CHECK-NEXT: sq_pullback(_r0, _r_d0, &_d_arr[i]); //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: } diff --git a/test/CUDA/GradientCuda.cu b/test/CUDA/GradientCuda.cu index f351c38ed..ac0a40765 100644 --- a/test/CUDA/GradientCuda.cu +++ b/test/CUDA/GradientCuda.cu @@ -57,19 +57,16 @@ __device__ __host__ double gauss(double* x, double* p, double sigma, int dim) { //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: double _grad1 = 0.; -//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(2 * 3.1415926535897931, -dim / 2., 1 * _t4 * _t5, &_grad0, &_grad1); -//CHECK-NEXT: double _r1 = _grad0; -//CHECK-NEXT: double _r2 = _grad1; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: double _r2 = 0; +//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(2 * 3.1415926535897931, -dim / 2., 1 * _t4 * _t5, &_r1, &_r2); //CHECK-NEXT: _d_dim += -_r2 / 2.; -//CHECK-NEXT: double _grad2 = 0.; -//CHECK-NEXT: double _grad3 = 0.; -//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(sigma, -0.5, _t6 * 1 * _t4, &_grad2, &_grad3); -//CHECK-NEXT: double _r3 = _grad2; +//CHECK-NEXT: double _r3 = 0; +//CHECK-NEXT: double _r4 = 0; +//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(sigma, -0.5, _t6 * 1 * _t4, &_r3, &_r4); //CHECK-NEXT: _d_sigma += _r3; -//CHECK-NEXT: double _r4 = _grad3; -//CHECK-NEXT: double _r5 = _t6 * _t5 * 1 * clad::custom_derivatives::exp_pushforward(t, 1.).pushforward; +//CHECK-NEXT: double _r5 = 0; +//CHECK-NEXT: _r5 += _t6 * _t5 * 1 * clad::custom_derivatives::exp_pushforward(t, 1.).pushforward; //CHECK-NEXT: _d_t += _r5; //CHECK-NEXT: } //CHECK-NEXT: { diff --git a/test/ErrorEstimation/BasicOps.C b/test/ErrorEstimation/BasicOps.C index 75c4e1f96..12f0abb86 100644 --- a/test/ErrorEstimation/BasicOps.C +++ b/test/ErrorEstimation/BasicOps.C @@ -150,12 +150,10 @@ float func4(float x, float y) { return std::pow(x, y); } //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: float _grad0 = 0.F; -//CHECK-NEXT: float _grad1 = 0.F; -//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &_grad0, &_grad1); -//CHECK-NEXT: float _r0 = _grad0; +//CHECK-NEXT: float _r0 = 0; +//CHECK-NEXT: float _r1 = 0; +//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &_r0, &_r1); //CHECK-NEXT: * _d_x += _r0; -//CHECK-NEXT: float _r1 = _grad1; //CHECK-NEXT: * _d_y += _r1; //CHECK-NEXT: } //CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); @@ -186,7 +184,8 @@ float func5(float x, float y) { //CHECK-NEXT: y = _t0; //CHECK-NEXT: float _r_d0 = * _d_y; //CHECK-NEXT: * _d_y -= _r_d0; -//CHECK-NEXT: float _r0 = _r_d0 * clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, 1.F).pushforward; +//CHECK-NEXT: float _r0 = 0; +//CHECK-NEXT: _r0 += _r_d0 * clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, 1.F).pushforward; //CHECK-NEXT: * _d_x += _r0; //CHECK-NEXT: } //CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); @@ -228,13 +227,11 @@ float func6(float x, float y) { //CHECK-NEXT: _d_z += z * 1; //CHECK-NEXT: } //CHECK-NEXT: { -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: double _grad1 = 0.; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: double _r1 = 0; //CHECK-NEXT: double _t0 = 0; -//CHECK-NEXT: helper_pullback(x, y, _d_z, &_grad0, &_grad1, _t0); -//CHECK-NEXT: double _r0 = _grad0; +//CHECK-NEXT: helper_pullback(x, y, _d_z, &_r0, &_r1, _t0); //CHECK-NEXT: * _d_x += _r0; -//CHECK-NEXT: double _r1 = _grad1; //CHECK-NEXT: * _d_y += _r1; //CHECK-NEXT: _final_error += _t0; //CHECK-NEXT: } @@ -301,9 +298,8 @@ float func8(float x, float y) { //CHECK-NEXT: x = _t1; //CHECK-NEXT: double _t2 = 0; //CHECK-NEXT: helper2_pullback(_t1, _r_d0, &* _d_x, _t2); -//CHECK-NEXT: float _r0 = * _d_x; //CHECK-NEXT: _final_error += _t2; -//CHECK-NEXT: _final_error += std::abs(_r0 * _t1 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * _t1 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); @@ -340,30 +336,25 @@ float func9(float x, float y) { //CHECK-NEXT: x = _t5; //CHECK-NEXT: double _t6 = 0; //CHECK-NEXT: helper2_pullback(_t5, _r_d0 * _t4, &* _d_x, _t6); -//CHECK-NEXT: float _r3 = * _d_x; //CHECK-NEXT: y = _t8; //CHECK-NEXT: double _t9 = 0; //CHECK-NEXT: helper2_pullback(_t8, _t7 * _r_d0, &* _d_y, _t9); -//CHECK-NEXT: float _r4 = * _d_y; //CHECK-NEXT: _final_error += _t6 + _t9; -//CHECK-NEXT: _final_error += std::abs(_r4 * _t8 * {{.+}}); -//CHECK-NEXT: _final_error += std::abs(_r3 * _t5 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_y * _t8 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * _t5 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: { -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: double _grad1 = 0.; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: double _r1 = 0; //CHECK-NEXT: double _t0 = 0; -//CHECK-NEXT: helper_pullback(x, y, _d_z, &_grad0, &_grad1, _t0); -//CHECK-NEXT: double _r0 = _grad0; +//CHECK-NEXT: helper_pullback(x, y, _d_z, &_r0, &_r1, _t0); //CHECK-NEXT: * _d_x += _r0; -//CHECK-NEXT: double _r1 = _grad1; //CHECK-NEXT: * _d_y += _r1; //CHECK-NEXT: x = _t1; //CHECK-NEXT: double _t2 = 0; //CHECK-NEXT: helper2_pullback(_t1, _d_z, &* _d_x, _t2); -//CHECK-NEXT: float _r2 = * _d_x; //CHECK-NEXT: _final_error += _t0 + _t2; -//CHECK-NEXT: _final_error += std::abs(_r2 * _t1 * {{.+}}); +//CHECK-NEXT: _final_error += std::abs(* _d_x * _t1 * {{.+}}); //CHECK-NEXT: } //CHECK-NEXT: _final_error += std::abs(* _d_x * x * {{.+}}); //CHECK-NEXT: _final_error += std::abs(* _d_y * y * {{.+}}); diff --git a/test/FirstDerivative/BuiltinDerivatives.C b/test/FirstDerivative/BuiltinDerivatives.C index 0144478af..0c8356389 100644 --- a/test/FirstDerivative/BuiltinDerivatives.C +++ b/test/FirstDerivative/BuiltinDerivatives.C @@ -105,12 +105,10 @@ void f7_grad(float x, clad::array_ref _d_x); // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: double _grad1 = 0.; -// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2., 1, &_grad0, &_grad1); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2., 1, &_r0, &_r1); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: double _r1 = _grad1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -130,12 +128,10 @@ void f8_grad(float x, clad::array_ref _d_x); // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: int _grad1 = 0; -// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2, 1, &_grad0, &_grad1); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: int _r1 = 0; +// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2, 1, &_r0, &_r1); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: int _r1 = _grad1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -156,12 +152,10 @@ void f9_grad(float x, float y, clad::array_ref _d_x, clad::array_ref _d_x, clad::array_ref // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: int _grad1 = 0; -// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &_grad0, &_grad1); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: int _r1 = 0; +// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, y, 1, &_r0, &_r1); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: int _r1 = _grad1; // CHECK-NEXT: * _d_y += _r1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -203,24 +195,18 @@ double f11(double x, double y) { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: int _grad1 = 0; -// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback((1. - x), 2, 1, &_grad0, &_grad1); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: int _r1 = 0; +// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback((1. - x), 2, 1, &_r0, &_r1); // CHECK-NEXT: * _d_x += -_r0; -// CHECK-NEXT: int _r1 = _grad1; -// CHECK-NEXT: double _grad4 = 0.; -// CHECK-NEXT: int _grad5 = 0; -// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(y - std::pow(x, 2), 2, 100. * 1, &_grad4, &_grad5); -// CHECK-NEXT: double _r2 = _grad4; +// CHECK-NEXT: double _r2 = 0; +// CHECK-NEXT: int _r5 = 0; +// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(y - std::pow(x, 2), 2, 100. * 1, &_r2, &_r5); // CHECK-NEXT: * _d_y += _r2; -// CHECK-NEXT: double _grad2 = 0.; -// CHECK-NEXT: int _grad3 = 0; -// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2, -_r2, &_grad2, &_grad3); -// CHECK-NEXT: double _r3 = _grad2; +// CHECK-NEXT: double _r3 = 0; +// CHECK-NEXT: int _r4 = 0; +// CHECK-NEXT: {{(clad::)?}}custom_derivatives{{(::std)?}}::pow_pullback(x, 2, -_r2, &_r3, &_r4); // CHECK-NEXT: * _d_x += _r3; -// CHECK-NEXT: int _r4 = _grad3; -// CHECK-NEXT: int _r5 = _grad5; // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/Gradient/Assignments.C b/test/Gradient/Assignments.C index f209372ab..22d12f64a 100644 --- a/test/Gradient/Assignments.C +++ b/test/Gradient/Assignments.C @@ -792,15 +792,12 @@ double f19(double a, double b) { //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: double _grad1 = 0.; -//CHECK-NEXT: double _grad2 = 0.; -//CHECK-NEXT: clad::custom_derivatives::fma_pullback(a, b, b, 1, &_grad0, &_grad1, &_grad2); -//CHECK-NEXT: double _r0 = _grad0; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: double _r2 = 0; +//CHECK-NEXT: clad::custom_derivatives::fma_pullback(a, b, b, 1, &_r0, &_r1, &_r2); //CHECK-NEXT: * _d_a += _r0; -//CHECK-NEXT: double _r1 = _grad1; //CHECK-NEXT: * _d_b += _r1; -//CHECK-NEXT: double _r2 = _grad2; //CHECK-NEXT: * _d_b += _r2; //CHECK-NEXT: } //CHECK-NEXT: } diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 94ca8615c..0af0c6ccd 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -40,9 +40,8 @@ double fn1(float i) { // CHECK-NEXT: * _d_i += res * _d_a; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: constantFn_pullback(i, _d_res, &_grad0); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: constantFn_pullback(i, _d_res, &_r0); // CHECK-NEXT: * _d_i += _r0; // CHECK-NEXT: } // CHECK-NEXT: } @@ -116,8 +115,6 @@ double fn2(double i, double j) { // CHECK-NEXT: i = _t4; // CHECK-NEXT: j = _t5; // CHECK-NEXT: modify1_pullback(_t4, _t5, _r_d1, &* _d_i, &* _d_j); -// CHECK-NEXT: double _r2 = * _d_i; -// CHECK-NEXT: double _r3 = * _d_j; // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: temp = _t0; @@ -126,8 +123,6 @@ double fn2(double i, double j) { // CHECK-NEXT: i = _t1; // CHECK-NEXT: j = _t2; // CHECK-NEXT: modify1_pullback(_t1, _t2, _r_d0, &* _d_i, &* _d_j); -// CHECK-NEXT: double _r0 = * _d_i; -// CHECK-NEXT: double _r1 = * _d_j; // CHECK-NEXT: } // CHECK-NEXT: } @@ -179,15 +174,11 @@ double fn3(double i, double j) { // CHECK-NEXT: i = _t2; // CHECK-NEXT: j = _t3; // CHECK-NEXT: update1_pullback(_t2, _t3, &* _d_i, &* _d_j); -// CHECK-NEXT: double _r2 = * _d_i; -// CHECK-NEXT: double _r3 = * _d_j; // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: i = _t0; // CHECK-NEXT: j = _t1; // CHECK-NEXT: update1_pullback(_t0, _t1, &* _d_i, &* _d_j); -// CHECK-NEXT: double _r0 = * _d_i; -// CHECK-NEXT: double _r1 = * _d_j; // CHECK-NEXT: } // CHECK-NEXT: } @@ -260,50 +251,44 @@ double fn4(double* arr, int n) { // CHECK: void fn4_grad(double *arr, int n, clad::array_ref _d_arr, clad::array_ref _d_n) { // CHECK-NEXT: double _d_res = 0; // CHECK-NEXT: double _t0; -// CHECK-NEXT: double *_t1; -// CHECK-NEXT: unsigned long _t2; +// CHECK-NEXT: unsigned long _t1; // CHECK-NEXT: int _d_i = 0; // CHECK-NEXT: int i = 0; +// CHECK-NEXT: clad::tape _t2 = {}; // CHECK-NEXT: clad::tape _t3 = {}; -// CHECK-NEXT: clad::tape _t4 = {}; // CHECK-NEXT: double res = 0; // CHECK-NEXT: _t0 = res; -// CHECK-NEXT: _t1 = arr; // CHECK-NEXT: res += sum(arr, n); -// CHECK-NEXT: _t2 = 0; +// CHECK-NEXT: _t1 = 0; // CHECK-NEXT: for (i = 0; i < n; ++i) { -// CHECK-NEXT: _t2++; -// CHECK-NEXT: clad::push(_t3, arr[i]); +// CHECK-NEXT: _t1++; +// CHECK-NEXT: clad::push(_t2, arr[i]); // CHECK-NEXT: twice(arr[i]); -// CHECK-NEXT: clad::push(_t4, res); +// CHECK-NEXT: clad::push(_t3, res); // CHECK-NEXT: res += arr[i]; // CHECK-NEXT: } // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (; _t2; _t2--) { +// CHECK-NEXT: for (; _t1; _t1--) { // CHECK-NEXT: --i; // CHECK-NEXT: { -// CHECK-NEXT: res = clad::pop(_t4); +// CHECK-NEXT: res = clad::pop(_t3); // CHECK-NEXT: double _r_d1 = _d_res; // CHECK-NEXT: _d_arr[i] += _r_d1; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: double _r3 = clad::pop(_t3); -// CHECK-NEXT: arr[i] = _r3; -// CHECK-NEXT: twice_pullback(_r3, &_d_arr[i]); -// CHECK-NEXT: double _r2 = _d_arr[i]; +// CHECK-NEXT: double _r1 = clad::pop(_t2); +// CHECK-NEXT: arr[i] = _r1; +// CHECK-NEXT: twice_pullback(_r1, &_d_arr[i]); // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: res = _t0; // CHECK-NEXT: double _r_d0 = _d_res; -// CHECK-NEXT: arr = _t1; -// CHECK-NEXT: int _grad1 = 0; -// CHECK-NEXT: sum_pullback(_t1, n, _r_d0, _d_arr, &_grad1); -// CHECK-NEXT: double *_r0 = _d_arr; -// CHECK-NEXT: int _r1 = _grad1; -// CHECK-NEXT: * _d_n += _r1; +// CHECK-NEXT: int _r0 = 0; +// CHECK-NEXT: sum_pullback(arr, n, _r_d0, _d_arr, &_r0); +// CHECK-NEXT: * _d_n += _r0; // CHECK-NEXT: } // CHECK-NEXT: } @@ -334,18 +319,12 @@ double fn5(double* arr, int n) { } // CHECK: void fn5_grad(double *arr, int n, clad::array_ref _d_arr, clad::array_ref _d_n) { -// CHECK-NEXT: double *_t0; // CHECK-NEXT: double _d_temp = 0; -// CHECK-NEXT: _t0 = arr; // CHECK-NEXT: double temp = modify2(arr); // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: _d_arr[0] += 1; -// CHECK-NEXT: { -// CHECK-NEXT: arr = _t0; -// CHECK-NEXT: modify2_pullback(_t0, _d_temp, _d_arr); -// CHECK-NEXT: double *_r0 = _d_arr; -// CHECK-NEXT: } +// CHECK-NEXT: modify2_pullback(arr, _d_temp, _d_arr); // CHECK-NEXT: } double fn6(double i=0, double j=0) { @@ -445,12 +424,10 @@ double fn7(double i, double j) { // CHECK-NEXT: { // CHECK-NEXT: j = _t2; // CHECK-NEXT: identity_pullback(_t2, 0, &* _d_j); -// CHECK-NEXT: double _r1 = * _d_j; // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: i = _t0; // CHECK-NEXT: identity_pullback(_t0, 0, &* _d_i); -// CHECK-NEXT: double _r0 = * _d_i; // CHECK-NEXT: } // CHECK-NEXT: } @@ -479,22 +456,18 @@ double fn8(double x, double y) { // CHECK: void fn8_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { // CHECK-NEXT: double _t0; // CHECK-NEXT: double _t1; -// CHECK-NEXT: double _t3; -// CHECK-NEXT: _t3 = check_and_return(x, 'a', "aa"); +// CHECK-NEXT: double _t2; +// CHECK-NEXT: _t2 = check_and_return(x, 'a', "aa"); // CHECK-NEXT: _t1 = std::tanh(1.); // CHECK-NEXT: _t0 = std::max(1., 2.); // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: char _grad1 = 0i8; -// CHECK-NEXT: clad::array_ref _t2 = {"", 3UL}; -// CHECK-NEXT: check_and_return_pullback(x, 'a', "aa", 1 * _t0 * _t1 * y, &_grad0, &_grad1, _t2); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: char _r1 = 0; +// CHECK-NEXT: check_and_return_pullback(x, 'a', "aa", 1 * _t0 * _t1 * y, &_r0, &_r1, ""); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: char _r1 = _grad1; -// CHECK-NEXT: const char *_r2 = ""; -// CHECK-NEXT: * _d_y += _t3 * 1 * _t0 * _t1; +// CHECK-NEXT: * _d_y += _t2 * 1 * _t0 * _t1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -524,12 +497,10 @@ double fn9(double x, double y) { // CHECK-NEXT: _label0: // CHECK-NEXT: { // CHECK-NEXT: y = _t0; -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: custom_max_pullback(x * y, _t0, 1, &_grad0, &* _d_y); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: custom_max_pullback(x * y, _t0, 1, &_r0, &* _d_y); // CHECK-NEXT: * _d_x += _r0 * y; // CHECK-NEXT: * _d_y += x * _r0; -// CHECK-NEXT: double _r1 = * _d_y; // CHECK-NEXT: } // CHECK-NEXT: } @@ -570,32 +541,25 @@ double fn10(double x, double y) { // CHECK-NEXT: double _r_d2 = _d_out; // CHECK-NEXT: _d_out -= _r_d2; // CHECK-NEXT: out = _t5; -// CHECK-NEXT: double _grad5 = 0.; -// CHECK-NEXT: double _grad6 = 0.; -// CHECK-NEXT: clad::custom_derivatives::std::clamp_pullback(_t5, 3., 7., _r_d2, &_d_out, &_grad5, &_grad6); -// CHECK-NEXT: double _r4 = _d_out; -// CHECK-NEXT: double _r5 = _grad5; -// CHECK-NEXT: double _r6 = _grad6; +// CHECK-NEXT: double _r2 = 0; +// CHECK-NEXT: double _r3 = 0; +// CHECK-NEXT: clad::custom_derivatives::std::clamp_pullback(_t5, 3., 7., _r_d2, &_d_out, &_r2, &_r3); // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: out = _t2; // CHECK-NEXT: double _r_d1 = _d_out; // CHECK-NEXT: _d_out -= _r_d1; // CHECK-NEXT: out = _t3; -// CHECK-NEXT: double _grad3 = 0.; -// CHECK-NEXT: clad::custom_derivatives::std::min_pullback(_t3, 10., _r_d1, &_d_out, &_grad3); -// CHECK-NEXT: double _r2 = _d_out; -// CHECK-NEXT: double _r3 = _grad3; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: clad::custom_derivatives::std::min_pullback(_t3, 10., _r_d1, &_d_out, &_r1); // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: out = _t0; // CHECK-NEXT: double _r_d0 = _d_out; // CHECK-NEXT: _d_out -= _r_d0; // CHECK-NEXT: out = _t1; -// CHECK-NEXT: double _grad1 = 0.; -// CHECK-NEXT: clad::custom_derivatives::std::max_pullback(_t1, 0., _r_d0, &_d_out, &_grad1); -// CHECK-NEXT: double _r0 = _d_out; -// CHECK-NEXT: double _r1 = _grad1; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: clad::custom_derivatives::std::max_pullback(_t1, 0., _r_d0, &_d_out, &_r0); // CHECK-NEXT: } // CHECK-NEXT: * _d_x += _d_out; // CHECK-NEXT: } @@ -636,8 +600,6 @@ double fn11(double x, double y) { // CHECK-NEXT: x = _t0; // CHECK-NEXT: y = _t1; // CHECK-NEXT: clad::custom_derivatives::n1::sum_pullback(_t0, _t1, 1, &* _d_x, &* _d_y); -// CHECK-NEXT: double _r0 = * _d_x; -// CHECK-NEXT: double _r1 = * _d_y; // CHECK-NEXT: } // CHECK-NEXT: } @@ -656,16 +618,9 @@ double fn12(double x, double y) { } // CHECK: void fn12_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { -// CHECK-NEXT: double *_t0; -// CHECK-NEXT: _t0 = &x; // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: -// CHECK-NEXT: { -// CHECK-NEXT: do_nothing_pullback(_t0, nullptr, 0, 1, &* _d_x, nullptr, 0); -// CHECK-NEXT: double *_r0 = &* _d_x; -// CHECK-NEXT: {{(std::)?}}nullptr_t _r1 = nullptr; -// CHECK-NEXT: double *_r2 = 0; -// CHECK-NEXT: } +// CHECK-NEXT: do_nothing_pullback(&x, nullptr, 0, 1, &* _d_x, nullptr, 0); // CHECK-NEXT: } double multiply(double* a, double* b) { @@ -693,8 +648,6 @@ double fn13(double* x, const double* w) { // CHECK-NEXT: std::size_t _d_i = 0; // CHECK-NEXT: std::size_t i = 0; // CHECK-NEXT: clad::tape _t1 = {}; -// CHECK-NEXT: double *_t2; -// CHECK-NEXT: double *_t3; // CHECK-NEXT: double wCopy[2]; // CHECK-NEXT: _t0 = 0; // CHECK-NEXT: for (i = 0; i < 2; ++i) { @@ -702,16 +655,9 @@ double fn13(double* x, const double* w) { // CHECK-NEXT: clad::push(_t1, wCopy[i]); // CHECK-NEXT: wCopy[i] = w[i]; // CHECK-NEXT: } -// CHECK-NEXT: _t2 = x; -// CHECK-NEXT: _t3 = wCopy + 1; // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: -// CHECK-NEXT: { -// CHECK-NEXT: x = _t2; -// CHECK-NEXT: multiply_pullback(_t2, _t3, 1, _d_x, _d_wCopy.ptr_ref() + 1); -// CHECK-NEXT: double *_r0 = _d_x; -// CHECK-NEXT: double *_r1 = _d_wCopy.ptr_ref() + 1; -// CHECK-NEXT: } +// CHECK-NEXT: multiply_pullback(x, wCopy + 1, 1, _d_x, _d_wCopy.ptr_ref() + 1); // CHECK-NEXT: for (; _t0; _t0--) { // CHECK-NEXT: --i; // CHECK-NEXT: { @@ -744,11 +690,9 @@ double fn14(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: x = _t0; -// CHECK-NEXT: double _grad1 = 0.; -// CHECK-NEXT: emptyFn_pullback(_t0, y, &* _d_x, &_grad1); -// CHECK-NEXT: double _r0 = * _d_x; -// CHECK-NEXT: double _r1 = _grad1; -// CHECK-NEXT: * _d_y += _r1; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: emptyFn_pullback(_t0, y, &* _d_x, &_r0); +// CHECK-NEXT: * _d_y += _r0; // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/Gradient/Functors.C b/test/Gradient/Functors.C index d1979b2b9..3ce9aa6bb 100644 --- a/test/Gradient/Functors.C +++ b/test/Gradient/Functors.C @@ -246,12 +246,10 @@ int main() { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { - // CHECK-NEXT: double _grad0 = 0.; - // CHECK-NEXT: double _grad1 = 0.; - // CHECK-NEXT: _t0.operator_call_pullback(i, j, 1, &_d_E, &_grad0, &_grad1); - // CHECK-NEXT: double _r0 = _grad0; + // CHECK-NEXT: double _r0 = 0; + // CHECK-NEXT: double _r1 = 0; + // CHECK-NEXT: _t0.operator_call_pullback(i, j, 1, &_d_E, &_r0, &_r1); // CHECK-NEXT: * _d_i += _r0; - // CHECK-NEXT: double _r1 = _grad1; // CHECK-NEXT: * _d_j += _r1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -268,12 +266,10 @@ int main() { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { - // CHECK-NEXT: double _grad0 = 0.; - // CHECK-NEXT: double _grad1 = 0.; - // CHECK-NEXT: _t0.operator_call_pullback(i, j, 1, &(* _d_fn), &_grad0, &_grad1); - // CHECK-NEXT: double _r0 = _grad0; + // CHECK-NEXT: double _r0 = 0; + // CHECK-NEXT: double _r1 = 0; + // CHECK-NEXT: _t0.operator_call_pullback(i, j, 1, &(* _d_fn), &_r0, &_r1); // CHECK-NEXT: * _d_i += _r0; - // CHECK-NEXT: double _r1 = _grad1; // CHECK-NEXT: * _d_j += _r1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -291,12 +287,10 @@ int main() { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { - // CHECK-NEXT: double _grad0 = 0.; - // CHECK-NEXT: double _grad1 = 0.; - // CHECK-NEXT: _t0.operator_call_pullback(i, j, _d_y, &(* _d_fn), &_grad0, &_grad1); - // CHECK-NEXT: double _r0 = _grad0; + // CHECK-NEXT: double _r0 = 0; + // CHECK-NEXT: double _r1 = 0; + // CHECK-NEXT: _t0.operator_call_pullback(i, j, _d_y, &(* _d_fn), &_r0, &_r1); // CHECK-NEXT: * _d_i += _r0; - // CHECK-NEXT: double _r1 = _grad1; // CHECK-NEXT: * _d_j += _r1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -307,14 +301,11 @@ int main() { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { - // CHECK-NEXT: Experiment _grad0; - // CHECK-NEXT: double _grad1 = 0.; - // CHECK-NEXT: double _grad2 = 0.; - // CHECK-NEXT: FunctorAsArg_pullback(E, i, j, 1, &_grad0, &_grad1, &_grad2); - // CHECK-NEXT: Experiment _r0(_grad0); - // CHECK-NEXT: double _r1 = _grad1; + // CHECK-NEXT: Experiment _r0 = {}; + // CHECK-NEXT: double _r1 = 0; + // CHECK-NEXT: double _r2 = 0; + // CHECK-NEXT: FunctorAsArg_pullback(E, i, j, 1, &_r0, &_r1, &_r2); // CHECK-NEXT: * _d_i += _r1; - // CHECK-NEXT: double _r2 = _grad2; // CHECK-NEXT: * _d_j += _r2; // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index ac30263e1..588b0df78 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -444,24 +444,18 @@ void f_norm_grad(double x, //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: double _grad4 = 0.; -//CHECK-NEXT: double _grad5 = 0.; -//CHECK-NEXT: clad::custom_derivatives::pow_pullback(sum_of_powers(x, y, z, d), 1 / d, 1, &_grad4, &_grad5); -//CHECK-NEXT: double _r0 = _grad4; -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: double _grad1 = 0.; -//CHECK-NEXT: double _grad2 = 0.; -//CHECK-NEXT: double _grad3 = 0.; -//CHECK-NEXT: clad::custom_derivatives::sum_of_powers_pullback(x, y, z, d, _r0, &_grad0, &_grad1, &_grad2, &_grad3); -//CHECK-NEXT: double _r1 = _grad0; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: double _r5 = 0; +//CHECK-NEXT: clad::custom_derivatives::pow_pullback(sum_of_powers(x, y, z, d), 1 / d, 1, &_r0, &_r5); +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: double _r2 = 0; +//CHECK-NEXT: double _r3 = 0; +//CHECK-NEXT: double _r4 = 0; +//CHECK-NEXT: clad::custom_derivatives::sum_of_powers_pullback(x, y, z, d, _r0, &_r1, &_r2, &_r3, &_r4); //CHECK-NEXT: * _d_x += _r1; -//CHECK-NEXT: double _r2 = _grad1; //CHECK-NEXT: * _d_y += _r2; -//CHECK-NEXT: double _r3 = _grad2; //CHECK-NEXT: * _d_z += _r3; -//CHECK-NEXT: double _r4 = _grad3; //CHECK-NEXT: * _d_d += _r4; -//CHECK-NEXT: double _r5 = _grad5; //CHECK-NEXT: double _r6 = _r5 * -1 / (d * d); //CHECK-NEXT: * _d_d += _r6; //CHECK-NEXT: } @@ -478,9 +472,11 @@ void f_sin_grad(double x, double y, clad::array_ref _d_x, clad::array_re //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: double _r0 = 1 * (x + y) * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: _r0 += 1 * (x + y) * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward; //CHECK-NEXT: * _d_x += _r0; -//CHECK-NEXT: double _r1 = 1 * (x + y) * clad::custom_derivatives::sin_pushforward(y, 1.).pushforward; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: _r1 += 1 * (x + y) * clad::custom_derivatives::sin_pushforward(y, 1.).pushforward; //CHECK-NEXT: * _d_y += _r1; //CHECK-NEXT: * _d_x += _t0 * 1; //CHECK-NEXT: * _d_y += _t0 * 1; diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 82de66836..130fa803d 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -315,9 +315,8 @@ double f_sum_squares(double *p, int n) { //CHECK-NEXT: i--; //CHECK-NEXT: s = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_s; -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: sq_pullback(p[i], _r_d0, &_grad0); -//CHECK-NEXT: double _r0 = _grad0; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: sq_pullback(p[i], _r_d0, &_r0); //CHECK-NEXT: _d_p[i] += _r0; //CHECK-NEXT: } //CHECK-NEXT: } @@ -364,20 +363,21 @@ double f_log_gaus(double* x, double* p /*means*/, double n, double sigma) { //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: double _r8 = 1 * clad::custom_derivatives::log_pushforward(gaus, 1.).pushforward; +//CHECK-NEXT: double _r8 = 0; +//CHECK-NEXT: _r8 += 1 * clad::custom_derivatives::log_pushforward(gaus, 1.).pushforward; //CHECK-NEXT: _d_gaus += _r8; //CHECK-NEXT: } //CHECK-NEXT: { //CHECK-NEXT: double _r3 = _d_gaus * _t5 * -1. / (_t6 * _t6); -//CHECK-NEXT: double _r4 = _r3 * clad::custom_derivatives::sqrt_pushforward(_t7 * sigma, 1.).pushforward; -//CHECK-NEXT: double _grad2 = 0.; -//CHECK-NEXT: double _grad3 = 0.; -//CHECK-NEXT: clad::custom_derivatives::pow_pullback(2 * 3.1415926535897931, n, _r4 * sigma, &_grad2, &_grad3); -//CHECK-NEXT: double _r5 = _grad2; -//CHECK-NEXT: double _r6 = _grad3; +//CHECK-NEXT: double _r4 = 0; +//CHECK-NEXT: _r4 += _r3 * clad::custom_derivatives::sqrt_pushforward(_t7 * sigma, 1.).pushforward; +//CHECK-NEXT: double _r5 = 0; +//CHECK-NEXT: double _r6 = 0; +//CHECK-NEXT: clad::custom_derivatives::pow_pullback(2 * 3.1415926535897931, n, _r4 * sigma, &_r5, &_r6); //CHECK-NEXT: _d_n += _r6; //CHECK-NEXT: _d_sigma += _t7 * _r4; -//CHECK-NEXT: double _r7 = 1. / _t6 * _d_gaus * clad::custom_derivatives::exp_pushforward(power, 1.).pushforward; +//CHECK-NEXT: double _r7 = 0; +//CHECK-NEXT: _r7 += 1. / _t6 * _d_gaus * clad::custom_derivatives::exp_pushforward(power, 1.).pushforward; //CHECK-NEXT: _d_power += _r7; //CHECK-NEXT: } //CHECK-NEXT: { @@ -386,18 +386,16 @@ double f_log_gaus(double* x, double* p /*means*/, double n, double sigma) { //CHECK-NEXT: _d_power -= _r_d1; //CHECK-NEXT: _d_power += -_r_d1 / _t3; //CHECK-NEXT: double _r1 = _r_d1 * --power / (_t3 * _t3); -//CHECK-NEXT: double _grad1 = 0.; -//CHECK-NEXT: sq_pullback(sigma, 2 * _r1, &_grad1); -//CHECK-NEXT: double _r2 = _grad1; +//CHECK-NEXT: double _r2 = 0; +//CHECK-NEXT: sq_pullback(sigma, 2 * _r1, &_r2); //CHECK-NEXT: _d_sigma += _r2; //CHECK-NEXT: } //CHECK-NEXT: for (; _t0; _t0--) { //CHECK-NEXT: i--; //CHECK-NEXT: power = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_power; -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: sq_pullback(x[i] - p[i], _r_d0, &_grad0); -//CHECK-NEXT: double _r0 = _grad0; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: sq_pullback(x[i] - p[i], _r_d0, &_r0); //CHECK-NEXT: _d_p[i] += -_r0; //CHECK-NEXT: } //CHECK-NEXT: } diff --git a/test/Gradient/MemberFunctions.C b/test/Gradient/MemberFunctions.C index c647ae4ff..159974336 100644 --- a/test/Gradient/MemberFunctions.C +++ b/test/Gradient/MemberFunctions.C @@ -537,9 +537,8 @@ double fn2(SimpleFunctions& sf, double i) { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: _t0.ref_mem_fn_pullback(i, 1, &(* _d_sf), &_grad0); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: _t0.ref_mem_fn_pullback(i, 1, &(* _d_sf), &_r0); // CHECK-NEXT: * _d_i += _r0; // CHECK-NEXT: } // CHECK-NEXT: } @@ -583,9 +582,8 @@ double fn5(SimpleFunctions& v, double value) { // CHECK-NEXT: _label0: // CHECK-NEXT: (* _d_v).x += 1; // CHECK-NEXT: { -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: _t0.operator_plus_equal_pullback(value, {}, &(* _d_v), &_grad0); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: _t0.operator_plus_equal_pullback(value, {}, &(* _d_v), &_r0); // CHECK-NEXT: * _d_value += _r0; // CHECK-NEXT: } // CHECK-NEXT: } @@ -723,12 +721,10 @@ int main() { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: double _grad1 = 0.; -// CHECK-NEXT: _t0.mem_fn_pullback(i, j, 1, &_d_sf, &_grad0, &_grad1); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: _t0.mem_fn_pullback(i, j, 1, &_d_sf, &_r0, &_r1); // CHECK-NEXT: * _d_i += _r0; -// CHECK-NEXT: double _r1 = _grad1; // CHECK-NEXT: * _d_j += _r1; // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index 4484df6ae..075ba5c03 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -131,16 +131,13 @@ double fn2(Tangent t, double i) { // CHECK-NEXT: { // CHECK-NEXT: res = _t1; // CHECK-NEXT: double _r_d0 = _d_res; -// CHECK-NEXT: clad::array_ref _t2 = {(* _d_t).data, 5UL}; -// CHECK-NEXT: sum_pullback(t.data, _r_d0, _t2); -// CHECK-NEXT: clad::array _r1({(* _d_t).data, 5UL}); +// CHECK-NEXT: sum_pullback(t.data, _r_d0, (* _d_t).data); // CHECK-NEXT: * _d_i += _r_d0; // CHECK-NEXT: (* _d_t).data[0] += 2 * _r_d0; // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: t = _t0; // CHECK-NEXT: sum_pullback(_t0, _d_res, &(* _d_t)); -// CHECK-NEXT: Tangent _r0 = (* _d_t); // CHECK-NEXT: } // CHECK-NEXT: } @@ -167,7 +164,6 @@ double fn3(double i, double j) { // CHECK-NEXT: { // CHECK-NEXT: t = _t2; // CHECK-NEXT: sum_pullback(_t2, 1, &_d_t); -// CHECK-NEXT: Tangent _r0 = _d_t; // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: t.data[1] = _t1; @@ -245,12 +241,10 @@ double fn5(const Tangent& t, double i) { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: double _grad1 = 0.; -// CHECK-NEXT: _t0.someMemFn2_pullback(i, i, 1, &(* _d_t), &_grad0, &_grad1); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: _t0.someMemFn2_pullback(i, i, 1, &(* _d_t), &_r0, &_r1); // CHECK-NEXT: * _d_i += _r0; -// CHECK-NEXT: double _r1 = _grad1; // CHECK-NEXT: * _d_i += _r1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -320,9 +314,8 @@ double fn6(dcomplex c, double i) { // CHECK-NEXT: * _d_i += 6 * _d_res; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: {{.*}} _grad0 = 0.; -// CHECK-NEXT: _t0.real_pullback(5 * i, &(* _d_c), &_grad0); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: _t0.real_pullback(5 * i, &(* _d_c), &_r0); // CHECK-NEXT: * _d_i += 5 * _r0; // CHECK-NEXT: } // CHECK-NEXT: } @@ -355,9 +348,8 @@ double fn7(dcomplex c1, dcomplex c2) { // CHECK-NEXT: _t6.imag_pullback(3 * 1, &(* _d_c1)); // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: {{.*}} _grad0 = 0.; -// CHECK-NEXT: _t3.real_pullback(c2.imag() + 5 * _t1, &(* _d_c1), &_grad0); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: _t3.real_pullback(c2.imag() + 5 * _t1, &(* _d_c1), &_r0); // CHECK-NEXT: _t0.imag_pullback(_r0, &(* _d_c2)); // CHECK-NEXT: _t2.real_pullback(5 * _r0, &(* _d_c2)); // CHECK-NEXT: } @@ -401,12 +393,10 @@ double fn8(Tangent t, dcomplex c) { // CHECK-NEXT: { // CHECK-NEXT: t = _t2; // CHECK-NEXT: sum_pullback(_t2, 1, &(* _d_t)); -// CHECK-NEXT: Tangent _r1 = (* _d_t); // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: _t1.updateTo_pullback(c.real(), &(* _d_t), &_grad0); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: _t1.updateTo_pullback(c.real(), &(* _d_t), &_r0); // CHECK-NEXT: _t0.real_pullback(_r0, &(* _d_c)); // CHECK-NEXT: } // CHECK-NEXT: } @@ -451,7 +441,6 @@ double fn9(Tangent t, dcomplex c) { // CHECK-NEXT: double _r_d1 = _d_res; // CHECK-NEXT: t = _t6; // CHECK-NEXT: sum_pullback(_t6, _r_d1, &(* _d_t)); -// CHECK-NEXT: Tangent _r2 = (* _d_t); // CHECK-NEXT: } // CHECK-NEXT: for (; _t0; _t0--) { // CHECK-NEXT: --i; diff --git a/test/Hessian/BuiltinDerivatives.C b/test/Hessian/BuiltinDerivatives.C index 7587a86c1..dc48fcc21 100644 --- a/test/Hessian/BuiltinDerivatives.C +++ b/test/Hessian/BuiltinDerivatives.C @@ -27,9 +27,11 @@ float f1(float x) { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: float _r0 = _d_y.value * clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: _r0 += _d_y.value * clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, 1.F).pushforward; // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: float _r1 = _d_y.pushforward * d_x * clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: _r1 += _d_y.pushforward * d_x * clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, 1.F).pushforward; // CHECK-NEXT: * _d_x += _r1; // CHECK-NEXT: * _d_d_x += _t0 * _d_y.pushforward; // CHECK-NEXT: } @@ -41,9 +43,11 @@ float f1(float x) { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: float _r0 = _d_y.value * clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: _r0 += _d_y.value * clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, 1.F).pushforward; // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: float _r1 = -1 * _d_y.pushforward * d_x * clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: _r1 += -1 * _d_y.pushforward * d_x * clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, 1.F).pushforward; // CHECK-NEXT: * _d_x += _r1; // CHECK-NEXT: * _d_d_x += -1 * _t0 * _d_y.pushforward; // CHECK-NEXT: } @@ -63,21 +67,17 @@ float f1(float x) { // CHECK-NEXT: _d__t1.pushforward += 1; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: float _grad2 = 0.F; -// CHECK-NEXT: float _grad3 = 0.F; -// CHECK-NEXT: cos_pushforward_pullback(x, _d_x0, _d__t1, &_grad2, &_grad3); -// CHECK-NEXT: float _r2 = _grad2; +// CHECK-NEXT: float _r2 = 0; +// CHECK-NEXT: float _r3 = 0; +// CHECK-NEXT: cos_pushforward_pullback(x, _d_x0, _d__t1, &_r2, &_r3); // CHECK-NEXT: * _d_x += _r2; -// CHECK-NEXT: float _r3 = _grad3; // CHECK-NEXT: _d__d_x += _r3; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: float _grad1 = 0.F; -// CHECK-NEXT: sin_pushforward_pullback(x, _d_x0, _d__t0, &_grad0, &_grad1); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: sin_pushforward_pullback(x, _d_x0, _d__t0, &_r0, &_r1); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: float _r1 = _grad1; // CHECK-NEXT: _d__d_x += _r1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -102,9 +102,11 @@ float f2(float x) { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: float _r0 = _d_y.value * clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: _r0 += _d_y.value * clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, 1.F).pushforward; // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: float _r1 = _d_y.pushforward * d_x * clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: _r1 += _d_y.pushforward * d_x * clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, 1.F).pushforward; // CHECK-NEXT: * _d_x += _r1; // CHECK-NEXT: * _d_d_x += _t0 * _d_y.pushforward; // CHECK-NEXT: } @@ -119,12 +121,10 @@ float f2(float x) { // CHECK-NEXT: _label0: // CHECK-NEXT: _d__t0.pushforward += 1; // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: float _grad1 = 0.F; -// CHECK-NEXT: exp_pushforward_pullback(x, _d_x0, _d__t0, &_grad0, &_grad1); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: exp_pushforward_pullback(x, _d_x0, _d__t0, &_r0, &_r1); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: float _r1 = _grad1; // CHECK-NEXT: _d__d_x += _r1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -148,7 +148,8 @@ float f3(float x) { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: float _r0 = _d_y.value * clad::custom_derivatives{{(::std)?}}::log_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: _r0 += _d_y.value * clad::custom_derivatives{{(::std)?}}::log_pushforward(x, 1.F).pushforward; // CHECK-NEXT: * _d_x += _r0; // CHECK-NEXT: double _r1 = _d_y.pushforward * d_x * -1. / (x * x); // CHECK-NEXT: * _d_x += _r1; @@ -165,12 +166,10 @@ float f3(float x) { // CHECK-NEXT: _label0: // CHECK-NEXT: _d__t0.pushforward += 1; // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: float _grad1 = 0.F; -// CHECK-NEXT: log_pushforward_pullback(x, _d_x0, _d__t0, &_grad0, &_grad1); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: log_pushforward_pullback(x, _d_x0, _d__t0, &_r0, &_r1); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: float _r1 = _grad1; // CHECK-NEXT: _d__d_x += _r1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -217,35 +216,30 @@ float f4(float x) { // CHECK-NEXT: if (_cond0) { // CHECK-NEXT: derivative = _t1; // CHECK-NEXT: float _r_d0 = _d_derivative; -// CHECK-NEXT: float _grad4 = 0.F; -// CHECK-NEXT: float _grad5 = 0.F; -// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(x, exponent, _r_d0 * d_exponent * _t2, &_grad4, &_grad5); -// CHECK-NEXT: float _r4 = _grad4; +// CHECK-NEXT: float _r4 = 0; +// CHECK-NEXT: float _r5 = 0; +// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(x, exponent, _r_d0 * d_exponent * _t2, &_r4, &_r5); // CHECK-NEXT: * _d_x += _r4; -// CHECK-NEXT: float _r5 = _grad5; // CHECK-NEXT: * _d_exponent += _r5; -// CHECK-NEXT: float _r6 = _t3 * _r_d0 * d_exponent * clad::custom_derivatives{{(::std)?}}::log_pushforward(x, 1.F).pushforward; +// CHECK-NEXT: float _r6 = 0; +// CHECK-NEXT: _r6 += _t3 * _r_d0 * d_exponent * clad::custom_derivatives{{(::std)?}}::log_pushforward(x, 1.F).pushforward; // CHECK-NEXT: * _d_x += _r6; // CHECK-NEXT: * _d_d_exponent += (_t3 * _t2) * _r_d0; // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: * _d_exponent += _d_derivative * d_x * _t0; -// CHECK-NEXT: float _grad2 = 0.F; -// CHECK-NEXT: float _grad3 = 0.F; -// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(x, exponent - 1, exponent * _d_derivative * d_x, &_grad2, &_grad3); -// CHECK-NEXT: float _r2 = _grad2; +// CHECK-NEXT: float _r2 = 0; +// CHECK-NEXT: float _r3 = 0; +// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(x, exponent - 1, exponent * _d_derivative * d_x, &_r2, &_r3); // CHECK-NEXT: * _d_x += _r2; -// CHECK-NEXT: float _r3 = _grad3; // CHECK-NEXT: * _d_exponent += _r3; // CHECK-NEXT: * _d_d_x += (exponent * _t0) * _d_derivative; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: float _grad1 = 0.F; -// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(x, exponent, _d_val, &_grad0, &_grad1); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(x, exponent, _d_val, &_r0, &_r1); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: float _r1 = _grad1; // CHECK-NEXT: * _d_exponent += _r1; // CHECK-NEXT: } // CHECK-NEXT: } @@ -259,17 +253,13 @@ float f4(float x) { // CHECK-NEXT: _label0: // CHECK-NEXT: _d__t0.pushforward += 1; // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: float _grad1 = 0.F; -// CHECK-NEXT: float _grad2 = 0.F; -// CHECK-NEXT: float _grad3 = 0.F; -// CHECK-NEXT: pow_pushforward_pullback(x, 4.F, _d_x0, 0.F, _d__t0, &_grad0, &_grad1, &_grad2, &_grad3); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: float _r2 = 0; +// CHECK-NEXT: float _r3 = 0; +// CHECK-NEXT: pow_pushforward_pullback(x, 4.F, _d_x0, 0.F, _d__t0, &_r0, &_r1, &_r2, &_r3); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: float _r1 = _grad1; -// CHECK-NEXT: float _r2 = _grad2; // CHECK-NEXT: _d__d_x += _r2; -// CHECK-NEXT: float _r3 = _grad3; // CHECK-NEXT: } // CHECK-NEXT: } @@ -297,16 +287,12 @@ float f5(float x) { // CHECK-NEXT: _label0: // CHECK-NEXT: _d__t0.pushforward += 1; // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: float _grad1 = 0.F; -// CHECK-NEXT: float _grad2 = 0.F; -// CHECK-NEXT: float _grad3 = 0.F; -// CHECK-NEXT: pow_pushforward_pullback(2.F, x, 0.F, _d_x0, _d__t0, &_grad0, &_grad1, &_grad2, &_grad3); -// CHECK-NEXT: float _r0 = _grad0; -// CHECK-NEXT: float _r1 = _grad1; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: float _r2 = 0; +// CHECK-NEXT: float _r3 = 0; +// CHECK-NEXT: pow_pushforward_pullback(2.F, x, 0.F, _d_x0, _d__t0, &_r0, &_r1, &_r2, &_r3); // CHECK-NEXT: * _d_x += _r1; -// CHECK-NEXT: float _r2 = _grad2; -// CHECK-NEXT: float _r3 = _grad3; // CHECK-NEXT: _d__d_x += _r3; // CHECK-NEXT: } // CHECK-NEXT: } @@ -338,18 +324,14 @@ float f6(float x, float y) { // CHECK-NEXT: _label0: // CHECK-NEXT: _d__t0.pushforward += 1; // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: float _grad1 = 0.F; -// CHECK-NEXT: float _grad2 = 0.F; -// CHECK-NEXT: float _grad3 = 0.F; -// CHECK-NEXT: pow_pushforward_pullback(x, y, _d_x0, _d_y0, _d__t0, &_grad0, &_grad1, &_grad2, &_grad3); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: float _r2 = 0; +// CHECK-NEXT: float _r3 = 0; +// CHECK-NEXT: pow_pushforward_pullback(x, y, _d_x0, _d_y0, _d__t0, &_r0, &_r1, &_r2, &_r3); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: float _r1 = _grad1; // CHECK-NEXT: * _d_y += _r1; -// CHECK-NEXT: float _r2 = _grad2; // CHECK-NEXT: _d__d_x += _r2; -// CHECK-NEXT: float _r3 = _grad3; // CHECK-NEXT: _d__d_y += _r3; // CHECK-NEXT: } // CHECK-NEXT: } @@ -372,18 +354,14 @@ float f6(float x, float y) { // CHECK-NEXT: _label0: // CHECK-NEXT: _d__t0.pushforward += 1; // CHECK-NEXT: { -// CHECK-NEXT: float _grad0 = 0.F; -// CHECK-NEXT: float _grad1 = 0.F; -// CHECK-NEXT: float _grad2 = 0.F; -// CHECK-NEXT: float _grad3 = 0.F; -// CHECK-NEXT: pow_pushforward_pullback(x, y, _d_x0, _d_y0, _d__t0, &_grad0, &_grad1, &_grad2, &_grad3); -// CHECK-NEXT: float _r0 = _grad0; +// CHECK-NEXT: float _r0 = 0; +// CHECK-NEXT: float _r1 = 0; +// CHECK-NEXT: float _r2 = 0; +// CHECK-NEXT: float _r3 = 0; +// CHECK-NEXT: pow_pushforward_pullback(x, y, _d_x0, _d_y0, _d__t0, &_r0, &_r1, &_r2, &_r3); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: float _r1 = _grad1; // CHECK-NEXT: * _d_y += _r1; -// CHECK-NEXT: float _r2 = _grad2; // CHECK-NEXT: _d__d_x += _r2; -// CHECK-NEXT: float _r3 = _grad3; // CHECK-NEXT: _d__d_y += _r3; // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/Hessian/NestedFunctionCalls.C b/test/Hessian/NestedFunctionCalls.C index dea80f1d0..189a50ad1 100644 --- a/test/Hessian/NestedFunctionCalls.C +++ b/test/Hessian/NestedFunctionCalls.C @@ -67,18 +67,14 @@ double f2(double x, double y){ // CHECK-NEXT: _d__t0.value += _d_ans0; // CHECK-NEXT: _d__t0.pushforward += _d__d_ans; // CHECK-NEXT: { -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: double _grad1 = 0.; -// CHECK-NEXT: double _grad2 = 0.; -// CHECK-NEXT: double _grad3 = 0.; -// CHECK-NEXT: f_pushforward_pullback(x, y, _d_x0, _d_y0, _d__t0, &_grad0, &_grad1, &_grad2, &_grad3); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: double _r2 = 0; +// CHECK-NEXT: double _r3 = 0; +// CHECK-NEXT: f_pushforward_pullback(x, y, _d_x0, _d_y0, _d__t0, &_r0, &_r1, &_r2, &_r3); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: double _r1 = _grad1; // CHECK-NEXT: * _d_y += _r1; -// CHECK-NEXT: double _r2 = _grad2; // CHECK-NEXT: _d__d_x += _r2; -// CHECK-NEXT: double _r3 = _grad3; // CHECK-NEXT: _d__d_y += _r3; // CHECK-NEXT: } // CHECK-NEXT: } @@ -109,18 +105,14 @@ double f2(double x, double y){ // CHECK-NEXT: _d__t0.value += _d_ans0; // CHECK-NEXT: _d__t0.pushforward += _d__d_ans; // CHECK-NEXT: { -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: double _grad1 = 0.; -// CHECK-NEXT: double _grad2 = 0.; -// CHECK-NEXT: double _grad3 = 0.; -// CHECK-NEXT: f_pushforward_pullback(x, y, _d_x0, _d_y0, _d__t0, &_grad0, &_grad1, &_grad2, &_grad3); -// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: double _r2 = 0; +// CHECK-NEXT: double _r3 = 0; +// CHECK-NEXT: f_pushforward_pullback(x, y, _d_x0, _d_y0, _d__t0, &_r0, &_r1, &_r2, &_r3); // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: double _r1 = _grad1; // CHECK-NEXT: * _d_y += _r1; -// CHECK-NEXT: double _r2 = _grad2; // CHECK-NEXT: _d__d_x += _r2; -// CHECK-NEXT: double _r3 = _grad3; // CHECK-NEXT: _d__d_y += _r3; // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/Jacobian/FunctionCalls.C b/test/Jacobian/FunctionCalls.C index e4d714728..ea3966847 100644 --- a/test/Jacobian/FunctionCalls.C +++ b/test/Jacobian/FunctionCalls.C @@ -20,21 +20,17 @@ void fn1(double i, double j, double* output) { // CHECK-NEXT: output[0] = std::pow(i, j); // CHECK-NEXT: output[1] = std::pow(j, i); // CHECK-NEXT: { -// CHECK-NEXT: double _jac2 = 0.; -// CHECK-NEXT: double _jac3 = 0.; -// CHECK-NEXT: clad::custom_derivatives::pow_pullback(j, i, 1, &_jac2, &_jac3); -// CHECK-NEXT: double _r2 = _jac2; +// CHECK-NEXT: double _r2 = 0; +// CHECK-NEXT: double _r3 = 0; +// CHECK-NEXT: clad::custom_derivatives::pow_pullback(j, i, 1, &_r2, &_r3); // CHECK-NEXT: jacobianMatrix[3UL] += _r2; -// CHECK-NEXT: double _r3 = _jac3; // CHECK-NEXT: jacobianMatrix[2UL] += _r3; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: double _jac0 = 0.; -// CHECK-NEXT: double _jac1 = 0.; -// CHECK-NEXT: clad::custom_derivatives::pow_pullback(i, j, 1, &_jac0, &_jac1); -// CHECK-NEXT: double _r0 = _jac0; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: clad::custom_derivatives::pow_pullback(i, j, 1, &_r0, &_r1); // CHECK-NEXT: jacobianMatrix[0UL] += _r0; -// CHECK-NEXT: double _r1 = _jac1; // CHECK-NEXT: jacobianMatrix[1UL] += _r1; // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/Jacobian/Jacobian.C b/test/Jacobian/Jacobian.C index f7862e56c..a3902a88b 100644 --- a/test/Jacobian/Jacobian.C +++ b/test/Jacobian/Jacobian.C @@ -64,15 +64,18 @@ void f_3_jac(double x, double y, double z, double *_result, double *jacobianMatr //CHECK-NEXT: _t2 = sin(z); //CHECK-NEXT: _result[2] = sin(z) * constant; //CHECK-NEXT: { -//CHECK-NEXT: double _r2 = 1 * constant * clad::custom_derivatives::sin_pushforward(z, 1.).pushforward; +//CHECK-NEXT: double _r2 = 0; +//CHECK-NEXT: _r2 += 1 * constant * clad::custom_derivatives::sin_pushforward(z, 1.).pushforward; //CHECK-NEXT: jacobianMatrix[8UL] += _r2; //CHECK-NEXT: } //CHECK-NEXT: { -//CHECK-NEXT: double _r1 = 1 * constant * clad::custom_derivatives::sin_pushforward(y, 1.).pushforward; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: _r1 += 1 * constant * clad::custom_derivatives::sin_pushforward(y, 1.).pushforward; //CHECK-NEXT: jacobianMatrix[4UL] += _r1; //CHECK-NEXT: } //CHECK-NEXT: { -//CHECK-NEXT: double _r0 = 1 * constant * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: _r0 += 1 * constant * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward; //CHECK-NEXT: jacobianMatrix[0UL] += _r0; //CHECK-NEXT: } //CHECK-NEXT:} @@ -109,30 +112,24 @@ void f_4_jac(double x, double y, double z, double *_result, double *jacobianMatr //CHECK-NEXT: _t2 = multiply(z, x); //CHECK-NEXT: _result[2] = multiply(z, x) * constant; //CHECK-NEXT: { -//CHECK-NEXT: double _jac4 = 0.; -//CHECK-NEXT: double _jac5 = 0.; -//CHECK-NEXT: multiply_pullback(z, x, 1 * constant, &_jac4, &_jac5); -//CHECK-NEXT: double _r4 = _jac4; +//CHECK-NEXT: double _r4 = 0; +//CHECK-NEXT: double _r5 = 0; +//CHECK-NEXT: multiply_pullback(z, x, 1 * constant, &_r4, &_r5); //CHECK-NEXT: jacobianMatrix[8UL] += _r4; -//CHECK-NEXT: double _r5 = _jac5; //CHECK-NEXT: jacobianMatrix[6UL] += _r5; //CHECK-NEXT: } //CHECK-NEXT: { -//CHECK-NEXT: double _jac2 = 0.; -//CHECK-NEXT: double _jac3 = 0.; -//CHECK-NEXT: multiply_pullback(y, z, 1 * constant, &_jac2, &_jac3); -//CHECK-NEXT: double _r2 = _jac2; +//CHECK-NEXT: double _r2 = 0; +//CHECK-NEXT: double _r3 = 0; +//CHECK-NEXT: multiply_pullback(y, z, 1 * constant, &_r2, &_r3); //CHECK-NEXT: jacobianMatrix[4UL] += _r2; -//CHECK-NEXT: double _r3 = _jac3; //CHECK-NEXT: jacobianMatrix[5UL] += _r3; //CHECK-NEXT: } //CHECK-NEXT: { -//CHECK-NEXT: double _jac0 = 0.; -//CHECK-NEXT: double _jac1 = 0.; -//CHECK-NEXT: multiply_pullback(x, y, 1 * constant, &_jac0, &_jac1); -//CHECK-NEXT: double _r0 = _jac0; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: multiply_pullback(x, y, 1 * constant, &_r0, &_r1); //CHECK-NEXT: jacobianMatrix[0UL] += _r0; -//CHECK-NEXT: double _r1 = _jac1; //CHECK-NEXT: jacobianMatrix[1UL] += _r1; //CHECK-NEXT: } //CHECK-NEXT:} diff --git a/test/Misc/RunDemos.C b/test/Misc/RunDemos.C index 24dbf467c..f9a774e65 100644 --- a/test/Misc/RunDemos.C +++ b/test/Misc/RunDemos.C @@ -236,15 +236,12 @@ //CHECK_GRADIENT_DESCENT-NEXT: * _d_y += -(f_x - y) * 1; //CHECK_GRADIENT_DESCENT-NEXT: } //CHECK_GRADIENT_DESCENT-NEXT: { -//CHECK_GRADIENT_DESCENT-NEXT: double _grad0 = 0.; -//CHECK_GRADIENT_DESCENT-NEXT: double _grad1 = 0.; -//CHECK_GRADIENT_DESCENT-NEXT: double _grad2 = 0.; -//CHECK_GRADIENT_DESCENT-NEXT: f_pullback(theta_0, theta_1, x, _d_f_x, &_grad0, &_grad1, &_grad2); -//CHECK_GRADIENT_DESCENT-NEXT: double _r0 = _grad0; +//CHECK_GRADIENT_DESCENT-NEXT: double _r0 = 0; +//CHECK_GRADIENT_DESCENT-NEXT: double _r1 = 0; +//CHECK_GRADIENT_DESCENT-NEXT: double _r2 = 0; +//CHECK_GRADIENT_DESCENT-NEXT: f_pullback(theta_0, theta_1, x, _d_f_x, &_r0, &_r1, &_r2); //CHECK_GRADIENT_DESCENT-NEXT: * _d_theta_0 += _r0; -//CHECK_GRADIENT_DESCENT-NEXT: double _r1 = _grad1; //CHECK_GRADIENT_DESCENT-NEXT: * _d_theta_1 += _r1; -//CHECK_GRADIENT_DESCENT-NEXT: double _r2 = _grad2; //CHECK_GRADIENT_DESCENT-NEXT: * _d_x += _r2; //CHECK_GRADIENT_DESCENT-NEXT: } //CHECK_GRADIENT_DESCENT-NEXT: } diff --git a/test/NestedCalls/NestedCalls.C b/test/NestedCalls/NestedCalls.C index 9bbf831b0..b0b90c01b 100644 --- a/test/NestedCalls/NestedCalls.C +++ b/test/NestedCalls/NestedCalls.C @@ -52,15 +52,15 @@ double f(double x, double y) { //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: sq_pullback(std::sin(x), _d_y, &_grad0); -//CHECK-NEXT: double _r0 = _grad0; -//CHECK-NEXT: double _r1 = _r0 * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: sq_pullback(std::sin(x), _d_y, &_r0); +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: _r1 += _r0 * clad::custom_derivatives::sin_pushforward(x, 1.).pushforward; //CHECK-NEXT: * _d_x += _r1; -//CHECK-NEXT: double _grad1 = 0.; -//CHECK-NEXT: sq_pullback(std::cos(x), _d_y, &_grad1); -//CHECK-NEXT: double _r2 = _grad1; -//CHECK-NEXT: double _r3 = _r2 * clad::custom_derivatives::cos_pushforward(x, 1.).pushforward; +//CHECK-NEXT: double _r2 = 0; +//CHECK-NEXT: sq_pullback(std::cos(x), _d_y, &_r2); +//CHECK-NEXT: double _r3 = 0; +//CHECK-NEXT: _r3 += _r2 * clad::custom_derivatives::cos_pushforward(x, 1.).pushforward; //CHECK-NEXT: * _d_x += _r3; //CHECK-NEXT: } //CHECK-NEXT: } @@ -75,9 +75,8 @@ double f(double x, double y) { //CHECK-NEXT: * _d_y += t * 1; //CHECK-NEXT: } //CHECK-NEXT: { -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: one_pullback(x, _d_t, &_grad0); -//CHECK-NEXT: double _r0 = _grad0; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: one_pullback(x, _d_t, &_r0); //CHECK-NEXT: * _d_x += _r0; //CHECK-NEXT: } //CHECK-NEXT: } diff --git a/test/NumericalDiff/GradientMultiArg.C b/test/NumericalDiff/GradientMultiArg.C index 5270965d0..a7a831bee 100644 --- a/test/NumericalDiff/GradientMultiArg.C +++ b/test/NumericalDiff/GradientMultiArg.C @@ -18,15 +18,17 @@ double test_1(double x, double y){ // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: double _grad0 = 0.; -// CHECK-NEXT: double _grad1 = 0.; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: double _r1 = 0; // CHECK-NEXT: clad::tape > _t0 = {}; +// CHECK-NEXT: double _grad0 = 0; // CHECK-NEXT: clad::push(_t0, &_grad0); +// CHECK-NEXT: double _grad1 = 0; // CHECK-NEXT: clad::push(_t0, &_grad1); // CHECK-NEXT: numerical_diff::central_difference(std::hypot, _t0, 0, x, y); -// CHECK-NEXT: double _r0 = 1 * _grad0; +// CHECK-NEXT: _r0 += 1 * _grad0; +// CHECK-NEXT: _r1 += 1 * _grad1; // CHECK-NEXT: * _d_x += _r0; -// CHECK-NEXT: double _r1 = 1 * _grad1; // CHECK-NEXT: * _d_y += _r1; // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/NumericalDiff/NoNumDiff.C b/test/NumericalDiff/NoNumDiff.C index 76749de86..10007647a 100644 --- a/test/NumericalDiff/NoNumDiff.C +++ b/test/NumericalDiff/NoNumDiff.C @@ -19,8 +19,7 @@ double func(double x) { return std::tanh(x); } //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: double _r0; +//CHECK-NEXT: double _r0 = 0; //CHECK-NEXT: * _d_x += _r0; //CHECK-NEXT: } //CHECK-NEXT: } diff --git a/test/NumericalDiff/NumDiff.C b/test/NumericalDiff/NumDiff.C index 2860ab4ab..2542cb23b 100644 --- a/test/NumericalDiff/NumDiff.C +++ b/test/NumericalDiff/NumDiff.C @@ -16,7 +16,8 @@ double test_1(double x){ //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: double _r0 = 1 * numerical_diff::forward_central_difference(tanh, x, 0, 0, x); +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: _r0 += 1 * numerical_diff::forward_central_difference(tanh, x, 0, 0, x); //CHECK-NEXT: * _d_x += _r0; //CHECK-NEXT: } //CHECK-NEXT: } @@ -53,15 +54,17 @@ double test_3(double x) { //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: double _grad0 = 0.; -//CHECK-NEXT: double _grad1 = 0.; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: double _r1 = 0; //CHECK-NEXT: clad::tape > _t0 = {}; +//CHECK-NEXT: double _grad0 = 0; //CHECK-NEXT: clad::push(_t0, &_grad0); +//CHECK-NEXT: double _grad1 = 0; //CHECK-NEXT: clad::push(_t0, &_grad1); //CHECK-NEXT: numerical_diff::central_difference(std::hypot, _t0, 0, x, constant); -//CHECK-NEXT: double _r0 = 1 * _grad0; +//CHECK-NEXT: _r0 += 1 * _grad0; +//CHECK-NEXT: _r1 += 1 * _grad1; //CHECK-NEXT: * _d_x += _r0; -//CHECK-NEXT: double _r1 = 1 * _grad1; //CHECK-NEXT: _d_constant += _r1; //CHECK-NEXT: } //CHECK-NEXT: } diff --git a/test/NumericalDiff/PrintErrorNumDiff.C b/test/NumericalDiff/PrintErrorNumDiff.C index 4d1cf069e..b1d74f574 100644 --- a/test/NumericalDiff/PrintErrorNumDiff.C +++ b/test/NumericalDiff/PrintErrorNumDiff.C @@ -20,7 +20,8 @@ double test_1(double x){ //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: { -//CHECK-NEXT: double _r0 = 1 * numerical_diff::forward_central_difference(tanh, x, 0, 1, x); +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: _r0 += 1 * numerical_diff::forward_central_difference(tanh, x, 0, 1, x); //CHECK-NEXT: * _d_x += _r0; //CHECK-NEXT: } //CHECK-NEXT: } diff --git a/test/ROOT/TFormula.C b/test/ROOT/TFormula.C index e9ff33663..ffd17bade 100644 --- a/test/ROOT/TFormula.C +++ b/test/ROOT/TFormula.C @@ -47,9 +47,11 @@ void TFormula_example_grad_1(Double_t* x, Double_t* p, Double_t* _d_p); //CHECK-NEXT: _d_p[0] += x[0] * 1; //CHECK-NEXT: _d_p[1] += x[0] * 1; //CHECK-NEXT: _d_p[2] += x[0] * 1; -//CHECK-NEXT: Double_t _r0 = 1 * clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], 1.).pushforward; +//CHECK-NEXT: Double_t _r0 = 0; +//CHECK-NEXT: _r0 += 1 * clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], 1.).pushforward; //CHECK-NEXT: _d_p[0] += -_r0; -//CHECK-NEXT: Double_t _r1 = 1 * clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 1.).pushforward; +//CHECK-NEXT: Double_t _r1 = 0; +//CHECK-NEXT: _r1 += 1 * clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 1.).pushforward; //CHECK-NEXT: _d_p[1] += _r1; //CHECK-NEXT: } //CHECK-NEXT: } From 76246b7360e09327d6cd3e08a65b547f2e61eb4f Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 13 Mar 2024 12:51:23 +0200 Subject: [PATCH 02/10] Expect a SourceLocation assertion in ArrayInputsReverseMode.C --- test/Arrays/ArrayInputsReverseMode.C | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Arrays/ArrayInputsReverseMode.C b/test/Arrays/ArrayInputsReverseMode.C index 50ca4a889..bd7728641 100644 --- a/test/Arrays/ArrayInputsReverseMode.C +++ b/test/Arrays/ArrayInputsReverseMode.C @@ -4,6 +4,7 @@ // RUN: ./ArrayInputsReverseMode.out | FileCheck -check-prefix=CHECK-EXEC %s //CHECK-NOT: {{.*error|warning|note:.*}} +// XFAIL: asserts // XFAIL: target={{i586.*}} #include "clad/Differentiator/Differentiator.h" From f510256d4c9ecfaef932a411a16941dcfa3e4ab8 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 13 Mar 2024 13:20:10 +0200 Subject: [PATCH 03/10] Don't expect a SourceLocation assertion in UserDefinedTypes.C --- test/Gradient/UserDefinedTypes.C | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index 075ba5c03..40244ee73 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -2,7 +2,6 @@ // RUN: ./UserDefinedTypes.out | FileCheck -check-prefix=CHECK-EXEC %s // RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oUserDefinedTypes.out // RUN: ./UserDefinedTypes.out | FileCheck -check-prefix=CHECK-EXEC %s -// XFAIL: asserts // CHECK-NOT: {{.*error|warning|note:.*}} // XFAIL: target={{i586.*}} #include "clad/Differentiator/Differentiator.h" From bed0fbfe83bf1f2c5d42b6335d92db508e4b1af7 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 13 Mar 2024 13:33:20 +0200 Subject: [PATCH 04/10] Remove pointer arithmetic from VisitCallExpr by replacing pointers with iterators --- lib/Differentiator/ReverseModeVisitor.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index acae93e56..9471e898b 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1868,7 +1868,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Put Result array declaration in the function body. // Call the gradient, passing Result as the last Arg. Stmts& block = getCurrentBlock(direction::reverse); - Stmt** it = std::begin(block) + insertionPoint; + Stmts::iterator it = std::begin(block) + insertionPoint; // Insert PreCallStmts it = block.insert(it, PreCallStmts.begin(), PreCallStmts.end()); it += PreCallStmts.size(); @@ -1887,7 +1887,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, it++; } // Insert PostCallStmts - it = block.insert(it, PostCallStmts.begin(), PostCallStmts.end()); + block.insert(it, PostCallStmts.begin(), PostCallStmts.end()); } if (m_ExternalSource) m_ExternalSource->ActBeforeFinalizingVisitCallExpr( From 496e5e6f317f2d89faadbce5b3118224cf869725 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 13 Mar 2024 23:29:20 +0200 Subject: [PATCH 05/10] Address clang-tidy complaints --- lib/Differentiator/VisitorBase.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 32a5e99a2..86ea4e9cc 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -799,7 +799,7 @@ namespace clad { CXXScopeSpec CSS; CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc); LookupResult& Push = GetCladTapePush(); - auto PushDRE = + Expr* PushDRE = m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false).get(); for (unsigned i = 0, e = numArgs; i < e; i++) { QualType argTy = args[i]->getType(); @@ -808,7 +808,7 @@ namespace clad { Expr* PushExpr = BuildDeclRef(gradVar); if (!isCladArrayType(argTy)) PushExpr = BuildOp(UO_AddrOf, PushExpr); - Expr* callArgs[] = {TapeRef, PushExpr}; + std::array callArgs = {TapeRef, PushExpr}; Stmt* PushStmt = m_Sema .ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, noLoc) From 9648f438ba32deda79871e147dc63e9a33285db8 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 13 Mar 2024 23:54:16 +0200 Subject: [PATCH 06/10] Improve test coverage of VisitCallExpr --- test/Gradient/FunctionCalls.C | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 0af0c6ccd..013dbc7b7 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -696,6 +696,25 @@ double fn14(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn15(double x, double y) { + A::constantFn(y += x); + return y; +} + +//CHECK: void fn15_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { +//CHECK-NEXT: double _t0; +//CHECK-NEXT: _t0 = y; +//CHECK-NEXT: A::constantFn(y += x); +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: * _d_y += 1; +//CHECK-NEXT: { +//CHECK-NEXT: y = _t0; +//CHECK-NEXT: double _r_d0 = * _d_y; +//CHECK-NEXT: * _d_x += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } + template void reset(T* arr, int n) { for (int i=0; i Date: Thu, 14 Mar 2024 01:30:18 +0200 Subject: [PATCH 07/10] Fix recursive call differentiation and add a test --- lib/Differentiator/ReverseModeVisitor.cpp | 12 +++--- test/Gradient/FunctionCalls.C | 46 +++++++++++++++++++++++ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9471e898b..166b90f5f 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1790,18 +1790,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!OverloadedDerivedFn) { if (FD == m_Function && m_Mode == DiffMode::experimental_pullback) { // Recursive call. - auto* selfRef = + Expr* selfRef = m_Sema .BuildDeclarationNameExpr( CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative) .get(); - OverloadedDerivedFn = - m_Sema - .ActOnCallExpr(getCurrentScope(), selfRef, noLoc, - llvm::MutableArrayRef(DerivedCallArgs), - noLoc) - .get(); + OverloadedDerivedFn = m_Sema + .ActOnCallExpr(getCurrentScope(), selfRef, + noLoc, pullbackCallArgs, noLoc) + .get(); } else { if (m_ExternalSource) m_ExternalSource->ActBeforeDifferentiatingCallExpr( diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 013dbc7b7..03fb4045b 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -715,6 +715,50 @@ double fn15(double x, double y) { //CHECK-NEXT: } //CHECK-NEXT: } +double recFun (double x, double y) { + if (x > y) + return recFun(x-1, y); + return x * y; +} + +//CHECK: void recFun_pullback(double x, double y, double _d_y0, clad::array_ref _d_x, clad::array_ref _d_y) { +//CHECK-NEXT: bool _cond0; +//CHECK-NEXT: _cond0 = x > y; +//CHECK-NEXT: if (_cond0) +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: goto _label1; +//CHECK-NEXT: _label1: +//CHECK-NEXT: { +//CHECK-NEXT: * _d_x += _d_y0 * y; +//CHECK-NEXT: * _d_y += x * _d_y0; +//CHECK-NEXT: } +//CHECK-NEXT: if (_cond0) +//CHECK-NEXT: _label0: +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: recFun_pullback(x - 1, y, _d_y0, &_r0, &_r1); +//CHECK-NEXT: * _d_x += _r0; +//CHECK-NEXT: * _d_y += _r1; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double fn16(double x, double y) { + return recFun(x, y); +} + +//CHECK: void fn16_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: recFun_pullback(x, y, 1, &_r0, &_r1); +//CHECK-NEXT: * _d_x += _r0; +//CHECK-NEXT: * _d_y += _r1; +//CHECK-NEXT: } +//CHECK-NEXT: } + template void reset(T* arr, int n) { for (int i=0; i Date: Thu, 14 Mar 2024 12:16:30 +0200 Subject: [PATCH 08/10] Move the checks of call arguments for being temporary exprs to a routine --- include/clad/Differentiator/CladUtils.h | 9 +++++---- lib/Differentiator/CladUtils.cpp | 8 ++++++-- lib/Differentiator/ReverseModeVisitor.cpp | 7 ++----- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 56519593d..ffc90be12 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -174,11 +174,12 @@ namespace clad { /// otherwise returns false. bool HasAnyReferenceOrPointerArgument(const clang::FunctionDecl* FD); - /// Returns true if `T` is a reference, pointer or array type. + /// Returns true if `arg` is an argument passed by reference or is of + /// pointer/array type. /// - /// \note Please note that this function returns true for array types as - /// well. - bool IsReferenceOrPointerType(clang::QualType T); + /// \note Please note that this function returns false for temporary + /// expressions. + bool IsReferenceOrPointerArg(const clang::Expr* arg); /// Returns true if `T1` and `T2` have same cononical type; otherwise /// returns false. diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index d7d599ce5..6144ae23a 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -301,8 +301,12 @@ namespace clad { return false; } - bool IsReferenceOrPointerType(QualType T) { - return T->isReferenceType() || isArrayOrPointerType(T); + bool IsReferenceOrPointerArg(const Expr* arg) { + // The argument is passed by reference if it's passed as an L-value. + // However, if arg is a MaterializeTemporaryExpr, then arg is a + // temporary variable passed as a const reference. + bool isRefType = arg->isLValue() && !isa(arg); + return isRefType || isArrayOrPointerType(arg->getType()); } bool SameCanonicalType(clang::QualType T1, clang::QualType T2) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 166b90f5f..1962971e1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1537,8 +1537,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. - if (utils::IsReferenceOrPointerType(PVD->getType()) && - !isa(arg)) { + if (utils::IsReferenceOrPointerArg(arg)) { argDiff = Visit(arg); CallArgDx.push_back(argDiff.getExpr_dx()); } else { @@ -1723,9 +1722,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (auto* argDerivative : CallArgDx) { Expr* gradArgExpr = nullptr; const Expr* arg = CE->getArg(idx); - const auto* PVD = FD->getParamDecl(idx); - if (utils::IsReferenceOrPointerType(PVD->getType()) && - !isa(arg)) { + if (utils::IsReferenceOrPointerArg(arg)) { if (argDerivative) { if (utils::isArrayOrPointerType(argDerivative->getType()) || isCladArrayType(argDerivative->getType()) || From 453043e91fafa1c7befdb2ee5277505b22e8595b Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 14 Mar 2024 13:08:44 +0200 Subject: [PATCH 09/10] Remove redundant checks from VisitCallExpr --- lib/Differentiator/ReverseModeVisitor.cpp | 26 +++++------------------ 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 1962971e1..0038a6bbd 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1722,21 +1722,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (auto* argDerivative : CallArgDx) { Expr* gradArgExpr = nullptr; const Expr* arg = CE->getArg(idx); - if (utils::IsReferenceOrPointerArg(arg)) { - if (argDerivative) { - if (utils::isArrayOrPointerType(argDerivative->getType()) || - isCladArrayType(argDerivative->getType()) || - !argDerivative->isLValue()) - gradArgExpr = argDerivative; - else - gradArgExpr = - BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative); - } - } else { - Expr* gradVarExpr = CallArgDx[idx]; + if (utils::isArrayOrPointerType(arg->getType()) || + isCladArrayType(argDerivative->getType())) + gradArgExpr = argDerivative; + else gradArgExpr = - BuildOp(UO_AddrOf, gradVarExpr, m_Function->getLocation()); - } + BuildOp(UO_AddrOf, argDerivative, m_Function->getLocation()); DerivedCallOutputArgs.push_back(gradArgExpr); idx++; } @@ -1745,13 +1736,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if ((pullback == nullptr) && FD->getReturnType()->isLValueReferenceType()) pullback = getZeroInit(FD->getReturnType().getNonReferenceType()); - // FIXME: Remove this restriction. - if (!FD->getReturnType()->isVoidType()) { - assert((pullback && !FD->getReturnType()->isVoidType()) && - "Call to function returning non-void type with no dfdx() is not " - "supported!"); - } - if (FD->getReturnType()->isVoidType()) { assert(pullback == nullptr && FD->getReturnType()->isVoidType() && "Call to function returning void type should not have any " From b340d05302d270d1b5afc410e7310fb97f7c5d08 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 18 Mar 2024 09:33:19 +0200 Subject: [PATCH 10/10] Move GetMultiArgCentralDiffCall to ReverseModeVisitor --- .../clad/Differentiator/ReverseModeVisitor.h | 22 ++++++++ include/clad/Differentiator/VisitorBase.h | 21 -------- lib/Differentiator/ReverseModeVisitor.cpp | 53 ++++++++++++++++++ lib/Differentiator/VisitorBase.cpp | 54 ------------------- 4 files changed, 75 insertions(+), 75 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 10564dbbf..6c7d1fe71 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -315,6 +315,28 @@ namespace clad { CladTapeResult MakeCladTapeFor(clang::Expr* E, llvm::StringRef prefix = "_t"); + /// A function to get the multi-argument "central_difference" + /// call expression for the given arguments. + /// + /// \param[in] targetFuncCall The function to get the derivative for. + /// \param[in] retType The return type of the target call expression. + /// \param[in] dfdx The dfdx corresponding to this call expression. + /// \param[in] numArgs The total number of 'args'. + /// \param[in] PreCallStmts The built statements to add to block + /// before the call to the derived function. + /// \param[in] PostCallStmts The built statements to add to block + /// after the call to the derived function. + /// \param[in] args All the arguments to the target function. + /// \param[in] outputArgs The output gradient arguments. + /// + /// \returns The derivative function call. + clang::Expr* GetMultiArgCentralDiffCall( + clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs, + clang::Expr* dfdx, llvm::SmallVectorImpl& PreCallStmts, + llvm::SmallVectorImpl& PostCallStmts, + llvm::SmallVectorImpl& args, + llvm::SmallVectorImpl& outputArgs); + public: ReverseModeVisitor(DerivativeBuilder& builder); virtual ~ReverseModeVisitor(); diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index a588b2b97..9feeb6624 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -585,27 +585,6 @@ namespace clad { clang::Expr* GetSingleArgCentralDiffCall( clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos, unsigned numArgs, llvm::SmallVectorImpl& args); - /// A function to get the multi-argument "central_difference" - /// call expression for the given arguments. - /// - /// \param[in] targetFuncCall The function to get the derivative for. - /// \param[in] retType The return type of the target call expression. - /// \param[in] dfdx The dfdx corresponding to this call expression. - /// \param[in] numArgs The total number of 'args'. - /// \param[in] PreCallStmts The built statements to add to block - /// before the call to the derived function. - /// \param[in] PostCallStmts The built statements to add to block - /// after the call to the derived function. - /// \param[in] args All the arguments to the target function. - /// \param[in] outputArgs The output gradient arguments. - /// - /// \returns The derivative function call. - clang::Expr* GetMultiArgCentralDiffCall( - clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs, - clang::Expr* dfdx, llvm::SmallVectorImpl& PreCallStmts, - llvm::SmallVectorImpl& PostCallStmts, - llvm::SmallVectorImpl& args, - llvm::SmallVectorImpl& outputArgs); /// Emits diagnostic messages on differentiation (or lack thereof) for /// call expressions. /// diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 0038a6bbd..ed8683483 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1957,6 +1957,59 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {}; } + Expr* ReverseModeVisitor::GetMultiArgCentralDiffCall( + Expr* targetFuncCall, QualType retType, unsigned numArgs, Expr* dfdx, + llvm::SmallVectorImpl& PreCallStmts, + llvm::SmallVectorImpl& PostCallStmts, + llvm::SmallVectorImpl& args, + llvm::SmallVectorImpl& outputArgs) { + int printErrorInf = m_Builder.shouldPrintNumDiffErrs(); + llvm::SmallVector NumDiffArgs = {}; + NumDiffArgs.push_back(targetFuncCall); + // build the clad::tape> = {}; + QualType RefType = GetCladArrayRefOfType(retType); + QualType TapeType = GetCladTapeOfType(RefType); + auto* VD = BuildVarDecl( + TapeType, "_t", getZeroInit(TapeType), /*DirectInit=*/false, + /*TSI=*/nullptr, VarDecl::InitializationStyle::CInit); + PreCallStmts.push_back(BuildDeclStmt(VD)); + Expr* TapeRef = BuildDeclRef(VD); + NumDiffArgs.push_back(TapeRef); + NumDiffArgs.push_back(ConstantFolder::synthesizeLiteral( + m_Context.IntTy, m_Context, printErrorInf)); + + // Build the tape push expressions. + VD->setLocation(m_Function->getLocation()); + m_Sema.AddInitializerToDecl(VD, getZeroInit(TapeType), false); + CXXScopeSpec CSS; + CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc); + LookupResult& Push = GetCladTapePush(); + Expr* PushDRE = + m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false).get(); + for (unsigned i = 0, e = numArgs; i < e; i++) { + QualType argTy = args[i]->getType(); + VarDecl* gradVar = BuildVarDecl(argTy, "_grad", getZeroInit(argTy)); + PreCallStmts.push_back(BuildDeclStmt(gradVar)); + Expr* PushExpr = BuildDeclRef(gradVar); + if (!isCladArrayType(argTy)) + PushExpr = BuildOp(UO_AddrOf, PushExpr); + std::array callArgs = {TapeRef, PushExpr}; + Stmt* PushStmt = + m_Sema + .ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, noLoc) + .get(); + PreCallStmts.push_back(PushStmt); + Expr* gradExpr = BuildOp(BO_Mul, dfdx, BuildDeclRef(gradVar)); + PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr)); + NumDiffArgs.push_back(args[i]); + } + std::string Name = "central_difference"; + return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( + Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr, + /*forCustomDerv=*/false, + /*namespaceShouldExist=*/false); + } + StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { auto opCode = UnOp->getOpcode(); Expr* valueForRevPass = nullptr; diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 86ea4e9cc..c242bad2d 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -771,60 +771,6 @@ namespace clad { /*namespaceShouldExist=*/false); } - Expr* VisitorBase::GetMultiArgCentralDiffCall( - Expr* targetFuncCall, QualType retType, unsigned numArgs, Expr* dfdx, - llvm::SmallVectorImpl& PreCallStmts, - llvm::SmallVectorImpl& PostCallStmts, - llvm::SmallVectorImpl& args, - llvm::SmallVectorImpl& outputArgs) { - int printErrorInf = m_Builder.shouldPrintNumDiffErrs(); - llvm::SmallVector NumDiffArgs = {}; - NumDiffArgs.push_back(targetFuncCall); - // build the clad::tape> = {}; - QualType RefType = GetCladArrayRefOfType(retType); - QualType TapeType = GetCladTapeOfType(RefType); - auto VD = BuildVarDecl( - TapeType, "_t", getZeroInit(TapeType), /*DirectInit=*/false, - /*TSI=*/nullptr, VarDecl::InitializationStyle::CInit); - PreCallStmts.push_back(BuildDeclStmt(VD)); - Expr* TapeRef = BuildDeclRef(VD); - NumDiffArgs.push_back(TapeRef); - NumDiffArgs.push_back(ConstantFolder::synthesizeLiteral(m_Context.IntTy, - m_Context, - printErrorInf)); - - // Build the tape push expressions. - VD->setLocation(m_Function->getLocation()); - m_Sema.AddInitializerToDecl(VD, getZeroInit(TapeType), false); - CXXScopeSpec CSS; - CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc); - LookupResult& Push = GetCladTapePush(); - Expr* PushDRE = - m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false).get(); - for (unsigned i = 0, e = numArgs; i < e; i++) { - QualType argTy = args[i]->getType(); - VarDecl* gradVar = BuildVarDecl(argTy, "_grad", getZeroInit(argTy)); - PreCallStmts.push_back(BuildDeclStmt(gradVar)); - Expr* PushExpr = BuildDeclRef(gradVar); - if (!isCladArrayType(argTy)) - PushExpr = BuildOp(UO_AddrOf, PushExpr); - std::array callArgs = {TapeRef, PushExpr}; - Stmt* PushStmt = - m_Sema - .ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, noLoc) - .get(); - PreCallStmts.push_back(PushStmt); - Expr* gradExpr = BuildOp(BO_Mul, dfdx, BuildDeclRef(gradVar)); - PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr)); - NumDiffArgs.push_back(args[i]); - } - std::string Name = "central_difference"; - return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr, - /*forCustomDerv=*/false, - /*namespaceShouldExist=*/false); - } - void VisitorBase::CallExprDiffDiagnostics(llvm::StringRef funcName, SourceLocation srcLoc, bool isDerived){ if (!isDerived) {