-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify pullback calls in the reverse mode #802
Simplify pullback calls in the reverse mode #802
Conversation
3c552ec
to
b573371
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
b573371
to
3aa40e5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
The work in this pull request seems good. Thank you for proactively trying to improve one of the most complicated components in the codebase. However, I think we may need to look at the issue and the solution more closely. Please correct me if I am wrong anywhere. My autodiff skills are a bit rusty now. Given a function: double do_something(double p_u, double p_v) { ... } Considering the primal code is double _r_d0 = _d_res;
double _grad0 = 0;
double _grad1 = 0;
do_something_pullback(u, v, _r_d0, _grad0, _grad1);
double _r0 = _grad0;
double _r1 = _grad1;
*_d_u += _r0;
*_d_v += _r1 With your patch, the same call will be differentiated as: double _r_d0 = _d_res;
do_something_pullback(u, v, _r_d0, &*_d_u, &*_d_v); The above two differentiated codes look similar and it appears that they should bring the same result mathematically. However, that's not the case. This is because of two reasons:
Let's look at what happens in a function call more closely. Primal code r = do_something(u, v); // u and v are passed by value This call can be made more clear by expanding it as follows: double p_u = u;
double p_v = v;
do_something(p_u, p_v); // p_u, and p_v are passed by reference In the original case, A practical example where this distinction matters: #include "clad/Differentiator/Differentiator.h"
#include <iostream>
#define show(x) std::cout << #x << ": " << x <<"\n";
double reset(double u) {
u = 0;
return u;
}
double fn(double u, double v) {
double res = u + v;
res += reset(u);
res += u;
return res;
}
int main() {
auto fn_grad = clad::gradient(fn);
double u = 3, v = 5, du = 0, dv = 0;
fn_grad.execute(u, v, &du, &dv);
show(du);
show(dv);
} With this patch, the above code (incorrectly) outputs:
The correct output is:
Please let me know your thoughts on this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see the above comment for details.
@parth-07 |
cb3f1a4
to
e779baf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
That's great! It's important to simplify the call differentiation component. It has grown really complex over the years. |
836aa4f
to
b813c44
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #802 +/- ##
==========================================
+ Coverage 94.86% 94.95% +0.08%
==========================================
Files 49 49
Lines 7357 7468 +111
==========================================
+ Hits 6979 7091 +112
+ Misses 378 377 -1
... and 2 files with indirect coverage changes
|
20c252f
to
19e5e9e
Compare
clang-tidy review says "All clean, LGTM! 👍" |
0044e43
to
067ab71
Compare
clang-tidy review says "All clean, LGTM! 👍" |
1 similar comment
clang-tidy review says "All clean, LGTM! 👍" |
70a186d
to
51afd38
Compare
clang-tidy review says "All clean, LGTM! 👍" |
clang-tidy review says "All clean, LGTM! 👍" |
Can you please add more details in the pull-request description / commit-message regarding how the pullback calls are being simplified? |
f8e50b7
to
80034ea
Compare
clang-tidy review says "All clean, LGTM! 👍" |
…riables and not storing pointer args
80034ea
to
453043e
Compare
clang-tidy review says "All clean, LGTM! 👍" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good apart from the few minor comments.
// The argument is passed by reference if it's passed as an L-value. | ||
// However, if arg is a MaterializeTemporaryExpr, then arg is a | ||
// temporary variable passed as a const reference. | ||
bool isRefType = arg->isLValue() && !isa<MaterializeTemporaryExpr>(arg); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why is lvalue a reference type?
int a = b; // a is an l-value, but not a reference.
int &a_ref = a; // a_ref is an l-value and a reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arg
is supposed to be the argument expression passed to the function. If the function expects a ref-type argument, then arg
is an l-value (usually a DeclRefExpr
). But when it expects a non-ref type argument, it is implicitly converted to an r-value. The AST of arg
will look somewhat like this:
ImplicitCastExpr <l-value to r-value>
-DeclRefExpr
So arg will be an r-value. At least this is my understanding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy made some suggestions
49a1593
to
b340d05
Compare
clang-tidy review says "All clean, LGTM! 👍" |
This PR simplifies both the code of
RMV::VisitCallExpr
and the code generated by it. In particular, it replaces the_grad
/_r
variable pairs with single_r
variables. The PR also removes some dead code fromRMV::VisitCallExpr
, addresses clang-tidy warnings triggered by it, and improves its test coverage.Fixes #801.