Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: fast_broadcast implementation made compatible with chainrules
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 13, 2024
1 parent 9f18c16 commit f6fe2f4
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 151 deletions.
1 change: 0 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
style = "sciml"
whitespace_in_kwargs = false
always_use_return = true
margin = 92
indent = 4
format_docstrings = true
Expand Down
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -18,9 +19,9 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
VectorizedReduction = "4ffe575c-65e5-43f4-bc05-e0b500dc3d2c"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand Down Expand Up @@ -48,6 +49,7 @@ ComponentArrays = "0.15.8"
DispatchDoctor = "0.4.9"
EnzymeCore = "0.7"
ExplicitImports = "1.9.0"
FastBroadcast = "0.3.4"
FastClosures = "0.3.2"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
Expand All @@ -63,12 +65,12 @@ Random = "1.10"
ReTestItems = "1.23.1"
Reexport = "1"
ReverseDiff = "1.15"
SIMDTypes = "0.1.0"
StableRNGs = "1"
Statistics = "1.10"
Test = "1.10"
Tracker = "0.2.34"
UnrolledUtilities = "0.1.2"
VectorizedReduction = "0.1.12"
Zygote = "0.6.69"
cuDNN = "1.3"
julia = "1.10"
Expand Down
12 changes: 7 additions & 5 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
module LuxLib

using ArrayInterface: ArrayInterface
using ChainRulesCore: ChainRulesCore, NoTangent
using ArrayInterface: ArrayInterface, fast_scalar_indexing
using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig
using DispatchDoctor: @stable
using EnzymeCore: EnzymeCore, EnzymeRules
using FastBroadcast: @..
using FastClosures: @closure
using ForwardDiff: ForwardDiff
using LinearAlgebra: LinearAlgebra, BLAS, mul!
using LoopVectorization: @turbo
using LoopVectorization: LoopVectorization, @turbo, vmap
using LuxCore: LuxCore
using LuxDeviceUtils: get_device_type, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice,
AbstractLuxDevice
Expand All @@ -16,15 +17,14 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇con
∇conv_filter
using Random: Random, AbstractRNG, rand!
using Reexport: @reexport
using SIMDTypes
using Statistics: Statistics, mean, var
using UnrolledUtilities: unrolled_any

@reexport using NNlib

const CRC = ChainRulesCore

const Optional{T} = Union{Nothing, T}

include("utils.jl")

# Low-Level Implementations
Expand All @@ -34,6 +34,7 @@ include("impl/fused_conv.jl")
include("impl/fast_activation.jl")
include("impl/forward_diff.jl")
include("impl/broadcast.jl")
include("impl/bias_act.jl")

# User Facing
include("api/batchnorm.jl")
Expand All @@ -51,5 +52,6 @@ include("deprecations.jl")
export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout
export fused_dense_bias_activation, fused_conv_bias_activation
export fast_activation!!
export fast_broadcast, fast_broadcast!, fast_broadcast!!

end
42 changes: 41 additions & 1 deletion src/api/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,50 @@
"""
fast_broadcast(f::F, x::AbstractArray, args...) where {F <: Function}
Computes `@. f(x, args...)`. If it is possible, we use `LoopVectorization.@turbo` to speed
up the computation.
"""
function fast_broadcast end

function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast),
f::F, x::AbstractArray, args...) where {F <: Function}
return CRC.rrule_via_ad(cfg, broadcast, f, x, args...)
end

"""
fast_broadcast!(f::F, x::AbstractArray, args...) where {F <: Function}
Computes `@. x = f(x, args...)` updating `x` in place. This assumes that `x` has the
matching shape as the broadcasted result.
If it is possible, we use `LoopVectorization.@turbo` to speed up the computation.
This function isn't compatible with ChainRules, use [`fast_broadcast!!`](@ref) instead.
"""
function fast_broadcast! end

for (op, impl) in (
(:fast_broadcast, :__fast_broadcast_impl), (:fast_broadcast!, :__fast_broadcast_impl!))
@eval function $(op)(f::F, x::AbstractArray, args...) where {F <: Function}
return $(impl)(get_device_type(x, args...), f, x, args...)
return $(impl)(get_device_type((x, args...)), f, x, args...)
end
end

"""
fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function}
Compues `@. f(x, args...)`. If `x` can be set-indexed, it uses `fast_broadcast!` to
update `x` in place. Otherwise, it falls back to `fast_broadcast`.
Since `x` is not guaranteed to be modified inplace, call the function as
`y = fast_broadcast!!(...)`.
"""
@generated function fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function}
ArrayInterface.can_setindex(x) && :(return fast_broadcast!(f, x, args...))
return :(fast_broadcast(f, x, args...))
end

function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!),
f::F, x::AbstractArray, args...) where {F <: Function}
return CRC.rrule_via_ad(cfg, fast_broadcast, f, x, args...)
end
25 changes: 25 additions & 0 deletions src/impl/bias_act.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Helper to add bias and apply activation function
## This is only meant to be used inside rrules
function __apply_bias_activation!!(
σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache}
if σ === identity
bias === nothing && return x
return fast_broadcast!(+, x, bias)
end
if !cache
bias === nothing && return fast_broadcast!(σ, x)
return fast_broadcast! +, x, bias)
end
bias === nothing && return fast_broadcast(σ, x), x
x = fast_broadcast!(+, x, bias)
return fast_broadcast(σ, x), x
end

function __apply_bias_activation::F, x, bias::AbstractArray) where {F}
return fast_broadcast +, x, bias)
end
function __apply_bias_activation(::typeof(identity), x, bias::AbstractArray)
return fast_broadcast(+, x, bias)
end
__apply_bias_activation::F, x, ::Nothing) where {F} = fast_broadcast(σ, x)
__apply_bias_activation(::typeof(identity), x, ::Nothing) = x
66 changes: 41 additions & 25 deletions src/impl/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,53 @@
function __fast_broadcast_impl(
::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F <: Function}
if maximum(length, (x, args...)) > 100_000
return @turbo thread=true @. f(x, args...)
end
return @turbo @. f(x, args...)
end
# NOTE: These functions aren't compatible for use inside ChainRules. Use `fast_broadcast`
# instead.

# TODO: Enzyme would need a split reverse + forward pass to handle LV
# xref https://github.com/EnzymeAD/Enzyme.jl/issues/1635
## CPU -- LV --> FastBraodcast --> Generic Broadcast
function __fast_broadcast_impl(
::Type, f::F, x::AbstractArray, args...) where {F <: Function}
if __fails_inplace_bcast_gpu(f) && length(args) == 1
y = first(args)
return @. f.outer(f.inner(x, y))
end
return @. f(x, args...)
::Type{T}, f::F, x::AbstractArray, args...) where {F <: Function, T}
return __fast_broadcast_impl!(T, similar(x), f, x, args...)
end

function __fast_broadcast_impl!(
::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F <: Function}
if maximum(length, (x, args...)) > 100_000
@turbo thread=true @. x = f(x, args...)
::Type{T}, f::F, x::AbstractArray, args...) where {F <: Function, T}
return __fast_broadcast_impl!(T, x, f, x, args...) # aliased
end

## CPU -- LV --> FastBraodcast --> Generic Broadcast
function __fast_broadcast_impl!(::Type{LuxCPUDevice}, y::AbstractArray{<:LV_ELTYPES},
f::F, x::AbstractArray{<:LV_ELTYPES},
args::Union{AbstractArray{<:LV_ELTYPES}, <:LV_ELTYPES}...) where {F <: Function}
fast_scalar_indexing(x) || return __fast_broadcast_impl!(Nothing, y, f, x, args...)
if maximum(length, (x, args...)) > THREADING_THRESHOLD
@turbo thread=true @. y = f(x, args...)
else
@turbo @. x = f(x, args...)
@turbo @. y = f(x, args...)
end
return x
return y
end

function __fast_broadcast_impl!(
::Type, f::F, x::AbstractArray, args...) where {F <: Function}
if __fails_inplace_bcast_gpu(f) && length(args) == 1
y = first(args)
@. x = f.outer(f.inner(x, y))
function __fast_broadcast_impl!(::Type{LuxCPUDevice}, y::AbstractArray, f::F,
x::AbstractArray, args...) where {F <: Function}
fast_scalar_indexing(x) || return __fast_broadcast_impl!(Nothing, y, f, x, args...)
if maximum(length, (x, args...)) > THREADING_THRESHOLD
@.. thread=true y=f(x, args...)
else
@. x = f(x, args...)
@.. y = f(x, args...)
end
return y
end

for ffail in (sigmoid_fast +, swish +)
@eval function __fast_broadcast_impl!(
::Type{T}, y::AbstractArray, f::typeof($ffail), x::AbstractArray, z) where {T}
@. y = f.outer(f.inner(x, z))
return y
end
return x
end

function __fast_broadcast_impl!(::Type{T}, y::AbstractArray{T}, f::F,
x::AbstractArray{T}, args...) where {F <: Function, T}
@. y = f(x, args...)
return y
end
4 changes: 2 additions & 2 deletions src/impl/fast_activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
return fast_broadcast!(σ, x)
end

function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(__fast_activation_impl!!), σ::F, x::AbstractArray{T}) where {F, T}
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fast_activation_impl!!),
σ::F, x::AbstractArray{T}) where {F, T}
σ === identity && return x, @closure->(NoTangent(), NoTangent(), Δ))

if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber}))
Expand Down
4 changes: 2 additions & 2 deletions src/impl/fused_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ end
return __conv_bias_act(x, weight, cdims, bias, act)
end

function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(__fused_conv_bias_activation_impl),
function CRC.rrule(
cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_conv_bias_activation_impl),
act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N},
bias::Optional{<:AbstractArray}, cdims::ConvDims) where {wT, xT, N, F}
T = __get_concrete_fba_output_eltype(act, weight, x, bias)
Expand Down
7 changes: 4 additions & 3 deletions src/impl/fused_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ end
return __apply_bias_activation!!(act, y, b, Val(false))
end

function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix,
x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F}
function CRC.rrule(
cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl),
act::F, weight::AbstractMatrix, x::AbstractMatrix,
b::Optional{<:AbstractVector}) where {F}
T = __get_concrete_fba_output_eltype(act, weight, x, b)

# Case I: Activation Function doesn't require caching the intermediate value
Expand Down
Loading

0 comments on commit f6fe2f4

Please sign in to comment.