From c850df5409ca545be433dec835034cffa8486aa4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 7 Dec 2022 23:38:51 -0500 Subject: [PATCH] Allow ForwardDiff in BatchNorm's track_stats (#2127) * allow ForwardDiff in BatchNorm's track_stats * second test * add comments * Update test/layers/normalisation.jl --- src/Flux.jl | 1 + src/layers/normalise.jl | 5 +++-- test/layers/normalisation.jl | 16 +++++++++++++++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 3853712f7b..66796491dd 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,6 +11,7 @@ import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owne using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback, @nograd +using Zygote.ForwardDiff: value export gradient # Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 0f2696a50a..89eee976ee 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -275,8 +275,9 @@ function _track_stats!( μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N)) σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N)) - bn.μ = res_mtm .* bn.μ .+ mtm .* μnew - bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new + # ForwardDiff.value removes Dual, was an error, issue #2122 + bn.μ .= value.(res_mtm .* bn.μ .+ mtm .* μnew) + bn.σ² .= value.(res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new) return nothing end diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 7ae15aeff9..859d703368 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -1,5 +1,5 @@ using Flux, Test, Statistics -using Zygote: pullback +using Zygote: pullback, ForwardDiff evalwgrad(f, x...) = pullback(f, x...)[1] @@ -462,4 +462,18 @@ end @testset "second derivatives" begin m1 = Dropout(0.5) @test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3) + + m2 = Chain(BatchNorm(3), sum) + @test Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6) +end + +@testset "ForwardDiff" begin + bn = BatchNorm(3) + @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} + # iszero(bn.μ) # is true. But ideally would not be, if Flux would automatically choose trainmode + Flux.trainmode!(bn) + # This was an error, https://github.com/FluxML/Flux.jl/issues/2122 + @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} + @test !iszero(bn.μ) end +