diff --git a/src/api/groupnorm.jl b/src/api/groupnorm.jl index 72f5f8e6..5f713cf3 100644 --- a/src/api/groupnorm.jl +++ b/src/api/groupnorm.jl @@ -33,8 +33,7 @@ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector sz = size(x) x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) - x_ = _groupnorm_impl( - x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), Val(false), epsilon, σ) + x_ = _groupnorm_impl(x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), epsilon, σ) return reshape(x_, sz) end diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index 87cbecf7..dcfc0cdd 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -111,7 +111,8 @@ EnzymeRules.inactive_noinl(::typeof(_get_norm_reshape_dims), ::Any...) = nothing # code. function _groupnorm_impl(x::AbstractArray, scale::Optional{<:AbstractVector}, bias::Optional{<:AbstractVector}, reduce_dims::Val, - training::Val, epsilon, act::F=identity) where {F} - (μ, σ²), _ = _get_batch_statistics(x, nothing, nothing, reduce_dims, training, nothing) + epsilon, act::F=identity) where {F} + (μ, σ²), _ = _get_batch_statistics( + x, nothing, nothing, reduce_dims, Val(false), nothing) return _affine_normalize_gn(act, x, μ, σ², scale, bias, epsilon) end diff --git a/test/common_ops/conv_tests.jl b/test/common_ops/conv_tests.jl index f3674d0a..3e2b7616 100644 --- a/test/common_ops/conv_tests.jl +++ b/test/common_ops/conv_tests.jl @@ -53,17 +53,20 @@ @test y≈y_generic atol=atol rtol=rtol @test eltype(y) == promote_type(Tw, Tx) - @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) + @test @inferred(fused_conv_bias_activation( + activation, weight, x, bias, cdims)) isa Any @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) __f = (σ, w, x, b, cdims) -> sum( abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) if mode != "amdgpu" && activation !== anonact - @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) + @test @inferred(Zygote.gradient( + __f, activation, weight, x, bias, cdims)) isa Any else try - @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) + @test @inferred(Zygote.gradient( + __f, activation, weight, x, bias, cdims)) isa Any @test true catch @test_broken false diff --git a/test/common_ops/dense_tests.jl b/test/common_ops/dense_tests.jl index 8b7fcf4d..0ec78459 100644 --- a/test/common_ops/dense_tests.jl +++ b/test/common_ops/dense_tests.jl @@ -25,13 +25,13 @@ @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) - @inferred fused_dense_bias_activation(activation, w, x, bias) + @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any @jet fused_dense_bias_activation(activation, w, x, bias) __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) if activation !== anonact - @inferred Zygote.gradient(__f, activation, w, x, bias) + @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any else @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true end diff --git a/test/common_ops/dropout_tests.jl b/test/common_ops/dropout_tests.jl index 55aeaa91..ca5e9b9c 100644 --- a/test/common_ops/dropout_tests.jl +++ b/test/common_ops/dropout_tests.jl @@ -9,7 +9,7 @@ x = randn(rng, T, x_shape) |> aType - @inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon()) + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), Colon())) isa Any y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), Colon()) @@ -20,7 +20,7 @@ @test rng != rng_ __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, Colon()))) - @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) + @test @inferred(Zygote.gradient(__f, x)) isa Any __f = let rng = rng, T = T x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) @@ -41,7 +41,7 @@ end @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) - @inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon()) + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), Colon())) isa Any y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon()) @@ -68,7 +68,8 @@ end mask = rand(T, x_shape) |> aType # Update mask - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()) @@ -82,7 +83,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon()))) - @test size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x) + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any __f = let rng = rng, mask = mask x -> sum(first(dropout( @@ -109,7 +110,8 @@ end rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) # Try using mask if possible (possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) @@ -124,7 +126,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) # Branching based on runtime values - @test_broken size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x) + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true __f = let rng = rng, mask = mask x -> sum(first(dropout( @@ -147,7 +149,8 @@ end mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType # Try using mask if possible (not possible!!) - @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) @@ -162,7 +165,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) # Branching based on runtime activity - @test_broken size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x) + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true __f = let rng = rng, mask = mask x -> sum(first(dropout( @@ -187,7 +190,8 @@ end @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) # Testing Mode - @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon()) @@ -212,7 +216,7 @@ end x = randn(rng, T, x_shape) |> aType - @inferred alpha_dropout(rng, x, T(0.5), Val(true)) + @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) @@ -223,7 +227,7 @@ end @test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2) __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) - @test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x) + @test @inferred(Zygote.gradient(__f, x)) isa Any __f = let rng = rng x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) @@ -243,8 +247,7 @@ end end @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - - @inferred alpha_dropout(rng, x, T(0.5), Val(false)) + @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) diff --git a/test/normalization/batchnorm_tests.jl b/test/normalization/batchnorm_tests.jl index 1c5f82f8..fb3a5d3c 100644 --- a/test/normalization/batchnorm_tests.jl +++ b/test/normalization/batchnorm_tests.jl @@ -30,7 +30,8 @@ y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - @inferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + @test @inferred(batchnorm( + x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa Any @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @test y isa aType{T, length(sz)} diff --git a/test/normalization/groupnorm_tests.jl b/test/normalization/groupnorm_tests.jl index 2fc3393e..c499c493 100644 --- a/test/normalization/groupnorm_tests.jl +++ b/test/normalization/groupnorm_tests.jl @@ -8,31 +8,77 @@ return x, scale, bias end + # Bypassing all optimizations + function __groupnorm_basic( + x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, + bias::LuxLib.Optional{<:AbstractVector}, groups::Int, + σ::F=identity, epsilon::Real=1.0f-5) where {F, N} + sz = size(x) + x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) + x_ = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias, + LuxLib._get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)[1] + return reshape(x_, sz) + end + @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( Float16, Float32, Float64), - sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), + sz in ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), groups in (2, 3), act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3) _f = (args...) -> groupnorm(args..., groups, act, epsilon) + _f2 = (args...) -> groupnorm(args..., groups, act, epsilon) epsilon = T(1e-5) x, scale, bias = _setup_groupnorm(aType, T, sz) y = _f(x, scale, bias) - @inferred groupnorm(x, scale, bias, groups, act, epsilon) + y_simple = _f2(x, scale, bias) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + + # Check the rrules + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient( + sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + + @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any @jet groupnorm(x, scale, bias, groups, act, epsilon) + lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa + Any + @test y isa aType{T, length(sz)} @test size(y) == sz - fp16 = T == Float16 __f = (args...) -> sum(groupnorm(x, args..., groups, act, epsilon)) skip_fd = act === relu allow_unstable() do - @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 skip_finite_differences=$(skip_fd) + @eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd) + end + + if !on_gpu + ∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias) + + ∂x_enz = Enzyme.make_zero(x) + ∂scale_enz = Enzyme.make_zero(scale) + ∂bias_enz = Enzyme.make_zero(bias) + Enzyme.autodiff(Reverse, __f, Duplicated(x, ∂x_enz), + Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz)) + + @test ∂x≈∂x_enz rtol=rtol atol=atol + @test ∂scale≈∂scale_enz rtol=rtol atol=atol + @test ∂bias≈∂bias_enz rtol=rtol atol=atol end end end diff --git a/test/normalization/instancenorm_tests.jl b/test/normalization/instancenorm_tests.jl index b135c4ed..e989343e 100644 --- a/test/normalization/instancenorm_tests.jl +++ b/test/normalization/instancenorm_tests.jl @@ -24,7 +24,7 @@ y, nt = instancenorm(x, scale, bias, training, act, epsilon) - @inferred instancenorm(x, scale, bias, training, act, epsilon) + @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any @jet instancenorm(x, scale, bias, training, act, epsilon) @test y isa aType{T, length(sz)} diff --git a/test/normalization/layernorm_tests.jl b/test/normalization/layernorm_tests.jl index 7be16eaf..384470ff 100644 --- a/test/normalization/layernorm_tests.jl +++ b/test/normalization/layernorm_tests.jl @@ -24,7 +24,7 @@ x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape) - @inferred layernorm(x, scale, bias, act, dims, epsilon) + @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any @jet layernorm(x, scale, bias, act, dims, epsilon) y = _f(x, scale, bias)