-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nested gradient fails with "Mutating arrays is not supported" #1244
Comments
I believe the problem is thast the adjoint definition for @adjoint function values2array(values, indexes, shape)
result = fill(0.0, shape...)
result[indexes] = values
return result, function(Δ)
return (Δ[indexes], nothing, nothing)
end
end
@adjoint function maximum(xs::AbstractArray; dims = :)
max, i = findmax(xs, dims = dims)
max, function (Δ)
Δ isa Real && abs(Δ) <= sqrt(eps(float(Δ))) && return nothing
# Δ′ = zero(xs)
# Δ′[i] = Δ
Δ′ = values2array(Δ, i, size(xs))
return (Δ′,)
end
end |
Can you try deleting the existing |
I'll give it a try and let you know. (In general, I'm noticing a lot of problems with gradients of functions of gradients, and I'm having a hard time debugging them because it's mostly macro expansions and _pullback calls, with no indication of the actual source of the problem.) |
Yeah, nested AD is a big pain point with Zygote. That's why you see so many recommendations to use mixed forward + reverse mode where possible (it's also more algorithmically efficient). On debugging tips, Zygote stacktraces are a bit of a beast but there are a couple things you can look for: julia> gradient(x -> sum(union!(x, x)), ones(1))
ERROR: Mutating arrays is not supported -- called setindex!(::Vector{Float64}, _...)
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#437#438"{Vector{Float64}})(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/array.jl:71
[3] (::Zygote.var"#2337#back#439"{Zygote.var"#437#438"{Vector{Float64}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ./array.jl:2528 [inlined]
[5] (::typeof(∂(filter!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[6] Pullback
@ ./array.jl:2605 [inlined]
[7] (::typeof(∂(_grow!)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[8] Pullback
@ ./array.jl:2612 [inlined]
[9] (::typeof(∂(union!)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[4]:1 [inlined]
[11] (::typeof(∂(#5)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[12] (::Zygote.var"#52#53"{typeof(∂(#5))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:76
[14] top-level scope
@ REPL[4]:1
julia> f(x) = error(x)
f (generic function with 1 method)
julia> gradient(f, 1)
ERROR: 1
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(ctx::Zygote.Context, f::typeof(throw), args::ErrorException)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:9
[3] _pullback
@ ./error.jl:42 [inlined]
[4] _pullback(ctx::Zygote.Context, f::typeof(error), args::Int64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[5] _pullback
@ ./REPL[6]:1 [inlined]
[6] _pullback(ctx::Zygote.Context, f::typeof(f), args::Int64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[7] _pullback(f::Function, args::Int64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:34
[8] pullback(f::Function, args::Int64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:40
[9] gradient(f::Function, args::Int64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:75
[10] top-level scope
@ REPL[7]:1
|
Thanks. In ZygoteRules.gradm(...), couldn't you add more info to the generated code to make tracebacks more informative? It would really be useful to be able to see in the stack trace what function the rrule was invoked for. For getting this information into the stack trace, you might be able to...
Also, you have a test suite for differentiation; maybe test cases for gradients-of-functions-of-gradients could just be included into that test suite? |
Thanks for the suggestions. My initial thoughts below:
We already have tests like this in various places, e.g. Lines 5 to 20 in dad65a8
FWIW the end goal is to move everything to use |
For So simply deleting Zygote's rules should solve things: Lines 322 to 339 in 3239330
array.jl:327 mentioned in the stack trace.)
|
Thanks for the responses and help. I'm trying to convert some Jax code; gradients-of-gradients work quite reliably in Jax, probably because it uses XLA underneath. It seems like there is initial support for XLA in Julia and Flux; do you think that would help with gradient-of-gradient? What is the performance? Another option seems to be Enzyme; has Enzyme been integrated with Flux yet? Would that help? |
There were experiments with hooking things up to XLA, but I don't think anyone is working on that. I don't think they passed the handling of AD to XLA, just the execution. The immediate plan was to move from Zygote to Diffractor, which uses the same I am not sure what the status of Enzyme with things like CuBLAS is. This is a much larger deviation from how Flux currently works. |
How to perform mixed forward + reverse mode? Could you please use the example above as an example? |
Source code:
Output:
The text was updated successfully, but these errors were encountered: