Skip to content

Commit

Permalink
Move GetMultiArgCentralDiffCall to ReverseModeVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Mar 18, 2024
1 parent 453043e commit b340d05
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 75 deletions.
22 changes: 22 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,28 @@ namespace clad {
CladTapeResult MakeCladTapeFor(clang::Expr* E,
llvm::StringRef prefix = "_t");

/// A function to get the multi-argument "central_difference"
/// call expression for the given arguments.
///
/// \param[in] targetFuncCall The function to get the derivative for.
/// \param[in] retType The return type of the target call expression.
/// \param[in] dfdx The dfdx corresponding to this call expression.
/// \param[in] numArgs The total number of 'args'.
/// \param[in] PreCallStmts The built statements to add to block
/// before the call to the derived function.
/// \param[in] PostCallStmts The built statements to add to block
/// after the call to the derived function.
/// \param[in] args All the arguments to the target function.
/// \param[in] outputArgs The output gradient arguments.
///
/// \returns The derivative function call.
clang::Expr* GetMultiArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs,
clang::Expr* dfdx, llvm::SmallVectorImpl<clang::Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<clang::Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<clang::Expr*>& args,
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);

public:
ReverseModeVisitor(DerivativeBuilder& builder);
virtual ~ReverseModeVisitor();
Expand Down
21 changes: 0 additions & 21 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,27 +585,6 @@ namespace clad {
clang::Expr* GetSingleArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos,
unsigned numArgs, llvm::SmallVectorImpl<clang::Expr*>& args);
/// A function to get the multi-argument "central_difference"
/// call expression for the given arguments.
///
/// \param[in] targetFuncCall The function to get the derivative for.
/// \param[in] retType The return type of the target call expression.
/// \param[in] dfdx The dfdx corresponding to this call expression.
/// \param[in] numArgs The total number of 'args'.
/// \param[in] PreCallStmts The built statements to add to block
/// before the call to the derived function.
/// \param[in] PostCallStmts The built statements to add to block
/// after the call to the derived function.
/// \param[in] args All the arguments to the target function.
/// \param[in] outputArgs The output gradient arguments.
///
/// \returns The derivative function call.
clang::Expr* GetMultiArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs,
clang::Expr* dfdx, llvm::SmallVectorImpl<clang::Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<clang::Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<clang::Expr*>& args,
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);
/// Emits diagnostic messages on differentiation (or lack thereof) for
/// call expressions.
///
Expand Down
53 changes: 53 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,59 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {};
}

Expr* ReverseModeVisitor::GetMultiArgCentralDiffCall(
Expr* targetFuncCall, QualType retType, unsigned numArgs, Expr* dfdx,
llvm::SmallVectorImpl<Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<Expr*>& args,
llvm::SmallVectorImpl<Expr*>& outputArgs) {
int printErrorInf = m_Builder.shouldPrintNumDiffErrs();
llvm::SmallVector<Expr*, 16U> NumDiffArgs = {};
NumDiffArgs.push_back(targetFuncCall);
// build the clad::tape<clad::array_ref>> = {};
QualType RefType = GetCladArrayRefOfType(retType);
QualType TapeType = GetCladTapeOfType(RefType);
auto* VD = BuildVarDecl(
TapeType, "_t", getZeroInit(TapeType), /*DirectInit=*/false,
/*TSI=*/nullptr, VarDecl::InitializationStyle::CInit);
PreCallStmts.push_back(BuildDeclStmt(VD));
Expr* TapeRef = BuildDeclRef(VD);
NumDiffArgs.push_back(TapeRef);
NumDiffArgs.push_back(ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, printErrorInf));

// Build the tape push expressions.
VD->setLocation(m_Function->getLocation());
m_Sema.AddInitializerToDecl(VD, getZeroInit(TapeType), false);
CXXScopeSpec CSS;
CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc);
LookupResult& Push = GetCladTapePush();
Expr* PushDRE =
m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false).get();
for (unsigned i = 0, e = numArgs; i < e; i++) {
QualType argTy = args[i]->getType();
VarDecl* gradVar = BuildVarDecl(argTy, "_grad", getZeroInit(argTy));
PreCallStmts.push_back(BuildDeclStmt(gradVar));
Expr* PushExpr = BuildDeclRef(gradVar);
if (!isCladArrayType(argTy))
PushExpr = BuildOp(UO_AddrOf, PushExpr);
std::array<Expr*, 2> callArgs = {TapeRef, PushExpr};
Stmt* PushStmt =
m_Sema
.ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, noLoc)
.get();
PreCallStmts.push_back(PushStmt);
Expr* gradExpr = BuildOp(BO_Mul, dfdx, BuildDeclRef(gradVar));
PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr));
NumDiffArgs.push_back(args[i]);
}
std::string Name = "central_difference";
return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
}

StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
auto opCode = UnOp->getOpcode();
Expr* valueForRevPass = nullptr;
Expand Down
54 changes: 0 additions & 54 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -771,60 +771,6 @@ namespace clad {
/*namespaceShouldExist=*/false);
}

Expr* VisitorBase::GetMultiArgCentralDiffCall(
Expr* targetFuncCall, QualType retType, unsigned numArgs, Expr* dfdx,
llvm::SmallVectorImpl<Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<Expr*>& args,
llvm::SmallVectorImpl<Expr*>& outputArgs) {
int printErrorInf = m_Builder.shouldPrintNumDiffErrs();
llvm::SmallVector<Expr*, 16U> NumDiffArgs = {};
NumDiffArgs.push_back(targetFuncCall);
// build the clad::tape<clad::array_ref>> = {};
QualType RefType = GetCladArrayRefOfType(retType);
QualType TapeType = GetCladTapeOfType(RefType);
auto VD = BuildVarDecl(
TapeType, "_t", getZeroInit(TapeType), /*DirectInit=*/false,
/*TSI=*/nullptr, VarDecl::InitializationStyle::CInit);
PreCallStmts.push_back(BuildDeclStmt(VD));
Expr* TapeRef = BuildDeclRef(VD);
NumDiffArgs.push_back(TapeRef);
NumDiffArgs.push_back(ConstantFolder::synthesizeLiteral(m_Context.IntTy,
m_Context,
printErrorInf));

// Build the tape push expressions.
VD->setLocation(m_Function->getLocation());
m_Sema.AddInitializerToDecl(VD, getZeroInit(TapeType), false);
CXXScopeSpec CSS;
CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc);
LookupResult& Push = GetCladTapePush();
Expr* PushDRE =
m_Sema.BuildDeclarationNameExpr(CSS, Push, /*ADL*/ false).get();
for (unsigned i = 0, e = numArgs; i < e; i++) {
QualType argTy = args[i]->getType();
VarDecl* gradVar = BuildVarDecl(argTy, "_grad", getZeroInit(argTy));
PreCallStmts.push_back(BuildDeclStmt(gradVar));
Expr* PushExpr = BuildDeclRef(gradVar);
if (!isCladArrayType(argTy))
PushExpr = BuildOp(UO_AddrOf, PushExpr);
std::array<Expr*, 2> callArgs = {TapeRef, PushExpr};
Stmt* PushStmt =
m_Sema
.ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, callArgs, noLoc)
.get();
PreCallStmts.push_back(PushStmt);
Expr* gradExpr = BuildOp(BO_Mul, dfdx, BuildDeclRef(gradVar));
PostCallStmts.push_back(BuildOp(BO_AddAssign, outputArgs[i], gradExpr));
NumDiffArgs.push_back(args[i]);
}
std::string Name = "central_difference";
return m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
Name, NumDiffArgs, getCurrentScope(), /*OriginalFnDC=*/nullptr,
/*forCustomDerv=*/false,
/*namespaceShouldExist=*/false);
}

void VisitorBase::CallExprDiffDiagnostics(llvm::StringRef funcName,
SourceLocation srcLoc, bool isDerived){
if (!isDerived) {
Expand Down

0 comments on commit b340d05

Please sign in to comment.