-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for call expressions in vector forward mode AD #638
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
//--------------------------------------------------------------------*- C++ -*- | ||
// clad - the C++ Clang-based Automatic Differentiator | ||
// version: $Id: ClangPlugin.cpp 7 2013-06-01 22:48:03Z [email protected] $ | ||
// author: Vassil Vassilev <vvasilev-at-cern.ch> | ||
//------------------------------------------------------------------------------ | ||
|
||
#ifndef CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H | ||
#define CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H | ||
|
||
#include "BaseForwardModeVisitor.h" | ||
|
||
namespace clad { | ||
/// A visitor for processing the function code in forward mode. | ||
/// Used to compute derivatives by clad::differentiate. | ||
class PushForwardModeVisitor : public BaseForwardModeVisitor { | ||
|
||
public: | ||
PushForwardModeVisitor(DerivativeBuilder& builder); | ||
~PushForwardModeVisitor() override; | ||
|
||
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; | ||
}; | ||
} // end namespace clad | ||
|
||
#endif // CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H | ||
#define CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H | ||
|
||
#include "PushForwardModeVisitor.h" | ||
#include "VectorForwardModeVisitor.h" | ||
|
||
namespace clad { | ||
class VectorPushForwardModeVisitor : public VectorForwardModeVisitor { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: class 'VectorPushForwardModeVisitor' defines a non-default destructor but does not define a copy constructor, a copy assignment operator, a move constructor or a move assignment operator [cppcoreguidelines-special-member-functions] class VectorPushForwardModeVisitor : public VectorForwardModeVisitor {
^ |
||
|
||
public: | ||
VectorPushForwardModeVisitor(DerivativeBuilder& builder); | ||
~VectorPushForwardModeVisitor() override; | ||
|
||
void ExecuteInsidePushforwardFunctionBlock() override; | ||
|
||
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; | ||
}; | ||
} // end namespace clad | ||
|
||
#endif // CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -359,6 +359,157 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, | |
/*OverloadFunctionDecl=*/nullptr}; | ||
} | ||
|
||
clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() { | ||
assert(m_Mode == GetPushForwardMode()); | ||
QualType originalFnRT = m_Function->getReturnType(); | ||
if (originalFnRT->isVoidType()) | ||
return m_Context.VoidTy; | ||
TemplateDecl* valueAndPushforward = | ||
LookupTemplateDeclInCladNamespace("ValueAndPushforward"); | ||
assert(valueAndPushforward && | ||
"clad::ValueAndPushforward template not found!!"); | ||
QualType RT = InstantiateTemplate( | ||
valueAndPushforward, | ||
{originalFnRT, GetPushForwardDerivativeType(originalFnRT)}); | ||
return RT; | ||
} | ||
|
||
void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { | ||
Stmt* bodyDiff = Visit(m_Function->getBody()).getStmt(); | ||
auto* CS = cast<CompoundStmt>(bodyDiff); | ||
for (Stmt* S : CS->body()) | ||
addToCurrentBlock(S); | ||
} | ||
|
||
DerivativeAndOverload | ||
BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, | ||
const DiffRequest& request) { | ||
m_Function = FD; | ||
m_Functor = request.Functor; | ||
m_DerivativeOrder = request.CurrentDerivativeOrder; | ||
m_Mode = GetPushForwardMode(); | ||
assert(!m_DerivativeInFlight && | ||
"Doesn't support recursive diff. Use DiffPlan."); | ||
m_DerivativeInFlight = true; | ||
|
||
auto originalFnEffectiveName = utils::ComputeEffectiveFnName(m_Function); | ||
|
||
IdentifierInfo* derivedFnII = &m_Context.Idents.get( | ||
originalFnEffectiveName + GetPushForwardFunctionSuffix()); | ||
DeclarationNameInfo derivedFnName(derivedFnII, noLoc); | ||
llvm::SmallVector<QualType, 16> paramTypes; | ||
llvm::SmallVector<QualType, 16> derivedParamTypes; | ||
|
||
// If we are differentiating an instance member function then | ||
// create a parameter type for the parameter that will represent the | ||
// derivative of `this` pointer with respect to the independent parameter. | ||
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) { | ||
if (MD->isInstance()) { | ||
QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); | ||
derivedParamTypes.push_back(thisType); | ||
} | ||
} | ||
|
||
for (auto* PVD : m_Function->parameters()) { | ||
paramTypes.push_back(PVD->getType()); | ||
|
||
if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) | ||
derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType())); | ||
} | ||
|
||
paramTypes.insert(paramTypes.end(), derivedParamTypes.begin(), | ||
derivedParamTypes.end()); | ||
|
||
const auto* originalFnType = | ||
dyn_cast<FunctionProtoType>(m_Function->getType()); | ||
QualType returnType = ComputePushforwardFnReturnType(); | ||
QualType derivedFnType = m_Context.getFunctionType( | ||
returnType, paramTypes, originalFnType->getExtProtoInfo()); | ||
llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext); | ||
llvm::SaveAndRestore<Scope*> saveScope(m_CurScope); | ||
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: do not use const_cast [cppcoreguidelines-pro-type-const-cast] auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
^ |
||
m_Sema.CurContext = DC; | ||
|
||
DeclWithContext cloneFunctionResult = m_Builder.cloneFunction( | ||
m_Function, *this, DC, noLoc, derivedFnName, derivedFnType); | ||
m_Derivative = cloneFunctionResult.first; | ||
|
||
llvm::SmallVector<ParmVarDecl*, 16> params; | ||
llvm::SmallVector<ParmVarDecl*, 16> derivedParams; | ||
beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | | ||
Scope::DeclScope); | ||
m_Sema.PushFunctionScope(); | ||
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); | ||
|
||
// If we are differentiating an instance member function then | ||
// create a parameter for representing derivative of | ||
// `this` pointer with respect to the independent parameter. | ||
if (const auto* MFD = dyn_cast<CXXMethodDecl>(FD)) { | ||
if (MFD->isInstance()) { | ||
auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MFD); | ||
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier("_d_this"); | ||
auto* derivedPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, | ||
derivedPVDII, thisType); | ||
m_Sema.PushOnScopeChains(derivedPVD, getCurrentScope(), | ||
/*AddToContext=*/false); | ||
derivedParams.push_back(derivedPVD); | ||
m_ThisExprDerivative = BuildDeclRef(derivedPVD); | ||
} | ||
} | ||
|
||
std::size_t numParamsOriginalFn = m_Function->getNumParams(); | ||
for (std::size_t i = 0; i < numParamsOriginalFn; ++i) { | ||
const auto* PVD = m_Function->getParamDecl(i); | ||
// Some of the special member functions created implicitly by compilers | ||
// have missing parameter identifier. | ||
bool identifierMissing = false; | ||
IdentifierInfo* PVDII = PVD->getIdentifier(); | ||
if (!PVDII || PVDII->getLength() == 0) { | ||
PVDII = CreateUniqueIdentifier("param"); | ||
identifierMissing = true; | ||
} | ||
auto* newPVD = CloneParmVarDecl(PVD, PVDII, | ||
/*pushOnScopeChains=*/true, | ||
/*cloneDefaultArg=*/false); | ||
params.push_back(newPVD); | ||
|
||
if (identifierMissing) | ||
m_DeclReplacements[PVD] = newPVD; | ||
|
||
if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) | ||
continue; | ||
auto derivedPVDName = "_d_" + std::string(PVDII->getName()); | ||
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName); | ||
auto* derivedPVD = utils::BuildParmVarDecl( | ||
m_Sema, m_Derivative, derivedPVDII, | ||
GetPushForwardDerivativeType(PVD->getType()), PVD->getStorageClass()); | ||
derivedParams.push_back(derivedPVD); | ||
m_Variables[newPVD] = BuildDeclRef(derivedPVD); | ||
} | ||
|
||
params.insert(params.end(), derivedParams.begin(), derivedParams.end()); | ||
m_Derivative->setParams(params); | ||
m_Derivative->setBody(nullptr); | ||
|
||
beginScope(Scope::FnScope | Scope::DeclScope); | ||
m_DerivativeFnScope = getCurrentScope(); | ||
beginBlock(); | ||
|
||
// execute the functor inside the function body. | ||
ExecuteInsidePushforwardFunctionBlock(); | ||
|
||
Stmt* derivativeBody = endBlock(); | ||
m_Derivative->setBody(derivativeBody); | ||
|
||
endScope(); // Function body scope | ||
m_Sema.PopFunctionScopeInfo(); | ||
m_Sema.PopDeclContext(); | ||
endScope(); // Function decl scope | ||
|
||
m_DerivativeInFlight = false; | ||
return DerivativeAndOverload{cloneFunctionResult.first}; | ||
} | ||
|
||
StmtDiff BaseForwardModeVisitor::VisitStmt(const Stmt* S) { | ||
diag(DiagnosticsEngine::Warning, S->getBeginLoc(), | ||
"attempted to differentiate unsupported statement, no changes applied"); | ||
|
@@ -913,6 +1064,19 @@ Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( | |
return OverloadedFn; | ||
} | ||
|
||
QualType | ||
BaseForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) { | ||
return ParamType; | ||
} | ||
|
||
std::string BaseForwardModeVisitor::GetPushForwardFunctionSuffix() { | ||
return "_pushforward"; | ||
} | ||
|
||
DiffMode BaseForwardModeVisitor::GetPushForwardMode() { | ||
return DiffMode::experimental_pushforward; | ||
} | ||
|
||
StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { | ||
const FunctionDecl* FD = CE->getDirectCallee(); | ||
if (!FD) { | ||
|
@@ -1007,10 +1171,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { | |
CallArgs.push_back(argDiff.getExpr()); | ||
if (BaseForwardModeVisitor::IsDifferentiableType(arg->getType())) { | ||
Expr* dArg = argDiff.getExpr_dx(); | ||
QualType CallArgTy = CallArgs.back()->getType(); | ||
assert((!dArg || m_Context.hasSameType(CallArgTy, dArg->getType())) && | ||
"Type mismatch, we might fail to instantiate a pullback"); | ||
(void)CallArgTy; | ||
// FIXME: What happens when dArg is nullptr? | ||
diffArgs.push_back(dArg); | ||
} | ||
|
@@ -1033,14 +1193,13 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { | |
|
||
// Try to find a user-defined overloaded derivative. | ||
std::string customPushforward = | ||
clad::utils::ComputeEffectiveFnName(FD) + "_pushforward"; | ||
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix(); | ||
Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( | ||
customPushforward, customDerivativeArgs, getCurrentScope(), | ||
const_cast<DeclContext*>(FD->getDeclContext())); | ||
|
||
// Check if it is a recursive call. | ||
if (!callDiff && (FD == m_Function) && | ||
m_Mode == DiffMode::experimental_pushforward) { | ||
if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) { | ||
// The differentiated function is called recursively. | ||
Expr* derivativeRef = | ||
m_Sema | ||
|
@@ -1089,7 +1248,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { | |
// derive the called function. | ||
DiffRequest pushforwardFnRequest; | ||
pushforwardFnRequest.Function = FD; | ||
pushforwardFnRequest.Mode = DiffMode::experimental_pushforward; | ||
pushforwardFnRequest.Mode = GetPushForwardMode(); | ||
pushforwardFnRequest.BaseFunctionName = FD->getNameAsString(); | ||
// pushforwardFnRequest.RequestedDerivativeOrder = m_DerivativeOrder; | ||
// Silence diag outputs in nested derivation process. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: class 'PushForwardModeVisitor' defines a non-default destructor but does not define a copy constructor, a copy assignment operator, a move constructor or a move assignment operator [cppcoreguidelines-special-member-functions]