Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
test: more comprehensive norm testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 20, 2024
1 parent bba8d8a commit 80363da
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 30 deletions.
3 changes: 1 addition & 2 deletions src/api/groupnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector

sz = size(x)
x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N])
x_ = _groupnorm_impl(
x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), Val(false), epsilon, σ)
x_ = _groupnorm_impl(x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), epsilon, σ)

return reshape(x_, sz)
end
Expand Down
5 changes: 3 additions & 2 deletions src/impl/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ EnzymeRules.inactive_noinl(::typeof(_get_norm_reshape_dims), ::Any...) = nothing
# code.
function _groupnorm_impl(x::AbstractArray, scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, reduce_dims::Val,
training::Val, epsilon, act::F=identity) where {F}
(μ, σ²), _ = _get_batch_statistics(x, nothing, nothing, reduce_dims, training, nothing)
epsilon, act::F=identity) where {F}
(μ, σ²), _ = _get_batch_statistics(
x, nothing, nothing, reduce_dims, Val(false), nothing)
return _affine_normalize_gn(act, x, μ, σ², scale, bias, epsilon)
end
9 changes: 6 additions & 3 deletions test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,20 @@
@test yy_generic atol=atol rtol=rtol
@test eltype(y) == promote_type(Tw, Tx)

@inferred fused_conv_bias_activation(activation, weight, x, bias, cdims)
@test @inferred(fused_conv_bias_activation(
activation, weight, x, bias, cdims)) isa Any
@jet fused_conv_bias_activation(activation, weight, x, bias, cdims)

__f = (σ, w, x, b, cdims) -> sum(
abs2, fused_conv_bias_activation(σ, w, x, b, cdims))

if mode != "amdgpu" && activation !== anonact
@inferred Zygote.gradient(__f, activation, weight, x, bias, cdims)
@test @inferred(Zygote.gradient(
__f, activation, weight, x, bias, cdims)) isa Any
else
try
@inferred Zygote.gradient(__f, activation, weight, x, bias, cdims)
@test @inferred(Zygote.gradient(
__f, activation, weight, x, bias, cdims)) isa Any
@test true
catch
@test_broken false
Expand Down
4 changes: 2 additions & 2 deletions test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
@test y y_generic
@test eltype(y) == promote_type(Tw, Tx)

@inferred fused_dense_bias_activation(activation, w, x, bias)
@test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any
@jet fused_dense_bias_activation(activation, w, x, bias)

__f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b))

if activation !== anonact
@inferred Zygote.gradient(__f, activation, w, x, bias)
@test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any
else
@test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true
end
Expand Down
31 changes: 17 additions & 14 deletions test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

x = randn(rng, T, x_shape) |> aType

@inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon())
@test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), Colon())) isa Any

y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), Colon())

Expand All @@ -20,7 +20,7 @@
@test rng != rng_

__f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, Colon())))
@test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x)
@test @inferred(Zygote.gradient(__f, x)) isa Any

__f = let rng = rng, T = T
x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon())))
Expand All @@ -41,7 +41,7 @@
end

@jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon())))
@inferred dropout(rng, x, T(0.5), Val(true), T(2), Colon())
@test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), Colon())) isa Any

y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon())

Expand All @@ -68,7 +68,8 @@ end
mask = rand(T, x_shape) |> aType

# Update mask
@inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())
Expand All @@ -82,7 +83,7 @@ end

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, Colon())))
@test size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x)
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
Expand All @@ -109,7 +110,8 @@ end
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))

# Try using mask if possible (possible!!)
@inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())
Expand All @@ -124,7 +126,7 @@ end
__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon())))
# Branching based on runtime values
@test_broken size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x)
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
Expand All @@ -147,7 +149,8 @@ end
mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

# Try using mask if possible (not possible!!)
@inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())
Expand All @@ -162,7 +165,7 @@ end
__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon())))
# Branching based on runtime activity
@test_broken size(first(@inferred(Zygote.gradient(__f, x, mask)))) == size(x)
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
Expand All @@ -187,7 +190,8 @@ end
@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
# Testing Mode
@inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())
Expand All @@ -212,7 +216,7 @@ end

x = randn(rng, T, x_shape) |> aType

@inferred alpha_dropout(rng, x, T(0.5), Val(true))
@test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true))

Expand All @@ -223,7 +227,7 @@ end
@test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2)

__f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true))))
@test size(only(@inferred(Zygote.gradient(__f, x)))) == size(x)
@test @inferred(Zygote.gradient(__f, x)) isa Any

__f = let rng = rng
x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
Expand All @@ -243,8 +247,7 @@ end
end

@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))

@inferred alpha_dropout(rng, x, T(0.5), Val(false))
@test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false))

Expand Down
3 changes: 2 additions & 1 deletion test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)

@inferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)
@test @inferred(batchnorm(
x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa Any
@jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)

@test y isa aType{T, length(sz)}
Expand Down
54 changes: 50 additions & 4 deletions test/normalization/groupnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,77 @@
return x, scale, bias
end

# Bypassing all optimizations
function __groupnorm_basic(
x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector},
bias::LuxLib.Optional{<:AbstractVector}, groups::Int,
σ::F=identity, epsilon::Real=1.0f-5) where {F, N}
sz = size(x)
x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N])
x_ = LuxLib._normalization(x_reshaped, nothing, nothing, scale, bias,
LuxLib._get_groupnorm_reduce_dims(x), Val(false), nothing, epsilon, σ)[1]
return reshape(x_, sz)
end

@testset "$mode" for (mode, aType, on_gpu) in MODES
@testset "eltype $T, size $sz, ngroups $groups, $act" for T in (
Float16, Float32, Float64),
sz in ((4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2),
sz in ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2),
(4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)),
groups in (2, 3),
act in (identity, relu, tanh_fast, sigmoid_fast, x -> x^3)

_f = (args...) -> groupnorm(args..., groups, act, epsilon)
_f2 = (args...) -> groupnorm(args..., groups, act, epsilon)

epsilon = T(1e-5)
x, scale, bias = _setup_groupnorm(aType, T, sz)
y = _f(x, scale, bias)

@inferred groupnorm(x, scale, bias, groups, act, epsilon)
y_simple = _f2(x, scale, bias)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3

@test yy_simple atol=atol rtol=rtol

# Check the rrules
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(
sum _f2, x, scale, bias)
@test ∂x∂x_simple atol=atol rtol=rtol
@test ∂scale∂scale_simple atol=atol rtol=rtol
@test ∂bias∂bias_simple atol=atol rtol=rtol

@test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any
@jet groupnorm(x, scale, bias, groups, act, epsilon)

lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ))
@test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa
Any

@test y isa aType{T, length(sz)}
@test size(y) == sz

fp16 = T == Float16
__f = (args...) -> sum(groupnorm(x, args..., groups, act, epsilon))
skip_fd = act === relu
allow_unstable() do
@eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 skip_finite_differences=$(skip_fd)
@eval @test_gradients $__f $scale $bias gpu_testing=$on_gpu atol=$atol rtol=$rtol soft_fail=$fp16 skip_finite_differences=$(skip_fd)
end

if !on_gpu
∂x, ∂scale, ∂bias = Zygote.gradient(__f, x, scale, bias)

∂x_enz = Enzyme.make_zero(x)
∂scale_enz = Enzyme.make_zero(scale)
∂bias_enz = Enzyme.make_zero(bias)
Enzyme.autodiff(Reverse, __f, Duplicated(x, ∂x_enz),
Duplicated(scale, ∂scale_enz), Duplicated(bias, ∂bias_enz))

@test ∂x∂x_enz rtol=rtol atol=atol
@test ∂scale∂scale_enz rtol=rtol atol=atol
@test ∂bias∂bias_enz rtol=rtol atol=atol
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

y, nt = instancenorm(x, scale, bias, training, act, epsilon)

@inferred instancenorm(x, scale, bias, training, act, epsilon)
@test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any
@jet instancenorm(x, scale, bias, training, act, epsilon)

@test y isa aType{T, length(sz)}
Expand Down
2 changes: 1 addition & 1 deletion test/normalization/layernorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape)

@inferred layernorm(x, scale, bias, act, dims, epsilon)
@test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any
@jet layernorm(x, scale, bias, act, dims, epsilon)

y = _f(x, scale, bias)
Expand Down

0 comments on commit 80363da

Please sign in to comment.