Skip to content

Commit

Permalink
Delay the differentiation process until the end of TU.
Browse files Browse the repository at this point in the history
Before this patch clad attaches itself as a first consumer and applies AD before
code generation. However, that is limited since clang sends every top-level
declaration to codegen which limits the amount of flexibility clad has. For
example, we have to instantiate all pending templates at every
HandleTopLevelDecl calls; we cannot really differentiate virtual functions
whose classes have sent their key function to CodeGen; and in general we perform
actions which are semantically useful for the end of the translation unit.

This patch makes clad a single consumer of clang which dispatches to the others.
That's done by delaying all calls to the consumers until the end of the TU where
clad can replay the exact sequence of calls to the other consumers as if they
were directly connected to the frontend.

Fixes #248
  • Loading branch information
vgvassilev committed Mar 9, 2024
1 parent d7e5434 commit 0fe47e4
Show file tree
Hide file tree
Showing 6 changed files with 390 additions and 79 deletions.
29 changes: 29 additions & 0 deletions include/clad/Differentiator/Sins.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef CLAD_DIFFERENTIATOR_SINS_H
#define CLAD_DIFFERENTIATOR_SINS_H

#include <type_traits>

/// Standard-protected facility allowing access into private members in C++.
/// Use with caution!
// NOLINTBEGIN(cppcoreguidelines-macro-usage)
#define CONCATE_(X, Y) X##Y
#define CONCATE(X, Y) CONCATE_(X, Y)
#define ALLOW_ACCESS(CLASS, MEMBER, ...) \
template <typename Only, __VA_ARGS__ CLASS::*Member> \
struct CONCATE(MEMBER, __LINE__) { \
friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \
}; \
template <typename> struct Only_##MEMBER; \
template <> struct Only_##MEMBER<CLASS> { \
friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER<CLASS>*); \
}; \
template struct CONCATE(MEMBER, \
__LINE__)<Only_##MEMBER<CLASS>, &CLASS::MEMBER>

#define ACCESS(OBJECT, MEMBER) \
(OBJECT).*Access((Only_##MEMBER< \
std::remove_reference<decltype(OBJECT)>::type>*)nullptr)

// NOLINTEND(cppcoreguidelines-macro-usage)

#endif // CLAD_DIFFERENTIATOR_SINS_H
37 changes: 5 additions & 32 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

#include "ConstantFolder.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/CladUtils.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
Expand Down Expand Up @@ -59,42 +60,14 @@ namespace clad {
return true;
}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};

// Tag used to access Sema::CurScope.
using namespace clang;
struct Sema_CurScope : TagBase<Sema_CurScope, Scope * Sema::*> {};
template struct Rob<Sema_CurScope, &Sema::CurScope>;
} // namespace
ALLOW_ACCESS(Sema, CurScope, Scope*);

clang::Scope*& VisitorBase::getCurrentScope() {
return m_Sema.*get(Sema_CurScope());
return ACCESS(m_Sema, CurScope);
}

void VisitorBase::setCurrentScope(clang::Scope* S) {
m_Sema.*get(Sema_CurScope()) = S;
getCurrentScope() = S;
assert(getEnclosingNamespaceOrTUScope() && "Lost path to base.");
}

Expand Down
8 changes: 8 additions & 0 deletions test/FirstDerivative/CodeGenSimple.C
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ extern "C" int printf(const char* fmt, ...);

int f_1_darg0(int x);

double sq_defined_later(double);

int main() {
int x = 4;
clad::differentiate(f_1, 0);
auto df = clad::differentiate(sq_defined_later, "x");
printf("Result is = %d\n", f_1_darg0(1)); // CHECK-EXEC: Result is = 2
printf("Result is = %f\n", df.execute(3)); // CHECK-EXEC: Result is = 6
return 0;
}

double sq_defined_later(double x) {
return x * x;
}
8 changes: 8 additions & 0 deletions test/Misc/ClangConsumers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out -Xclang -print-stats 2>&1 | FileCheck %s
// CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"
// CHECK: HandleTopLevelDecl
int main() {

}
224 changes: 184 additions & 40 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "clad/Differentiator/DerivativeBuilder.h"
#include "clad/Differentiator/EstimationModel.h"

#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/Version.h"

#include "clang/AST/ASTConsumer.h"
Expand Down Expand Up @@ -91,60 +92,52 @@ namespace clad {

CladPlugin::~CladPlugin() {}

ALLOW_ACCESS(MultiplexConsumer, Consumers,
std::vector<std::unique_ptr<ASTConsumer>>);

void CladPlugin::Initialize(clang::ASTContext& C) {
// We know we have a multiplexer. We commit a sin here by stealing it and
// making the consumer pass-through so that we can delay all operations
// until clad is happy.

using namespace clang;

auto& MultiplexC = static_cast<MultiplexConsumer&>(m_CI.getASTConsumer());
auto& RobbedCs = ACCESS(MultiplexC, Consumers);
assert(RobbedCs.back().get() == this && "Clad is not the last consumer");
std::vector<std::unique_ptr<ASTConsumer>> StolenConsumers;

// The range-based for loop in MultiplexConsumer::Initialize has
// dispatched this call. Generally, it is unsafe to delete elements while
// iterating but we know we are in the end of the loop and ::end() won't
// be invalidated.
for (auto& RC : RobbedCs)
if (RC.get() == this)
RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1);
else
StolenConsumers.push_back(std::move(RC));
m_Multiplexer.reset(new MultiplexConsumer(std::move(StolenConsumers)));
}

// We cannot use HandleTranslationUnit because codegen already emits code on
// HandleTopLevelDecl calls and makes updateCall with no effect.
bool CladPlugin::HandleTopLevelDecl(DeclGroupRef DGR) {
void CladPlugin::HandleTopLevelDeclForClad(DeclGroupRef DGR) {
if (!CheckBuiltins())
return true;
return;

Sema& S = m_CI.getSema();

if (!m_DerivativeBuilder)
m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this));

// if HandleTopLevelDecl was called through clad we don't need to process
// it for diff requests
if (m_HandleTopLevelDeclInternal)
return true;

DiffSchedule requests{};
DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema());

if (requests.empty())
return true;
m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this));

// FIXME: flags have to be set manually since DiffCollector's constructor
// does not have access to m_DO.
if (m_DO.EnableTBRAnalysis)
for (DiffRequest& request : requests)
request.EnableTBRAnalysis = true;

// FIXME: Remove the PerformPendingInstantiations altogether. We should
// somehow make the relevant functions referenced.
// Instantiate all pending for instantiations templates, because we will
// need the full bodies to produce derivatives.
// FIXME: Confirm if we really need `m_PendingInstantiationsInFlight`?
if (!m_PendingInstantiationsInFlight) {
m_PendingInstantiationsInFlight = true;
S.PerformPendingInstantiations();
m_PendingInstantiationsInFlight = false;
}

for (DiffRequest& request : requests)
ProcessDiffRequest(request);
return true; // Happiness
}

void CladPlugin::ProcessTopLevelDecl(Decl* D) {
m_HandleTopLevelDeclInternal = true;
m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D));
m_HandleTopLevelDeclInternal = false;
DiffCollector collector(DGR, CladEnabledRange, m_DiffSchedule, S);
}

FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) {
Sema& S = m_CI.getSema();
// Required due to custom derivatives function templates that might be
// used in the function that we need to derive.
// FIXME: Remove the call to PerformPendingInstantiations().
S.PerformPendingInstantiations();
if (request.Function->getDefinition())
request.Function = request.Function->getDefinition();
Expand Down Expand Up @@ -267,6 +260,8 @@ namespace clad {

// Call CodeGen only if the produced Decl is a top-most
// decl or is contained in a namespace decl.
// FIXME: We could get rid of this by prepending the produced
// derivatives in CladPlugin::HandleTranslationUnitDecl
DeclContext* derivativeDC = DerivativeDecl->getDeclContext();
bool isTUorND =
derivativeDC->isTranslationUnit() || derivativeDC->isNamespace();
Expand Down Expand Up @@ -296,6 +291,70 @@ namespace clad {
return nullptr;
}

void CladPlugin::SendToMultiplexer() {
for (auto DelayedCall : m_DelayedCalls) {
DeclGroupRef& D = DelayedCall.m_DGR;
switch (DelayedCall.m_Kind) {
case CallKind::HandleCXXStaticMemberVarInstantiation:
m_Multiplexer->HandleCXXStaticMemberVarInstantiation(
cast<VarDecl>(D.getSingleDecl()));
break;
case CallKind::HandleTopLevelDecl:
m_Multiplexer->HandleTopLevelDecl(D);
break;
case CallKind::HandleInlineFunctionDefinition:
m_Multiplexer->HandleInlineFunctionDefinition(
cast<FunctionDecl>(D.getSingleDecl()));
break;
case CallKind::HandleInterestingDecl:
m_Multiplexer->HandleInterestingDecl(D);
break;
case CallKind::HandleTagDeclDefinition:
m_Multiplexer->HandleTagDeclDefinition(
cast<TagDecl>(D.getSingleDecl()));
break;
case CallKind::HandleTagDeclRequiredDefinition:
m_Multiplexer->HandleTagDeclRequiredDefinition(
cast<TagDecl>(D.getSingleDecl()));
break;
case CallKind::HandleCXXImplicitFunctionInstantiation:
m_Multiplexer->HandleCXXImplicitFunctionInstantiation(
cast<FunctionDecl>(D.getSingleDecl()));
break;
case CallKind::HandleTopLevelDeclInObjCContainer:
m_Multiplexer->HandleTopLevelDeclInObjCContainer(D);
break;
case CallKind::HandleImplicitImportDecl:
m_Multiplexer->HandleImplicitImportDecl(
cast<ImportDecl>(D.getSingleDecl()));
break;
case CallKind::CompleteTentativeDefinition:
m_Multiplexer->CompleteTentativeDefinition(
cast<VarDecl>(D.getSingleDecl()));
break;
#if CLANG_VERSION_MAJOR > 9
case CallKind::CompleteExternalDeclaration:
m_Multiplexer->CompleteExternalDeclaration(
cast<VarDecl>(D.getSingleDecl()));
break;
#endif
case CallKind::AssignInheritanceModel:
m_Multiplexer->AssignInheritanceModel(
cast<CXXRecordDecl>(D.getSingleDecl()));
break;
case CallKind::HandleVTable:
m_Multiplexer->HandleVTable(cast<CXXRecordDecl>(D.getSingleDecl()));
break;
case CallKind::InitializeSema:
m_Multiplexer->InitializeSema(m_CI.getSema());
break;
case CallKind::ForgetSema:
m_Multiplexer->ForgetSema();
break;
};
}
}

bool CladPlugin::CheckBuiltins() {
// If we have included "clad/Differentiator/Differentiator.h" return.
if (m_HasRuntime)
Expand All @@ -318,6 +377,91 @@ namespace clad {
m_HasRuntime = !R.empty();
return m_HasRuntime;
}

void CladPlugin::HandleTranslationUnit(ASTContext& C) {
Sema& S = m_CI.getSema();
constexpr bool Enabled = true;
Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled);
Sema::LocalEagerInstantiationScope LocalInstantiations(S);

for (DiffRequest& request : m_DiffSchedule) {
// FIXME: flags have to be set manually since DiffCollector's
// constructor does not have access to m_DO.
request.EnableTBRAnalysis = m_DO.EnableTBRAnalysis;
ProcessDiffRequest(request);
}
// Force emission of the produced pending template instantiations.
LocalInstantiations.perform();
GlobalInstantiations.perform();

SendToMultiplexer();
m_Multiplexer->HandleTranslationUnit(C);
}

void CladPlugin::PrintStats() {
llvm::errs() << "*** INFORMATION ABOUT THE DELAYED CALLS\n";
for (const DelayedCallInfo& DCI : m_DelayedCalls) {
llvm::errs() << " ";
switch (DCI.m_Kind) {
case CallKind::HandleCXXStaticMemberVarInstantiation:
llvm::errs() << "HandleCXXStaticMemberVarInstantiation";
break;
case CallKind::HandleTopLevelDecl:
llvm::errs() << "HandleTopLevelDecl";
break;
case CallKind::HandleInlineFunctionDefinition:
llvm::errs() << "HandleInlineFunctionDefinition";
break;
case CallKind::HandleInterestingDecl:
llvm::errs() << "HandleInterestingDecl";
break;
case CallKind::HandleTagDeclDefinition:
llvm::errs() << "HandleTagDeclDefinition";
break;
case CallKind::HandleTagDeclRequiredDefinition:
llvm::errs() << "HandleTagDeclRequiredDefinition";
break;
case CallKind::HandleCXXImplicitFunctionInstantiation:
llvm::errs() << "HandleCXXImplicitFunctionInstantiation";
break;
case CallKind::HandleTopLevelDeclInObjCContainer:
llvm::errs() << "HandleTopLevelDeclInObjCContainer";
break;
case CallKind::HandleImplicitImportDecl:
llvm::errs() << "HandleImplicitImportDecl";
break;
case CallKind::CompleteTentativeDefinition:
llvm::errs() << "CompleteTentativeDefinition";
break;
#if CLANG_VERSION_MAJOR > 9
case CallKind::CompleteExternalDeclaration:
llvm::errs() << "CompleteExternalDeclaration";
break;
#endif
case CallKind::AssignInheritanceModel:
llvm::errs() << "AssignInheritanceModel";
break;
case CallKind::HandleVTable:
llvm::errs() << "HandleVTable";
break;
case CallKind::InitializeSema:
llvm::errs() << "InitializeSema";
break;
case CallKind::ForgetSema:
llvm::errs() << "ForgetSema";
break;
};
for (const clang::Decl* D : DCI.m_DGR) {
llvm::errs() << " " << D;
if (const auto* ND = dyn_cast<NamedDecl>(D))
llvm::errs() << " " << ND->getNameAsString();
}
llvm::errs() << "\n";
}

m_Multiplexer->PrintStats();
}

} // end namespace plugin

clad::CladTimerGroup::CladTimerGroup()
Expand Down
Loading

0 comments on commit 0fe47e4

Please sign in to comment.