Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge pull request #36 from LuxDL/ap/fmode
Browse files Browse the repository at this point in the history
Fix ForwardMode tests
  • Loading branch information
avik-pal authored Sep 19, 2023
2 parents ba69eef + ef7b7c0 commit 640cf1c
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.4"
version = "0.3.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
4 changes: 4 additions & 0 deletions ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,8 @@ for op in [:conv, :depthwiseconv]
end
end

function LuxLib._drop_forwarddiff_partials(x::AbstractArray{<:Dual})
return ForwardDiff.value.(x)
end

end
1 change: 0 additions & 1 deletion ext/LuxLibLuxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64
LuxLib._replicate(rng::CUDA.RNG) = deepcopy(rng)

# api/batchnorm.jl

const CUDNN_BN_ARRAY_TYPE = Union{CuArray{<:FP_32_64, 2}, CuArray{<:FP_32_64, 4},
CuArray{<:FP_32_64, 5}}
const BNParamType = Union{Nothing, CuVector{<:FP_32_64}}
Expand Down
5 changes: 3 additions & 2 deletions src/api/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ function batchnorm(x::AA{<:Real, N}, scale::NOrAVR, bias::NOrAVR, running_mean::
running_var::NOrAVR; momentum::Real, training::Val, epsilon::Real) where {N}
x_, xm, xv = _normalization(x, running_mean, running_var, scale, bias,
_get_batchnorm_reduce_dims(x), training, momentum, epsilon)

return x_, (; running_mean=xm, running_var=xv)
stats = (; running_mean=_drop_forwarddiff_partials(xm),
running_var=_drop_forwarddiff_partials(xv))
return (x_, stats)
end

@generated function _get_batchnorm_reduce_dims(::AA{T, N}) where {T, N}
Expand Down
10 changes: 10 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,13 @@ function Base.showerror(io::IO, ex::OutdatedNNlibDependencyException)
print(io, "OutdatedNNlibDependencyException: ")
return println(io, "$msg")
end

# Droping ForwardDiff Gradients
function _drop_forwarddiff_partials end

_drop_forwarddiff_partials(x::AbstractArray) = x
_drop_forwarddiff_partials(::Nothing) = nothing
_drop_forwarddiff_partials(x::Tuple) = _drop_forwarddiff_partials.(x)
function _drop_forwarddiff_partials(x::NamedTuple{N}) where {N}
return NamedTuple{N}(map(_drop_forwarddiff_partials, values(x)))
end
17 changes: 9 additions & 8 deletions test/jvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ function jvp_zygote(f, x, u)
return Jₓ * vec(u)
end

function test_jvp_computation(f, x, u)
function test_jvp_computation(f, x, u, on_gpu)
jvp₁ = jvp_forwarddiff(f, x, u)
if !(x isa ComponentArray)
if !(x isa ComponentArray && on_gpu)
# ComponentArray + ForwardDiff on GPU don't play nice
jvp₂ = jvp_forwarddiff_concrete(f, x, u)
@test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5)
end

jvp₃ = jvp_zygote(f, x, u)
@test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5)
jvp₃ = jvp_zygote(f, x, u)
@test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5)
end
end

@testset "$mode: Jacobian Vector Products" for (mode, aType, on_gpu) in MODES
Expand All @@ -66,9 +66,10 @@ end
uw = randn(Float32, size(w)...) |> aType
u = randn(Float32, length(x) + length(w)) |> aType

test_jvp_computation(x -> op(x, w; flipped), x, ux)
test_jvp_computation(w -> op(x, w; flipped), w, uw)
test_jvp_computation(xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u)
test_jvp_computation(x -> op(x, w; flipped), x, ux, on_gpu)
test_jvp_computation(w -> op(x, w; flipped), w, uw, on_gpu)
test_jvp_computation(xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u,
on_gpu)
end
end
end

2 comments on commit 640cf1c

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/91720

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.5 -m "<description of version>" 640cf1c7137a3a09a2a886517b6d88d780071187
git push origin v0.3.5

Please sign in to comment.