Skip to content

Commit

Permalink
Allow ForwardDiff in BatchNorm's track_stats (#2127)
Browse files Browse the repository at this point in the history
* allow ForwardDiff in BatchNorm's track_stats

* second test

* add comments

* Update test/layers/normalisation.jl
  • Loading branch information
mcabbott authored Dec 8, 2022
1 parent 815deaa commit c850df5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owne

using Zygote, ChainRulesCore
using Zygote: Params, @adjoint, gradient, pullback, @nograd
using Zygote.ForwardDiff: value
export gradient

# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)
Expand Down
5 changes: 3 additions & 2 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,9 @@ function _track_stats!(
μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))

bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
# ForwardDiff.value removes Dual, was an error, issue #2122
bn.μ .= value.(res_mtm .* bn.μ .+ mtm .* μnew)
bn.σ² .= value.(res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new)
return nothing
end

Expand Down
16 changes: 15 additions & 1 deletion test/layers/normalisation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Flux, Test, Statistics
using Zygote: pullback
using Zygote: pullback, ForwardDiff

evalwgrad(f, x...) = pullback(f, x...)[1]

Expand Down Expand Up @@ -462,4 +462,18 @@ end
@testset "second derivatives" begin
m1 = Dropout(0.5)
@test Zygote.hessian_reverse(summ1, [1.0,2.0,3.0]) == zeros(3, 3)

m2 = Chain(BatchNorm(3), sum)
@test Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6)
end

@testset "ForwardDiff" begin
bn = BatchNorm(3)
@test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32}
# iszero(bn.μ) # is true. But ideally would not be, if Flux would automatically choose trainmode
Flux.trainmode!(bn)
# This was an error, https://github.com/FluxML/Flux.jl/issues/2122
@test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32}
@test !iszero(bn.μ)
end

0 comments on commit c850df5

Please sign in to comment.