Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Integrate LuxTestUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 31, 2023
1 parent 2e34962 commit 18a81d5
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 162 deletions.
2 changes: 2 additions & 0 deletions test/LocalPreferences.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[LuxTestUtils]
target_modules = ["LuxLib"]
4 changes: 1 addition & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
[deps]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down
12 changes: 6 additions & 6 deletions test/api/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ end
y, nt = batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9))

@inferred batchnorm(x, scale, bias, rm, rv; epsilon, training, momentum=T(0.9))
run_JET_tests(_f, x, scale, bias, rm, rv)

@jet _f(x, scale, bias, rm, rv)

@test y isa aType{T, length(sz)}
@test size(y) == sz
Expand All @@ -45,17 +46,16 @@ end
end

if __istraining(training)
fp16 = T == Float16
if affine
__f = (args...) -> sum(first(batchnorm(args..., rm, rv; epsilon, training,
momentum=T(0.9))))
test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu,
skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
@eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2
else
__f = (args...) -> sum(first(batchnorm(args..., scale, bias, rm, rv;
epsilon, training, momentum=T(0.9))))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16,
atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16)

@eval @test_gradients $__f $x gpu_testing=$on_gpu soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2
end
end
end
Expand Down
35 changes: 20 additions & 15 deletions test/api/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ rng = MersenneTwister(0)
@test rng != rng_

__f = x -> sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon())))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
run_JET_tests(__f, x)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)

@inferred dropout(rng, x, T(0.5), Val(true); dims=Colon())

Expand Down Expand Up @@ -58,9 +59,10 @@ end end

__f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true);
dims=Colon())))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
run_JET_tests(__f, x)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)

# Try using mask if possible (possible!!)
@inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon())
Expand All @@ -76,9 +78,10 @@ end end

__f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false);
dims=Colon())))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
run_JET_tests(__f, x)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)

mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

Expand All @@ -96,9 +99,10 @@ end end

__f = x -> sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false);
dims=Colon())))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
run_JET_tests(__f, x)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)

# Testing Mode
@inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon())
Expand Down Expand Up @@ -129,9 +133,10 @@ end end
@test_broken isapprox(std(y), std(x); atol=1.0f-2, rtol=1.0f-2)

__f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
run_JET_tests(__f, x)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)

@inferred alpha_dropout(rng, x, T(0.5), Val(false))

Expand Down
25 changes: 12 additions & 13 deletions test/api/groupnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
bias)

@inferred groupnorm(x, scale, bias; groups, epsilon)
run_JET_tests(_f, x, scale, bias; opt_broken=true)
@jet _f(x, scale, bias) opt_broken=true
@test y isa aType{T, 4}
@test size(y) == sz

Expand All @@ -60,14 +60,14 @@ end

# The KA implementation reorders operations manually for maximal
# performance. Hence equality cannot be guaranteed.
@test isapprox(y, y_; atol=1.0f-3, rtol=1.0f-3)
@test isapprox(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3)
@test isapprox(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3)
@test isapprox(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3)

test_gradient_correctness((args...) -> sum(_f(args...)), x, scale, bias;
gpu_testing=on_gpu, atol=1.0f-3, rtol=1.0f-3,
soft_fail=T == Float16)
@test check_approx(y, y_; atol=1.0f-3, rtol=1.0f-3)
@test check_approx(gs_x, gs_x_; atol=1.0f-3, rtol=1.0f-3)
@test check_approx(gs_scale, gs_scale_; atol=1.0f-3, rtol=1.0f-3)
@test check_approx(gs_bias, gs_bias_; atol=1.0f-3, rtol=1.0f-3)

fp16 = T == Float16
__f = sum _f
@eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=1.0f-3 rtol=1.0f-3 soft_fail=$fp16
end
end end

Expand All @@ -85,17 +85,16 @@ end end

@inferred groupnorm(x, scale, bias, rm, rv; groups, epsilon, training,
momentum=T(0.9))
run_JET_tests(_f, x, scale, bias, rm, rv; opt_broken=true)
@jet _f(x, scale, bias, rm, rv) opt_broken=true

@test y isa aType{T, 4}
@test size(y) == sz
@test size(nt.running_mean) == (groups,)
@test size(nt.running_var) == (groups,)

fp16 = T == Float16
__f = (args...) -> sum(first(groupnorm(args..., rm, rv; groups, epsilon, training,
momentum=T(0.9))))
test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu,
skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
@eval @test_gradients $__f $x $scale $bias gpu_testing=$on_gpu atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16
end
end end
18 changes: 6 additions & 12 deletions test/api/instancenorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,24 @@ end
y, nt = instancenorm(x, scale, bias; epsilon, training)

@inferred instancenorm(x, scale, bias; epsilon, training)
run_JET_tests(_f, x, scale, bias)
@jet _f(x, scale, bias)
@test y isa aType{T, length(sz)}
@test size(y) == sz

_target_std = ones(ntuple(_ -> 1, length(sz) - 2)..., size(x)[(end - 1):end]...)
if length(sz) != 3
@test isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std; atol=0.2)
else
@test_broken isapprox(std(Array(y); dims=1:(length(sz) - 2)), _target_std;
atol=0.2)
end
@eval @test check_approx(std(Array($y); dims=1:($(length(sz) - 2))), $_target_std;
atol=0.2) broken=$(length(sz) == 3)
@test std(y; dims=1:(length(sz) - 2)) != std(x; dims=1:(length(sz) - 2))

if __istraining(training)
fp16 = T == Float16
if affine
__f = (args...) -> sum(first(instancenorm(args...; epsilon, training)))
test_gradient_correctness(__f, x, scale, bias; gpu_testing=on_gpu,
skip_fdm=T == Float16, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
@eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu
else
__f = (args...) -> sum(first(instancenorm(args..., scale, bias; epsilon,
training)))
test_gradient_correctness(__f, x; gpu_testing=on_gpu, skip_fdm=T == Float16,
atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16)
@eval @test_gradients $__f $x soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu
end
end
end
Expand Down
17 changes: 8 additions & 9 deletions test/api/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,25 @@ end
x, scale, bias = _setup_layernorm(aType, T, x_shape, affine_shape)

@inferred _f(x, scale, bias)
run_JET_tests(_f, x, scale, bias)
@jet _f(x, scale, bias)

y = _f(x, scale, bias)

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape

if affine_shape === nothing
@test isapprox(mean(y; dims), 0; atol=1e-3, rtol=1e-3)
@test isapprox(std(y; dims), 1; atol=1e-1, rtol=1e-1)
@test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3)
@test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1)
end

fp16 = T == Float16
if affine_shape === nothing
test_gradient_correctness(x -> sum(_f(x, nothing, nothing)), x;
skip_fdm=T == Float16, gpu_testing=on_gpu,
atol=1.0f-2, rtol=1.0f-2, soft_fail=T == Float16)
__f = x -> sum(_f(x, nothing, nothing))
@eval @test_gradients $__f $x soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu
else
test_gradient_correctness(sum _f, x, scale, bias; skip_fdm=T == Float16,
gpu_testing=on_gpu, atol=1.0f-2, rtol=1.0f-2,
soft_fail=T == Float16)
__f = sum _f
@eval @test_gradients $__f $x $scale $bias soft_fail=$fp16 atol=1.0f-2 rtol=1.0f-2 gpu_testing=$on_gpu
end
end
end end
2 changes: 1 addition & 1 deletion test/ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ rng = MersenneTwister(0)
x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1]
x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1])

@test isapprox(x_dropout, x_dual_dropout)
@test check_approx(x_dropout, x_dual_dropout)
end end
105 changes: 2 additions & 103 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using FiniteDifferences, LuxLib, Test
using LuxLib, LuxTestUtils, Test, Zygote
using LuxCUDA # CUDA Support
using ReverseDiff, Tracker, Zygote # AD Packages
using LuxTestUtils: @jet, @test_gradients, check_approx

const GROUP = get(ENV, "GROUP", "All")

Expand All @@ -23,105 +23,4 @@ const MODES = begin
end
end

try
using JET
catch
@warn "JET not not precompiling. All JET tests will be skipped." maxlog=1
global test_call(args...; kwargs...) = nothing
global test_opt(args...; kwargs...) = nothing
end

function Base.isapprox(x, y; kwargs...)
@warn "`isapprox` is not defined for ($(typeof(x)), $(typeof(y))). Using `==` instead."
return x == y
end

function Base.isapprox(x::Tuple, y::Tuple; kwargs...)
return all(isapprox.(x, y; kwargs...))
end

function Base.isapprox(nt1::NamedTuple{fields}, nt2::NamedTuple{fields};
kwargs...) where {fields}
checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...)
checkapprox(t::Tuple{Nothing, Nothing}) = true
return all(checkapprox, zip(values(nt1), values(nt2)))
end

function Base.isapprox(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T}
checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...)
checkapprox(t::Tuple{Nothing, Nothing}) = true
return all(checkapprox, zip(t1, t2))
end

Base.isapprox(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0
Base.isapprox(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0
Base.isapprox(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0
Base.isapprox(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0
Base.isapprox(v::Tuple, ::Nothing; kwargs...) = length(v) == 0
Base.isapprox(::Nothing, v::Tuple; kwargs...) = length(v) == 0
Base.isapprox(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0
Base.isapprox(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0
Base.isapprox(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0
Base.isapprox(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0

# JET Tests
function run_JET_tests(f, args...; call_broken=false, opt_broken=false, kwargs...)
@static if VERSION >= v"1.7"
test_call(f, typeof.(args); broken=call_broken, target_modules=(LuxLib,))
test_opt(f, typeof.(args); broken=opt_broken, target_modules=(LuxLib,))
end
end

__istraining(::Val{training}) where {training} = training

# Test the gradients across AD Frameworks and FiniteDifferences
# TODO: Implement it as a macro so that we get correct line numbers for `@test` failures.
function test_gradient_correctness(f::Function, args...; gpu_testing::Bool=false,
skip_fdm::Bool=false, skip_fdm_override::Bool=false,
soft_fail::Bool=false, kwargs...)
gs_ad_zygote = Zygote.gradient(f, args...)
gs_ad_tracker = Tracker.gradient(f, args...)
gs_ad_reversediff = gpu_testing ? nothing : ReverseDiff.gradient(f, args)

if !skip_fdm_override
arr_len = length.(args)
if any(x -> x >= 25, arr_len) || sum(arr_len) >= 100
@warn "Skipping FiniteDifferences test for large arrays: $(arr_len)."
skip_fdm = true
end
end

gs_fdm = gpu_testing || skip_fdm ? nothing :
FiniteDifferences.grad(FiniteDifferences.central_fdm(8, 1), f, args...)
for idx in 1:length(gs_ad_zygote)
_c1 = isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; kwargs...)
if soft_fail && !_c1
@test_broken isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx];
kwargs...)
else
@test isapprox(Tracker.data(gs_ad_tracker[idx]), gs_ad_zygote[idx]; kwargs...)
end

if !gpu_testing
if !skip_fdm
_c2 = isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...)
if soft_fail && !_c2
@test_broken isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...)
else
@test isapprox(gs_ad_zygote[idx], gs_fdm[idx]; kwargs...)
end
end

_c3 = isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), gs_ad_zygote[idx];
kwargs...)
if soft_fail && !_c3
@test_broken isapprox(ReverseDiff.value(gs_ad_reversediff[idx]),
gs_ad_zygote[idx]; kwargs...)
else
@test isapprox(ReverseDiff.value(gs_ad_reversediff[idx]), gs_ad_zygote[idx];
kwargs...)
end
end
end
return
end

0 comments on commit 18a81d5

Please sign in to comment.