Skip to content

Commit

Permalink
Fix for Parenthesis issue in gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
toshit3q34 committed Dec 22, 2024
1 parent 3e50707 commit 9461931
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
9 changes: 8 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,14 @@ namespace clad {
/// either LHS or RHS is null.
clang::Expr* BuildOp(clang::BinaryOperatorKind OpCode, clang::Expr* L,
clang::Expr* R, clang::SourceLocation OpLoc = noLoc);

/// Recursive Function to resolve Unary Minus. If the leftmost operand
/// has a Unary Minus then does not add the operator else adds.
/// Also resolves multiple Unary Minus. Example : (-(-(-2))) gives -2.
/// \param[in] E Expression fed to the recursive call.
/// \param[in] OpLoc Location to add Unary Minus if needed.
/// \returns Expression with correct Unary Operator placement.
clang::Expr* RemoveFirstUnaryMinus(clang::Expr* E,
clang::SourceLocation OpLoc);
clang::Expr* BuildParens(clang::Expr* E);
/// Builds variable declaration to be used inside the derivative
/// body.
Expand Down
23 changes: 23 additions & 0 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,31 @@ namespace clad {
// Debug clang requires the location to be valid
if (!OpLoc.isValid())
OpLoc = utils::GetValidSLoc(m_Sema);
// Call function for UnaryMinus
if (OpCode == UO_Minus)
return RemoveFirstUnaryMinus(E->IgnoreCasts(), OpLoc);
return m_Sema.BuildUnaryOp(nullptr, OpLoc, OpCode, E).get();
}
Expr* VisitorBase::RemoveFirstUnaryMinus(Expr* E, SourceLocation OpLoc) {
if (auto* UO = llvm::dyn_cast<UnaryOperator>(E)) {
if (UO->getOpcode() == UO_Minus)
return UO->getSubExpr();

Check warning on line 411 in lib/Differentiator/VisitorBase.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/VisitorBase.cpp#L410-L411

Added lines #L410 - L411 were not covered by tests
}
if (auto* BO = llvm::dyn_cast<BinaryOperator>(E)) {
if (BO->getOpcode() == BO_Mul || BO->getOpcode() == BO_Div) {
Expr* LHS = BO->getLHS();
Expr* RHS = BO->getRHS();
Expr* LHSModified = RemoveFirstUnaryMinus(LHS->IgnoreCasts(), OpLoc);
SourceLocation BO_loc = BO->getExprLoc();
return m_Sema
.BuildBinOp(nullptr, BO_loc, BO->getOpcode(), LHSModified, RHS)
.get();
}
}
if (!OpLoc.isValid())
OpLoc = utils::GetValidSLoc(m_Sema);

Check warning on line 425 in lib/Differentiator/VisitorBase.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/VisitorBase.cpp#L425

Added line #L425 was not covered by tests
return m_Sema.BuildUnaryOp(nullptr, OpLoc, UO_Minus, E).get();
}

Expr* VisitorBase::BuildOp(clang::BinaryOperatorKind OpCode, Expr* L, Expr* R,
SourceLocation OpLoc) {
Expand Down

0 comments on commit 9461931

Please sign in to comment.