Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiate global variables in the reverse mode #1153

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace clad {
class CladPlugin;
clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P,
DiffRequest& request);
void ProcessTopLevelDecl(CladPlugin& P, clang::Decl* D);
} // namespace plugin

} // namespace clad
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ namespace clad {
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD,
bool keepLocal = false);
clang::Expr* DifferentiateGlobalVarDecl(clang::VarDecl* VD);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
Expand Down
8 changes: 5 additions & 3 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,9 @@ namespace clad {
/// \returns The newly built variable declaration.
clang::VarDecl*
BuildVarDecl(clang::QualType Type, clang::IdentifierInfo* Identifier,
clang::Scope* scope, clang::Expr* Init = nullptr,
bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr,
clang::Scope* scope, clang::DeclContext* DeclCtx,
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// Builds variable declaration to be used inside the derivative
Expand Down Expand Up @@ -334,7 +335,8 @@ namespace clad {
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
clang::VarDecl::InitializationStyle::CInit,
clang::DeclContext* DeclCtx = nullptr);
/// Creates a namespace declaration and enters its context. All subsequent
/// Stmts are built inside that namespace, until
/// m_Sema.PopDeclContextIsUsed.
Expand Down
44 changes: 42 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1413,8 +1413,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Check DeclRefExpr is a reference to an independent variable.
auto it = m_Variables.find(VD);
if (it == std::end(m_Variables)) {
// Is not an independent variable, ignored.
return StmtDiff(clonedDRE);
if (VD->isFileVarDecl()) {
Expr* DREDiff = DifferentiateGlobalVarDecl(VD);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can detect if a needed variable is global in the DiffPlanner and schedule it as part of the diff graph. This would enable us to remove the ProcessTopLevelDecl which is a layering violation.

it = m_Variables.emplace(VD, DREDiff).first;
} else
// Is not an independent variable, ignored.
return StmtDiff(clonedDRE);
}
// Create the (_d_param[idx] += dfdx) statement.
if (dfdx()) {
Expand All @@ -1440,6 +1444,42 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(clonedDRE);
}

Expr* ReverseModeVisitor::DifferentiateGlobalVarDecl(VarDecl* VD) {
assert(VD->isFileVarDecl() && "Must be a global variable");
std::string nameDiff_str = "_d_" + VD->getNameAsString();
DeclarationName nameDiff = &m_Context.Idents.get(nameDiff_str);
DeclContext* DC = VD->getDeclContext();

// Attempt to find the adjoint of VD in case it has already been created.
LookupResult result(m_Sema, nameDiff, noLoc, Sema::LookupOrdinaryName);
m_Sema.LookupQualifiedName(result, DC);
if (!result.empty()) {
// Found, return a reference
Expr* foundExpr = m_Sema
.BuildDeclarationNameExpr(CXXScopeSpec{}, result,
/*ADL=*/false)
.get();
return foundExpr;
}
// Not found, construct the adjoint and register it.
VarDecl* VDDiff =
BuildVarDecl(VD->getType(), CreateUniqueIdentifier(nameDiff_str),
m_DerivativeFnScope->getParent()->getParent(), DC,
getZeroInit(VD->getType()));

DC->addDecl(VDDiff);
DC->makeDeclVisibleInContext(VDDiff);
plugin::ProcessTopLevelDecl(m_CladPlugin, VDDiff);
// diag(DiagnosticsEngine::Warning,
// VD->getLocation(),
// "The gradient utilizes a global variable '%0' and its adjoint
// '%1'"
// ". Please make sure to properly reset '%0' and '%1' before
// re-running the gradient.",
// {VD->getNameAsString(), nameDiff_str});
return BuildDeclRef(VDDiff);
}

StmtDiff ReverseModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
auto* Constant0 =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down
20 changes: 12 additions & 8 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,19 @@ namespace clad {
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
return BuildVarDecl(Type, Identifier, getCurrentScope(), Init, DirectInit,
TSI, IS);
return BuildVarDecl(Type, Identifier, getCurrentScope(), m_Sema.CurContext,
Init, DirectInit, TSI, IS);
}
VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Scope* Scope, Expr* Init, bool DirectInit,
Scope* Scope, DeclContext* DeclCtx,
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
// add namespace specifier in variable declaration if needed.
Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type);
auto* VD = VarDecl::Create(
m_Context, m_Sema.CurContext, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None);
auto* VD = VarDecl::Create(m_Context, DeclCtx, m_DiffReq->getLocation(),
m_DiffReq->getLocation(), Identifier, Type, TSI,
SC_None);

if (Init) {
m_Sema.AddInitializerToDecl(VD, Init, DirectInit);
Expand Down Expand Up @@ -149,9 +150,12 @@ namespace clad {
VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type,
llvm::StringRef prefix, Expr* Init,
bool DirectInit, TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
VarDecl::InitializationStyle IS,
DeclContext* DeclCtx) {
DeclCtx = DeclCtx ? DeclCtx : m_Sema.CurContext;
return BuildVarDecl(Type, CreateUniqueIdentifier(prefix),
m_DerivativeFnScope, Init, DirectInit, TSI, IS);
m_DerivativeFnScope, DeclCtx, Init, DirectInit, TSI,
IS);
}

NamespaceDecl* VisitorBase::BuildNamespaceDecl(IdentifierInfo* II,
Expand Down
1 change: 1 addition & 0 deletions test/Gradient/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ int main() {

// CHECK: inline void operator_call_grad(double ii, double j, double *_d_ii, double *_d_j) const {
// CHECK-NEXT: {
// CHECK-NEXT: _d_x += 1 * j * ii;
// CHECK-NEXT: *_d_ii += x * 1 * j;
// CHECK-NEXT: *_d_j += x * ii * 1;
// CHECK-NEXT: }
Expand Down
62 changes: 61 additions & 1 deletion test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -666,14 +666,16 @@ float running_sum(float* p, int n) {
// CHECK-NEXT: }

double global = 7;
// CHECK: double _d_global = 0.;
// expected-warning {{The gradient utilizes a global variable 'global' and its adjoint '_d_global'. Please make sure to properly reset 'global' and '_d_global' before re-running the gradient.}}

double fn_global_var_use(double i, double j) {
double& ref = global;
return ref * i;
}

// CHECK: void fn_global_var_use_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _d_ref = 0.;
// CHECK-NEXT: double &_d_ref = _d_global;
// CHECK-NEXT: double &ref = global;
// CHECK-NEXT: {
// CHECK-NEXT: _d_ref += 1 * i;
Expand Down Expand Up @@ -1143,6 +1145,49 @@ double f_ref_in_rhs(double x, double y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double glob1 = 5;

double g(double a, double b) {
glob1 = b;
return a;
}

//CHECK: void g_pullback(double a, double b, double _d_y, double *_d_a, double *_d_b);

//CHECK: double _d_glob1 = 0.;
// expected-warning {{The gradient utilizes a global variable 'glob1' and its adjoint '_d_glob1'. Please make sure to properly reset 'glob1' and '_d_glob1' before re-running the gradient.}}

double f_reuse_global(double x, double t) {
t = g(t, x);
glob1 *= t;
return -glob1;
} // -x * t

//CHECK: void f_reuse_global_grad(double x, double t, double *_d_x, double *_d_t) {
//CHECK-NEXT: double _t0 = t;
//CHECK-NEXT: t = g(t, x);
//CHECK-NEXT: double _t1 = glob1;
//CHECK-NEXT: glob1 *= t;
//CHECK-NEXT: _d_glob1 += -1;
//CHECK-NEXT: {
//CHECK-NEXT: glob1 = _t1;
//CHECK-NEXT: double _r_d1 = _d_glob1;
//CHECK-NEXT: _d_glob1 = 0.;
//CHECK-NEXT: _d_glob1 += _r_d1 * t;
//CHECK-NEXT: *_d_t += glob1 * _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: t = _t0;
//CHECK-NEXT: double _r_d0 = *_d_t;
//CHECK-NEXT: *_d_t = 0.;
//CHECK-NEXT: double _r0 = 0.;
//CHECK-NEXT: double _r1 = 0.;
//CHECK-NEXT: g_pullback(t, x, _r_d0, &_r0, &_r1);
//CHECK-NEXT: *_d_t += _r0;
//CHECK-NEXT: *_d_x += _r1;
//CHECK-NEXT: }
//CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -1239,4 +1284,19 @@ int main() {

INIT_GRADIENT(f_ref_in_rhs);
TEST_GRADIENT(f_ref_in_rhs, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {5.00, 13.00}

INIT_GRADIENT(f_reuse_global);
TEST_GRADIENT(f_reuse_global, /*numOfDerivativeArgs=*/2, -3, 4, &d_i, &d_j); // CHECK-EXEC: {-4.00, 3.00}
}

//CHECK-NEXT: void g_pullback(double a, double b, double _d_y, double *_d_a, double *_d_b) {
//CHECK-NEXT: double _t0 = glob1;
//CHECK-NEXT: glob1 = b;
//CHECK-NEXT: *_d_a += _d_y;
//CHECK-NEXT: {
//CHECK-NEXT: glob1 = _t0;
//CHECK-NEXT: double _r_d0 = _d_glob1;
//CHECK-NEXT: _d_glob1 = 0.;
//CHECK-NEXT: *_d_b += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
5 changes: 4 additions & 1 deletion test/ValidCodeGen/ValidCodeGen.C
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ int main() {
//CHECK-NEXT: }

//CHECK: void fn_grad(double x, double *_d_x) {
//CHECK-NEXT: *_d_x += 1 * TN::coefficient;
//CHECK-NEXT: {
//CHECK-NEXT: *_d_x += 1 * TN::coefficient;
//CHECK-NEXT: _d_coefficient += x * 1;
//CHECK-NEXT: }
//CHECK-NEXT: }

//CHECK: void fn2_grad(double x, double y, double *_d_x, double *_d_y) {
Expand Down
32 changes: 22 additions & 10 deletions tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,24 @@ class CladTimerGroup {
// handling of the differentiation plans.
clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request);

void ProcessTopLevelDecl(clang::Decl* D) {
if (llvm::isa<clang::VarDecl>(D) && m_DO.DumpDerivedFn) {
clang::LangOptions LangOpts;
LangOpts.CPlusPlus = true;
clang::PrintingPolicy Policy(LangOpts);
Policy.Bool = true;
D->print(llvm::outs(), Policy);
llvm::outs() << ";\n";
}
DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D};
assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!");
AppendDelayed(DCI);
// We could not delay the process due to some strange way of
// initialization, inform the consumers now.
if (!m_Multiplexer)
m_CI.getASTConsumer().HandleTopLevelDecl(DCI.m_DGR);
}

private:
void AppendDelayed(DelayedCallInfo DCI) {
// Incremental processing handles the translation unit in chunks and it is
Expand All @@ -268,16 +286,6 @@ class CladTimerGroup {
void SendToMultiplexer();
bool CheckBuiltins();
void SetRequestOptions(RequestOptions& opts) const;

void ProcessTopLevelDecl(clang::Decl* D) {
DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D};
assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!");
AppendDelayed(DCI);
// We could not delay the process due to some strange way of
// initialization, inform the consumers now.
if (!m_Multiplexer)
m_CI.getASTConsumer().HandleTopLevelDecl(DCI.m_DGR);
}
void HandleTopLevelDeclForClad(clang::DeclGroupRef DGR);
};

Expand All @@ -286,6 +294,10 @@ class CladTimerGroup {
return P.ProcessDiffRequest(request);
}

void ProcessTopLevelDecl(CladPlugin& P, clang::Decl* D) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: function 'ProcessTopLevelDecl' defined in a header file; function definitions in header files can lead to ODR violations [misc-definitions-in-headers]

    void ProcessTopLevelDecl(CladPlugin& P, clang::Decl* D) {
         ^
Additional context

tools/ClangPlugin.h:296: make as 'inline'

    void ProcessTopLevelDecl(CladPlugin& P, clang::Decl* D) {
         ^

P.ProcessTopLevelDecl(D);
}

template <typename ConsumerType>
class Action : public clang::PluginASTAction {
private:
Expand Down
Loading