Skip to content

Commit

Permalink
Fix recursive call differentiation and add a test
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Mar 13, 2024
1 parent 8dbc74b commit 51afd38
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
12 changes: 5 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1790,18 +1790,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (!OverloadedDerivedFn) {
if (FD == m_Function && m_Mode == DiffMode::experimental_pullback) {
// Recursive call.
auto* selfRef =
Expr* selfRef =
m_Sema
.BuildDeclarationNameExpr(
CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative)
.get();

OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef, noLoc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs),
noLoc)
.get();
OverloadedDerivedFn = m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef,
noLoc, pullbackCallArgs, noLoc)
.get();
} else {
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingCallExpr(
Expand Down
46 changes: 46 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,50 @@ double fn15(double x, double y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double recFun (double x, double y) {
if (x > y)
return recFun(x-1, y);
return x * y;
}

//CHECK: void recFun_pullback(double x, double y, double _d_y0, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += _d_y0 * y;
//CHECK-NEXT: * _d_y += x * _d_y0;
//CHECK-NEXT: }
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0;
//CHECK-NEXT: double _r1 = 0;
//CHECK-NEXT: recFun_pullback(x - 1, y, _d_y0, &_r0, &_r1);
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: * _d_y += _r1;
//CHECK-NEXT: }
//CHECK-NEXT: }

double fn16(double x, double y) {
return recFun(x, y);
}

//CHECK: void fn16_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 0;
//CHECK-NEXT: double _r1 = 0;
//CHECK-NEXT: recFun_pullback(x, y, 1, &_r0, &_r1);
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: * _d_y += _r1;
//CHECK-NEXT: }
//CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -799,4 +843,6 @@ int main() {
TEST2(fn14, 3, 5); // CHECK-EXEC: {1.00, 1.00}
INIT(fn15);
TEST2(fn15, 6, -2) // CHECK-EXEC: {1.00, 1.00}
INIT(fn16);
TEST2(fn16, 12, 8) // CHECK-EXEC: {8.00, 8.00}
}

0 comments on commit 51afd38

Please sign in to comment.