Skip to content

Commit

Permalink
use NNlib.within_gradient (#2152)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Jan 7, 2023
1 parent aba285c commit 7997174
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ChainRulesCore = "1.12"
Functors = "0.3, 0.4"
MLUtils = "0.2, 0.3.1, 0.4"
MacroTools = "0.5"
NNlib = "0.8.9"
NNlib = "0.8.14"
NNlibCUDA = "0.2.4"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2.12"
Expand Down
2 changes: 1 addition & 1 deletion src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wrong number of channels"
return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum;
cache=cache, alpha=1, beta=0, eps=BN.ϵ,
training=Flux._isactive(BN)))
training=Flux._isactive(BN, x)))
end

function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)
Expand Down
11 changes: 11 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed.

@deprecate rng_from_array() default_rng_value()

function istraining()
Base.depwarn("Flux.istraining() is deprecated, use NNlib.within_gradient(x) instead", :istraining)
false
end
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)

function _isactive(m)
Base.depwarn("_isactive(m) is deprecated, use _isactive(m,x)", :_isactive, force=true)
_isactive(m, 1:0)
end

#=
# Valid method in Optimise, old implicit style, is:
train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
Expand Down
11 changes: 4 additions & 7 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
istraining() = false

ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)

_isactive(m) = isnothing(m.active) ? istraining() : m.active
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active

_dropout_shape(s, ::Colon) = size(s)
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
Expand Down Expand Up @@ -107,7 +104,7 @@ end
trainable(a::Dropout) = (;)

function (a::Dropout)(x)
_isactive(a) || return x
_isactive(a, x) || return x
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
end

Expand Down Expand Up @@ -162,7 +159,7 @@ AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
trainable(a::AlphaDropout) = (;)

function (a::AlphaDropout)(x::AbstractArray{T}) where T
_isactive(a) || return x
_isactive(a, x) || return x
p = a.p
iszero(p) && return x
isone(p) && return sign.(x) .* T(0)
Expand Down Expand Up @@ -242,7 +239,7 @@ end
function _norm_layer_forward(
l, x::AbstractArray{T, N}; reduce_dims, affine_shape,
) where {T, N}
if !_isactive(l) && l.track_stats # testmode with tracked stats
if !_isactive(l, x) && l.track_stats # testmode with tracked stats
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
μ = reshape(l.μ, stats_shape)
σ² = reshape(l.σ², stats_shape)
Expand Down
15 changes: 15 additions & 0 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,5 +475,20 @@ end
# This was an error, https://github.com/FluxML/Flux.jl/issues/2122
@test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32}
@test !iszero(bn.μ)

# Easy case of 2122, gradient with x
x5 = rand(Float32, 5, 3)
bn1 = BatchNorm(5, relu)
bn2 = BatchNorm(5, relu)
g1 = Zygote.gradient(x -> sum(abs2, bn1(x)), x5)[1]
g2 = ForwardDiff.gradient(x -> sum(abs2, bn2(x)), x5)
@test g1 g2

# Harder case?
v1, re1 = Flux.destructure(BatchNorm(5, relu));
g1 = Zygote.gradient(v -> sum(abs2, re1(v)(x5)), v1)[1]

v2, re2 = Flux.destructure(BatchNorm(5, relu));
g2 = ForwardDiff.gradient(v -> sum(abs2, re2(v)(x5)), v2)
end

0 comments on commit 7997174

Please sign in to comment.