Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Integrate LuxTestUtils #5

Merged
merged 1 commit into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/LocalPreferences.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[LuxTestUtils]
target_modules = ["LuxLib"]
4 changes: 1 addition & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
[deps]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down
12 changes: 6 additions & 6 deletions test/api/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ end
y, nt = batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9))

@inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9))
run_JET_tests(_f, x, scale, bias, rm, rv)

@jet _f(x, scale, bias, rm, rv)

@test y isa aType{T, length(sz)}
@test size(y) == sz
Expand All @@ -45,17 +46,16 @@ end
end

if __istraining(training)
fp16 = T == Float16
if affine
__f = (args...) -> sum(first(batchnorm(args..., rm, rv; epsilon, training,
momentum=T(0.9))))
test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu,
skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
@eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2
else
__f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv;
epsilon, training, momentum=T(0.9))))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16,
atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16)

@eval @test_gradients $__f $x gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2
end
end
end
Expand Down
35 changes: 20 additions & 15 deletions test/api/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ rng = MersenneTwister(0)
@test rng != rng_

__f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon())))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
run_JET_tests(__f, x)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)

@inferred dropout(rng, x, T(0.5), Val(true); dims=Colon())

Expand Down Expand Up @@ -58,9 +59,10 @@ end end

__f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true);
dims=Colon())))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
run_JET_tests(__f, x)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)

# Try using mask if possible (possible!!)
@inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon())
Expand All @@ -76,9 +78,10 @@ end end

__f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false);
dims=Colon())))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
run_JET_tests(__f, x)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)

mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

Expand All @@ -96,9 +99,10 @@ end end

__f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false);
dims=Colon())))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
run_JET_tests(__f, x)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)

# Testing Mode
@inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon())
Expand Down Expand Up @@ -129,9 +133,10 @@ end end
@test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2)

__f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
run_JET_tests(__f, x)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)

@inferred alpha_dropout(rng, x, T(0.5), Val(false))

Expand Down
25 changes: 12 additions & 13 deletions test/api/groupnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
bias)

@inferred groupnorm(x, scale, bias; groups, epsilon)
run_JET_tests(_f, x, scale, bias; opt_broken=true)
@jet _f(x, scale, bias) opt_broken=true
@test y isa aType{T, 4}
@test size(y) == sz

Expand All @@ -60,14 +60,14 @@ end

# The KA implementation reorders operations manually for maximal
# performance. Hence equality cannot be guaranteed.
@test isapprox(y, y_; atol=1.0f-3, rtol=1.0f-3)
@test isapprox(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3)
@test isapprox(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3)
@test isapprox(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3)

test_gradient_correctness((args...) -> sum(_f(args...)), x, scale, bias;
gpu_testing=on_gpu, atol=1.0f-3, rtol=1.0f-3,
soft_fail=T == Float16)
@test check_approx(y, y_; atol=1.0f-3, rtol=1.0f-3)
@test check_approx(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3)
@test check_approx(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3)
@test check_approx(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3)

fp16 = T == Float16
__f = sum ∘ _f
@eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=1.0f-3 rtol=1.0f-3 soft_fail=$fp16
end
end end

Expand All @@ -85,17 +85,16 @@ end end

@inferred groupnorm(x, scale, bias, rm, rv; groups, epsilon, training,
momentum=T(0.9))
run_JET_tests(_f, x, scale, bias, rm, rv; opt_broken=true)
@jet _f(x, scale, bias, rm, rv) opt_broken=true

@test y isa aType{T, 4}
@test size(y) == sz
@test size(nt.running_mean) == (groups,)
@test size(nt.running_var) == (groups,)

fp16 = T == Float16
__f = (args...) -> sum(first(groupnorm(args..., rm, rv; groups, epsilon, training,
momentum=T(0.9))))
test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu,
skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
@eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16
end
end end
18 changes: 6 additions & 12 deletions test/api/instancenorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,24 @@ end
y, nt = instancenorm(x, scale, bias; epsilon, training)

@inferred instancenorm(x, scale, bias; epsilon, training)
run_JET_tests(_f, x, scale, bias)
@jet _f(x, scale, bias)
@test y isa aType{T, length(sz)}
@test size(y) == sz

_target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...)
if length(sz) != 3
@test isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std; atol=0.2)
else
@test_broken isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std;
atol=0.2)
end
@eval @test check_approx(std(Array($y); dims=1:($(length(sz) - 2))), $_target_std;
atol=0.2, rtol=0.2)
@test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2))

if __istraining(training)
fp16 = T == Float16
if affine
__f = (args...) -> sum(first(instancenorm(args...; epsilon, training)))
test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu,
skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
@eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu
else
__f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon,
training)))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16,
atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16)
@eval @test_gradients $__f $x soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu
end
end
end
Expand Down
17 changes: 8 additions & 9 deletions test/api/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,25 @@ end
x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape)

@inferred _f(x, scale, bias)
run_JET_tests(_f, x, scale, bias)
@jet _f(x, scale, bias)

y = _f(x, scale, bias)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape

if affine_shape === nothing
@test isapprox(mean(y; dims), 0; atol=1e-3, rtol=1e-3)
@test isapprox(std(y; dims), 1; atol=1e-1, rtol=1e-1)
@test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3)
@test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1)
end

fp16 = T == Float16
if affine_shape === nothing
test_gradient_correctness(x -> sum(_f(x, nothing, nothing)), x;
skip_fdm=T == Float16, gpu_testing=on_gpu,
atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16)
__f = x -> sum(_f(x, nothing, nothing))
@eval @test_gradients $__f $x soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu
else
test_gradient_correctness(sum ∘ _f, x, scale, bias; skip_fdm=T == Float16,
gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
__f = sum ∘ _f
@eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu
end
end
end end
2 changes: 1 addition & 1 deletion test/ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ rng = MersenneTwister(0)
x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1]
x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1])

@test isapprox(x_dropout, x_dual_dropout)
@test check_approx(x_dropout, x_dual_dropout)
end end
105 changes: 2 additions & 103 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using FiniteDifferences, LuxLib, Test
using LuxLib, LuxTestUtils, Test, Zygote
using LuxCUDA # CUDA Support
using ReverseDiff, Tracker, Zygote # AD Packages
using LuxTestUtils: @jet, @test_gradients, check_approx

const GROUP = get(ENV, "GROUP", "All")

Expand All @@ -23,105 +23,4 @@ const MODES = begin
end
end

try
using JET
catch
@warn "JET not not precompiling. All JET tests will be skipped." maxlog=1
global test_call(args...; kwargs...) = nothing
global test_opt(args...; kwargs...) = nothing
end

function Base.isapprox(x, y; kwargs...)
@warn "`isapprox` is not defined for ($(typeof(x)), $(typeof(y))). Using `==` instead."
return x == y
end

function Base.isapprox(x::Tuple, y::Tuple; kwargs...)
return all(isapprox.(x, y; kwargs...))
end

function Base.isapprox(nt1::NamedTuple{fields}, nt2::NamedTuple{fields};
kwargs...) where {fields}
checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...)
checkapprox(t::Tuple{Nothing, Nothing}) = true
return all(checkapprox, zip(values(nt1), values(nt2)))
end

function Base.isapprox(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T}
checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...)
checkapprox(t::Tuple{Nothing, Nothing}) = true
return all(checkapprox, zip(t1, t2))
end

Base.isapprox(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0
Base.isapprox(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0
Base.isapprox(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0
Base.isapprox(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0
Base.isapprox(v::Tuple, ::Nothing; kwargs...) = length(v) == 0
Base.isapprox(::Nothing, v::Tuple; kwargs...) = length(v) == 0
Base.isapprox(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0
Base.isapprox(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0
Base.isapprox(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0
Base.isapprox(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0

# JET Tests
function run_JET_tests(f, args...; call_broken=false, opt_broken=false, kwargs...)
@static if VERSION >= v"1.7"
test_call(f, typeof.(args); broken=call_broken, target_modules=(LuxLib,))
test_opt(f, typeof.(args); broken=opt_broken, target_modules=(LuxLib,))
end
end

__istraining(::Val{training}) where {training} = training

# Test the gradients across AD Frameworks and FiniteDifferences
# TODO: Implement it as a macro so that we get correct line numbers for `@test` failures.
function test_gradient_correctness(f::Function, args...; gpu_testing::Bool=false,
skip_fdm::Bool=false, skip_fdm_override::Bool=false,
soft_fail::Bool=false, kwargs...)
gs_ad_zygote = Zygote.gradient(f, args...)
gs_ad_tracker = Tracker.gradient(f, args...)
gs_ad_reversediff = gpu_testing ? nothing : ReverseDiff.gradient(f, args)

if !skip_fdm_override
arr_len = length.(args)
if any(x -> x >= 25, arr_len) || sum(arr_len) >= 100
@warn "Skipping FiniteDifferences test for large arrays: $(arr_len)."
skip_fdm = true
end
end

gs_fdm = gpu_testing || skip_fdm ? nothing :
FiniteDifferences.grad(FiniteDifferences.central_fdm(8, 1), f, args...)
for idx in 1:length(gs_ad_zygote)
_c1 = isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; kwargs...)
if soft_fail && !_c1
@test_broken isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx];
kwargs...)
else
@test isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; kwargs...)
end

if !gpu_testing
if !skip_fdm
_c2 = isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...)
if soft_fail && !_c2
@test_broken isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...)
else
@test isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...)
end
end

_c3 = isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), gs_ad_zygote[idx];
kwargs...)
if soft_fail && !_c3
@test_broken isapprox(ReverseDiff.value(gs_ad_reversediff[idx]),
gs_ad_zygote[idx]; kwargs...)
else
@test isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), gs_ad_zygote[idx];
kwargs...)
end
end
end
return
end