Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ChainRules zero types internally #1389

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Zygote
using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
import ZygoteRules: ZygoteRules, @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield, unthunk_tangent

using ChainRulesCore
Expand Down
53 changes: 48 additions & 5 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,6 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us
@inline wrap_chainrules_output(x) = x
@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
@inline wrap_chainrules_output(x::ChainRulesCore.NotImplemented) = nothing
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
Expand All @@ -125,6 +121,8 @@ end
wrap_chainrules_output(dxs::AbstractArray{<:Number}) = dxs
wrap_chainrules_output(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
wrap_chainrules_output(dxs::AbstractArray) = map(wrap_chainrules_output, dxs)


#=
# As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers
@inline function wrap_chainrules_output(dxs::AbstractArray{<:ChainRules.Tangent{<:Any, B}}) where {B}
Expand Down Expand Up @@ -152,6 +150,7 @@ Convert `dx` from the format Zygote uses internally to differentials types Chain
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::Tuple{Vararg{Nothing}}) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(dxs::AbstractArray{T}) where {T<:AbstractZero} = first(dxs)
@inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, dxs)
# This produces Tangent{Any} since it does not get to see the primal, `x`.
Expand Down Expand Up @@ -186,9 +185,12 @@ Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
Safe to apply to arbitrary input.
"""
@inline function _project(x, dx)
wrap_chainrules_output(ProjectTo(x)(zygote2differential(dx, x)))
differential2zygote(ProjectTo(x)(zygote2differential(dx, x)))
end

_project(_, dx::Nothing) = nothing
_project(x::Tuple, dx::Tuple) = map(_project, x, dx)

# Restore splatted arrays
_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)))

Expand Down Expand Up @@ -350,3 +352,44 @@ z2d(dx::NamedTuple{L,S}, primal::AbstractDict) where {L,S<:Tuple{Vararg{Union{Nu
end

z2d(dx::Ref, primal) = z2d(dx[], primal) # mutable structs


"""
differential2zygote(dx)

Convert input `dx` from ChainRules differential types to the Zygote format.
This is similar to `wrap_chainrules_output(dx)`, but converts zero types,
and recursively converts Tangents.
"""
@inline differential2zygote(@nospecialize(x)) = x
@inline differential2zygote(::AbstractZero) = nothing
@inline differential2zygote(::ChainRulesCore.NotImplemented) = nothing
@inline differential2zygote(x::AbstractThunk) = differential2zygote(unthunk(x)) # For now we are just not going to deal with thunks
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
@eval @inline differential2zygote(x::$T_outer) = map(differential2zygote, x)
@eval @inline function differential2zygote(x::Tangent{<:Any, <:$T_outer})
# this is accessing ChainRulesCore internals, but it is prob safe enough, and it is fastest
inner = ChainRulesCore.backing(canonicalize(x))
return differential2zygote(inner)
end
end
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline differential2zygote(::Tuple{Vararg{AbstractZero}}) = nothing
@inline differential2zygote(::Tuple{}) = () # Edge case split off from the above method

differential2zygote(dxs::AbstractArray{<:Number}) = dxs
differential2zygote(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
differential2zygote(dxs::AbstractArray) = map(differential2zygote, dxs)
differential2zygote(dxs::Dict) = Dict(k => differential2zygote(v) for (k, v) in dxs)

# Mostly used in rule genfuncs
_iszerotype(T) = T === Nothing || T <: AbstractZero

# Note: safe piracy to make @adjoint definitions work
ZygoteRules.gradtuple0(x::AbstractZero) = x
ZygoteRules.gradtuple1(x::AbstractZero) = x
ZygoteRules.gradtuple2(x::AbstractZero) = x
ZygoteRules.gradtuple3(x::AbstractZero) = x
10 changes: 5 additions & 5 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ tailmemaybe(x::Tuple) = Base.tail(x)
@inline pullback(f, args...) = pullback(f, Context(), args...)
function pullback(f, cx::AContext, args...)
y, back = _pullback(cx, f, args...)
y, Δ -> tailmemaybe(back(Δ))
wrapped_back(Δ) = tailmemaybe(differential2zygote(back(Δ)))
y, wrapped_back
end
function pullback(cx::Context, f, args...)
ChainRulesCore.ignore_derivatives() do
Expand Down Expand Up @@ -95,7 +96,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
function gradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
isnothing(grad) ? nothing : map(_project, args, grad)
return _project(args, grad)
end

# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Expand Down Expand Up @@ -131,8 +132,7 @@ julia> res.grad[w]
function withgradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
(val=y, grad=results)
(val=y, grad=_project(args, grad))
end

# Param-style wrappers
Expand Down Expand Up @@ -184,7 +184,7 @@ Params(xs::Tuple) = Params(collect(xs))

Base.in(x, ps::Params) = x in ps.params

Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)
_project(::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)

function Base.union!(ps::Params, itrs...)
foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
Expand Down
1 change: 1 addition & 0 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk,
@inline tuple_va(N, xs) = xs
@inline tuple_va(N, x, xs...) = (x, tuple_va(N, xs...)...)
@inline tuple_va(::Val{N}, ::Nothing) where N = ntuple(_ -> nothing, Val(N))
@inline tuple_va(::Val{N}, x::AbstractZero) where N = ntuple(_ -> x, Val(N))

iscall(x, m::Module, n::Symbol) = isexpr(x, :call) && x.args[1] == GlobalRef(m, n)

Expand Down
3 changes: 3 additions & 0 deletions src/compiler/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ function funcname(T)
end

Base.show(io::IO, j::Pullback{S}) where S = print(io, "∂($(funcname(S.parameters[1])))")
function Base.show(io::IO, P::Type{<:Pullback{S}}) where S
@isdefined(S) ? print(io, "Pullback{", S, ", ...}") : print(io, "Pullback{S, T}")
end

56 changes: 28 additions & 28 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,6 @@ end
@adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,)
@adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,)

@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)

∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
if inds isa NTuple{N,Int} && T <: Number
dx = OneElement(dy, inds, axes(x))
elseif inds isa NTuple{<:Any, Integer}
dx = _zero(x, typeof(dy))
dx[inds...] = dy
else
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
return (_project(x, dx), map(_->nothing, inds)...)
end

"""
OneElement(val, ind, axes) <: AbstractArray

Expand Down Expand Up @@ -247,10 +229,10 @@ reconstruct_if_dict(x̄, _keys::Nothing) = x̄

function reconstruct_if_dict(x̄, _keys)
# This reverses `collect_if_dict`, which returns `_keys::Nothing` if x is not a Dict
@assert x̄ isa AbstractVector{<:Union{Nothing, NamedTuple{(:first,:second)}}}
@assert x̄ isa AbstractVector # {<:Union{Nothing, AbstractZero, NamedTuple{(:first,:second)}}}
# we don't compute gradients with respect to keys
# @assert all(x -> x === nothing || x[1] == 0 || x[1] === nothing, x̄)
d̄ = Dict(k => isnothing(x) ? nothing : x[2] for (x, k) in zip(x̄, _keys))
d̄ = Dict(k => x === nothing || x isa AbstractZero ? x : x[2] for (x, k) in zip(x̄, _keys))
return d̄
end

Expand Down Expand Up @@ -296,8 +278,9 @@ _ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)
nd = _ndims(xs[n])
dims = ntuple(i -> i<d ? i : i+nd, ndims(dy)-nd)
d += nd
first(dy)[n] === nothing && return nothing
init = zero.(first(dy)[n]) # allows for tuples, which accum can add:
dy_1n = first(dy)[n]
(dy_1n === nothing || dy_1n isa AbstractZero) && return dy_1n
init = zero.(dy_1n) # allows for tuples, which accum can add:
red = mapreduce(StaticGetter{n}(), accum, dy; dims=dims, init=init)
return _project(xs[n], reshape(red, axes(xs[n])))
end
Expand Down Expand Up @@ -332,8 +315,16 @@ function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
return _pullback(cx, (f, xs) -> prod(f.(xs)), f, xs)
end

@adjoint real(x::AbstractArray) = real(x), r̄ -> (real(r̄),)
@adjoint conj(x::AbstractArray) = conj(x), r̄ -> (conj(r̄),)
@adjoint function real(x::AbstractArray)
real_array_pullback(r̄::AbstractZero) = (r̄,)
real_array_pullback(r̄) = (real(r̄),)
return real(x), real_array_pullback
end
@adjoint function conj(x::AbstractArray)
conj_array_pullback(r̄::AbstractZero) = (r̄,)
conj_array_pullback(r̄) = (conj(r̄),)
return conj(x), conj_array_pullback
end
@adjoint imag(x::AbstractArray) = imag(x), ī -> (complex.(0, real.(ī)),)


Expand Down Expand Up @@ -445,6 +436,7 @@ _symmetric_back(Δ::LowerTriangular, uplo) = collect(uplo == 'U' ? transpose(Δ)

@adjoint function Symmetric(A::AbstractMatrix, uplo=:U)
S = Symmetric(A, uplo)
back(Δ::AbstractZero) = (Δ, nothing)
back(Δ::AbstractMatrix) = (_symmetric_back(Δ, S.uplo), nothing)
back(Δ::NamedTuple) = (_symmetric_back(Δ.data, S.uplo), nothing)
return S, back
Expand All @@ -469,15 +461,23 @@ end

@adjoint function LinearAlgebra.Hermitian(A::AbstractMatrix, uplo=:U)
H = Hermitian(A, uplo)
back(Δ::AbstractZero) = (Δ, nothing)
back(Δ::AbstractMatrix) = (_hermitian_back(Δ, H.uplo), nothing)
back(Δ::NamedTuple) = (_hermitian_back(Δ.data, H.uplo), nothing)
return H, back
end

@adjoint convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array} = convert(R, A),
Δ -> (nothing, convert(S, Δ),)
@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A),
Δ -> (convert(S, Δ),)
@adjoint function convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array}
convert_Array_HermOrSym_callback(Δ::AbstractZero) = (nothing, Δ)
convert_Array_HermOrSym_callback(Δ) = (nothing, convert(S, Δ))
return convert(R, A), convert_Array_HermOrSym_callback
end

@adjoint function Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S}
Matrix_HermOrSym_pullback(Δ::AbstractZero) = (Δ,)
Matrix_HermOrSym_pullback(Δ) = (convert(S, Δ),)
return Matrix(A), Matrix_HermOrSym_pullback
end

@adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix)
X = lyap(A, C)
Expand Down
4 changes: 4 additions & 0 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ end
@inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
bc_fwd_back(ȳ::AbstractZero) = ȳ
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out))
Expand All @@ -297,6 +298,7 @@ end
@inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
bc_fwd_back(ȳ::AbstractZero) = ȳ
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out))
Expand All @@ -311,6 +313,7 @@ end
@inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
bc_fwd_back(ȳ::AbstractZero) = ȳ
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(partials(o1, i), partials(o1, i+N)), ȳ, out))
Expand All @@ -335,6 +338,7 @@ end
@inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
bc_fwd_back(ȳ::AbstractZero) = ȳ
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out))
Expand Down
40 changes: 26 additions & 14 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ end
first(xs), Δ -> ((Δ, drest...),)
end

@adjoint Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),)
@adjoint function Base.tail(xs::Tuple)
Tuple_tail_pullback(x̄s::AbstractZero) = (x̄s,)
Tuple_tail_pullback(x̄s) = ((nothing, x̄s...),)
return tail(xs), Tuple_tail_pullback
end

_empty(x) = length(x)
_empty(x::Union{Tuple,NamedTuple}) = map(_->nothing, x)
Expand Down Expand Up @@ -202,11 +206,12 @@ if VERSION >= v"1.4.0-DEV.304"
@adjoint! function Core._apply_iterate(::typeof(iterate), f, args...)
y, back = Core._apply(_pullback, (__context__, f), args...)
st = map(_empty, args)
y, function (Δ)
function _apply_iterate_pullback(Δ)
Δ = back(Δ)
Δ === nothing ? nothing :
(nothing, first(Δ), unapply(st, Base.tail(Δ))...)
Δ isa Union{Nothing,AbstractZero} && return Δ
return (nothing, first(Δ), unapply(st, Base.tail(Δ))...)
end
return y, _apply_iterate_pullback
end
end

Expand All @@ -229,11 +234,14 @@ end
val = getfield(x, f)
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
# Const properties on modules are considered non-differentiable
x isa Module && isconst(x, f) && return
if isimmutable(x)
dx = (; nt_nothing(x)..., pair(Val(f), Δ, x)...)
(_project(x, dx), nothing)
else
dx = grad_mut(__context__, x)
# @show dx
dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...)
return (dx,nothing)
end
Expand Down Expand Up @@ -305,24 +313,28 @@ end
end

# TODO captured mutables + multiple calls to `back`
@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G}
!ismutabletype(T) && Δ == Nothing && return :nothing
Δ = G == Nothing ? :Δ :
Δ <: RefValue ? :(back.g[]) :
:(accum(back.g[], Δ))
@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,RefValue,AbstractZero}) where {T,G}
!ismutabletype(T) && _iszerotype(Δ) && return :Δ
Δ = if _iszerotype(G)
elseif Δ <: RefValue
:(back.g[])
else
:(accum(back.g[], Δ))
end
quote
x̄ = $Δ
$(G == Nothing || :(back.g[] = nt_nothing($Δ)))
$(_iszerotype(G) || :(back.g[] = nt_nothing($Δ)))
(nothing, $(map(f -> :(x̄.$f), fieldnames(T))...))
end
end

@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G}
!ismutabletype(T) && Δ == Nothing && return :nothing
Δ = G == Nothing ? :Δ : :(back.g)
@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,RefValue,AbstractZero}) where {T,G}
!ismutabletype(T) && _iszerotype(Δ) && return :Δ
Δ = _iszerotype(G) ? :Δ : :(back.g)
quote
x̄ = $Δ
$(G == Nothing || :($Δ = nt_nothing($Δ)))
$(_iszerotype(G) || :($Δ = nt_nothing($Δ)))
(nothing, ($(map(f -> :(x̄.$f), fieldnames(T))...),))
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ using Zygote: ZygoteRuleConfig
@test Zygote.gradient(f_notimplemented, 0.1) === (nothing,)
@test Zygote.gradient(x -> f_notimplemented(x[1]), 0.1) === (nothing,)
if isdefined(Base, :only)
@test Zygote.gradient(x -> f_notimplemented(only(x)), (0.1,)) === (nothing,)
@test Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,)
@test Zygote.gradient(x -> f_notimplemented(only(x)), (0.1,)) === ((nothing,),)
@test_broken Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,)
end
end

Expand Down
Loading