-
-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make the rrule's outer product lazy #484
Conversation
After this PR, the gradient wrt the matrix requires only using LinearAlgebra, LinearSolve, Zygote
n = 100; A = rand(n, n); b1 = rand(n); b2 = rand(n);
function invquad(a, A, b)
prob = LinearProblem(A, b)
sol = solve(
prob,
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.RFLUFactorization),
)
return dot(a, sol.u)
end
db1, dA, db2 = Zygote.gradient(invquad, b1, A, b2);
Base.summarysize(dA)
# 1752
Base.summarysize(A)
# 80040 |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #484 +/- ##
===========================================
- Coverage 64.22% 25.10% -39.13%
===========================================
Files 28 28
Lines 2200 2167 -33
===========================================
- Hits 1413 544 -869
- Misses 787 1623 +836 ☔ View full report in Codecov by Sentry. |
Tests seem to pass. I added version 3 to the docs toml file to hopefully fix the docs build. |
The test failure is new and not related to this PR. I only added a version number in the last commit but the test failure is a method ambiguity failure. If you re-run the tests on master, you will probably get the same failure. Could be a dependency that upgraded and broke things between the 2 commits. |
It seems like the tolerance is just set too tight in that test and multithreading in BLAS change it at that level. |
new release? |
I was going to handle some downgrade and test tolerance stuff #485 and release in a little bit. |
Wait, why was this a major? |
To be safe, I bumped the major version because the output type of Zygote.gradient wrt the matrix is changed in this PR. |
I don't think we guaranteed the type on the pullback anywhere, just that it has the right actions in the derivative, and the Zygote overload was just added a release ago, so together I don't think this constitute a major bump but instead a minor. |
Feel free to revert it. I was being safe. I wouldn't want the type to change on me if I am a user. |
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
This PR implements the suggestion in https://discourse.julialang.org/t/how-do-you-speed-up-the-linear-sparse-solver-in-zygote/111801/41?u=mohamed82008. To be safe, I also bumped the major version because the output type of
Zygote.gradient
wrt the matrix is changed in this PR. No new documentation was added as this PR is just a performance improvement.