diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 9004f9e4f..f8231362e 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -392,6 +392,15 @@ namespace clad { } m_IndependentVar = args.back(); + // If param is not real (i.e. floating point or integral), we cannot + // differentiate it. + // FIXME: we should support custom numeric types in the future. + if (!m_IndependentVar->getType()->isRealType()) { + diag(DiagnosticsEngine::Error, m_IndependentVar->getLocEnd(), + "attempted differentiation w.r.t. a parameter ('%0') which is not " + "of a real type", { m_IndependentVar->getNameAsString() }); + return {}; + } m_DerivativeOrder = plan.getCurrentDerivativeOrder(); std::string s = std::to_string(m_DerivativeOrder); std::string derivativeBaseName; @@ -486,17 +495,8 @@ namespace clad { beginBlock(); // For each function parameter variable, store its derivative value. for (auto param : params) { - // If param is not real (i.e. floating point or integral), we cannot - // differentiate it. - // FIXME: we should support custom numeric types in the future. - if (!param->getType()->isRealType()) { - if (param != m_IndependentVar) - continue; - diag(DiagnosticsEngine::Error, PVD->getLocEnd(), - "attempted differentiation w.r.t. a parameter ('%0') which is not " - "of a real type", { m_IndependentVar->getNameAsString() }); - return {}; - } + if (!param->getType()->isRealType()) + continue; // If param is independent variable, its derivative is 1, otherwise 0. int dValue = (param == m_IndependentVar); auto dParam = ConstantFolder::synthesizeLiteral(m_Context.IntTy, @@ -843,7 +843,8 @@ namespace clad { auto referencedDecl = cast(clonedDRE->getDecl()); clonedDRE = cast(BuildDeclRef(referencedDecl)); } - } + } else + clonedDRE = cast(Clone(DRE)); if (auto VD = dyn_cast(clonedDRE->getDecl())) { // If DRE references a variable, try to find if we know something about @@ -1674,7 +1675,8 @@ namespace clad { auto referencedDecl = cast(clonedDRE->getDecl()); clonedDRE = cast(BuildDeclRef(referencedDecl)); } - } + } else + clonedDRE = cast(Clone(DRE)); if (auto decl = dyn_cast(clonedDRE->getDecl())) { // Check DeclRefExpr is a reference to an independent variable.