From 2ca8db6db5ae213f79447c89f4d4c07cab0ca960 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Tue, 9 Jul 2024 19:40:02 +0200 Subject: [PATCH] Improve types for 0 literal --- lib/Differentiator/BaseForwardModeVisitor.cpp | 17 +++++++++++------ test/ForwardMode/Functors.C | 2 +- test/ForwardMode/Pointer.C | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index e9614c8a8..898b0c4b3 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1008,8 +1008,11 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { // If DRE is of type pointer, then the derivative is a null pointer. if (clonedDRE->getType()->isPointerType()) return StmtDiff(clonedDRE, nullptr); + QualType clonedDreTy = clad::utils::GetValueType(clonedDRE->getType()); + if (!clonedDreTy->isRealType()) + clonedDreTy = m_Context.IntTy; // force int if not real, ex. custom type return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral( - m_Context.IntTy, m_Context, /*val=*/0)); + clonedDreTy, m_Context, /*val=*/0)); } StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) { @@ -1318,11 +1321,13 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { opKind == UnaryOperatorKind::UO_Imag) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_Deref) { - Expr* dx = diff.getExpr_dx(); - if (dx == nullptr) - return StmtDiff(op, ConstantFolder::synthesizeLiteral( - m_Context.IntTy, m_Context, /*val=*/0)); - return StmtDiff(op, BuildOp(opKind, dx)); + if (Expr* dx = diff.getExpr_dx()) + return StmtDiff(op, BuildOp(opKind, dx)); + QualType pointeeTy = clad::utils::GetValueType(UnOp->getType()); + if (!pointeeTy->isRealType()) + pointeeTy = m_Context.IntTy; // force int if not real, ex. custom type + return StmtDiff( + op, ConstantFolder::synthesizeLiteral(pointeeTy, m_Context, /*val=*/0)); } else if (opKind == UnaryOperatorKind::UO_AddrOf) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_LNot) { diff --git a/test/ForwardMode/Functors.C b/test/ForwardMode/Functors.C index fdfcab92b..85026942c 100644 --- a/test/ForwardMode/Functors.C +++ b/test/ForwardMode/Functors.C @@ -430,7 +430,7 @@ int main() { // CHECK-NEXT: double _d_i = 1; // CHECK-NEXT: double _d_jj = 0; // CHECK-NEXT: double _t0 = x * i; - // CHECK-NEXT: return (0 * i + x * _d_i) * jj + _t0 * _d_jj; + // CHECK-NEXT: return (0. * i + x * _d_i) * jj + _t0 * _d_jj; // CHECK-NEXT: } auto lambdaNNS = outer::inner::lambdaNNS; diff --git a/test/ForwardMode/Pointer.C b/test/ForwardMode/Pointer.C index 69f050516..50cb0ee87 100644 --- a/test/ForwardMode/Pointer.C +++ b/test/ForwardMode/Pointer.C @@ -198,7 +198,7 @@ double fn9(double* params, const double *constants) { } // CHECK: double fn9_darg0_0(double *params, const double *constants) { -// CHECK-NEXT: double _d_c0 = 0; +// CHECK-NEXT: double _d_c0 = 0.; // CHECK-NEXT: double c0 = *constants; // CHECK-NEXT: return 1. * c0 + params[0] * _d_c0; // CHECK-NEXT: }