Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nested gradient fails with "Mutating arrays is not supported" #1244

Closed
tmbdev opened this issue Jun 16, 2022 · 10 comments · Fixed by #1250
Closed

nested gradient fails with "Mutating arrays is not supported" #1244

tmbdev opened this issue Jun 16, 2022 · 10 comments · Fixed by #1250
Labels
ChainRules adjoint -> rrule, and further integration second order zygote over zygote, or otherwise

Comments

@tmbdev
Copy link

tmbdev commented Jun 16, 2022

Source code:

using Zygote

println("start")

function f(w, x)
    return sum(maximum((w * x).^2, dims=1))
end

function g(w, x)
    d = gradient(f, w, x)[2]
    return sum(d.^2)
end

function h(w, x)
    t = gradient(g, w, x)[1]
    return t
end

w = rand(10, 100)
x = rand(100, 7)

x = f(w, x)
println("x=", x)
y = g(w, x)
println("y=", y)
z = h(w, x)
println("z=", z)

print("done")

Output:

$ julia ztest.jl
start
x=4589.3400747042815
y=5.4911019133806146e11
ERROR: LoadError: Mutating arrays is not supported -- called setindex!(::Matrix{Float64}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#437#438"{Matrix{Float64}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/array.jl:71
  [3] (::Zygote.var"#2337#back#439"{Zygote.var"#437#438"{Matrix{Float64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ~/.julia/packages/Zygote/DkIUK/src/lib/array.jl:327 [inlined]
  [5] (::typeof(∂(λ)))(Δ::Tuple{Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:73 [inlined]
  [7] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/exp/jadv/ztest.jl:6 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41 [inlined]
 [11] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:76 [inlined]
 [13] (::typeof(∂(gradient)))(Δ::Tuple{Nothing, Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/exp/jadv/ztest.jl:10 [inlined]
 [15] (::typeof(∂(g)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [16] (::Zygote.var"#52#53"{typeof(∂(g))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41
 [17] gradient(::Function, ::Matrix{Float64}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:76
 [18] h(w::Matrix{Float64}, x::Float64)
    @ Main ~/exp/jadv/ztest.jl:15
 [19] top-level scope
    @ ~/exp/jadv/ztest.jl:26
in expression starting at /home/tmb/exp/jadv/ztest.jl:26
@tmbdev
Copy link
Author

tmbdev commented Jun 17, 2022

I believe the problem is thast the adjoint definition for maximum is not differentiable because it mutates an array. The solution is to provide a separate adjoint for this mutation operation in src/lib/array.jl:

@adjoint function values2array(values, indexes, shape)
    result = fill(0.0, shape...)
    result[indexes] = values
    return result, function(Δ)
        return (Δ[indexes], nothing, nothing)
    end
end

@adjoint function maximum(xs::AbstractArray; dims = :)
  max, i = findmax(xs, dims = dims)
  max, function (Δ)
    Δ isa Real && abs(Δ) <= sqrt(eps(float(Δ))) && return nothing
    # Δ′ = zero(xs)
    # Δ′[i] = Δ
    Δ′ = values2array(Δ, i, size(xs))
    return (Δ′,)
  end
end

@ToucheSir
Copy link
Member

Can you try deleting the existing @adjoint rule and seeing if the rule at https://github.com/JuliaDiff/ChainRules.jl/blob/a0d86fea0c27d9d9ff2b1872b3f7601bf20b4999/src/rulesets/Base/array.jl#L581-L585 works? Ideally we'd want to improve the rules upstream wherever possible and remove the @adjoint afterwards.

@tmbdev
Copy link
Author

tmbdev commented Jun 17, 2022

I'll give it a try and let you know.

(In general, I'm noticing a lot of problems with gradients of functions of gradients, and I'm having a hard time debugging them because it's mostly macro expansions and _pullback calls, with no indication of the actual source of the problem.)

@ToucheSir
Copy link
Member

ToucheSir commented Jun 17, 2022

Yeah, nested AD is a big pain point with Zygote. That's why you see so many recommendations to use mixed forward + reverse mode where possible (it's also more algorithmically efficient).

On debugging tips, Zygote stacktraces are a bit of a beast but there are a couple things you can look for:

julia> gradient(x -> sum(union!(x, x)), ones(1))
ERROR: Mutating arrays is not supported -- called setindex!(::Vector{Float64}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#437#438"{Vector{Float64}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/array.jl:71
  [3] (::Zygote.var"#2337#back#439"{Zygote.var"#437#438"{Vector{Float64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ./array.jl:2528 [inlined]
  [5] (::typeof((filter!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./array.jl:2605 [inlined]
  [7] (::typeof((_grow!)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./array.jl:2612 [inlined]
  [9] (::typeof((union!)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[4]:1 [inlined]
 [11] (::typeof((#5)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#52#53"{typeof((#5))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:76
 [14] top-level scope
    @ REPL[4]:1

julia> f(x) = error(x)
f (generic function with 1 method)

julia> gradient(f, 1)
ERROR: 1
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context, f::typeof(throw), args::ErrorException)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:9
  [3] _pullback
    @ ./error.jl:42 [inlined]
  [4] _pullback(ctx::Zygote.Context, f::typeof(error), args::Int64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [5] _pullback
    @ ./REPL[6]:1 [inlined]
  [6] _pullback(ctx::Zygote.Context, f::typeof(f), args::Int64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [7] _pullback(f::Function, args::Int64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:34
  [8] pullback(f::Function, args::Int64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:40
  [9] gradient(f::Function, args::Int64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:75
 [10] top-level scope
    @ REPL[7]:1
  1. Source locations like @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/array.jl:71. this is a tip-off that some @adjoint rule is being run. If you go to that location, you'll be able to see the rule definition.
  2. rrule(...). These are ChainRules rules and generally pretty easy to spot. You can also find their source location relatively easily.
  3. typeof(∂(somefunction)). What you care about most here is that an auto-generated pullback for somefunction is running. That means there was no rule found for it, so if an error is happening further up the stack you may want to look at the original source to see if there is anything which might make Zygote unhappy (and ignore/add rules to intermediate functions if need be).
  4. Frames like [4] _pullback(ctx::Zygote.Context, f::typeof(f), args::Int64). These represent the augmented forward pass of your code. You can mentally rewrite this to f(::Int64) and that should help with finding which method of f is being called.

@tmbdev
Copy link
Author

tmbdev commented Jun 18, 2022

Thanks. In ZygoteRules.gradm(...), couldn't you add more info to the generated code to make tracebacks more informative? It would really be useful to be able to see in the stack trace what function the rrule was invoked for.

For getting this information into the stack trace, you might be able to...

  • use meaningful tags as arguments to gensym()
  • add a keyword "info=" argument to _pullback that contains contextual information (supplied in the macro), but isn't otherwise used
  • have _pullback call another function that performs the actual work and whose name reflects the context of the rrule
  • use a builtin mechanism to annotate the stack frame with context information (if Julia has such a mechanism)

Also, you have a test suite for differentiation; maybe test cases for gradients-of-functions-of-gradients could just be included into that test suite?

@ToucheSir
Copy link
Member

ToucheSir commented Jun 19, 2022

Thanks for the suggestions. My initial thoughts below:

  1. The only place gensym is used is for the f argument to _pullback(::Context, f, args...). It's not exactly clear to me why the function name isn't used as a tag, so that's something to explore.
  2. _pullback is actually the top-level function in a sense because Zygote's compiler directly emits calls to it, so adding a kwarg would be difficult. Perhaps if you could elaborate on specific types of contextual information to include here, we could brainstorm where best to place them.
  3. Were it not @inlined, ZygoteRules.adjoint is already another function called by _pullback. My concern with having a different name for every function is that both _pullback and adjoint are generic, so messing with the call stack may break downstream code which expects rules to hang methods off those functions.
  4. This ties well into the previous point. If anything, we need to reduce the number of seemingly redundant frames in stacktraces. Part of that is on Base, but we could also try to remove some of the layers generated by gradm. Going the direct route for this is tricky for the reason I mentioned above, but this discussion may have inspired an alternative approach that skirts around them.

maybe test cases for gradients-of-functions-of-gradients could just be included into that test suite?

We already have tests like this in various places, e.g.

Zygote.jl/test/utils.jl

Lines 5 to 20 in dad65a8

@testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse]
if hess == hessian_dual
@test hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0]
@test hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0] # original docstring version
else
@test_broken hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0] # can't differentiate ∇getindex
@test_broken hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0]
end
@test hess(x -> sum(x.^3), [1 2; 3 4]) Diagonal([6, 18, 12, 24])
@test hess(sin, pi/2) -1
@test_throws Exception hess(sin, im*pi)
@test_throws Exception hess(x -> x+im, pi)
@test_throws Exception hess(identity, randn(2))
end
. What we're lacking is a more rigorous and possibly centralized place for them. Two difficulties for testing nested AD for everything are that the number of functions that don't work is far beyond what we can tractably implement (and we don't want 80% of our tests to be marked broken because of it), and that it's not always clear what the second derivative rule for certain functions should be. That's not to say we should do nothing, but more to explain why things are done on a more ad-hoc and as-needed basis right now.

FWIW the end goal is to move everything to use rrule. This may be done via direct translation in Zygote.jl or via up/downstreaming rules to ChainRules/ZygoteRules dependants and then deleting the rules in this repo. The process is not too complex, but we are currently bottlenecked by a lack of dev hours from folks who understand particular rules (+ both rule systems) well enough to translate them. That said, there's probably low-hanging fruit which only needs a rough mechanical translation.

@mcabbott
Copy link
Member

For maximum, CR already has a rule which should allow second order:

https://github.com/JuliaDiff/ChainRules.jl/blob/a0d86fea0c27d9d9ff2b1872b3f7601bf20b4999/src/rulesets/Base/array.jl#L581-L585

So simply deleting Zygote's rules should solve things:

Zygote.jl/src/lib/array.jl

Lines 322 to 339 in 3239330

@adjoint function maximum(xs::AbstractArray; dims = :)
max, i = findmax(xs, dims = dims)
max, function (Δ)
Δ isa Real && abs(Δ) <= sqrt(eps(float(Δ))) && return nothing
Δ′ = zero(xs)
Δ′[i] = Δ
return (Δ′,)
end
end
@adjoint function minimum(xs::AbstractArray; dims = :)
min, i = findmin(xs, dims = dims)
min, function (Δ)
Δ′ = zero(xs)
Δ′[i] = Δ
return (Δ′,)
end
end
(This includes the line array.jl:327 mentioned in the stack trace.)

@tmbdev
Copy link
Author

tmbdev commented Jun 24, 2022

Thanks for the responses and help.

I'm trying to convert some Jax code; gradients-of-gradients work quite reliably in Jax, probably because it uses XLA underneath.

It seems like there is initial support for XLA in Julia and Flux; do you think that would help with gradient-of-gradient? What is the performance?

Another option seems to be Enzyme; has Enzyme been integrated with Flux yet? Would that help?

@mcabbott
Copy link
Member

There were experiments with hooking things up to XLA, but I don't think anyone is working on that. I don't think they passed the handling of AD to XLA, just the execution.

The immediate plan was to move from Zygote to Diffractor, which uses the same rrules from ChainRules. There aren't so many of these rules which don't allow second-order use, mostly what's needed is more hands to fix them -- like removing the offending rule here, an easy PR.

I am not sure what the status of Enzyme with things like CuBLAS is. This is a much larger deviation from how Flux currently works.

@mcabbott mcabbott added second order zygote over zygote, or otherwise ChainRules adjoint -> rrule, and further integration labels Jul 4, 2022
@fuyangfeng
Copy link

fuyangfeng commented Aug 22, 2022

Yeah, nested AD is a big pain point with Zygote. That's why you see so many recommendations to use mixed forward + reverse mode where possible (it's also more algorithmically efficient).

How to perform mixed forward + reverse mode? Could you please use the example above as an example?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration second order zygote over zygote, or otherwise
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants