diff --git a/lib/Differentiator/ConstantFolder.cpp b/lib/Differentiator/ConstantFolder.cpp index 8e7780ad7..e75b918ad 100644 --- a/lib/Differentiator/ConstantFolder.cpp +++ b/lib/Differentiator/ConstantFolder.cpp @@ -147,10 +147,13 @@ namespace clad { QT->isSignedIntegerOrEnumerationType()); Result = clad::synthesizeLiteral( dyn_cast(QT)->getDecl()->getIntegerType(), C, APVal); - Expr* cast = ImplicitCastExpr::Create( - C, QT, clang::CastKind::CK_IntegralCast, Result, nullptr, - CLAD_COMPAT_ExprValueKind_R_or_PR_Value - CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO); + SourceLocation noLoc; + Expr* cast = CXXStaticCastExpr::Create( + C, QT, CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + clang::CastKind::CK_IntegralCast, Result, nullptr, + C.getTrivialTypeSourceInfo(QT, noLoc) + CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO, + noLoc, noLoc, SourceRange()); Result = cast; } else if (QT->isPointerType()) { Result = clad::synthesizeLiteral(QT, C); diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C index 6e18bc04d..c701c3c00 100644 --- a/test/Gradient/Switch.C +++ b/test/Gradient/Switch.C @@ -682,6 +682,146 @@ double fn7(double u, double v) { // CHECK-NEXT: } // CHECK-NEXT: } +enum Op { + Add, + Sub, + Mul, + Div +}; + +double fn24(double x, double y, Op op) { + double res = 0; + switch (op) { + case Add: + res = x + y; + break; + case Sub: + res = x - y; + break; + case Mul: + res = x * y; + break; + case Div: + res = x / y; + break; + } + return res; +} + +// CHECK: void fn24_grad_0_1(double x, double y, Op op, double *_d_x, double *_d_y) { +// CHECK-NEXT: Op _d_op = static_cast(0U); +// CHECK-NEXT: Op _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _d_res = 0.; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: _cond0 = op; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: { +// CHECK-NEXT: case Add: +// CHECK-NEXT: res = x + y; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Sub: +// CHECK-NEXT: res = x - y; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 2UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Mul: +// CHECK-NEXT: res = x * y; +// CHECK-NEXT: _t3 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 3UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Div: +// CHECK-NEXT: res = x / y; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 4UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t1, 5UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case 5UL: +// CHECK-NEXT: ; +// CHECK-NEXT: case 4UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d3 / y; +// CHECK-NEXT: double _r0 = _r_d3 * -(x / (y * y)); +// CHECK-NEXT: _d_y += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: if (Div == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 3UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t3; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d2 * y; +// CHECK-NEXT: _d_y += x * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (Mul == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d1; +// CHECK-NEXT: _d_y += -_r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (Sub == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d0; +// CHECK-NEXT: _d_y += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (Add == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT:} + #define TEST_2(F, x, y) \ { \ @@ -691,6 +831,14 @@ double fn7(double u, double v) { printf("{%.2f, %.2f}\n", result[0], result[1]); \ } +#define TEST_2_Op(F, x, y, op) \ +{ \ + result[0] = result[1] = 0; \ + auto d_##F = clad::gradient(F, "x, y"); \ + d_##F.execute(x, y, op, result, result + 1); \ + printf("{%.2f, %.2f}\n", result[0], result[1]); \ +} + int main() { double result[2] = {}; @@ -705,4 +853,9 @@ int main() { TEST_GRADIENT(fn6, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {5.00, 3.00} TEST_GRADIENT(fn7, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {3.00, 2.00} + + TEST_2_Op(fn24, 3, 5, Add); // CHECK-EXEC: {1.00, 1.00} + TEST_2_Op(fn24, 3, 5, Sub); // CHECK-EXEC: {1.00, -1.00} + TEST_2_Op(fn24, 3, 5, Mul); // CHECK-EXEC: {5.00, 3.00} + TEST_2_Op(fn24, 3, 5, Div); // CHECK-EXEC: {0.20, -0.12} }