From 9461931d68a1024cecdc5f5ab28fe51bf04922fb Mon Sep 17 00:00:00 2001 From: ToshitJain Date: Sun, 22 Dec 2024 22:52:59 +0530 Subject: [PATCH] Fix for Parenthesis issue in gradients --- include/clad/Differentiator/VisitorBase.h | 9 ++++++++- lib/Differentiator/VisitorBase.cpp | 23 +++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 0abd868ca..ab6ed2b4c 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -275,7 +275,14 @@ namespace clad { /// either LHS or RHS is null. clang::Expr* BuildOp(clang::BinaryOperatorKind OpCode, clang::Expr* L, clang::Expr* R, clang::SourceLocation OpLoc = noLoc); - + /// Recursive Function to resolve Unary Minus. If the leftmost operand + /// has a Unary Minus then does not add the operator else adds. + /// Also resolves multiple Unary Minus. Example : (-(-(-2))) gives -2. + /// \param[in] E Expression fed to the recursive call. + /// \param[in] OpLoc Location to add Unary Minus if needed. + /// \returns Expression with correct Unary Operator placement. + clang::Expr* RemoveFirstUnaryMinus(clang::Expr* E, + clang::SourceLocation OpLoc); clang::Expr* BuildParens(clang::Expr* E); /// Builds variable declaration to be used inside the derivative /// body. diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 408dd9f91..dea78ddf0 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -400,8 +400,31 @@ namespace clad { // Debug clang requires the location to be valid if (!OpLoc.isValid()) OpLoc = utils::GetValidSLoc(m_Sema); + // Call function for UnaryMinus + if (OpCode == UO_Minus) + return RemoveFirstUnaryMinus(E->IgnoreCasts(), OpLoc); return m_Sema.BuildUnaryOp(nullptr, OpLoc, OpCode, E).get(); } + Expr* VisitorBase::RemoveFirstUnaryMinus(Expr* E, SourceLocation OpLoc) { + if (auto* UO = llvm::dyn_cast(E)) { + if (UO->getOpcode() == UO_Minus) + return UO->getSubExpr(); + } + if (auto* BO = llvm::dyn_cast(E)) { + if (BO->getOpcode() == BO_Mul || BO->getOpcode() == BO_Div) { + Expr* LHS = BO->getLHS(); + Expr* RHS = BO->getRHS(); + Expr* LHSModified = RemoveFirstUnaryMinus(LHS->IgnoreCasts(), OpLoc); + SourceLocation BO_loc = BO->getExprLoc(); + return m_Sema + .BuildBinOp(nullptr, BO_loc, BO->getOpcode(), LHSModified, RHS) + .get(); + } + } + if (!OpLoc.isValid()) + OpLoc = utils::GetValidSLoc(m_Sema); + return m_Sema.BuildUnaryOp(nullptr, OpLoc, UO_Minus, E).get(); + } Expr* VisitorBase::BuildOp(clang::BinaryOperatorKind OpCode, Expr* L, Expr* R, SourceLocation OpLoc) {