Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
test: more test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 21, 2024
1 parent 3d60e94 commit e52d6b1
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 34 deletions.
4 changes: 2 additions & 2 deletions ext/LuxLibTrackerAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module LuxLibTrackerAMDGPUExt

using AMDGPU: AMDGPU
using LuxLib: LuxLib, Optional
using NNlib: NNlib, ConvDims, PoolDims
using LuxLib: LuxLib
using NNlib: NNlib, PoolDims
using Tracker: Tracker, TrackedArray

const ROCTrackedArray{T, N} = TrackedArray{T, N, <:AMDGPU.ROCArray{T, N}}
Expand Down
10 changes: 4 additions & 6 deletions ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module LuxLibcuDNNExt

using LuxLib: LuxLib, Optional
using LuxLib: LuxLib, Optional, ∂∅
using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray
using ChainRulesCore: ChainRulesCore
using cuDNN: cuDNN, cudnnBatchNormalizationBackward,
Expand Down Expand Up @@ -44,11 +44,9 @@ function CRC.rrule(::typeof(LuxLib.batchnorm_cudnn), running_mean, running_var,
proj_b = CRC.ProjectTo(bias)
proj_x = CRC.ProjectTo(x)
∇batchnorm_cudnn_internal = @closure Δ -> begin
∂y = CRC.unthunk(first(Δ))
∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn(
scale, bias, x, ∂y, running_mean, running_var, xmean, xivar; ϵ=epsilon)
return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), proj_g(∂g), proj_b(∂b),
proj_x(∂x), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent())
∂g, ∂b, ∂x = LuxLib.∇batchnorm_cudnn(scale, bias, x, CRC.unthunk(first(Δ)),
running_mean, running_var, xmean, xivar; ϵ=epsilon)
return ∂∅, ∂∅, ∂∅, proj_g(∂g), proj_b(∂b), proj_x(∂x), ∂∅, ∂∅, ∂∅
end
return (y, xmean, xivar), ∇batchnorm_cudnn_internal
end
Expand Down
7 changes: 4 additions & 3 deletions test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ end
end

# Upstream bug: https://github.com/EnzymeAD/Enzyme.jl/issues/1651
if !on_gpu && !(Sys.iswindows() && T == Float16)
if !on_gpu && !Sys.iswindows()
∂x_zyg = only(Zygote.gradient(__f, x))
∂x_enz = zero.(x)
Enzyme.autodiff(
Expand Down Expand Up @@ -138,7 +138,7 @@ end
Float16)
end

if !on_gpu
if !on_gpu && !Sys.iswindows()
∂x_zyg = only(Zygote.gradient(__f, x))
∂x_enz = Enzyme.gradient(Reverse, __f, x)
@test ∂x_zyg∂x_enz atol=1.0f-3 rtol=1.0f-3
Expand Down Expand Up @@ -177,7 +177,8 @@ end
Float16)
end

if !on_gpu
# Upstream bug: https://github.com/EnzymeAD/Enzyme.jl/issues/1651
if !on_gpu && !Sys.iswindows()
∂x_zyg = only(Zygote.gradient(__f, x))
∂x_enz = zero.(x)
Enzyme.autodiff(
Expand Down
29 changes: 16 additions & 13 deletions test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
anonact = x -> x^3

@testset "$mode" for (mode, aType, on_gpu) in MODES
@testset "eltype $T, size $sz, $act" for T in (Float16, Float32, Float64),
@testset "eltype $T, size $sz, $act $affine $track_stats" for T in (
Float16, Float32, Float64),
sz in ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)),
training in (Val(true), Val(false)),
affine in (true, false),
Expand All @@ -56,18 +57,20 @@
end

# Check the rrules
_f = (args...) -> sum(first(batchnorm(
args..., rm, rv, training, act, T(0.9), epsilon)))
_f2 = (args...) -> sum(first(__batchnorm_basic(
args..., rm, rv, training, act, T(0.9), epsilon)))

∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(
sum _f2, x, scale, bias)
@test ∂x∂x_simple atol=atol rtol=rtol
if affine
@test ∂scale∂scale_simple atol=atol rtol=rtol
@test ∂bias∂bias_simple atol=atol rtol=rtol
if __istraining(training)
_f = (args...) -> sum(first(batchnorm(
args..., rm, rv, training, act, T(0.9), epsilon)))
_f2 = (args...) -> sum(first(__batchnorm_basic(
args..., rm, rv, training, act, T(0.9), epsilon)))

∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(
sum _f2, x, scale, bias)
@test ∂x∂x_simple atol=atol rtol=rtol
if affine
@test ∂scale∂scale_simple atol=atol rtol=rtol
@test ∂bias∂bias_simple atol=atol rtol=rtol
end
end

@test @inferred(batchnorm(
Expand Down
25 changes: 15 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@ const EXTRA_PKGS = String[]

if !isempty(EXTRA_PKGS)
@info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS
for pkg in EXTRA_PKGS
if pkg == "AMDGPU"
Pkg.add(; name=pkg, rev="master") # FIXME: remove before merge
else
Pkg.add(; name=pkg)
end
end
Pkg.add(EXTRA_PKGS)
Pkg.update()
Base.retry_load_extensions()
Pkg.instantiate()
Expand All @@ -26,6 +20,17 @@ const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all")
@info "Running tests for group: $LUXLIB_TEST_GROUP"
const RETESTITEMS_NWORKERS = parse(Int, get(ENV, "RETESTITEMS_NWORKERS", "0"))

ReTestItems.runtests(
@__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]),
nworkers=ifelse(BACKEND_GROUP ("cuda", "amdgpu"), 0, RETESTITEMS_NWORKERS))
if BACKEND_GROUP ("cuda", "amdgpu")
# Upstream bug: https://github.com/JuliaTesting/ReTestItems.jl/issues/164
if LUXLIB_TEST_GROUP == "all"
ReTestItems.runtests(@__DIR__; name=r"^(?!.*Normalization$).*")
ReTestItems.runtests(@__DIR__; name=r".*Normalization$", nworkers=0)
elseif LUXLIB_TEST_GROUP == "normalization"
ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)], nworkers=0)
else
ReTestItems.runtests(@__DIR__; tags=[Symbol(LUXLIB_TEST_GROUP)])
end
else
ReTestItems.runtests(
@__DIR__; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]))
end

0 comments on commit e52d6b1

Please sign in to comment.