Skip to content

Commit

Permalink
Enable recursive differentiation of nested calls
Browse files Browse the repository at this point in the history
Now, if a function `f` that is being differentiated contains a call to a function `g`, clad will proceed to try to differentiate `g` even if it has no
overload in `custom_derivatives`.

Example:
```
double f(double x) {
  return g(double x);
}
```
On `clad::differentiate(f, 0)`, clad will proceed to differentiate `g` and so on, even if it is not in `custom_derivatives` namespace.
  • Loading branch information
efremale committed Apr 2, 2019
1 parent 58d96b0 commit 5aa8536
Show file tree
Hide file tree
Showing 14 changed files with 513 additions and 400 deletions.
40 changes: 26 additions & 14 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,21 @@ namespace clad {
namespace utils {
class StmtClone;
}
class DiffRequest;
namespace plugin {
class CladPlugin;
clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P, DiffRequest& request);
}
}

namespace clad {
/// A pair of FunctionDecl and potential enclosing context, e.g. a function
// in nested namespaces
using DeclWithContext = std::pair<clang::FunctionDecl*, clang::Decl*>;

using DiffParams = llvm::SmallVector<clang::VarDecl*, 16>;
using DiffParams = llvm::SmallVector<const clang::VarDecl*, 16>;

static clang::SourceLocation noLoc{};
class DiffPlan;
/// The main builder class which then uses either ForwardModeVisitor or
/// ReverseModeVisitor based on the required mode.
class DerivativeBuilder {
Expand All @@ -51,6 +55,7 @@ namespace clad {
friend class ReverseModeVisitor;

clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
std::unique_ptr<utils::StmtClone> m_NodeCloner;
clang::NamespaceDecl* m_BuiltinDerivativesNSD;
Expand All @@ -71,7 +76,7 @@ namespace clad {
stream << arg;
}
public:
DerivativeBuilder(clang::Sema& S);
DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P);
~DerivativeBuilder();

///\brief Produces the derivative of a given function
Expand All @@ -82,7 +87,8 @@ namespace clad {
///\returns The differentiated function and potentially created enclosing
/// context.
///
DeclWithContext Derive(clang::FunctionDecl* FD, const DiffPlan & plan);
DeclWithContext Derive(const clang::FunctionDecl* FD,
const DiffRequest & request);
};

/// A base class for all common functionality for visitors
Expand All @@ -91,6 +97,7 @@ namespace clad {
VisitorBase(DerivativeBuilder& builder) :
m_Builder(builder),
m_Sema(builder.m_Sema),
m_CladPlugin(builder.m_CladPlugin),
m_Context(builder.m_Context),
m_CurScope(m_Sema.TUScope),
m_DerivativeInFlight(false),
Expand All @@ -101,17 +108,18 @@ namespace clad {

DerivativeBuilder& m_Builder;
clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
clang::Scope* m_CurScope;
bool m_DerivativeInFlight;
/// The Derivative function that is being generated.
clang::FunctionDecl* m_Derivative;
/// The function that is currently differentiated.
clang::FunctionDecl* m_Function;
const clang::FunctionDecl* m_Function;

/// Map used to keep track of variable declarations and match them
/// with their derivatives.
std::unordered_map<clang::VarDecl*, clang::Expr*> m_Variables;
std::unordered_map<const clang::VarDecl*, clang::Expr*> m_Variables;
/// Map contains variable declarations replacements. If the original
/// function contains a declaration which name collides with something
/// already created inside derivative's body, the declaration is replaced
Expand Down Expand Up @@ -204,7 +212,7 @@ namespace clad {
clang::DeclStmt* BuildDeclStmt(llvm::MutableArrayRef<clang::Decl*> DS);

/// Builds a DeclRefExpr to a given Decl.
clang::DeclRefExpr* BuildDeclRef(clang::VarDecl* D);
clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D);

/// Stores the result of an expression in a temporary variable (of the same
/// type as is the result of the expression) and returns a reference to it.
Expand All @@ -223,14 +231,16 @@ namespace clad {
clang::Expr* StoreAndRef(clang::Expr* E, clang::QualType Type, Stmts& block,
llvm::StringRef prefix = "_t",
bool forceDeclCreation = false);

/// A flag for silencing warnings/errors output by diag function.
bool silenceDiags = false;
/// Shorthand to issues a warning or error.
template <std::size_t N>
void diag(clang::DiagnosticsEngine::Level level, // Warning or Error
clang::SourceLocation loc,
const char (&format)[N],
llvm::ArrayRef<llvm::StringRef> args = {}) {
m_Builder.diag(level, loc, format, args);
if (!silenceDiags)
m_Builder.diag(level, loc, format, args);
}

/// Conuter used to create unique identifiers for temporaries
Expand All @@ -247,7 +257,6 @@ namespace clad {
clang::Stmt* Clone(const clang::Stmt* S);
/// A shorthand to simplify cloning of expressions.
clang::Expr* Clone(const clang::Expr* E);

/// Parses the argument expression for the clad::differentiate/clad::gradient
/// call. The argument is used to specify independent parameter(s) for
/// differentiation. There are three valid options for the argument expression:
Expand All @@ -258,7 +267,8 @@ namespace clad {
/// parameter corresponding to literal's value index.
/// 3) If no argument is provided, a default argument is used. The function
/// will be differentiated w.r.t. to its every parameter.
DiffParams parseDiffArgs(const clang::Expr* diffArgs, clang::FunctionDecl* FD);
DiffParams parseDiffArgs(const clang::Expr* diffArgs,
const clang::FunctionDecl* FD);
};

/// A class that represents the result of Visit of ForwardModeVisitor.
Expand Down Expand Up @@ -313,7 +323,7 @@ namespace clad {
: public clang::ConstStmtVisitor<ForwardModeVisitor, StmtDiff>,
public VisitorBase {
private:
clang::VarDecl* m_IndependentVar = nullptr;
const clang::VarDecl* m_IndependentVar = nullptr;
unsigned m_DerivativeOrder = ~0;
unsigned m_ArgIndex = ~0;

Expand All @@ -328,7 +338,8 @@ namespace clad {
///\returns The differentiated and potentially created enclosing
/// context.
///
DeclWithContext Derive(clang::FunctionDecl* FD, const DiffPlan& plan);
DeclWithContext Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
Expand Down Expand Up @@ -468,7 +479,8 @@ namespace clad {
/// Improved naming scheme is required. Hence, we append the indices to of the
/// requested parameters to 'f_grad', i.e. in the previous example "x, y" will
/// give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'.
DeclWithContext Derive(clang::FunctionDecl* FD, const DiffPlan& plan);
DeclWithContext Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
Expand Down
77 changes: 27 additions & 50 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,72 +19,49 @@ namespace clad {
reverse
};

///\brief The list of the dependent functions which also need differentiation
/// because they are called by the function we are asked to differentitate.
///
class DiffPlan {
private:
typedef llvm::SmallVector<clang::FunctionDecl*, 16> Functions;
Functions m_Functions;
clang::CallExpr* m_CallToUpdate = nullptr;
unsigned m_RequestedDerivativeOrder = 1;
unsigned m_CurrentDerivativeOrder = 1;
clang::Expr* m_DiffArgs = nullptr;
DiffMode m_Mode = DiffMode::unknown;
public:
typedef Functions::iterator iterator;
typedef Functions::const_iterator const_iterator;
/// A struct containing information about request to differentiate a function.
struct DiffRequest {
/// Function to be differentiated.
const clang::FunctionDecl* Function = nullptr;
/// Name of the base function to be differentiated. Can be different from
/// function->getNameAsString() when higher-order derivatives are computed.
std::string BaseFunctionName = {};
/// Current derivative order to be computed.
unsigned CurrentDerivativeOrder = 1;
/// Highest requested derivative order.
unsigned RequestedDerivativeOrder = 1;
/// Context in which the function is being called, or a call to
/// clad::gradient/differentiate, where function is the first arg.
clang::CallExpr* CallContext = nullptr;
/// Args provided to the call to clad::gradient/differentiate.
const clang::Expr* Args = nullptr;
/// Requested differentiation mode, forward or reverse.
DiffMode Mode = DiffMode::unknown;
/// If function appears in the call to clad::gradient/differentiate,
/// the call must be updated and the first arg replaced by the derivative.
bool CallUpdateRequired = false;
/// A flag to enable/disable diag warnings/errors during differentiation.
bool VerboseDiags = false;

DiffMode getMode() const {
assert(m_Mode != DiffMode::unknown && "Must be set!");
return m_Mode;
}
void setMode(DiffMode mode) {
m_Mode = mode;
}
unsigned getRequestedDerivativeOrder() const {
return m_RequestedDerivativeOrder;
}
void setCurrentDerivativeOrder(unsigned val) {
m_CurrentDerivativeOrder = val;
}
unsigned getCurrentDerivativeOrder() const {
return m_CurrentDerivativeOrder;
}
void push_back(clang::FunctionDecl* FD) { m_Functions.push_back(FD); }
iterator begin() { return m_Functions.begin(); }
iterator end() { return m_Functions.end(); }
const_iterator begin() const { return m_Functions.begin(); }
const_iterator end() const { return m_Functions.end(); }
size_t size() const { return m_Functions.size(); }
void setCallToUpdate(clang::CallExpr* CE) { m_CallToUpdate = CE; }
void updateCall(clang::FunctionDecl* FD, clang::Sema& SemaRef);
clang::Expr* getArgs() const { return m_DiffArgs; }
LLVM_DUMP_METHOD void dump();

friend class DiffCollector;
};

typedef llvm::SmallVector<DiffPlan, 16> DiffPlans;
using DiffSchedule = llvm::SmallVector<DiffRequest, 16>;

class DiffCollector: public clang::RecursiveASTVisitor<DiffCollector> {
private:
///\brief The diff step-by-step plan for differentiation.
///
DiffPlans& m_DiffPlans;
DiffSchedule& m_DiffPlans;

///\brief If set it means that we need to find the called functions and
/// add them for implicit diff.
///
clang::FunctionDecl* m_TopMostFD;

const clang::FunctionDecl* m_TopMostFD = nullptr;
clang::Sema& m_Sema;

DiffPlan& getCurrentPlan() { return m_DiffPlans.back(); }

public:
DiffCollector(clang::DeclGroupRef DGR, DiffPlans& plans, clang::Sema& S);
void UpdatePlan(clang::FunctionDecl* FD, DiffPlan* plan);
DiffCollector(clang::DeclGroupRef DGR, DiffSchedule& plans, clang::Sema& S);
bool VisitCallExpr(clang::CallExpr* E);
};
}
Loading

0 comments on commit 5aa8536

Please sign in to comment.