Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SparseDiffTools v2 for steadystateadjoint #808

Merged
merged 8 commits into from
Aug 22, 2023

Conversation

vpuri3
Copy link
Member

@vpuri3 vpuri3 commented Mar 22, 2023

@vpuri3
Copy link
Member Author

vpuri3 commented Mar 23, 2023

TODO- update OrdinaryDiffEq to use sparsedifftools v2

@@ -98,8 +100,7 @@ end
end

if !needs_jac
# TODO: FixedVecJacOperator should respect the `autojacvec` of the algorithm
operator = FixedVecJacOperator(f, y, p, Val(DiffEqBase.isinplace(sol.prob)))
operator = VecJac(f, y, p; autodiff = get_autodiff_from_vjp(vjp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really equivalent. IIRC VecJac recomputes the pullback everytime a call to mul! is made. In this case, we have a fixed input, only the seeding changes so we compute the pullback once and just reevaluate it multiple times.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually fixed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vpuri3 vpuri3 closed this May 27, 2023
@vpuri3 vpuri3 reopened this May 27, 2023
@vpuri3 vpuri3 closed this May 30, 2023
@vpuri3 vpuri3 reopened this May 30, 2023
@vpuri3
Copy link
Member Author

vpuri3 commented Jun 2, 2023

retriggering CI

@vpuri3 vpuri3 closed this Jun 2, 2023
@vpuri3 vpuri3 reopened this Jun 2, 2023
@vpuri3
Copy link
Member Author

vpuri3 commented Jun 2, 2023

We should add a test case to ensure that Zygote is able to propagate gradients past solve(::LinearProblem)

@avik-pal
Copy link
Member

avik-pal commented Jun 2, 2023

We should add a test case to ensure that Zygote is able to propagate gradients past solve(::LinearProblem)

Does this work yet? I was using LinearSolve.jl, and I think the adjoints are not implemented yet. I need it (quite desperately), so I will probably implement something by next week.

using Zygote, ForwardDiff, SciMLSensitivity, SciMLBase, LinearSolve, ComponentArrays,
      FiniteDiff

function loss_function(θ)
    (; A, b) = θ
    prob = LinearProblem(A, b)
    sol = solve(prob, nothing)
    return sum(sol.u)
end

function loss_function_chainrules(θ)
    (; A, b) = θ
    x = A \ b
    return sum(x)
end

A = Float32[1 0; 1 -2]; b = Float32[32; -4];
θ = ComponentArray(; A, b)

loss_function(θ)  loss_function_chainrules(θ)  # true

Zygote.gradient(loss_function, θ)  # fails
Zygote.gradient(loss_function_chainrules, θ)  # works

ForwardDiff.gradient(loss_function, θ)  # fails
ForwardDiff.gradient(loss_function_chainrules, θ)  # works

FiniteDiff.finite_difference_gradient(loss_function, θ)  # works

@vpuri3
Copy link
Member Author

vpuri3 commented Jun 2, 2023

with this branch, your example is erroring on some try/catch block which is incompatible with zygote.

ERROR: LoadError: Compiling Tuple{LinearSolve.var"##solve#32", Base.Pairs{Symbol, Union
{}, Tuple{}, NamedTuple{(), Tuple{}}}, typeof(solve), LinearSolve.LinearCache{Base.Resh
apedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, tr
ue}, Tuple{}}, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Ve
ctor{Float32}, SciMLBase.NullParameters, KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}
, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Krylov.GmresSolver{Fl
oat32, Float32, Vector{Float32}}, SciMLOperators.IdentityOperator, SciMLOperators.Ident
ityOperator, Float32, true, LinearSolve.OperatorCondition.IllConditioned}, KrylovJL{typ
eof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(),
 Tuple{}}}}}: try/catch is not supported.                                              
Refer to the Zygote documentation for fixes.                                           
https://fluxml.ai/Zygote.jl/latest/limitations                                         ```

@avik-pal
Copy link
Member

avik-pal commented Jun 3, 2023

Yes, ik I gave that as a testcase that doesn't work. The linear problem doesn't have tests because the adjoints aren't implemented.

Also, taking a look at how linearsolve works, it might a bit refactoring before we can include adjoints since it dispatches on solve instead of __solve

@ChrisRackauckas
Copy link
Member

We can make a higher level in SciMLBase and dispatch it on __solve.

@avik-pal
Copy link
Member

We can get this PR merged, right? The linear solve issue is entirely tangential and needs to be handled downstream first.

@ChrisRackauckas
Copy link
Member

someone needs to resolve the merge.

@ChrisRackauckas ChrisRackauckas merged commit 964a600 into SciML:master Aug 22, 2023
21 of 22 checks passed
@ChrisRackauckas
Copy link
Member

No Enzyme or ReverseDiff though?

@vpuri3 vpuri3 deleted the scimlops branch August 23, 2023 05:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants