-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Supporting Linear Algebraic Primitives #10
Comments
Just to be clear, let's address your
I would never argue that In order to make such a declaration in DiffRules, however, we still need to add API support for linear algebraic primitives (as I mentioned earlier). |
Great, it looks like we're on pretty much the same page then. Possibly our only difference of opinion is that I now can't see why you would ever want to have a library of "eager" kernels, as opposed to one that provides the code you need to automatically compile your own in a downstream package. I can't think of a situation in which exposing objects that contain
isn't strictly better than providing the first three things + the corresponding implemented method. A downstream package can clearly reconstruct the method given the code (it doesn't really matter how long the code for any particular sensitivity is), and as you pointed out it may be possible to perform optimisations given a symbolic representation of the sensitivity that you can't when you only have access to a method (i.e. it might be useful to perform a CSE optimisation when the sensitivities w.r.t. multiple arguments are required - if you compile a custom sensitivity on the fly using symbolic representations of the sensitivity w.r.t each argument, then you can do such optimisations). What are your thoughts on this? I may be missing something obvious. (On a related note, it might be an idea to replace the things in the last two bullet points with a function which accepts the argument names the downstream package wants to use, and returns code using those argument names) |
We're in full agreement here.
Definitely. Actually, this made me think of a nice API change for helping with manually-optimized "multiple sensitivities" cases (e.g. where CSE etc. can't/doesn't suffice). Currently, DiffRules requires that the rule author provide a standalone derivative expression for each argument. Instead, we could require that rule authors mark differentiated variables explicitly, for example: @define_diffrule M.f(wrt(x), wrt(y)) = expr_for_dfdx_and_dfdy($x, $y)
@define_diffrule M.f(wrt(x), y) = expr_for_dfdx($x, $y)
@define_diffrule M.f(x, wrt(y)) = expr_for_dfdy($x, $y)
Well, it's up to the rule author to decide the level of granularity of the function calls present in the derivative expression. On one extreme end of the spectrum, the rule author can inline as much as possible (i.e. compose the derivative expression using only
Yup, that's the way DiffRules currently works (you can interpolate |
I agree with all of the above. I like your Running with this, a reverse-mode rule could be something like: @define_reverse z::Tz z̄::Tz̄ M.f(wrt(x::Tx), y::Ty) = expr_dOdx($z, $z̄, $x, $y)
@define_reverse z::Tz z̄::Tz̄ M.f(wrt(x::Tx), y::Ty) = expr_dOdx!($x̄, $z, $z̄, $x, $y) x̄::Tx̄ where I've just given the macro a different name and added a couple of extra terms at the front end to pass in the , and the second rule is for in-place updates for if @define_forward ẋ::Tẋ ẏ::Tẏ M.f(x::Tx, y::Ty) = expr_dfdI($x, $y, $ẋ, $ẏ) Does this sound reasonable? The above doesn't directly address more complicated method definitions (e.g. involving diagonal dispatch), but I can't see any reason in principle that it couldn't be extended to handle that kind of thing. Also, I'm not sure about the ordering of the arguments for the in-place |
(continued from invenia/Nabla.jl#81, cc @willtebbutt)
Agreed, DiffRules only properly handles scalar kernels now. To support linear algebra, we need to add a notion of tensor/scalar, allowing in-place methods, marking adjoint variables, etc. to DiffRules.
I think there might've been a misunderstanding with my previous post 😛I definitely am not arguing that we should express e.g. complex LAPACK kernels symbolically, and I didn't mean to imply that DiffRules/DiffLinearAlgebra were directly competing approaches. On the contrary, I think they're quite complementary - if DiffLinearAlgebra didn't exist, I eventually would need to make a "DiffKernels.jl" anyway. DiffRules is useful for mapping primal functions to derivative functions, and is thus useful when generating e.g. instruction primitives/computation graphs within downstream tools (i.e. it solves the problem "what kernels should I call and how should I call them?"). DiffLinearAlgebra (as it stands) is useful for providing implementations of these kernels (i.e. solves the problem "how do I execute the kernels that I'm calling?"). They're both necessary components of the AD ecosystem.
As for deciding what computations should be primitives, I think we're already on the same page; a computation should be defined as a primitive if either/both of the following applies:
The text was updated successfully, but these errors were encountered: