Skip to content

Commit

Permalink
Proper differentiation of lambdas (no captures) & code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Nov 9, 2024
1 parent fa89545 commit 111dff3
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 15 deletions.
10 changes: 10 additions & 0 deletions include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,16 @@ getConstantArrayType(const ASTContext& Ctx, QualType EltTy,
#define CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam clang::ParsedAttributesView::none(),
#endif

#if CLANG_VERSION_MAJOR > 12
#define CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind( \
LAMBDACXXRECORDDECL) \
LAMBDACXXRECORDDECL->getLambdaDependencyKind()
#else
#define CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind( \
LAMBDACXXRECORDDECL) \
LAMBDACXXRECORDDECL->isDependentLambda()
#endif

// Clang 12 add one extra param (FPO) that we get from Node in Create method of:
// ImplicitCastExpr, CStyleCastExpr, CXXStaticCastExpr and CXXFunctionalCastExpr

Expand Down
6 changes: 6 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ namespace clad {
llvm::SmallVectorImpl<clang::Expr*>& outputArgs,
clang::Expr* CUDAExecConfig = nullptr);

clang::CXXRecordDecl*
diffLambdaCXXRecordDecl(const clang::CXXRecordDecl* Original);
clang::CXXMethodDecl*
DifferentiateCallOperatorIfLambda(const clang::CXXRecordDecl* RD);

public:
ReverseModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);
virtual ~ReverseModeVisitor();
Expand All @@ -383,6 +388,7 @@ namespace clad {
DerivativeAndOverload DerivePullback();
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitLambdaExpr(const clang::LambdaExpr* LE);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
Expand Down
Loading

0 comments on commit 111dff3

Please sign in to comment.