From 1c2dc1e2fdebaf0d6f2aca6a41ea2041a2cb1fb4 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 22 Apr 2024 23:12:08 +0300 Subject: [PATCH] Use one IsDifferentiableType for all modes and move it to VisitorBase --- .../Differentiator/BaseForwardModeVisitor.h | 2 -- include/clad/Differentiator/CladUtils.h | 2 -- include/clad/Differentiator/VisitorBase.h | 2 ++ lib/Differentiator/BaseForwardModeVisitor.cpp | 23 ++++--------------- lib/Differentiator/CladUtils.cpp | 5 ---- lib/Differentiator/ReverseModeVisitor.cpp | 8 +++---- lib/Differentiator/VisitorBase.cpp | 19 ++++++++++++++- 7 files changed, 28 insertions(+), 33 deletions(-) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 9e0394c8a..8e12af090 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -47,8 +47,6 @@ class BaseForwardModeVisitor virtual void ExecuteInsidePushforwardFunctionBlock(); - static bool IsDifferentiableType(clang::QualType T); - virtual StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 6e3d88a8a..d51a5c253 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -334,8 +334,6 @@ namespace clad { bool IsMemoryFunction(const clang::FunctionDecl* FD); bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD); - bool IsDifferentiableType(clang::QualType QT); - /// Removes the local const qualifiers from a QualType and returns a new /// type. clang::QualType getNonConstType(clang::QualType T, clang::ASTContext& C, diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 0aa613f38..3d4e18576 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -208,6 +208,8 @@ namespace clad { return QT->isArrayType() || QT->isPointerType(); } + static bool IsDifferentiableType(clang::QualType T); + clang::CompoundStmt* MakeCompoundStmt(const Stmts& Stmts); /// Get the latest block of code (i.e. place for statements output). diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index d0e53581a..6f2dd6962 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -38,21 +38,6 @@ BaseForwardModeVisitor::BaseForwardModeVisitor(DerivativeBuilder& builder) BaseForwardModeVisitor::~BaseForwardModeVisitor() {} -bool BaseForwardModeVisitor::IsDifferentiableType(QualType T) { - QualType origType = T; - // FIXME: arbitrary dimension array type as well. - while (utils::isArrayOrPointerType(T)) - T = utils::GetValueType(T); - T = T.getNonReferenceType(); - if (T->isEnumeralType()) - return false; - if (T->isRealType() || T->isStructureOrClassType()) - return true; - if (origType->isPointerType() && T->isVoidType()) - return true; - return false; -} - bool IsRealNonReferenceType(QualType T) { return T.getNonReferenceType()->isRealType(); } @@ -224,7 +209,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, // non-reference type for creating the derivatives. QualType dParamType = param->getType().getNonReferenceType(); // We do not create derived variable for array/pointer parameters. - if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) || + if (!IsDifferentiableType(dParamType) || utils::isArrayOrPointerType(dParamType)) continue; Expr* dParam = nullptr; @@ -420,7 +405,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, for (auto* PVD : m_Function->parameters()) { paramTypes.push_back(PVD->getType()); - if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) + if (IsDifferentiableType(PVD->getType())) derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType())); } @@ -485,7 +470,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, if (identifierMissing) m_DeclReplacements[PVD] = newPVD; - if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) + if (!IsDifferentiableType(PVD->getType())) continue; auto derivedPVDName = "_d_" + std::string(PVDII->getName()); IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName); @@ -1069,7 +1054,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { } } CallArgs.push_back(argDiff.getExpr()); - if (BaseForwardModeVisitor::IsDifferentiableType(arg->getType())) { + if (IsDifferentiableType(arg->getType())) { Expr* dArg = argDiff.getExpr_dx(); // FIXME: What happens when dArg is nullptr? diffArgs.push_back(dArg); diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 7899e6d49..791e6dc62 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -685,11 +685,6 @@ namespace clad { #endif } - bool IsDifferentiableType(clang::QualType QT) { - // FIXME: consider analysing object types with this - return !utils::GetValueType(QT)->isIntegerType(); - } - clang::QualType getNonConstType(clang::QualType T, clang::ASTContext& C, clang::Sema& S) { clang::Qualifiers quals(T.getQualifiers()); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 1c85d73fe..6679e6f0d 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -272,7 +272,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // FIXME: we have to create adjoints for all parameters when any // external sources are enabled because gradient overloads don't support // additional parameters. - if (utils::IsDifferentiableType(dParam.param->getType()) || m_ExternalSource) + if (IsDifferentiableType(dParam.param->getType()) || m_ExternalSource) args.push_back(dParam.param); } else @@ -604,7 +604,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // FIXME: we have to create adjoints for all parameters when any // external sources are enabled because gradient overloads don't support // additional parameters. - if (!utils::IsDifferentiableType(param->getType()) && !m_ExternalSource) + if (!IsDifferentiableType(param->getType()) && !m_ExternalSource) continue; // derived variables are already created for independent variables. if (m_Variables.count(param)) @@ -1548,7 +1548,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. // Also, no need to create adjoint variables for non-differentiable types. - if (utils::IsReferenceOrPointerArg(arg) || !utils::IsDifferentiableType(arg->getType())) { + if (utils::IsReferenceOrPointerArg(arg) || !IsDifferentiableType(arg->getType())) { argDiff = Visit(arg); CallArgDx.push_back(argDiff.getExpr_dx()); } else { @@ -2585,7 +2585,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Integer types are not differentiable, // no need to construct an adjoint. - if (!utils::IsDifferentiableType(VD->getType())) { + if (!IsDifferentiableType(VD->getType())) { Expr* init = nullptr; if (VD->getInit()) init = Visit(VD->getInit()).getExpr(); diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 5b1f5d625..0279ece62 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -906,7 +906,7 @@ namespace clad { m_Mode != DiffMode::experimental_pushforward) for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) { QualType paramTy = originalFD->getParamDecl(i)->getType(); - if (!utils::IsDifferentiableType(paramTy)) { + if (!IsDifferentiableType(paramTy)) { QualType argTy = utils::getNonConstType(paramTy, m_Context, m_Sema); VarDecl* argDecl = BuildVarDecl(argTy, "_r", getZeroInit(argTy)); Expr* arg = BuildDeclRef(argDecl); @@ -965,4 +965,21 @@ namespace clad { } return false; } + + + +bool VisitorBase::IsDifferentiableType(QualType T) { + QualType origType = T; + // FIXME: arbitrary dimension array type as well. + while (utils::isArrayOrPointerType(T)) + T = utils::GetValueType(T); + T = T.getNonReferenceType(); + if (T->isEnumeralType()) + return false; + if (T->isFloatingType() || T->isStructureOrClassType()) + return true; + if (origType->isPointerType() && T->isVoidType()) + return true; + return false; +} } // end namespace clad