Skip to content

Commit

Permalink
Enable differentiation wrt data members of parameters in forw mode
Browse files Browse the repository at this point in the history
This PR adds initial support for differentiation with respect to data
members in Forward Mode AD. Functions can be differentiated with
respect to a data member by specifying the data member in the
differentiation args. For example:

```cpp
double fn(std::pair<double, double> u, std::pair<double, double> v) {
  return u.first + v.second;
}

auto d_fn = clad::differentiate(fn, "u.first");
```

This PR does not add support for differentiating with respect to
array data members.
  • Loading branch information
parth-07 authored and vgvassilev committed May 20, 2022
1 parent e6ee2ae commit a2ed7f2
Show file tree
Hide file tree
Showing 12 changed files with 396 additions and 77 deletions.
42 changes: 42 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,48 @@ namespace clad {
/// If `S` is `null`, then nothing happens.
void AppendIndividualStmts(llvm::SmallVectorImpl<clang::Stmt*>& block,
clang::Stmt* S);
/// Builds a nested member expression that consist of base expression
/// specified by `base` argument and data members specified in `fields`
/// argument in the original sequence.
///
/// For example, if `base` represents `b` -- an expression of a record type,
/// and `fields` is the sequence {'mem1', 'mem2', 'mem3'}, then the function
/// builds and returns the following expression:
/// ```
/// b.mem1.mem2.mem3
/// ```
clang::MemberExpr*
BuildMemberExpr(clang::Sema& semaRef, clang::Scope* S, clang::Expr* base,
llvm::ArrayRef<llvm::StringRef> fields);

/// Returns true if member expression path specified by `fields` is correct;
/// otherwise returns false.
///
/// For example, if `base` represents `b` -- an expression of a record type,
/// and `fields` is the sequence {'mem1', 'mem2', 'mem3'}, then the function
/// returns true if `b.mem1.mem2.mem3` is a valid data member reference
/// expression, otherwise returns false.
///
/// \note This function returns true if `fields` is an empty sequence.
bool IsValidMemExprPath(clang::Sema& semaRef, clang::RecordDecl* RD,
llvm::ArrayRef<llvm::StringRef> fields);

/// Perform lookup for data member with name `name`. If lookup finds a
/// declaration, then return the field declaration; otherwise returns
/// `nullptr`.
clang::FieldDecl* LookupDataMember(clang::Sema& semaRef,
clang::RecordDecl* RD,
llvm::StringRef name);

/// Computes the type of a data member of the record specified by `RD`
/// and nested fields specified in `fields` argument.
/// For example, if `RD` represents `std::pair<std::pair<std::complex,
/// double>, std::pair<double, double>`, and `fields` is the sequence
/// {'first', 'first'}, then the corresponding data member is
// of type `std::complex`.
clang::QualType
ComputeMemExprPathType(clang::Sema& semaRef, clang::RecordDecl* RD,
llvm::ArrayRef<llvm::StringRef> fields);
} // namespace utils
}

Expand Down
8 changes: 6 additions & 2 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ namespace clad {
/// differentiated, for example, when we are computing higher
/// order derivatives.
const clang::CXXRecordDecl* Functor = nullptr;
DiffParamsWithIndices DiffParamsInfo;

/// Recomputes `DiffParamsInfo` using the current values of data members.
/// Stores differentiation parameters information. Stored information
/// includes info on indices range for array parameters, and nested data
/// member information for record (class) type parameters.
DiffInputVarsInfo DVI;

/// Recomputes `DiffInputVarsInfo` using the current values of data members.
///
/// Differentiation parameters info is computed by parsing the argument
/// expression for the clad differentiation function calls. The argument is
Expand Down
47 changes: 45 additions & 2 deletions include/clad/Differentiator/ParseDiffArgsTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#define CLAD_PARSE_DIFF_ARGS_TYPES_H

#include "clang/AST/Decl.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"

#include <cstddef>
#include <utility>
Expand Down Expand Up @@ -34,10 +36,51 @@ namespace clad {
return Start == rhs.Start && Finish == rhs.Finish;
}
};

using IndexIntervalTable = llvm::SmallVector<IndexInterval, 16>;

/// `DiffInputVarInfo` is designed to store all the essential information about a
/// differentiation input variable. Please note that here input variable corresponds
/// to mathematical variable, not a programming one.
// FIXME: 'DiffInputVarInfo' name is probably not accurate, since we can have multiple
// differentiation input variables for same parameter as well. 'DiffInputVarInfo'
// name implicitly guides that there would be at most one `DiffInputVarInfo` object for
// one parameter, but that is not strictly true.
struct DiffInputVarInfo {
/// Source string specified by user that defines differentiation
/// specification for the input variable.
/// For example, if complete input string specified by user is:
/// 'u, v.first, arr[3]'
/// then `source` data member value for 2nd input variable should be
/// 'v.first'
std::string source;
/// Parameter associated with the input variable.
const clang::ValueDecl* param = nullptr;
/// array index range associated with the parameter.
IndexInterval paramIndexInterval;
/// Nested field information.
llvm::SmallVector<std::string, 4> fields;
// FIXME: Add support for differentiating with respect to array fields.
// llvm::SmallVector<IndexInterval> fieldIndexIntervals;

DiffInputVarInfo(const clang::ValueDecl* pParam = nullptr,
IndexInterval pParamIndexInterval = {},
llvm::SmallVector<std::string, 4> pFields = {})
: param(pParam), paramIndexInterval(pParamIndexInterval),
fields(pFields) {}

// FIXME: Move function definitions to ParseDiffArgTypes.cpp
bool operator==(const DiffInputVarInfo& rhs) const {
return param == rhs.param &&
paramIndexInterval == rhs.paramIndexInterval &&
fields == rhs.fields;
}
};

using DiffInputVarsInfo = llvm::SmallVector<DiffInputVarInfo, 16>;

using DiffParams = llvm::SmallVector<const clang::ValueDecl*, 16>;
using IndexIntervalTable = llvm::SmallVector<IndexInterval, 16>;
using DiffParamsWithIndices = std::pair<DiffParams, IndexIntervalTable>;
} // namespace clad
} // namespace clad

#endif
60 changes: 60 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,5 +420,65 @@ namespace clad {
else if (S)
block.push_back(S);
}

MemberExpr*
BuildMemberExpr(clang::Sema& semaRef, clang::Scope* S, clang::Expr* base,
llvm::ArrayRef<clang::StringRef> fields) {
MemberExpr* ME = nullptr;
for (auto field : fields) {
ME = BuildMemberExpr(semaRef, S, base, field);
base = ME;
}
return ME;
}

clang::FieldDecl* LookupDataMember(clang::Sema& semaRef, clang::RecordDecl* RD,
llvm::StringRef name) {
LookupResult R(semaRef, BuildDeclarationNameInfo(semaRef, name),
Sema::LookupNameKind::LookupMemberName);
CXXScopeSpec CSS;
semaRef.LookupQualifiedName(R, RD, CSS);
if (R.empty())
return nullptr;
assert(R.isSingleResult() && "Lookup in valid classes should always "
"return a single data member result.");
auto D = R.getFoundDecl();
// We are looking data members only!
if (auto FD = dyn_cast<FieldDecl>(D))
return FD;
return nullptr;
}

bool IsValidMemExprPath(clang::Sema& semaRef, clang::RecordDecl* RD,
llvm::ArrayRef<llvm::StringRef> fields) {
for (std::size_t i = 0; i < fields.size(); ++i) {
FieldDecl* FD = LookupDataMember(semaRef, RD, fields[i]);
if (!FD)
return false;
if (FD->getType()->isRecordType())
RD = FD->getType()->getAsCXXRecordDecl();
// Current Field declaration is not of record type, therefore
// it cannot have any field within it. And any member access
// ('.') expression would be an invalid path.
else if (i != fields.size() - 1)
return false;
}
return true;
}

clang::QualType
ComputeMemExprPathType(clang::Sema& semaRef, clang::RecordDecl* RD,
llvm::ArrayRef<llvm::StringRef> fields) {
assert(IsValidMemExprPath(semaRef, RD, fields) &&
"Invalid field path specified!");
QualType T = RD->getTypeForDecl()->getCanonicalTypeInternal();
for (auto field : fields) {
auto FD = LookupDataMember(semaRef, RD, field);
if (FD->getType()->isRecordType())
RD = FD->getType()->getAsCXXRecordDecl();
T = FD->getType();
}
return T;
}
} // namespace utils
} // namespace clad
Loading

0 comments on commit a2ed7f2

Please sign in to comment.