Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Apr 26, 2024
1 parent 6c6b123 commit 9e343b5
Show file tree
Hide file tree
Showing 30 changed files with 518 additions and 620 deletions.
35 changes: 26 additions & 9 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,14 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
plugin::AddRequestToSchedule(m_CladPlugin, pushforwardFnRequest);

if (pushforwardFD) {
llvm::errs() << "check0\n";
if (baseDiff.getExpr()) {
llvm::errs() << "check1\n";
FD->dump();
baseDiff.getExpr()->dump();
pushforwardFD->dump();
for (auto& arg : pushforwardFnArgs)
arg->dump();
callDiff =
BuildCallExprToMemFn(baseDiff.getExpr(), pushforwardFD->getName(),
pushforwardFnArgs, CE->getBeginLoc());
Expand All @@ -1167,7 +1174,9 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
}
}
}

if (!callDiff)
llvm::errs() << "NULL\n";
llvm::errs() << "\n\n\n";
// If clad failed to derive it, try finding its derivative using
// numerical diff.
if (!callDiff) {
Expand Down Expand Up @@ -1293,9 +1302,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
opDiff = BuildOp(opCode, derivedL, derivedR);
} else if (BinOp->isAssignmentOp()) {
if (Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
// If the LHS has a non-differentiable type, Ldiff.getExpr_dx() will be 0.
// Don't create a warning then.
if (IsDifferentiableType(BinOp->getLHS()->getType()))
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
opDiff = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
} else if (opCode == BO_Assign || opCode == BO_AddAssign ||
opCode == BO_SubAssign) {
Expand Down Expand Up @@ -1372,10 +1384,13 @@ BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
BuildVarDecl(VD->getType(), VD->getNameAsString(), initDiff.getExpr(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
// FIXME: Create unique identifier for derivative.
VarDecl* VDDerived = BuildVarDecl(
VarDecl* VDDerived = nullptr;
if (IsDifferentiableType(VD->getType())) {
VDDerived = BuildVarDecl(
VD->getType(), "_d_" + VD->getNameAsString(), initDiff.getExpr_dx(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
}
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

Expand Down Expand Up @@ -1437,7 +1452,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName())
m_DeclReplacements[VD] = VDDiff.getDecl();
decls.push_back(VDDiff.getDecl());
declsDiff.push_back(VDDiff.getDecl_dx());
if (VDDiff.getDecl_dx())
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
if (SADDiff.getDecl())
Expand Down Expand Up @@ -1576,7 +1592,7 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) {
// ...
// ...
// }
if (condVarClone) {
if (condVarRes.getDecl_dx()) {
bodyResult = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), cast<CompoundStmt>(bodyResult),
BuildDeclStmt(condVarRes.getDecl_dx()));
Expand Down Expand Up @@ -1655,7 +1671,8 @@ StmtDiff BaseForwardModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) {
if (condVarDecl) {
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()));
if (condVarDeclDiff.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()));
}

StmtDiff initVarRes = (SS->getInit() ? Visit(SS->getInit()) : StmtDiff());
Expand Down
9 changes: 9 additions & 0 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,15 @@ namespace clad {
DeclStmts.push_back(BuildDeclStmt(argDecl));
}
}
else
for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) {
QualType paramTy = originalFD->getParamDecl(i)->getType();
if (!IsDifferentiableType(paramTy)) {
QualType argTy = utils::getNonConstType(paramTy, m_Context, m_Sema);
Expr* zero = getZeroInit(argTy);
ExtendedCallArgs.insert(ExtendedCallArgs.begin() + e + i, zero);
}
}
auto MARargs = llvm::MutableArrayRef<Expr*>(ExtendedCallArgs);

if (noOverloadExists(UnresolvedLookup, MARargs))
Expand Down
30 changes: 9 additions & 21 deletions test/Arrays/ArrayInputsForwardMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,11 @@ double addArr(const double *arr, int n) {
}

//CHECK: double addArr_darg0_1(const double *arr, int n) {
//CHECK-NEXT: int _d_n = 0;
//CHECK-NEXT: double _d_ret = 0;
//CHECK-NEXT: double ret = 0;
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: _d_ret += (i == 1.);
//CHECK-NEXT: ret += arr[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: _d_ret += (i == 1.);
//CHECK-NEXT: ret += arr[i];
//CHECK-NEXT: }
//CHECK-NEXT: return _d_ret;
//CHECK-NEXT: }
Expand All @@ -59,25 +55,17 @@ double numMultIndex(double* arr, size_t n, double x) {
}

// CHECK: double numMultIndex_darg2(double *arr, size_t n, double x) {
// CHECK-NEXT: size_t _d_n = 0;
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: bool _d_flag = 0;
// CHECK-NEXT: bool flag = false;
// CHECK-NEXT: size_t _d_idx = 0;
// CHECK-NEXT: size_t idx = 0;
// CHECK-NEXT: {
// CHECK-NEXT: size_t _d_i = 0;
// CHECK-NEXT: for (size_t i = 0; i < n; ++i) {
// CHECK-NEXT: if (arr[i] == x) {
// CHECK-NEXT: _d_flag = 0;
// CHECK-NEXT: flag = true;
// CHECK-NEXT: _d_idx = _d_i;
// CHECK-NEXT: idx = i;
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: for (size_t i = 0; i < n; ++i) {
// CHECK-NEXT: if (arr[i] == x) {
// CHECK-NEXT: flag = true;
// CHECK-NEXT: idx = i;
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return flag ? _d_idx * x + idx * _d_x : 0;
// CHECK-NEXT: return flag ? 0 * x + idx * _d_x : 0;
// CHECK-NEXT: }

int main() {
Expand Down
27 changes: 9 additions & 18 deletions test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@ double sum(double x, double y, double z) {
//CHECK-NEXT: double vars[3] = {x, y, z};
//CHECK-NEXT: double _d_s = 0;
//CHECK-NEXT: double s = 0;
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_vars[i];
//CHECK-NEXT: s = s + vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_vars[i];
//CHECK-NEXT: s = s + vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: return _d_s;
//CHECK-NEXT: }
Expand All @@ -55,21 +52,15 @@ double sum_squares(double x, double y, double z) {
//CHECK-NEXT: double vars[3] = {x, y, z};
//CHECK-NEXT: double _d_squares[3];
//CHECK-NEXT: double squares[3];
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_squares[i] = _d_vars[i] * vars[i] + vars[i] * _d_vars[i];
//CHECK-NEXT: squares[i] = vars[i] * vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_squares[i] = _d_vars[i] * vars[i] + vars[i] * _d_vars[i];
//CHECK-NEXT: squares[i] = vars[i] * vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: double _d_s = 0;
//CHECK-NEXT: double s = 0;
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_squares[i];
//CHECK-NEXT: s = s + squares[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_squares[i];
//CHECK-NEXT: s = s + squares[i];
//CHECK-NEXT: }
//CHECK-NEXT: return _d_s;
//CHECK-NEXT: }
Expand Down
122 changes: 61 additions & 61 deletions test/FirstDerivative/BasicArithmeticAddSub.C
Original file line number Diff line number Diff line change
Expand Up @@ -7,90 +7,90 @@

extern "C" int printf(const char* fmt, ...);

int a_1(int x) {
int y = 4;
float a_1(float x) {
float y = 4;
return y + y; // == 0
}
// CHECK: int a_1_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK: float a_1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: return _d_y + _d_y;
// CHECK-NEXT: }

int a_2(int x) {
float a_2(float x) {
return 1 + 1; // == 0
}
// CHECK: int a_2_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK: float a_2_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: return 0 + 0;
// CHECK-NEXT: }

int a_3(int x) {
float a_3(float x) {
return x + x; // == 2
}
// CHECK: int a_3_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK: float a_3_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: return _d_x + _d_x;
// CHECK-NEXT: }

int a_4(int x) {
int y = 4;
float a_4(float x) {
float y = 4;
return x + y + x + 3 + x; // == 3x
}
// CHECK: int a_4_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK: float a_4_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: return _d_x + _d_y + _d_x + 0 + _d_x;
// CHECK-NEXT: }

int s_1(int x) {
int y = 4;
float s_1(float x) {
float y = 4;
return y - y; // == 0
}
// CHECK: int s_1_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK: float s_1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: return _d_y - _d_y;
// CHECK-NEXT: }

int s_2(int x) {
float s_2(float x) {
return 1 - 1; // == 0
}
// CHECK: int s_2_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK: float s_2_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: return 0 - 0;
// CHECK-NEXT: }

int s_3(int x) {
float s_3(float x) {
return x - x; // == 0
}
// CHECK: int s_3_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK: float s_3_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: return _d_x - _d_x;
// CHECK-NEXT: }

int s_4(int x) {
int y = 4;
float s_4(float x) {
float y = 4;
return x - y - x - 3 - x; // == -1
}
// CHECK: int s_4_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK: float s_4_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: return _d_x - _d_y - _d_x - 0 - _d_x;
// CHECK-NEXT: }

int as_1(int x) {
int y = 4;
float as_1(float x) {
float y = 4;
return x + x - x + y - y + 3 - 3; // == 1
}
// CHECK: int as_1_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: int _d_y = 0;
// CHECK-NEXT: int y = 4;
// CHECK: float as_1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: float y = 4;
// CHECK-NEXT: return _d_x + _d_x - _d_x + _d_y - _d_y + 0 - 0;
// CHECK-NEXT: }

Expand All @@ -103,45 +103,45 @@ float IntegerLiteralToFloatLiteral(float x, float y) {
// CHECK-NEXT: return _d_x * x + x * _d_x - _d_y;
// CHECK-NEXT: }

int a_1_darg0(int x);
int a_2_darg0(int x);
int a_3_darg0(int x);
int a_4_darg0(int x);
int s_1_darg0(int x);
int s_2_darg0(int x);
int s_3_darg0(int x);
int s_4_darg0(int x);
int as_1_darg0(int x);
float a_1_darg0(float x);
float a_2_darg0(float x);
float a_3_darg0(float x);
float a_4_darg0(float x);
float s_1_darg0(float x);
float s_2_darg0(float x);
float s_3_darg0(float x);
float s_4_darg0(float x);
float as_1_darg0(float x);
float IntegerLiteralToFloatLiteral_darg0(float x, float y);

int main () { // expected-no-diagnostics
int x = 4;
float x = 4;
clad::differentiate(a_1, 0);
printf("Result is = %d\n", a_1_darg0(1)); // CHECK-EXEC: Result is = 0
printf("Result is = %f\n", a_1_darg0(1)); // CHECK-EXEC: Result is = 0

clad::differentiate(a_2, 0);
printf("Result is = %d\n", a_2_darg0(1)); // CHECK-EXEC: Result is = 0
printf("Result is = %f\n", a_2_darg0(1)); // CHECK-EXEC: Result is = 0

clad::differentiate(a_3, 0);
printf("Result is = %d\n", a_3_darg0(1)); // CHECK-EXEC: Result is = 2
printf("Result is = %f\n", a_3_darg0(1)); // CHECK-EXEC: Result is = 2

clad::differentiate(a_4, 0);
printf("Result is = %d\n", a_4_darg0(1)); // CHECK-EXEC: Result is = 3
printf("Result is = %f\n", a_4_darg0(1)); // CHECK-EXEC: Result is = 3

clad::differentiate(s_1, 0);
printf("Result is = %d\n", s_1_darg0(1)); // CHECK-EXEC: Result is = 0
printf("Result is = %f\n", s_1_darg0(1)); // CHECK-EXEC: Result is = 0

clad::differentiate(s_2, 0);
printf("Result is = %d\n", s_2_darg0(1)); // CHECK-EXEC: Result is = 0
printf("Result is = %f\n", s_2_darg0(1)); // CHECK-EXEC: Result is = 0

clad::differentiate(s_3, 0);
printf("Result is = %d\n", s_3_darg0(1)); // CHECK-EXEC: Result is = 0
printf("Result is = %f\n", s_3_darg0(1)); // CHECK-EXEC: Result is = 0

clad::differentiate(s_4, 0);
printf("Result is = %d\n", s_4_darg0(1)); // CHECK-EXEC: Result is = -1
printf("Result is = %f\n", s_4_darg0(1)); // CHECK-EXEC: Result is = -1

clad::differentiate(as_1, 0);
printf("Result is = %d\n", as_1_darg0(1)); // CHECK-EXEC: Result is = 1
printf("Result is = %f\n", as_1_darg0(1)); // CHECK-EXEC: Result is = 1

clad::differentiate(IntegerLiteralToFloatLiteral, 0);
printf("Result is = %f\n", IntegerLiteralToFloatLiteral_darg0(5., 0.)); // CHECK-EXEC: Result is = 10
Expand Down
Loading

0 comments on commit 9e343b5

Please sign in to comment.