Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move a bunch of no_grad to ChainRules #780

Merged
merged 3 commits into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5"
ArrayLayouts = "0.1, 0.2, 0.3, 0.4"
ChainRules = "0.7.0"
ChainRules = "0.7.15"
DiffRules = "1.0"
FillArrays = "0.8, 0.9"
ForwardDiff = "0"
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us
"""
@inline wrap_chainrules_output(x) = unthunk(x) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
Expand Down
9 changes: 1 addition & 8 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@ using Distributed: pmap
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)

@nograd size, length, eachindex, Base.OneTo, axes, Colon(), findfirst, findlast, findall, ones, zeros, one, zero, any, all
@nograd randn, randexp, randn!, randexp!
@static if VERSION > v"1.3"
@nograd Random.default_rng
end

@adjoint Base.rand(rng::AbstractRNG, ::Type{T}, dims...) where {T<:Number} =
rand(rng, T, dims...), _ -> nothing
@nograd ones, zeros, Base.OneTo, Colon(), one, zero

@adjoint Base.vect(xs...) = Base.vect(xs...), Δ -> (Δ...,)

Expand Down
7 changes: 1 addition & 6 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
@nograd readline, Base.gc_num, Base.time_ns, Base.print, Base.println, Base.show,
Core.show, Core.print, Core.println, string, repr, Threads.nthreads, Threads.threadid

# Gradient of AD stacks

grad_mut(::AbstractVector) = []
Expand Down Expand Up @@ -47,11 +44,9 @@ end
end
end

@nograd haskey

# Channels

@nograd Channel, schedule
@nograd Channel

grad_mut(ch::Channel) = Channel(ch.sz_max)

Expand Down
2 changes: 0 additions & 2 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ using Base.Broadcast
using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
using NNlib

@nograd Broadcast.combine_styles, Broadcast.result_style

# There's a saying that debugging code is about twice as hard as writing it in
# the first place. So if you're as clever as you can be when writing code, how
# will you ever debug it?
Expand Down
5 changes: 1 addition & 4 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ function accum(x::RefValue, y::RefValue)
end

# Core functions

@nograd Core.apply_type, Core.typeof, nfields, fieldtype, Core.TypeVar, Core.UnionAll,
(==), (===), (<=), (>=), (<), (>), isempty, supertype, Base.typename,
eps, Meta.parse, Base.eval, sleep, isassigned
@nograd eps, Base.eval, Core.TypeVar, Core.UnionAll

@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)

Expand Down
2 changes: 1 addition & 1 deletion src/lib/number.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@nograd floor, ceil, trunc, round, hash, div
@nograd floor, ceil, trunc, round, div

@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
Base.literal_pow(^,x,Val(p)),
Expand Down
16 changes: 16 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ using Zygote, Test, ChainRules
@test mimo_pullback_hitcount[] == 1
end

@testset "all AbstractZero partials" begin
# while ChainRules always has a partial for every input, Zygote combined them all
# to a single `nothing` if they are all zero-like.

not_diff_eg(x, i) = [10, 20][i]
function ChainRules.rrule(::typeof(not_diff_eg), x, i)
function not_diff_eg_pullback(Δ)
return ChainRules.NO_FIELDS, ChainRules.Zero(), ChainRules.DoesNotExist()
end
return not_diff_eg(x, i), not_diff_eg_pullback
end

_, pb = Zygote.pullback(not_diff_eg, 10.4, 2)
@test pb(1.2) === nothing
end

@testset "nested AD hitting identity(::Tuple) pullback" begin
# This is is a particularly fiddly case.
# Its kind of a simplified version of `sin'''(0.5)` but different in some places.
Expand Down
5 changes: 5 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1538,9 +1538,14 @@ end
end

@testset "@nograd" begin
@test gradient(x->eachindex([10,20,30])[1], 11) == (nothing,)

#These are defined in ChainRules, we test them here to check we are handling them right
@test gradient(x -> findfirst(ismissing, x), [1, missing]) == (nothing,)
@test gradient(x -> findlast(ismissing, x), [1, missing]) == (nothing,)
@test gradient(x -> findall(ismissing, x)[1], [1, missing]) == (nothing,)


@test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,)
@test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,)
@test gradient(1) do x
Expand Down