This repository has been archived by the owner on Nov 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: fast_broadcast implementation made compatible with chainrules
- Loading branch information
Showing
10 changed files
with
230 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,50 @@ | ||
""" | ||
fast_broadcast(f::F, x::AbstractArray, args...) where {F <: Function} | ||
Computes `@. f(x, args...)`. If it is possible, we use `LoopVectorization.@turbo` to speed | ||
up the computation. | ||
""" | ||
function fast_broadcast end | ||
|
||
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast), | ||
f::F, x::AbstractArray, args...) where {F <: Function} | ||
return CRC.rrule_via_ad(cfg, broadcast, f, x, args...) | ||
end | ||
|
||
""" | ||
fast_broadcast!(f::F, x::AbstractArray, args...) where {F <: Function} | ||
Computes `@. x = f(x, args...)` updating `x` in place. This assumes that `x` has the | ||
matching shape as the broadcasted result. | ||
If it is possible, we use `LoopVectorization.@turbo` to speed up the computation. | ||
This function isn't compatible with ChainRules, use [`fast_broadcast!!`](@ref) instead. | ||
""" | ||
function fast_broadcast! end | ||
|
||
for (op, impl) in ( | ||
(:fast_broadcast, :__fast_broadcast_impl), (:fast_broadcast!, :__fast_broadcast_impl!)) | ||
@eval function $(op)(f::F, x::AbstractArray, args...) where {F <: Function} | ||
return $(impl)(get_device_type(x, args...), f, x, args...) | ||
return $(impl)(get_device_type((x, args...)), f, x, args...) | ||
end | ||
end | ||
|
||
""" | ||
fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function} | ||
Compues `@. f(x, args...)`. If `x` can be set-indexed, it uses `fast_broadcast!` to | ||
update `x` in place. Otherwise, it falls back to `fast_broadcast`. | ||
Since `x` is not guaranteed to be modified inplace, call the function as | ||
`y = fast_broadcast!!(...)`. | ||
""" | ||
@generated function fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function} | ||
ArrayInterface.can_setindex(x) && :(return fast_broadcast!(f, x, args...)) | ||
return :(fast_broadcast(f, x, args...)) | ||
end | ||
|
||
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), | ||
f::F, x::AbstractArray, args...) where {F <: Function} | ||
return CRC.rrule_via_ad(cfg, fast_broadcast, f, x, args...) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Helper to add bias and apply activation function | ||
## This is only meant to be used inside rrules | ||
function __apply_bias_activation!!( | ||
σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache} | ||
if σ === identity | ||
bias === nothing && return x | ||
return fast_broadcast!(+, x, bias) | ||
end | ||
if !cache | ||
bias === nothing && return fast_broadcast!(σ, x) | ||
return fast_broadcast!(σ ∘ +, x, bias) | ||
end | ||
bias === nothing && return fast_broadcast(σ, x), x | ||
x = fast_broadcast!(+, x, bias) | ||
return fast_broadcast(σ, x), x | ||
end | ||
|
||
function __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} | ||
return fast_broadcast(σ ∘ +, x, bias) | ||
end | ||
function __apply_bias_activation(::typeof(identity), x, bias::AbstractArray) | ||
return fast_broadcast(+, x, bias) | ||
end | ||
__apply_bias_activation(σ::F, x, ::Nothing) where {F} = fast_broadcast(σ, x) | ||
__apply_bias_activation(::typeof(identity), x, ::Nothing) = x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,53 @@ | ||
function __fast_broadcast_impl( | ||
::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F <: Function} | ||
if maximum(length, (x, args...)) > 100_000 | ||
return @turbo thread=true @. f(x, args...) | ||
end | ||
return @turbo @. f(x, args...) | ||
end | ||
# NOTE: These functions aren't compatible for use inside ChainRules. Use `fast_broadcast` | ||
# instead. | ||
|
||
# TODO: Enzyme would need a split reverse + forward pass to handle LV | ||
# xref https://github.com/EnzymeAD/Enzyme.jl/issues/1635 | ||
## CPU -- LV --> FastBraodcast --> Generic Broadcast | ||
function __fast_broadcast_impl( | ||
::Type, f::F, x::AbstractArray, args...) where {F <: Function} | ||
if __fails_inplace_bcast_gpu(f) && length(args) == 1 | ||
y = first(args) | ||
return @. f.outer(f.inner(x, y)) | ||
end | ||
return @. f(x, args...) | ||
::Type{T}, f::F, x::AbstractArray, args...) where {F <: Function, T} | ||
return __fast_broadcast_impl!(T, similar(x), f, x, args...) | ||
end | ||
|
||
function __fast_broadcast_impl!( | ||
::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F <: Function} | ||
if maximum(length, (x, args...)) > 100_000 | ||
@turbo thread=true @. x = f(x, args...) | ||
::Type{T}, f::F, x::AbstractArray, args...) where {F <: Function, T} | ||
return __fast_broadcast_impl!(T, x, f, x, args...) # aliased | ||
end | ||
|
||
## CPU -- LV --> FastBraodcast --> Generic Broadcast | ||
function __fast_broadcast_impl!(::Type{LuxCPUDevice}, y::AbstractArray{<:LV_ELTYPES}, | ||
f::F, x::AbstractArray{<:LV_ELTYPES}, | ||
args::Union{AbstractArray{<:LV_ELTYPES}, <:LV_ELTYPES}...) where {F <: Function} | ||
fast_scalar_indexing(x) || return __fast_broadcast_impl!(Nothing, y, f, x, args...) | ||
if maximum(length, (x, args...)) > THREADING_THRESHOLD | ||
@turbo thread=true @. y = f(x, args...) | ||
else | ||
@turbo @. x = f(x, args...) | ||
@turbo @. y = f(x, args...) | ||
end | ||
return x | ||
return y | ||
end | ||
|
||
function __fast_broadcast_impl!( | ||
::Type, f::F, x::AbstractArray, args...) where {F <: Function} | ||
if __fails_inplace_bcast_gpu(f) && length(args) == 1 | ||
y = first(args) | ||
@. x = f.outer(f.inner(x, y)) | ||
function __fast_broadcast_impl!(::Type{LuxCPUDevice}, y::AbstractArray, f::F, | ||
x::AbstractArray, args...) where {F <: Function} | ||
fast_scalar_indexing(x) || return __fast_broadcast_impl!(Nothing, y, f, x, args...) | ||
if maximum(length, (x, args...)) > THREADING_THRESHOLD | ||
@.. thread=true y=f(x, args...) | ||
else | ||
@. x = f(x, args...) | ||
@.. y = f(x, args...) | ||
end | ||
return y | ||
end | ||
|
||
for ffail in (sigmoid_fast ∘ +, swish ∘ +) | ||
@eval function __fast_broadcast_impl!( | ||
::Type{T}, y::AbstractArray, f::typeof($ffail), x::AbstractArray, z) where {T} | ||
@. y = f.outer(f.inner(x, z)) | ||
return y | ||
end | ||
return x | ||
end | ||
|
||
function __fast_broadcast_impl!(::Type{T}, y::AbstractArray{T}, f::F, | ||
x::AbstractArray{T}, args...) where {F <: Function, T} | ||
@. y = f(x, args...) | ||
return y | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.