Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterated gradients error #1327

Open
Vilin97 opened this issue Nov 10, 2022 · 5 comments
Open

Iterated gradients error #1327

Vilin97 opened this issue Nov 10, 2022 · 5 comments
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration second order zygote over zygote, or otherwise

Comments

@Vilin97
Copy link

Vilin97 commented Nov 10, 2022

Package Version

Zygote v0.6.49

Julia Version

julia version 1.8.2

OS / Environment

Windows

Describe the bug

Iterated gradients seem to break arbitrarily. See the examples below.

Steps to Reproduce

using Zygote: gradient

g1(x,y) = (x+y)[1]
dxg1(x,y) = gradient(x -> g1(x,y), x)[1][1] #partial of g₁ wrt x₁
dxg1(ones(2), ones(2)) # 1.0, as expected
dxyg1(x,y) = gradient(y -> dxg1(x,y), y)[1][1] #partial of dxg1 wrt y₁
dxyg1(ones(2), ones(2)) # ERROR: Need an adjoint for constructor Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}. Gradient is of type Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}

g2(x,y) = transpose(x)*y
dxg2(x,y) = gradient(x -> g2(x,y), x)[1][1] 
dxyg2(x,y) = gradient(y -> dxg2(x,y), y)[1][1] 
dxyg2(ones(2), ones(2)) # 1.0, as expected

g3(x,y) = sum(x.*y)
dxg3(x,y) = gradient(x -> g3(x,y), x)[1][1] 
dxyg3(x,y) = gradient(y -> dxg3(x,y), y)[1][1] 
dxyg3(ones(2),ones(2)) # 1.0, as expected

g4(x,y) = x[1]*y[1] + x[2]*y[2]
dxg4(x,y) = gradient(x -> g4(x,y), x)[1][1] #partial wrt x
dxyg4(x,y) = gradient(y -> dxg4(x,y), y)[1][1] #mixed second derivative
dxyg4(ones(2),ones(2)) # ERROR: Need an adjoint for constructor Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}. Gradient is of type Vector{Float64}

Expected Results

I expected the code above to not error.

Observed Results

The code errors.

Relevant log output

No response

@Vilin97 Vilin97 added the bug Something isn't working label Nov 10, 2022
@mcabbott mcabbott added the second order zygote over zygote, or otherwise label Nov 10, 2022
@mcabbott
Copy link
Member

This is #820.

The solution is probably to remove most of Zygote's rules for indexing, as the ones at CR should now work at second order:

https://github.com/JuliaDiff/ChainRules.jl/blob/39c2d17df672836659493d6adb7d4ad8593250a5/src/rulesets/Base/indexing.jl#L63

Someone just has to do it.

@mcabbott mcabbott added the ChainRules adjoint -> rrule, and further integration label Nov 10, 2022
@Vilin97
Copy link
Author

Vilin97 commented Nov 14, 2022

Temporary solution is to dev Zygote and comment out one line. Then everything works.

# @adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

Thanks to @mcabbott for the workaround!

@Vilin97
Copy link
Author

Vilin97 commented Dec 1, 2022

@ToucheSir , any update on this?

@ToucheSir
Copy link
Member

Per #1328 (comment), it looks like some changes are required in ChainRules before we can make the Flux-side ones.

@Vilin97
Copy link
Author

Vilin97 commented Dec 4, 2022

More on this.
I am defining three equivalent methods for divergence, taking gradients, and getting nothing, an ERROR, and the correct result for them. That culprit line in Zygote is still commented out!

using Zygote, LinearAlgebra
f(x) = -x
x = [1., 1.]
function divergence1(f, v)
    _, ∂f = pullback(f, v)
    id = I(length(v))
    sum(eachindex(v)) do i
        ∂f( @view id[i,:] )[1][i]
    end
end
divergence2(f,v) = tr(jacobian(f, v)[1])
divergence3(f,v) = sum(gradient(v -> f(v)[i], v)[1][i] for i in eachindex(v))
divergence1(f, x) == divergence2(f, x) # true
divergence1(f, x) == divergence3(f, x) # true
gradient(x -> divergence1(f,x), x) # (nothing,)
gradient(x -> divergence2(f,x), x) # ERROR: Mutating arrays is not supported -- called copyto!(SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ...)
gradient(x -> divergence3(f,x), x) # correct answer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration second order zygote over zygote, or otherwise
Projects
None yet
Development

No branches or pull requests

3 participants