Skip to content

Commit

Permalink
Add initial support for pushforward functions to diff fn calls
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 authored and vgvassilev committed Jan 30, 2022
1 parent 8bec3c4 commit 9d74f6d
Show file tree
Hide file tree
Showing 10 changed files with 375 additions and 67 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace clad {
enum class DiffMode {
unknown = 0,
forward,
experimental_pushforward,
reverse,
hessian,
jacobian,
Expand Down
5 changes: 4 additions & 1 deletion include/clad/Differentiator/ForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"
#include "clad/Differentiator/DiffMode.h"

#include <array>
#include <stack>
Expand All @@ -28,7 +29,7 @@ namespace clad {
unsigned m_IndependentVarIndex = ~0;
unsigned m_DerivativeOrder = ~0;
unsigned m_ArgIndex = ~0;

DiffMode m_Mode;
public:
ForwardModeVisitor(DerivativeBuilder& builder);
~ForwardModeVisitor();
Expand All @@ -42,6 +43,8 @@ namespace clad {
///
OverloadedDeclWithContext Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
OverloadedDeclWithContext DerivePushforward(const clang::FunctionDecl* FD,
const DiffRequest& request);
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
Expand Down
3 changes: 3 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ namespace clad {
clang::Expr*
BuildArrayRefSliceExpr(clang::Expr* Base,
llvm::MutableArrayRef<clang::Expr*> Args);
clang::ParmVarDecl* CloneParmVarDecl(const clang::ParmVarDecl* PVD,
clang::IdentifierInfo* II,
bool pushOnScopeChains = false);
/// A function to get the single argument "forward_central_difference"
/// call expression for the given arguments.
///
Expand Down
3 changes: 3 additions & 0 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ namespace clad {
if (request.Mode == DiffMode::forward) {
ForwardModeVisitor V(*this);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::experimental_pushforward) {
ForwardModeVisitor V(*this);
result = V.DerivePushforward(FD, request);
} else if (request.Mode == DiffMode::reverse) {
ReverseModeVisitor V(*this);
result = V.Derive(FD, request);
Expand Down
152 changes: 124 additions & 28 deletions lib/Differentiator/ForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,107 @@ namespace clad {

ForwardModeVisitor::~ForwardModeVisitor() {}

OverloadedDeclWithContext
ForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
const DiffRequest& request) {
m_Function = FD;
m_Functor = request.Functor;
m_DerivativeOrder = request.CurrentDerivativeOrder;
m_Mode = DiffMode::experimental_pushforward;
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;

IdentifierInfo* derivedFnII =
&m_Context.Idents.get(m_Function->getNameAsString() + "_pushforward");
DeclarationNameInfo derivedFnName(derivedFnII, noLoc);
llvm::SmallVector<QualType, 16> paramTypes, derivedParamTypes;

for (auto* PVD : m_Function->parameters()) {
paramTypes.push_back(PVD->getType());

// Pushforward functions currently only support real parameters
QualType nonRefParamType = PVD->getType().getNonReferenceType();
if (nonRefParamType->isRealType())
derivedParamTypes.push_back(PVD->getType());
}
paramTypes.insert(paramTypes.end(), derivedParamTypes.begin(),
derivedParamTypes.end());

auto originalFnType = dyn_cast<FunctionProtoType>(m_Function->getType());
QualType derivedFnType =
m_Context.getFunctionType(m_Function->getReturnType(), paramTypes,
originalFnType->getExtProtoInfo());
llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(m_CurScope);
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;

DeclWithContext cloneFunctionResult =
m_Builder.cloneFunction(m_Function, *this, DC, m_Sema, m_Context, noLoc,
derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

llvm::SmallVector<ParmVarDecl*, 16> params;
llvm::SmallVector<ParmVarDecl*, 16> derivedParams;
beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);

std::size_t numParamsOriginalFn = m_Function->getNumParams();
for (std::size_t i = 0; i < numParamsOriginalFn; ++i) {
auto PVD = m_Function->getParamDecl(i);
auto newPVD = CloneParmVarDecl(PVD, PVD->getIdentifier(),
/*pushOnScopeChains=*/true);
params.push_back(newPVD);

QualType nonRefParamType = PVD->getType().getNonReferenceType();
if (!nonRefParamType->isRealType())
continue;
auto derivedPVDName = "_d_" + PVD->getNameAsString();
auto derivedPVDII = &m_Context.Idents.get(derivedPVDName);
// TODO: Check for name conflicts.
auto derivedPVD = CloneParmVarDecl(PVD, derivedPVDII,
/*pushOnScopeChains=*/true);
derivedParams.push_back(derivedPVD);
m_Variables[newPVD] = BuildDeclRef(derivedPVD);
}

params.insert(params.end(), derivedParams.begin(), derivedParams.end());
m_Derivative->setParams(params);
m_Derivative->setBody(nullptr);

beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();

Stmt* bodyDiff = Visit(FD->getBody()).getStmt();
CompoundStmt* CS = cast<CompoundStmt>(bodyDiff);
for (Stmt* S : CS->body())
addToCurrentBlock(S);

Stmt* derivativeBody = endBlock();
m_Derivative->setBody(derivativeBody);

endScope(); // Function body scope
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope

m_DerivativeInFlight = false;
return OverloadedDeclWithContext{cloneFunctionResult.first,
cloneFunctionResult.second,
/*OverloadFunctionDecl=*/nullptr};
}

OverloadedDeclWithContext
ForwardModeVisitor::Derive(const FunctionDecl* FD,
const DiffRequest& request) {
silenceDiags = !request.VerboseDiags;
m_Function = FD;
m_Functor = request.Functor;
m_Mode = DiffMode::forward;
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;
Expand Down Expand Up @@ -837,16 +932,22 @@ namespace clad {

SourceLocation noLoc;
llvm::SmallVector<Expr*, 4> CallArgs{};
llvm::SmallVector<Expr*, 4> diffArgs;
// For f(g(x)) = f'(x) * g'(x)
Expr* Multiplier = nullptr;
for (size_t i = 0, e = CE->getNumArgs(); i < e; ++i) {
StmtDiff argDiff = Visit(CE->getArg(i));
const Expr* arg = CE->getArg(i);
StmtDiff argDiff = Visit(arg);
if (!Multiplier)
Multiplier = argDiff.getExpr_dx();
else {
Multiplier = BuildOp(BO_Add, Multiplier, argDiff.getExpr_dx());
}
CallArgs.push_back(argDiff.getExpr());
// FIXME: Add support for pointer and array arguments in the
// pushforward mode.
if (arg->getType().getNonReferenceType()->isRealType())
diffArgs.push_back(argDiff.getExpr_dx());
}

Expr* call = m_Sema
Expand All @@ -859,46 +960,41 @@ namespace clad {

// Try to find an overloaded derivative in 'custom_derivatives'
Expr* callDiff = m_Builder.findOverloadedDefinition(DNInfo, CallArgs);

// FIXME: add gradient-vector products to fix that.
if (!callDiff)
assert((CE->getNumArgs() <= 1) &&
"forward differentiation of multi-arg calls is currently broken");

if (callDiff && Multiplier)
callDiff = BuildOp(BO_Mul, callDiff, BuildParens(Multiplier));

// Check if it is a recursive call.
if (!callDiff && (FD == m_Function)) {
if (!callDiff && (FD == m_Function) && m_Mode == DiffMode::experimental_pushforward) {
// The differentiated function is called recursively.
Expr* derivativeRef =
m_Sema
.BuildDeclarationNameExpr(CXXScopeSpec(),
m_Derivative->getNameInfo(),
m_Derivative)
.BuildDeclarationNameExpr(
CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative)
.get();
CallArgs.insert(CallArgs.end(), diffArgs.begin(), diffArgs.end());
callDiff =
m_Sema
.ActOnCallExpr(m_Sema.getScopeForContext(m_Sema.CurContext),
derivativeRef,
noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs),
noLoc)
derivativeRef, noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.get();
}

if (!callDiff) {
// Overloaded derivative was not found, request the CladPlugin to
// derive the called function.
DiffRequest request{};
request.Function = FD;
request.BaseFunctionName = FD->getNameAsString();
request.Mode = DiffMode::forward;
DiffRequest pushforwardFnRequest;
pushforwardFnRequest.Function = FD;
pushforwardFnRequest.Mode = DiffMode::experimental_pushforward;
pushforwardFnRequest.BaseFunctionName = FD->getNameAsString();
// pushforwardFnRequest.RequestedDerivativeOrder = m_DerivativeOrder;
// Silence diag outputs in nested derivation process.
request.VerboseDiags = false;

FunctionDecl* derivedFD =
plugin::ProcessDiffRequest(m_CladPlugin, request);
pushforwardFnRequest.VerboseDiags = false;
FunctionDecl* pushforwardFD =
plugin::ProcessDiffRequest(m_CladPlugin, pushforwardFnRequest);
// If clad failed to derive it, try finding its derivative using
// numerical diff.
if (!derivedFD) {
if (!pushforwardFD) {
// FIXME: Extend this for multiarg support
// Check if the function is eligible for numerical differentiation.
if (CE->getNumArgs() == 1) {
Expand All @@ -914,19 +1010,19 @@ namespace clad {
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
return StmtDiff(call, zero);
}
if (Multiplier)
callDiff = BuildOp(BO_Mul, callDiff, BuildParens(Multiplier));
} else {
CallArgs.insert(CallArgs.end(), diffArgs.begin(), diffArgs.end());
callDiff = m_Sema
.ActOnCallExpr(getCurrentScope(),
BuildDeclRef(derivedFD),
BuildDeclRef(pushforwardFD),
noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs),
noLoc)
.get();
}
}

if (Multiplier)
callDiff = BuildOp(BO_Mul, callDiff, BuildParens(Multiplier));
return StmtDiff(call, callDiff);
}

Expand Down
16 changes: 16 additions & 0 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,4 +659,20 @@ namespace clad {
}
}

ParmVarDecl* VisitorBase::CloneParmVarDecl(const ParmVarDecl* PVD,
IdentifierInfo* II,
bool pushOnScopeChains) {
Expr* newPVDDefaultArg = nullptr;
if (PVD->hasDefaultArg()) {
newPVDDefaultArg = Clone(PVD->getDefaultArg());
}
auto newPVD = ParmVarDecl::Create(
m_Context, m_Sema.CurContext, noLoc, noLoc, II, PVD->getType(),
PVD->getTypeSourceInfo(), PVD->getStorageClass(), newPVDDefaultArg);
if (pushOnScopeChains && newPVD->getIdentifier()) {
m_Sema.PushOnScopeChains(newPVD, getCurrentScope(),
/*AddToContext=*/false);
}
return newPVD;
}
} // end namespace clad
8 changes: 6 additions & 2 deletions test/FirstDerivative/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,16 @@ float test_3() {
// CHECK-NOT: float test_3_darg0() {

float test_4(int x) {
return overloaded(); // expected-warning {{function 'overloaded' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives', and function may not be eligible for numerical differentiation.}}
return overloaded();
}

// CHECK: int overloaded_pushforward() {
// CHECK-NEXT: return 0;
// CHECK-NEXT: }

// CHECK: float test_4_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: return 0;
// CHECK-NEXT: return overloaded_pushforward();
// CHECK-NEXT: }

float test_5(int x) {
Expand Down
Loading

0 comments on commit 9d74f6d

Please sign in to comment.