diff --git a/include/clad/Differentiator/Sins.h b/include/clad/Differentiator/Sins.h new file mode 100644 index 000000000..28983d626 --- /dev/null +++ b/include/clad/Differentiator/Sins.h @@ -0,0 +1,29 @@ +#ifndef CLAD_DIFFERENTIATOR_SINS_H +#define CLAD_DIFFERENTIATOR_SINS_H + +#include + +/// Standard-protected facility allowing access into private members in C++. +/// Use with caution! +// NOLINTBEGIN(cppcoreguidelines-macro-usage) +#define CONCATE_(X, Y) X##Y +#define CONCATE(X, Y) CONCATE_(X, Y) +#define ALLOW_ACCESS(CLASS, MEMBER, ...) \ + template \ + struct CONCATE(MEMBER, __LINE__) { \ + friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \ + }; \ + template struct Only_##MEMBER; \ + template <> struct Only_##MEMBER { \ + friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER*); \ + }; \ + template struct CONCATE(MEMBER, \ + __LINE__), &CLASS::MEMBER> + +#define ACCESS(OBJECT, MEMBER) \ + (OBJECT).*Access((Only_##MEMBER< \ + std::remove_reference::type>*)nullptr) + +// NOLINTEND(cppcoreguidelines-macro-usage) + +#endif // CLAD_DIFFERENTIATOR_SINS_H diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index eef3e2353..32ab9f161 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -8,10 +8,11 @@ #include "ConstantFolder.h" +#include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/StmtClone.h" -#include "clad/Differentiator/CladUtils.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" @@ -59,42 +60,14 @@ namespace clad { return true; } - // A facility allowing us to access the private member CurScope of the Sema - // object using standard-conforming C++. - namespace { - template struct Rob { - friend typename Tag::type get(Tag) { return M; } - }; - - template struct TagBase { - using type = Member; -#ifdef MSVC -#pragma warning(push, 0) -#endif // MSVC -#pragma GCC diagnostic push -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wunknown-warning-option" -#endif // __clang__ -#pragma GCC diagnostic ignored "-Wnon-template-friend" - friend type get(Tag); -#pragma GCC diagnostic pop -#ifdef MSVC -#pragma warning(pop) -#endif // MSVC - }; - - // Tag used to access Sema::CurScope. - using namespace clang; - struct Sema_CurScope : TagBase {}; - template struct Rob; - } // namespace + ALLOW_ACCESS(Sema, CurScope, Scope*); clang::Scope*& VisitorBase::getCurrentScope() { - return m_Sema.*get(Sema_CurScope()); + return ACCESS(m_Sema, CurScope); } void VisitorBase::setCurrentScope(clang::Scope* S) { - m_Sema.*get(Sema_CurScope()) = S; + getCurrentScope() = S; assert(getEnclosingNamespaceOrTUScope() && "Lost path to base."); } diff --git a/test/FirstDerivative/CodeGenSimple.C b/test/FirstDerivative/CodeGenSimple.C index 02a815c92..4ff77e806 100644 --- a/test/FirstDerivative/CodeGenSimple.C +++ b/test/FirstDerivative/CodeGenSimple.C @@ -33,9 +33,17 @@ extern "C" int printf(const char* fmt, ...); int f_1_darg0(int x); +double sq_defined_later(double); + int main() { int x = 4; clad::differentiate(f_1, 0); + auto df = clad::differentiate(sq_defined_later, "x"); printf("Result is = %d\n", f_1_darg0(1)); // CHECK-EXEC: Result is = 2 + printf("Result is = %f\n", df.execute(3)); // CHECK-EXEC: Result is = 6 return 0; } + +double sq_defined_later(double x) { + return x * x; +} diff --git a/test/Misc/ClangConsumers.cpp b/test/Misc/ClangConsumers.cpp new file mode 100644 index 000000000..cfd977b99 --- /dev/null +++ b/test/Misc/ClangConsumers.cpp @@ -0,0 +1,8 @@ +// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out -Xclang -print-stats 2>&1 | FileCheck %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" +// CHECK: HandleTopLevelDecl +int main() { + +} diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index d678d12e1..47646f182 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -9,6 +9,7 @@ #include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/EstimationModel.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/Version.h" #include "clang/AST/ASTConsumer.h" @@ -121,39 +122,8 @@ namespace clad { CladPlugin::~CladPlugin() {} - // A facility allowing us to access the private member CurScope of the Sema - // object using standard-conforming C++. - namespace { - template struct Rob { - friend typename Tag::type get(Tag) { return M; } - }; - - template struct TagBase { - using type = Member; -#ifdef MSVC -#pragma warning(push, 0) -#endif // MSVC -#pragma GCC diagnostic push -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wunknown-warning-option" -#endif // __clang__ -#pragma GCC diagnostic ignored "-Wnon-template-friend" - friend type get(Tag); -#pragma GCC diagnostic pop -#ifdef MSVC -#pragma warning(pop) -#endif // MSVC - }; - // Tag used to access MultiplexConsumer::Consumers. - using namespace clang; - struct MultiplexConsumer_Consumers - : TagBase< - MultiplexConsumer_Consumers, - std::vector> MultiplexConsumer::*> { - }; - template struct Rob; - } // namespace + ALLOW_ACCESS(MultiplexConsumer, Consumers, + std::vector>); void CladPlugin::Initialize(clang::ASTContext& C) { // We know we have a multiplexer. We commit a sin here by stealing it and @@ -163,7 +133,7 @@ namespace clad { using namespace clang; auto& MultiplexC = static_cast(m_CI.getASTConsumer()); - auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers()); + auto& RobbedCs = ACCESS(MultiplexC, Consumers); assert(RobbedCs.back().get() == this && "Clad is not the last consumer"); std::vector> StolenConsumers; @@ -181,60 +151,23 @@ namespace clad { // We cannot use HandleTranslationUnit because codegen already emits code on // HandleTopLevelDecl calls and makes updateCall with no effect. - bool CladPlugin::HandleTopLevelDecl(DeclGroupRef DGR) { - AppendDelayed({CallKind::HandleTopLevelDecl, DGR}); + void CladPlugin::HandleTopLevelDeclForClad(DeclGroupRef DGR) { if (!CheckBuiltins()) - return m_Multiplexer->HandleTopLevelDecl(DGR); // true; + return; Sema& S = m_CI.getSema(); if (!m_DerivativeBuilder) - m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this)); - - // if HandleTopLevelDecl was called through clad we don't need to process - // it for diff requests - if (m_HandleTopLevelDeclInternal) - return m_Multiplexer->HandleTopLevelDecl(DGR); // true; - - DiffSchedule requests{}; - DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema()); - - if (requests.empty()) - return m_Multiplexer->HandleTopLevelDecl(DGR); // true; - - // FIXME: flags have to be set manually since DiffCollector's constructor - // does not have access to m_DO. - if (m_DO.EnableTBRAnalysis) - for (DiffRequest& request : requests) - request.EnableTBRAnalysis = true; - - // FIXME: Remove the PerformPendingInstantiations altogether. We should - // somehow make the relevant functions referenced. - // Instantiate all pending for instantiations templates, because we will - // need the full bodies to produce derivatives. - // FIXME: Confirm if we really need `m_PendingInstantiationsInFlight`? - if (!m_PendingInstantiationsInFlight) { - m_PendingInstantiationsInFlight = true; - S.PerformPendingInstantiations(); - m_PendingInstantiationsInFlight = false; - } - - for (DiffRequest& request : requests) - ProcessDiffRequest(request); + m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this)); - return m_Multiplexer->HandleTopLevelDecl(DGR); // Happiness - } - - void CladPlugin::ProcessTopLevelDecl(Decl* D) { - m_HandleTopLevelDeclInternal = true; - m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D)); - m_HandleTopLevelDeclInternal = false; + DiffCollector collector(DGR, CladEnabledRange, m_DiffSchedule, S); } FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { Sema& S = m_CI.getSema(); // Required due to custom derivatives function templates that might be // used in the function that we need to derive. + // FIXME: Remove the call to PerformPendingInstantiations(). S.PerformPendingInstantiations(); if (request.Function->getDefinition()) request.Function = request.Function->getDefinition(); @@ -345,6 +278,8 @@ namespace clad { // Call CodeGen only if the produced Decl is a top-most // decl or is contained in a namespace decl. + // FIXME: We could get rid of this by prepending the produced + // derivatives in CladPlugin::HandleTranslationUnitDecl DeclContext* derivativeDC = DerivativeDecl->getDeclContext(); bool isTUorND = derivativeDC->isTranslationUnit() || derivativeDC->isNamespace(); @@ -374,6 +309,70 @@ namespace clad { return nullptr; } + void CladPlugin::SendToMultiplexer() { + for (auto I = m_DelayedCalls.begin(); I != m_DelayedCalls.end(); ++I) { + DeclGroupRef& D = I->m_DGR; + switch (I->m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + m_Multiplexer->HandleCXXStaticMemberVarInstantiation( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTopLevelDecl: + m_Multiplexer->HandleTopLevelDecl(D); + break; + case CallKind::HandleInlineFunctionDefinition: + m_Multiplexer->HandleInlineFunctionDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleInterestingDecl: + m_Multiplexer->HandleInterestingDecl(D); + break; + case CallKind::HandleTagDeclDefinition: + m_Multiplexer->HandleTagDeclDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTagDeclRequiredDefinition: + m_Multiplexer->HandleTagDeclRequiredDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + m_Multiplexer->HandleCXXImplicitFunctionInstantiation( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + m_Multiplexer->HandleTopLevelDeclInObjCContainer(D); + break; + case CallKind::HandleImplicitImportDecl: + m_Multiplexer->HandleImplicitImportDecl( + cast(D.getSingleDecl())); + break; + case CallKind::CompleteTentativeDefinition: + m_Multiplexer->CompleteTentativeDefinition( + cast(D.getSingleDecl())); + break; +#if CLANG_VERSION_MAJOR > 9 + case CallKind::CompleteExternalDeclaration: + m_Multiplexer->CompleteExternalDeclaration( + cast(D.getSingleDecl())); + break; +#endif + case CallKind::AssignInheritanceModel: + m_Multiplexer->AssignInheritanceModel( + cast(D.getSingleDecl())); + break; + case CallKind::HandleVTable: + m_Multiplexer->HandleVTable(cast(D.getSingleDecl())); + break; + case CallKind::InitializeSema: + m_Multiplexer->InitializeSema(m_CI.getSema()); + break; + case CallKind::ForgetSema: + m_Multiplexer->ForgetSema(); + break; + }; + } + } + bool CladPlugin::CheckBuiltins() { // If we have included "clad/Differentiator/Differentiator.h" return. if (m_HasRuntime) @@ -396,6 +395,89 @@ namespace clad { m_HasRuntime = !R.empty(); return m_HasRuntime; } + + void CladPlugin::HandleTranslationUnit(ASTContext& C) { + Sema& S = m_CI.getSema(); + constexpr bool Enabled = true; + Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled); + Sema::LocalEagerInstantiationScope LocalInstantiations(S); + + for (DiffRequest& request : m_DiffSchedule) { + // FIXME: flags have to be set manually since DiffCollector's + // constructor does not have access to m_DO. + request.EnableTBRAnalysis = m_DO.EnableTBRAnalysis; + ProcessDiffRequest(request); + } + // Force emission of the produced pending template instantiations. + LocalInstantiations.perform(); + GlobalInstantiations.perform(); + + SendToMultiplexer(); + m_Multiplexer->HandleTranslationUnit(C); + } + + void CladPlugin::PrintStats() { + llvm::errs() << "*** INFORMATION ABOUT THE DELAYED CALLS\n"; + for (const DelayedCallInfo& DCI : m_DelayedCalls) { + llvm::errs() << " "; + switch (DCI.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + llvm::errs() << "HandleCXXStaticMemberVarInstantiation"; + break; + case CallKind::HandleTopLevelDecl: + llvm::errs() << "HandleTopLevelDecl"; + break; + case CallKind::HandleInlineFunctionDefinition: + llvm::errs() << "HandleInlineFunctionDefinition"; + break; + case CallKind::HandleInterestingDecl: + llvm::errs() << "HandleInterestingDecl"; + break; + case CallKind::HandleTagDeclDefinition: + llvm::errs() << "HandleTagDeclDefinition"; + break; + case CallKind::HandleTagDeclRequiredDefinition: + llvm::errs() << "HandleTagDeclRequiredDefinition"; + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + llvm::errs() << "HandleCXXImplicitFunctionInstantiation"; + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + llvm::errs() << "HandleTopLevelDeclInObjCContainer"; + break; + case CallKind::HandleImplicitImportDecl: + llvm::errs() << "HandleImplicitImportDecl"; + break; + case CallKind::CompleteTentativeDefinition: + llvm::errs() << "CompleteTentativeDefinition"; + break; + case CallKind::CompleteExternalDeclaration: + llvm::errs() << "CompleteExternalDeclaration"; + break; + case CallKind::AssignInheritanceModel: + llvm::errs() << "AssignInheritanceModel"; + break; + case CallKind::HandleVTable: + llvm::errs() << "HandleVTable"; + break; + case CallKind::InitializeSema: + llvm::errs() << "InitializeSema"; + break; + case CallKind::ForgetSema: + llvm::errs() << "ForgetSema"; + break; + }; + for (const clang::Decl* D : DCI.m_DGR) { + llvm::errs() << " " << D; + if (const auto* ND = dyn_cast(D)) + llvm::errs() << " " << ND->getNameAsString(); + } + llvm::errs() << "\n"; + } + + m_Multiplexer->PrintStats(); + } + } // end namespace plugin // Routine to check clang version at runtime against the clang version for diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 3e99a93b9..85f7deff2 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -88,9 +88,8 @@ namespace clad { DifferentiationOptions m_DO; std::unique_ptr m_DerivativeBuilder; bool m_HasRuntime = false; - bool m_PendingInstantiationsInFlight = false; - bool m_HandleTopLevelDeclInternal = false; DerivedFnCollector m_DFC; + DiffSchedule m_DiffSchedule; enum class CallKind { HandleCXXStaticMemberVarInstantiation, HandleTopLevelDecl, @@ -102,7 +101,9 @@ namespace clad { HandleTopLevelDeclInObjCContainer, HandleImplicitImportDecl, CompleteTentativeDefinition, +#if CLANG_VERSION_MAJOR > 9 CompleteExternalDeclaration, +#endif AssignInheritanceModel, HandleVTable, InitializeSema, @@ -115,8 +116,23 @@ namespace clad { : m_Kind(K), m_DGR(DGR) {} DelayedCallInfo(CallKind K, const clang::Decl* D) : m_Kind(K), m_DGR(const_cast(D)) {} + bool operator==(const DelayedCallInfo& other) const { + if (m_Kind != other.m_Kind) + return false; + + auto first1 = m_DGR.begin(); + auto first2 = other.m_DGR.begin(); + auto last1 = m_DGR.end(); + for (; first1 != last1; ++first1, ++first2) + if (!(*first1 == *first2)) + return false; + return true; + } }; + /// The calls to the main action which clad delayed and will dispatch at + /// then end of the translation unit. std::vector m_DelayedCalls; + /// The default clang consumers which are called after clad is done. std::unique_ptr m_Multiplexer; public: @@ -126,71 +142,60 @@ namespace clad { void Initialize(clang::ASTContext& Context) override; void HandleCXXStaticMemberVarInstantiation(clang::VarDecl* D) override { AppendDelayed({CallKind::HandleCXXStaticMemberVarInstantiation, D}); - m_Multiplexer->HandleCXXStaticMemberVarInstantiation(D); } - bool HandleTopLevelDecl(clang::DeclGroupRef D) override; /*{ + bool HandleTopLevelDecl(clang::DeclGroupRef D) override { + HandleTopLevelDeclForClad(D); AppendDelayed({CallKind::HandleTopLevelDecl, D}); return true; // happyness, continue parsing - }*/ + } void HandleInlineFunctionDefinition(clang::FunctionDecl* D) override { AppendDelayed({CallKind::HandleInlineFunctionDefinition, D}); - m_Multiplexer->HandleInlineFunctionDefinition(D); } void HandleInterestingDecl(clang::DeclGroupRef D) override { AppendDelayed({CallKind::HandleInterestingDecl, D}); - m_Multiplexer->HandleInterestingDecl(D); } void HandleTagDeclDefinition(clang::TagDecl* D) override { AppendDelayed({CallKind::HandleTagDeclDefinition, D}); - m_Multiplexer->HandleTagDeclDefinition(D); } void HandleTagDeclRequiredDefinition(const clang::TagDecl* D) override { AppendDelayed({CallKind::HandleTagDeclRequiredDefinition, D}); - m_Multiplexer->HandleTagDeclRequiredDefinition(D); } void HandleCXXImplicitFunctionInstantiation(clang::FunctionDecl* D) override { AppendDelayed({CallKind::HandleCXXImplicitFunctionInstantiation, D}); - m_Multiplexer->HandleCXXImplicitFunctionInstantiation(D); } void HandleTopLevelDeclInObjCContainer(clang::DeclGroupRef D) override { AppendDelayed({CallKind::HandleTopLevelDeclInObjCContainer, D}); - m_Multiplexer->HandleTopLevelDeclInObjCContainer(D); } void HandleImplicitImportDecl(clang::ImportDecl* D) override { AppendDelayed({CallKind::HandleImplicitImportDecl, D}); - m_Multiplexer->HandleImplicitImportDecl(D); } void CompleteTentativeDefinition(clang::VarDecl* D) override { AppendDelayed({CallKind::CompleteTentativeDefinition, D}); - m_Multiplexer->CompleteTentativeDefinition(D); } #if CLANG_VERSION_MAJOR > 9 void CompleteExternalDeclaration(clang::VarDecl* D) override { AppendDelayed({CallKind::CompleteExternalDeclaration, D}); - m_Multiplexer->CompleteExternalDeclaration(D); } #endif void AssignInheritanceModel(clang::CXXRecordDecl* D) override { AppendDelayed({CallKind::AssignInheritanceModel, D}); - m_Multiplexer->AssignInheritanceModel(D); } void HandleVTable(clang::CXXRecordDecl* D) override { AppendDelayed({CallKind::HandleVTable, D}); - m_Multiplexer->HandleVTable(D); } // Not delayed. - void HandleTranslationUnit(clang::ASTContext& C) override { - m_Multiplexer->HandleTranslationUnit(C); - } + void HandleTranslationUnit(clang::ASTContext& C) override; + // No need to handle the listeners, they will be handled at non-delayed by // the parent multiplexer. // // clang::ASTMutationListener *GetASTMutationListener() override; // clang::ASTDeserializationListener *GetASTDeserializationListener() // override; - void PrintStats() override { m_Multiplexer->PrintStats(); } + void PrintStats() override; + bool shouldSkipFunctionBody(clang::Decl* D) override { return m_Multiplexer->shouldSkipFunctionBody(D); } @@ -198,20 +203,23 @@ namespace clad { // SemaConsumer void InitializeSema(clang::Sema& S) override { AppendDelayed({CallKind::InitializeSema, nullptr}); - m_Multiplexer->InitializeSema(S); } void ForgetSema() override { AppendDelayed({CallKind::ForgetSema, nullptr}); - m_Multiplexer->ForgetSema(); } - // bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override; clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); private: void AppendDelayed(DelayedCallInfo DCI) { m_DelayedCalls.push_back(DCI); } + void SendToMultiplexer(); bool CheckBuiltins(); - void ProcessTopLevelDecl(clang::Decl* D); + void ProcessTopLevelDecl(clang::Decl* D) { + DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D}; + assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!"); + AppendDelayed(DCI); + } + void HandleTopLevelDeclForClad(clang::DeclGroupRef DGR); }; clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P,