Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge pull request #3 from LuxDL/ap/luxcuda
Browse files Browse the repository at this point in the history
Start using LuxCUDA
  • Loading branch information
avik-pal authored Mar 27, 2023
2 parents c4cdcc7 + 27abc9e commit c1bb495
Show file tree
Hide file tree
Showing 19 changed files with 415 additions and 783 deletions.
12 changes: 4 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.12"
version = "0.1.13"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -29,13 +27,11 @@ LuxLibReverseDiffExt = "ReverseDiff"
LuxLibTrackerExt = "Tracker"

[compat]
CUDA = "3, 4"
CUDAKernels = "0.3, 0.4"
ChainRulesCore = "1"
ForwardDiff = "0.10"
KernelAbstractions = "0.7, 0.8"
KernelAbstractions = "0.9"
LuxCUDA = "0.1"
NNlib = "0.8"
NNlibCUDA = "0.2"
Requires = "1"
ReverseDiff = "1"
Tracker = "0.2"
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxLibReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end
@grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedArray)
@grad_from_chainrules LuxLib._copy_autodiff_barrier(x::TrackedReal)

LuxLib._get_device(x::TrackedArray) = LuxLib._get_device(value(x))
LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(value(x))

# api/dropout.jl
LuxLib._dropout_fptype(x::TrackedArray) = LuxLib._dropout_fptype(value(x))
Expand Down
18 changes: 9 additions & 9 deletions ext/LuxLibTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ else
import ..Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector,
TrackedReal
end
using CUDA, NNlibCUDA
using LuxCUDA
using NNlib, LuxLib
using LuxLib: _CUDNN_BATCHNORM_FLOAT, _GROUPNORM_IMPL_FLOAT
import ChainRulesCore as CRC
Expand Down Expand Up @@ -61,7 +61,7 @@ function LuxLib._copy_autodiff_barrier(x::Union{TrackedArray, TrackedReal})
return LuxLib._copy_autodiff_barrier(data(x))
end

LuxLib._get_device(x::TrackedArray) = LuxLib._get_device(data(x))
LuxLib._get_backend(x::TrackedArray) = LuxLib._get_backend(data(x))

# api/batchnorm.jl
_TR_BN = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:_CUDNN_BATCHNORM_FLOAT, 2}},
Expand All @@ -81,16 +81,16 @@ function LuxLib.batchnorm(x::_TR_BN, scale::Union{_TR_BN_VEC, Nothing},
return x_, (; running_mean=rm, running_var=rv)
end

for RM in (:TrackedVector, :AbstractVector),
RV in (:TrackedVector, :AbstractVector),
for RM in (:TrackedVector, :Nothing, :AbstractVector),
RV in (:TrackedVector, :Nothing, :AbstractVector),
S in (:TrackedVector, :Nothing, :AbstractVector),
B in (:TrackedVector, :Nothing, :AbstractVector),
XT in (:TrackedArray, :AbstractArray)

RM == :AbstractVector &&
RV == :AbstractVector &&
(S == :AbstractVector || S == Nothing) &&
(B == :AbstractVector || B == Nothing) &&
(RM == :AbstractVector || RM == :Nothing) &&
(RV == :AbstractVector || RV == :Nothing) &&
(S == :AbstractVector || S == :Nothing) &&
(B == :AbstractVector || B == :Nothing) &&
XT == :AbstractArray &&
continue

Expand Down Expand Up @@ -133,7 +133,7 @@ end
@grad function LuxLib.groupnorm(x::AbstractArray{T, 4}, scale::AbstractVector{T},
bias::AbstractVector{T}; groups::Int,
epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT}
LuxLib._assert_same_device(data(x), data(scale), data(bias))
LuxLib._assert_same_backend(data(x), data(scale), data(bias))
if length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of
channels (N - 1 dim of the input array)."))
Expand Down
2 changes: 1 addition & 1 deletion src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import ChainRulesCore as CRC
using KernelAbstractions
import KernelAbstractions as KA

using CUDA, CUDAKernels, NNlibCUDA # CUDA Support
using LuxCUDA # CUDA Support

# Extensions
if !isdefined(Base, :get_extension)
Expand Down
5 changes: 4 additions & 1 deletion src/api/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ end
function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B)
rng = _replicate(rng)
noise = rand!(rng, similar(x, _dropout_fptype(x)))
return (A .* ifelse.(noise .> p, x, α) .+ B), rng
# NOTE(@avik-pal): Combining the last 2 lines causes a compilation error for Tracker
# on GPU
y = ifelse.(noise .> p, x, α)
return (A .* y .+ B), rng
end

alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng)
Expand Down
6 changes: 3 additions & 3 deletions src/api/groupnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ interface.
function groupnorm(x::AbstractArray{T, 4}, scale::AbstractVector{T},
bias::AbstractVector{T}; groups::Int,
epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT}
_assert_same_device(x, scale, bias)
_assert_same_backend(x, scale, bias)
if length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " *
"channels (N - 1 dim of the input array)."))
Expand Down Expand Up @@ -97,7 +97,7 @@ function groupnorm(x::AbstractArray{<:Real, N},
running_mean::Union{Nothing, AbstractVector{<:Real}},
running_var::Union{Nothing, AbstractVector{<:Real}}; groups::Int,
momentum::Real, training::Val, epsilon::Real) where {N}
_assert_same_device(x, scale, bias, running_mean, running_var)
_assert_same_backend(x, scale, bias, running_mean, running_var)
if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " *
"channels (N - 1 dim of the input array)."))
Expand All @@ -124,7 +124,7 @@ end
function CRC.rrule(::typeof(groupnorm), x::AbstractArray{T, 4}, scale::AbstractVector{T},
bias::AbstractVector{T}; groups::Int,
epsilon::Real) where {T <: _GROUPNORM_IMPL_FLOAT}
_assert_same_device(x, scale, bias)
_assert_same_backend(x, scale, bias)
if length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of " *
"channels (N - 1 dim of the input array)."))
Expand Down
6 changes: 2 additions & 4 deletions src/api/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ Normalized Array of same size as `x`.
"""
function layernorm(x::AbstractArray{<:Real, N}, scale::AbstractArray{<:Real, N},
bias::AbstractArray{<:Real, N}; dims, epsilon) where {N}
_mean = mean(x; dims)
_rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon)

return scale .* (x .- _mean) .* _rstd .+ bias
x_norm = layernorm(x, nothing, nothing; dims, epsilon)
return scale .* x_norm .+ bias
end

function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon)
Expand Down
6 changes: 3 additions & 3 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
function _normalization(x, running_mean, running_var, scale, bias, reduce_dims, training,
momentum, epsilon)
Base.depwarn("`LuxLib._normalization` with `reduce_dims` of type " *
"$(typeof(reduce_dims)) has been deprecated and will be removed in v0.2" *
". Pass `reduce_dims` as `Val(Tuple(reduce_dims))`", :_normalization)
Base.depwarn("""`LuxLib._normalization` with `reduce_dims` of type
$(typeof(reduce_dims)) has been deprecated and will be removed in v0.2.
Pass `reduce_dims` as `Val(Tuple(reduce_dims))`""", :_normalization)
return _normalization(x, running_mean, running_var, scale, bias,
Val(Tuple(reduce_dims)), training, momentum, epsilon)
end
40 changes: 21 additions & 19 deletions src/impl/groupnorm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Launch Heuristics
_linear_threads_groupnorm(::CPU) = Threads.nthreads()
_linear_threads_groupnorm(::CUDADevice) = (16, 16)
_linear_threads_groupnorm(::GPU) = 256

_GROUPNORM_IMPL_FLOAT = Union{Float32, Float64}
Expand Down Expand Up @@ -66,15 +65,17 @@ end
_scale = similar(X, (C, N))
_bias = similar(X, (C, N))

device = get_device(X)
backend = KA.get_backend(X)

n = _linear_threads_groupnorm(device)
compute_fixed_params! = _compute_fused_params_kernel!(device, n, size(_scale))
groupnorm_forward! = _groupnorm_forward_kernel!(device, n, size(X))
n = _linear_threads_groupnorm(backend)
compute_fixed_params! = _compute_fused_params_kernel!(backend, n, size(_scale))
groupnorm_forward! = _groupnorm_forward_kernel!(backend, n, size(X))

wait(compute_fixed_params!(_scale, _bias, C, K, mu, rsig, gamma, beta;
ndrange=size(_scale)))
wait(groupnorm_forward!(Y, W * H, X, _scale, _bias; ndrange=size(Y)))
compute_fixed_params!(_scale, _bias, C, K, mu, rsig, gamma, beta; ndrange=size(_scale))
KA.synchronize(backend)

groupnorm_forward!(Y, W * H, X, _scale, _bias; ndrange=size(Y))
KA.synchronize(backend)

return Y, mu, rsig
end
Expand All @@ -86,35 +87,36 @@ end
W, H, C, N = size(X)
K = div(C, G)
WxH = W * H
device = get_device(X)
n = _linear_threads_groupnorm(device)
backend = KA.get_backend(X)
n = _linear_threads_groupnorm(backend)

dbias = reshape(sum(dY; dims=(1, 2)), (1, 1, K, G, N))
dscale = reshape(sum(X .* dY; dims=(1, 2)), (1, 1, K, G, N))

dY_dscale = similar(X, (C, N))
groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(device, n, size(dY_dscale))
ev = groupnorm_dy_dscale!(dY_dscale, C, K, rsig, gamma; ndrange=size(dY_dscale))
groupnorm_dy_dscale! = _groupnorm_dy_dscale_kernel!(backend, n, size(dY_dscale))
groupnorm_dy_dscale!(dY_dscale, C, K, rsig, gamma; ndrange=size(dY_dscale))

gamma_ = reshape(gamma, (1, 1, K, G, 1))
db_sum = sum(gamma_ .* dbias; dims=3)
ds_sum = sum(gamma_ .* dscale; dims=3)
wait(ev)
KA.synchronize(backend)

X_scale = similar(X, (G, N))
bias = similar(X, (G, N))

groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(device, n,
groupnorm_xscale_and_bias! = _groupnorm_xscale_and_bias_kernel!(backend, n,
size(X_scale))
wait(groupnorm_xscale_and_bias!(X_scale, bias, T(1 / (K * WxH)), mu, rsig, ds_sum,
db_sum; ndrange=size(X_scale)))
groupnorm_xscale_and_bias!(X_scale, bias, T(1 / (K * WxH)), mu, rsig, ds_sum, db_sum;
ndrange=size(X_scale))
KA.synchronize(backend)

dX = similar(X)
groupnorm_dx! = _groupnorm_dx_kernel!(device, n, size(dX))
ev = groupnorm_dx!(dX, WxH, K, dY_dscale, dY, X_scale, X, bias; ndrange=size(dX))
groupnorm_dx! = _groupnorm_dx_kernel!(backend, n, size(dX))
groupnorm_dx!(dX, WxH, K, dY_dscale, dY, X_scale, X, bias; ndrange=size(dX))
dgamma = vec(sum((-dbias .* mu .+ dscale) .* rsig; dims=5))
dbeta = vec(sum(dbias; dims=5))
wait(ev)
KA.synchronize(backend)

return dX, dgamma, dbeta
end
5 changes: 4 additions & 1 deletion src/impl/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ end
@generated function _affine_normalize(x::AbstractArray, xmean::ST, xvar::ST, scale::A,
bias::A, epsilon::Real) where {ST, A}
if A != Nothing
return :(return scale .* (x .- xmean) ./ sqrt.(xvar .+ epsilon) .+ bias)
return quote
x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon)
return scale .* x_norm .+ bias
end
else
return :(return (x .- xmean) ./ sqrt.(xvar .+ epsilon))
end
Expand Down
34 changes: 10 additions & 24 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,23 @@
_div_idx(idx, n) = div(idx - 1, n) + 1
_mod_idx(idx, n) = mod(idx - 1, n) + 1

@static if VERSION >= v"1.7"
get_device(x) = KA.get_device(x)
else
# KA.get_device is not present in <= v0.7 but that is what works on julia 1.6
get_device(x::CuArray) = CUDADevice()
get_device(x::Array) = CPU()
get_device(x::SubArray) = CPU()
function get_device(x)
throw(ArgumentError("get_device not implemented for $(typeof(x)). This is an" *
"undesirable codepath. Please use julia 1.7+ for more " *
"meaningful error messages using KA.jl."))
end
end

_get_device(::Nothing) = nothing
_get_device(d) = hasmethod(get_device, (typeof(d),)) ? get_device(d) : nothing
_get_device(t::Tuple) = filter(!isnothing, _get_device.(t))
_get_backend(::Nothing) = nothing
_get_backend(d) = hasmethod(KA.get_backend, (typeof(d),)) ? KA.get_backend(d) : nothing
_get_backend(t::Tuple) = filter(!isnothing, _get_backend.(t))

CRC.@non_differentiable _get_device(::Any)
CRC.@non_differentiable _get_backend(::Any)

function _assert_same_device(args...)
devs = _get_device(args)
function _assert_same_backend(args...)
devs = _get_backend(args)
if !all(devs .== (first(devs),))
throw(ArgumentError("All arguments must be on the same device. This error is
encountered if you are calling a function with a mix of CPU
and GPU arrays."))
throw(ArgumentError("""All arguments must be on the same backend. This error is
encountered if you are calling a function with a mix of CPU
and GPU arrays."""))
end
return
end

CRC.@non_differentiable _assert_same_device(::Any...)
CRC.@non_differentiable _assert_same_backend(::Any...)

@inline @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand Down
Loading

2 comments on commit c1bb495

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/80423

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.13 -m "<description of version>" c1bb495c1a7051e9220db21a1663e0a007706a80
git push origin v0.1.13

Please sign in to comment.