Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for call expressions in vector forward mode AD #638

Merged
merged 1 commit into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ class BaseForwardModeVisitor
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);

DerivativeAndOverload DerivePushforward(const clang::FunctionDecl* FD,
const DiffRequest& request);

/// Returns the return type for the pushforward function of the function
/// `m_Function`.
/// \note `m_Function` field should be set before using this function.
clang::QualType ComputePushforwardFnReturnType();

virtual void ExecuteInsidePushforwardFunctionBlock();

static bool IsDifferentiableType(clang::QualType T);

virtual StmtDiff
Expand Down Expand Up @@ -93,6 +103,11 @@ class BaseForwardModeVisitor
VisitUnaryExprOrTypeTraitExpr(const clang::UnaryExprOrTypeTraitExpr* UE);
StmtDiff VisitPseudoObjectExpr(const clang::PseudoObjectExpr* POE);

virtual clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType);
virtual std::string GetPushForwardFunctionSuffix();
virtual DiffMode GetPushForwardMode();

protected:
/// Helper function for differentiating the switch statement body.
///
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ namespace clad {
private:
friend class VisitorBase;
friend class BaseForwardModeVisitor;
friend class ForwardModeVisitor;
friend class PushForwardModeVisitor;
friend class VectorForwardModeVisitor;
friend class VectorPushForwardModeVisitor;
friend class ReverseModeVisitor;
friend class HessianModeVisitor;
friend class JacobianModeVisitor;
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ enum class DiffMode {
vector_forward_mode,
experimental_pushforward,
experimental_pullback,
experimental_vector_pushforward,
reverse,
hessian,
jacobian,
Expand Down
33 changes: 0 additions & 33 deletions include/clad/Differentiator/ForwardModeVisitor.h

This file was deleted.

25 changes: 25 additions & 0 deletions include/clad/Differentiator/PushForwardModeVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//--------------------------------------------------------------------*- C++ -*-
// clad - the C++ Clang-based Automatic Differentiator
// version: $Id: ClangPlugin.cpp 7 2013-06-01 22:48:03Z [email protected] $
// author: Vassil Vassilev <vvasilev-at-cern.ch>
//------------------------------------------------------------------------------

#ifndef CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H
#define CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H

#include "BaseForwardModeVisitor.h"

namespace clad {
/// A visitor for processing the function code in forward mode.
/// Used to compute derivatives by clad::differentiate.
class PushForwardModeVisitor : public BaseForwardModeVisitor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: class 'PushForwardModeVisitor' defines a non-default destructor but does not define a copy constructor, a copy assignment operator, a move constructor or a move assignment operator [cppcoreguidelines-special-member-functions]

class PushForwardModeVisitor : public BaseForwardModeVisitor {
      ^


public:
PushForwardModeVisitor(DerivativeBuilder& builder);
~PushForwardModeVisitor() override;

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // end namespace clad

#endif // CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H
8 changes: 8 additions & 0 deletions include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
// Decl is not Stmt, so it cannot be visited directly.
VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD) override;

clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType) override;
std::string GetPushForwardFunctionSuffix() override;
DiffMode GetPushForwardMode() override;

// Function for setting the independent variables for vector mode.
void SetIndependentVarsExpr(clang::Expr* IndVarCountExpr);
};
} // end namespace clad

Expand Down
20 changes: 20 additions & 0 deletions include/clad/Differentiator/VectorPushForwardModeVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H
#define CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H

#include "PushForwardModeVisitor.h"
#include "VectorForwardModeVisitor.h"

namespace clad {
class VectorPushForwardModeVisitor : public VectorForwardModeVisitor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: class 'VectorPushForwardModeVisitor' defines a non-default destructor but does not define a copy constructor, a copy assignment operator, a move constructor or a move assignment operator [cppcoreguidelines-special-member-functions]

class VectorPushForwardModeVisitor : public VectorForwardModeVisitor {
      ^


public:
VectorPushForwardModeVisitor(DerivativeBuilder& builder);
~VectorPushForwardModeVisitor() override;

void ExecuteInsidePushforwardFunctionBlock() override;

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // end namespace clad

#endif // CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H
175 changes: 167 additions & 8 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,157 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
/*OverloadFunctionDecl=*/nullptr};
}

clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() {
assert(m_Mode == GetPushForwardMode());
QualType originalFnRT = m_Function->getReturnType();
if (originalFnRT->isVoidType())
return m_Context.VoidTy;
TemplateDecl* valueAndPushforward =
LookupTemplateDeclInCladNamespace("ValueAndPushforward");
assert(valueAndPushforward &&
"clad::ValueAndPushforward template not found!!");
QualType RT = InstantiateTemplate(
valueAndPushforward,
{originalFnRT, GetPushForwardDerivativeType(originalFnRT)});
return RT;
}

void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
Stmt* bodyDiff = Visit(m_Function->getBody()).getStmt();
auto* CS = cast<CompoundStmt>(bodyDiff);
for (Stmt* S : CS->body())
addToCurrentBlock(S);
}

DerivativeAndOverload
BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
const DiffRequest& request) {
m_Function = FD;
m_Functor = request.Functor;
m_DerivativeOrder = request.CurrentDerivativeOrder;
m_Mode = GetPushForwardMode();
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;

auto originalFnEffectiveName = utils::ComputeEffectiveFnName(m_Function);

IdentifierInfo* derivedFnII = &m_Context.Idents.get(
originalFnEffectiveName + GetPushForwardFunctionSuffix());
DeclarationNameInfo derivedFnName(derivedFnII, noLoc);
llvm::SmallVector<QualType, 16> paramTypes;
llvm::SmallVector<QualType, 16> derivedParamTypes;

// If we are differentiating an instance member function then
// create a parameter type for the parameter that will represent the
// derivative of `this` pointer with respect to the independent parameter.
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD);
derivedParamTypes.push_back(thisType);
}
}

for (auto* PVD : m_Function->parameters()) {
paramTypes.push_back(PVD->getType());

if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType()));
}

paramTypes.insert(paramTypes.end(), derivedParamTypes.begin(),
derivedParamTypes.end());

const auto* originalFnType =
dyn_cast<FunctionProtoType>(m_Function->getType());
QualType returnType = ComputePushforwardFnReturnType();
QualType derivedFnType = m_Context.getFunctionType(
returnType, paramTypes, originalFnType->getExtProtoInfo());
llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(m_CurScope);
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use const_cast [cppcoreguidelines-pro-type-const-cast]

  auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
             ^

m_Sema.CurContext = DC;

DeclWithContext cloneFunctionResult = m_Builder.cloneFunction(
m_Function, *this, DC, noLoc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

llvm::SmallVector<ParmVarDecl*, 16> params;
llvm::SmallVector<ParmVarDecl*, 16> derivedParams;
beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);

// If we are differentiating an instance member function then
// create a parameter for representing derivative of
// `this` pointer with respect to the independent parameter.
if (const auto* MFD = dyn_cast<CXXMethodDecl>(FD)) {
if (MFD->isInstance()) {
auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MFD);
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier("_d_this");
auto* derivedPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext,
derivedPVDII, thisType);
m_Sema.PushOnScopeChains(derivedPVD, getCurrentScope(),
/*AddToContext=*/false);
derivedParams.push_back(derivedPVD);
m_ThisExprDerivative = BuildDeclRef(derivedPVD);
}
}

std::size_t numParamsOriginalFn = m_Function->getNumParams();
for (std::size_t i = 0; i < numParamsOriginalFn; ++i) {
const auto* PVD = m_Function->getParamDecl(i);
// Some of the special member functions created implicitly by compilers
// have missing parameter identifier.
bool identifierMissing = false;
IdentifierInfo* PVDII = PVD->getIdentifier();
if (!PVDII || PVDII->getLength() == 0) {
PVDII = CreateUniqueIdentifier("param");
identifierMissing = true;
}
auto* newPVD = CloneParmVarDecl(PVD, PVDII,
/*pushOnScopeChains=*/true,
/*cloneDefaultArg=*/false);
params.push_back(newPVD);

if (identifierMissing)
m_DeclReplacements[PVD] = newPVD;

if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
continue;
auto derivedPVDName = "_d_" + std::string(PVDII->getName());
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName);
auto* derivedPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, derivedPVDII,
GetPushForwardDerivativeType(PVD->getType()), PVD->getStorageClass());
derivedParams.push_back(derivedPVD);
m_Variables[newPVD] = BuildDeclRef(derivedPVD);
}

params.insert(params.end(), derivedParams.begin(), derivedParams.end());
m_Derivative->setParams(params);
m_Derivative->setBody(nullptr);

beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();

// execute the functor inside the function body.
ExecuteInsidePushforwardFunctionBlock();

Stmt* derivativeBody = endBlock();
m_Derivative->setBody(derivativeBody);

endScope(); // Function body scope
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope

m_DerivativeInFlight = false;
return DerivativeAndOverload{cloneFunctionResult.first};
}

StmtDiff BaseForwardModeVisitor::VisitStmt(const Stmt* S) {
diag(DiagnosticsEngine::Warning, S->getBeginLoc(),
"attempted to differentiate unsupported statement, no changes applied");
Expand Down Expand Up @@ -913,6 +1064,19 @@ Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
return OverloadedFn;
}

QualType
BaseForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) {
return ParamType;
}

std::string BaseForwardModeVisitor::GetPushForwardFunctionSuffix() {
return "_pushforward";
}

DiffMode BaseForwardModeVisitor::GetPushForwardMode() {
return DiffMode::experimental_pushforward;
}

StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
const FunctionDecl* FD = CE->getDirectCallee();
if (!FD) {
Expand Down Expand Up @@ -1007,10 +1171,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
CallArgs.push_back(argDiff.getExpr());
if (BaseForwardModeVisitor::IsDifferentiableType(arg->getType())) {
Expr* dArg = argDiff.getExpr_dx();
QualType CallArgTy = CallArgs.back()->getType();
assert((!dArg || m_Context.hasSameType(CallArgTy, dArg->getType())) &&
"Type mismatch, we might fail to instantiate a pullback");
(void)CallArgTy;
// FIXME: What happens when dArg is nullptr?
diffArgs.push_back(dArg);
}
Expand All @@ -1033,14 +1193,13 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {

// Try to find a user-defined overloaded derivative.
std::string customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + "_pushforward";
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));

// Check if it is a recursive call.
if (!callDiff && (FD == m_Function) &&
m_Mode == DiffMode::experimental_pushforward) {
if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) {
// The differentiated function is called recursively.
Expr* derivativeRef =
m_Sema
Expand Down Expand Up @@ -1089,7 +1248,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// derive the called function.
DiffRequest pushforwardFnRequest;
pushforwardFnRequest.Function = FD;
pushforwardFnRequest.Mode = DiffMode::experimental_pushforward;
pushforwardFnRequest.Mode = GetPushForwardMode();
pushforwardFnRequest.BaseFunctionName = FD->getNameAsString();
// pushforwardFnRequest.RequestedDerivativeOrder = m_DerivativeOrder;
// Silence diag outputs in nested derivation process.
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ add_llvm_library(cladDifferentiator
DiffPlanner.cpp
ErrorEstimator.cpp
EstimationModel.cpp
ForwardModeVisitor.cpp
HessianModeVisitor.cpp
JacobianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
PushForwardModeVisitor.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
TBRAnalyzer.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
VectorPushForwardModeVisitor.cpp
Version.cpp
VisitorBase.cpp
${version_inc}
Expand Down
Loading
Loading