Skip to content

Commit

Permalink
Improve types for 0 literal
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jul 9, 2024
1 parent cddb7b5 commit 2ca8db6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
17 changes: 11 additions & 6 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1008,8 +1008,11 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
// If DRE is of type pointer, then the derivative is a null pointer.
if (clonedDRE->getType()->isPointerType())
return StmtDiff(clonedDRE, nullptr);
QualType clonedDreTy = clad::utils::GetValueType(clonedDRE->getType());
if (!clonedDreTy->isRealType())
clonedDreTy = m_Context.IntTy; // force int if not real, ex. custom type
return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
clonedDreTy, m_Context, /*val=*/0));
}

StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
Expand Down Expand Up @@ -1318,11 +1321,13 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
opKind == UnaryOperatorKind::UO_Imag) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_Deref) {
Expr* dx = diff.getExpr_dx();
if (dx == nullptr)
return StmtDiff(op, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
return StmtDiff(op, BuildOp(opKind, dx));
if (Expr* dx = diff.getExpr_dx())
return StmtDiff(op, BuildOp(opKind, dx));
QualType pointeeTy = clad::utils::GetValueType(UnOp->getType());
if (!pointeeTy->isRealType())
pointeeTy = m_Context.IntTy; // force int if not real, ex. custom type
return StmtDiff(
op, ConstantFolder::synthesizeLiteral(pointeeTy, m_Context, /*val=*/0));
} else if (opKind == UnaryOperatorKind::UO_AddrOf) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_LNot) {
Expand Down
2 changes: 1 addition & 1 deletion test/ForwardMode/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ int main() {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_jj = 0;
// CHECK-NEXT: double _t0 = x * i;
// CHECK-NEXT: return (0 * i + x * _d_i) * jj + _t0 * _d_jj;
// CHECK-NEXT: return (0. * i + x * _d_i) * jj + _t0 * _d_jj;
// CHECK-NEXT: }

auto lambdaNNS = outer::inner::lambdaNNS;
Expand Down
2 changes: 1 addition & 1 deletion test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ double fn9(double* params, const double *constants) {
}

// CHECK: double fn9_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0;
// CHECK-NEXT: double _d_c0 = 0.;
// CHECK-NEXT: double c0 = *constants;
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }
Expand Down

0 comments on commit 2ca8db6

Please sign in to comment.