From f6fe2f470a2c138bb157aa3b5e1a391462029863 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 13 Jul 2024 15:32:11 -0700 Subject: [PATCH] feat: fast_broadcast implementation made compatible with chainrules --- .JuliaFormatter.toml | 1 - Project.toml | 6 +- src/LuxLib.jl | 12 +- src/api/broadcast.jl | 42 ++++++- src/impl/bias_act.jl | 25 +++++ src/impl/broadcast.jl | 66 ++++++----- src/impl/fast_activation.jl | 4 +- src/impl/fused_conv.jl | 4 +- src/impl/fused_dense.jl | 7 +- src/utils.jl | 214 ++++++++++++++++++------------------ 10 files changed, 230 insertions(+), 151 deletions(-) create mode 100644 src/impl/bias_act.jl diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index f1f84c1c..22c3407c 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,6 +1,5 @@ style = "sciml" whitespace_in_kwargs = false -always_use_return = true margin = 92 indent = 4 format_docstrings = true diff --git a/Project.toml b/Project.toml index 2368081f..5451f121 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -18,9 +19,9 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" -VectorizedReduction = "4ffe575c-65e5-43f4-bc05-e0b500dc3d2c" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -48,6 +49,7 @@ ComponentArrays = "0.15.8" DispatchDoctor = "0.4.9" EnzymeCore = "0.7" ExplicitImports = "1.9.0" +FastBroadcast = "0.3.4" FastClosures = "0.3.2" ForwardDiff = "0.10.36" LinearAlgebra = "1.10" @@ -63,12 +65,12 @@ Random = "1.10" ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" +SIMDTypes = "0.1.0" StableRNGs = "1" Statistics = "1.10" Test = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" -VectorizedReduction = "0.1.12" Zygote = "0.6.69" cuDNN = "1.3" julia = "1.10" diff --git a/src/LuxLib.jl b/src/LuxLib.jl index 066e99fa..9389f6c5 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -1,13 +1,14 @@ module LuxLib -using ArrayInterface: ArrayInterface -using ChainRulesCore: ChainRulesCore, NoTangent +using ArrayInterface: ArrayInterface, fast_scalar_indexing +using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig using DispatchDoctor: @stable using EnzymeCore: EnzymeCore, EnzymeRules +using FastBroadcast: @.. using FastClosures: @closure using ForwardDiff: ForwardDiff using LinearAlgebra: LinearAlgebra, BLAS, mul! -using LoopVectorization: @turbo +using LoopVectorization: LoopVectorization, @turbo, vmap using LuxCore: LuxCore using LuxDeviceUtils: get_device_type, LuxCUDADevice, LuxCPUDevice, AbstractLuxGPUDevice, AbstractLuxDevice @@ -16,6 +17,7 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, sigmoid_fast, swish, σ, ∇con ∇conv_filter using Random: Random, AbstractRNG, rand! using Reexport: @reexport +using SIMDTypes using Statistics: Statistics, mean, var using UnrolledUtilities: unrolled_any @@ -23,8 +25,6 @@ using UnrolledUtilities: unrolled_any const CRC = ChainRulesCore -const Optional{T} = Union{Nothing, T} - include("utils.jl") # Low-Level Implementations @@ -34,6 +34,7 @@ include("impl/fused_conv.jl") include("impl/fast_activation.jl") include("impl/forward_diff.jl") include("impl/broadcast.jl") +include("impl/bias_act.jl") # User Facing include("api/batchnorm.jl") @@ -51,5 +52,6 @@ include("deprecations.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation, fused_conv_bias_activation export fast_activation!! +export fast_broadcast, fast_broadcast!, fast_broadcast!! end diff --git a/src/api/broadcast.jl b/src/api/broadcast.jl index 6a6e6b92..c44a1251 100644 --- a/src/api/broadcast.jl +++ b/src/api/broadcast.jl @@ -1,10 +1,50 @@ +""" + fast_broadcast(f::F, x::AbstractArray, args...) where {F <: Function} + +Computes `@. f(x, args...)`. If it is possible, we use `LoopVectorization.@turbo` to speed +up the computation. +""" function fast_broadcast end +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast), + f::F, x::AbstractArray, args...) where {F <: Function} + return CRC.rrule_via_ad(cfg, broadcast, f, x, args...) +end + +""" + fast_broadcast!(f::F, x::AbstractArray, args...) where {F <: Function} + +Computes `@. x = f(x, args...)` updating `x` in place. This assumes that `x` has the +matching shape as the broadcasted result. + +If it is possible, we use `LoopVectorization.@turbo` to speed up the computation. + +This function isn't compatible with ChainRules, use [`fast_broadcast!!`](@ref) instead. +""" function fast_broadcast! end for (op, impl) in ( (:fast_broadcast, :__fast_broadcast_impl), (:fast_broadcast!, :__fast_broadcast_impl!)) @eval function $(op)(f::F, x::AbstractArray, args...) where {F <: Function} - return $(impl)(get_device_type(x, args...), f, x, args...) + return $(impl)(get_device_type((x, args...)), f, x, args...) end end + +""" + fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function} + +Compues `@. f(x, args...)`. If `x` can be set-indexed, it uses `fast_broadcast!` to +update `x` in place. Otherwise, it falls back to `fast_broadcast`. + +Since `x` is not guaranteed to be modified inplace, call the function as +`y = fast_broadcast!!(...)`. +""" +@generated function fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function} + ArrayInterface.can_setindex(x) && :(return fast_broadcast!(f, x, args...)) + return :(fast_broadcast(f, x, args...)) +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), + f::F, x::AbstractArray, args...) where {F <: Function} + return CRC.rrule_via_ad(cfg, fast_broadcast, f, x, args...) +end diff --git a/src/impl/bias_act.jl b/src/impl/bias_act.jl new file mode 100644 index 00000000..e9e36a60 --- /dev/null +++ b/src/impl/bias_act.jl @@ -0,0 +1,25 @@ +# Helper to add bias and apply activation function +## This is only meant to be used inside rrules +function __apply_bias_activation!!( + σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache} + if σ === identity + bias === nothing && return x + return fast_broadcast!(+, x, bias) + end + if !cache + bias === nothing && return fast_broadcast!(σ, x) + return fast_broadcast!(σ ∘ +, x, bias) + end + bias === nothing && return fast_broadcast(σ, x), x + x = fast_broadcast!(+, x, bias) + return fast_broadcast(σ, x), x +end + +function __apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} + return fast_broadcast(σ ∘ +, x, bias) +end +function __apply_bias_activation(::typeof(identity), x, bias::AbstractArray) + return fast_broadcast(+, x, bias) +end +__apply_bias_activation(σ::F, x, ::Nothing) where {F} = fast_broadcast(σ, x) +__apply_bias_activation(::typeof(identity), x, ::Nothing) = x diff --git a/src/impl/broadcast.jl b/src/impl/broadcast.jl index 7c216721..584c1dea 100644 --- a/src/impl/broadcast.jl +++ b/src/impl/broadcast.jl @@ -1,37 +1,53 @@ -function __fast_broadcast_impl( - ::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F <: Function} - if maximum(length, (x, args...)) > 100_000 - return @turbo thread=true @. f(x, args...) - end - return @turbo @. f(x, args...) -end +# NOTE: These functions aren't compatible for use inside ChainRules. Use `fast_broadcast` +# instead. +# TODO: Enzyme would need a split reverse + forward pass to handle LV +# xref https://github.com/EnzymeAD/Enzyme.jl/issues/1635 +## CPU -- LV --> FastBraodcast --> Generic Broadcast function __fast_broadcast_impl( - ::Type, f::F, x::AbstractArray, args...) where {F <: Function} - if __fails_inplace_bcast_gpu(f) && length(args) == 1 - y = first(args) - return @. f.outer(f.inner(x, y)) - end - return @. f(x, args...) + ::Type{T}, f::F, x::AbstractArray, args...) where {F <: Function, T} + return __fast_broadcast_impl!(T, similar(x), f, x, args...) end function __fast_broadcast_impl!( - ::Type{LuxCPUDevice}, f::F, x::AbstractArray, args...) where {F <: Function} - if maximum(length, (x, args...)) > 100_000 - @turbo thread=true @. x = f(x, args...) + ::Type{T}, f::F, x::AbstractArray, args...) where {F <: Function, T} + return __fast_broadcast_impl!(T, x, f, x, args...) # aliased +end + +## CPU -- LV --> FastBraodcast --> Generic Broadcast +function __fast_broadcast_impl!(::Type{LuxCPUDevice}, y::AbstractArray{<:LV_ELTYPES}, + f::F, x::AbstractArray{<:LV_ELTYPES}, + args::Union{AbstractArray{<:LV_ELTYPES}, <:LV_ELTYPES}...) where {F <: Function} + fast_scalar_indexing(x) || return __fast_broadcast_impl!(Nothing, y, f, x, args...) + if maximum(length, (x, args...)) > THREADING_THRESHOLD + @turbo thread=true @. y = f(x, args...) else - @turbo @. x = f(x, args...) + @turbo @. y = f(x, args...) end - return x + return y end -function __fast_broadcast_impl!( - ::Type, f::F, x::AbstractArray, args...) where {F <: Function} - if __fails_inplace_bcast_gpu(f) && length(args) == 1 - y = first(args) - @. x = f.outer(f.inner(x, y)) +function __fast_broadcast_impl!(::Type{LuxCPUDevice}, y::AbstractArray, f::F, + x::AbstractArray, args...) where {F <: Function} + fast_scalar_indexing(x) || return __fast_broadcast_impl!(Nothing, y, f, x, args...) + if maximum(length, (x, args...)) > THREADING_THRESHOLD + @.. thread=true y=f(x, args...) else - @. x = f(x, args...) + @.. y = f(x, args...) + end + return y +end + +for ffail in (sigmoid_fast ∘ +, swish ∘ +) + @eval function __fast_broadcast_impl!( + ::Type{T}, y::AbstractArray, f::typeof($ffail), x::AbstractArray, z) where {T} + @. y = f.outer(f.inner(x, z)) + return y end - return x +end + +function __fast_broadcast_impl!(::Type{T}, y::AbstractArray{T}, f::F, + x::AbstractArray{T}, args...) where {F <: Function, T} + @. y = f(x, args...) + return y end diff --git a/src/impl/fast_activation.jl b/src/impl/fast_activation.jl index 222fecf4..e5e390ba 100644 --- a/src/impl/fast_activation.jl +++ b/src/impl/fast_activation.jl @@ -6,8 +6,8 @@ return fast_broadcast!(σ, x) end -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(__fast_activation_impl!!), σ::F, x::AbstractArray{T}) where {F, T} +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fast_activation_impl!!), + σ::F, x::AbstractArray{T}) where {F, T} σ === identity && return x, @closure(Δ->(NoTangent(), NoTangent(), Δ)) if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) diff --git a/src/impl/fused_conv.jl b/src/impl/fused_conv.jl index 29c747e0..ae53f744 100644 --- a/src/impl/fused_conv.jl +++ b/src/impl/fused_conv.jl @@ -121,8 +121,8 @@ end return __conv_bias_act(x, weight, cdims, bias, act) end -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(__fused_conv_bias_activation_impl), +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_conv_bias_activation_impl), act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Optional{<:AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} T = __get_concrete_fba_output_eltype(act, weight, x, bias) diff --git a/src/impl/fused_dense.jl b/src/impl/fused_dense.jl index 8be559a7..1d87f8be 100644 --- a/src/impl/fused_dense.jl +++ b/src/impl/fused_dense.jl @@ -31,9 +31,10 @@ end return __apply_bias_activation!!(act, y, b, Val(false)) end -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(__fused_dense_bias_activation_impl), act::F, weight::AbstractMatrix, - x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(__fused_dense_bias_activation_impl), + act::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {F} T = __get_concrete_fba_output_eltype(act, weight, x, b) # Case I: Activation Function doesn't require caching the intermediate value diff --git a/src/utils.jl b/src/utils.jl index 7dbdb58e..ed379e67 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,78 @@ +const THREADING_THRESHOLD = 100_000 + +const LV_ELTYPES = Union{Bool, Float16, Float32, Float64, Int16, Int32, Int64, + Int8, UInt16, UInt32, UInt64, UInt8, SIMDTypes.Bit} + +const Optional{T} = Union{Nothing, T} + +# Bias Gradient -- can't be used inside gradient rules +__added_bias_gradient(::Nothing, Δ::AbstractArray) = NoTangent() +__added_bias_gradient(b::AbstractArray, Δ::AbstractArray) = __reduce_sum(b, Δ) + +# Common Activation Gradient +function __activation_gradient(Δ, out, act::F, x) where {F} + only_deriv = @closure (oᵢ, xᵢ) -> only_derivative(oᵢ, act, xᵢ) + if fast_scalar_indexing(out) && eltype(out) <: LV_ELTYPES + return @turbo @. Δ * only_deriv(out, x) + end + return @. Δ * only_deriv(out, x) +end + +## Needed for reverse over reverse mode AD +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, + ::typeof(__activation_gradient), Δ, out, act::F, x) where {F} + return CRC.rrule_via_ad(cfg, __activation_gradient_simple, Δ, out, act, x) +end + +function __activation_gradient_simple(Δ, out, act::F, x) where {F} + return @. Δ * only_derivative(out, act, x) +end + +# Operations that most AD won't be able to differentiate +## If possible then we use loop vectorization for faster CPI operaitons +function __reduce_sum(x::AbstractArray, y::AbstractArray) + return __reduce_sum(get_device_type((x, y)), x, y) +end +function __reduce_sum(::Type{<:LuxCPUDevice}, x::AbstractArray, y::AbstractArray) + if fast_scalar_indexing(x) && fast_scalar_indexing(y) && ndims(x) == 1 + @assert length(x) == size(y, 1) + z, y_ = vmap(zero, x), reshape(y, length(x), :) + @turbo for i in eachindex(z), j in axes(y_, 2) + z[i] += y_[i, j] + end + return z + end + return __reduce_sum(Nothing, x, y) +end +__reduce_sum(::Type{T}, x::AbstractArray, y::AbstractArray) where {T} = sum!(similar(x), x) + +# Simple Operations -- no rrules needed @generated _vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x +_reshape_into_proper_shape(::Nothing, y) = nothing +_reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) + +## Maybe typecast the array +_ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x +_ofeltype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) +_ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing + +__materialize_subarray(x::AbstractArray) = x +__materialize_subarray(x::SubArray) = copy(x) + +__value(x::Number) = x +__value(x::AbstractArray) = x +__value(::Type{T}) where {T <: Number} = T + +__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) +__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) +__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = __value(T) + +__value(::Nothing) = nothing + +__aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl + +# Non-differentiable functions @inbounds function _get_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} if ly == sx[N - 1] return ntuple(i -> i == N - 1 ? ly : 1, N) @@ -12,34 +85,35 @@ end CRC.@non_differentiable _get_reshape_dims(::Any...) EnzymeRules.inactive_noinl(::typeof(_get_reshape_dims), ::Any...) = nothing -_reshape_into_proper_shape(::Nothing, y) = nothing -_reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(x))) +## Reduce BLAS threads if we are going to use a native Julia implementation +function __maybe_reduce_BLAS_threads(x::AbstractArray)::Int + if fast_scalar_indexing(x) + old_threads = BLAS.get_num_threads() + BLAS.set_num_threads(1) + return old_threads + end + return -1 +end + +CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) +EnzymeRules.inactive_noinl(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing -# Copy and don't allow gradient propagation +function __reset_BLAS_threads(old_threads::Int) + old_threads ≥ 1 && BLAS.set_num_threads(old_threads) + return nothing +end + +CRC.@non_differentiable __reset_BLAS_threads(::Int) +EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing + +## Copy and don't allow gradient propagation _copy_autodiff_barrier(x) = copy(__value(x)) _copy_autodiff_barrier(::Nothing) = nothing CRC.@non_differentiable _copy_autodiff_barrier(::Any) EnzymeRules.inactive_noinl(::typeof(_copy_autodiff_barrier), ::Any...) = nothing -# Meta Programming Utilities -__is_tracked(x) = x == :TrackedArray || x == :TrackedVector -__is_tracked(args...) = any(__is_tracked, args) - -# Maybe typecast the array -_ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x -_ofeltype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) -_ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing - -## This part is taken from NNlib.jl -# This just saves typing `only.(only.(` many times: -only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output(y, f, x))) - -# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` -# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. -struct NotaNumber <: Real end - -# Check no setindexing +## Check no setindexing __is_immutable_array(x::AbstractArray) = !ArrayInterface.can_setindex(x) __is_immutable_array(::Nothing) = false __is_immutable_array_val(x) = Val(__is_immutable_array(x)) @@ -59,11 +133,6 @@ end CRC.@non_differentiable __is_immutable_array_or_dual_val(::Any...) EnzymeRules.inactive_noinl(::typeof(__is_immutable_array_or_dual_val), ::Any...) = nothing -function __expand_conv_bias_dims(bias::AbstractVector, ::AbstractArray{T, N}) where {T, N} - @assert N ≥ 2 - return reshape(bias, (ntuple(Returns(1), N - 2)..., length(bias), 1)) -end - function __get_concrete_fba_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, b::Optional{<:AbstractArray}) where {F, Tw, Tx} if b === nothing @@ -79,89 +148,14 @@ end CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) EnzymeRules.inactive_noinl(::typeof(__get_concrete_fba_output_eltype), ::Any...) = nothing -# Helper to add bias and apply activation function -## This is only meant to be used inside rrules -function __apply_bias_activation!!( - σ::F, x, bias::Optional{<:AbstractArray}, ::Val{cache}) where {F, cache} - if σ === identity - bias === nothing && return x - return fast_broadcast!(+, x, bias) - end - if !cache - bias === nothing && return fast_broadcast!(σ, x) - return fast_broadcast!(σ ∘ +, x, bias) - end - bias === nothing && return fast_broadcast(σ, x), x - x = fast_broadcast!(+, x, bias) - return fast_broadcast(σ, x), x -end - -__fails_inplace_bcast_gpu(::ComposedFunction{typeof(sigmoid_fast), typeof(+)}) = true -__fails_inplace_bcast_gpu(::ComposedFunction{typeof(swish), typeof(+)}) = true -__fails_inplace_bcast_gpu(::F) where {F} = false - -__apply_bias_activation(σ::F, x, bias::AbstractArray) where {F} = @. σ(x + bias) -__apply_bias_activation(::typeof(identity), x, bias::AbstractArray) = @. x + bias -__apply_bias_activation(σ::F, x, ::Nothing) where {F} = @. σ(x) -__apply_bias_activation(::typeof(identity), x, ::Nothing) = x - -__added_bias_gradient(::Nothing, _) = NoTangent() -function __added_bias_gradient(b::AbstractArray, Δ) - ∂b = similar(b, promote_type(eltype(b), eltype(Δ))) - sum!(∂b, Δ) - return ∂b -end - -function __activation_gradient(Δ, out, act::F, x) where {F} - only_deriv = @closure (oᵢ, xᵢ) -> only_derivative(oᵢ, act, xᵢ) - if ArrayInterface.fast_scalar_indexing(out) - return @turbo @. Δ * only_deriv(out, x) - end - return @. Δ * only_deriv(out, x) -end - -function __activation_gradient_simple(Δ, out, act::F, x) where {F} - return @. Δ * only_derivative(out, act, x) -end - -# Needed for reverse over reverse mode AD -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, - ::typeof(__activation_gradient), Δ, out, act::F, x) where {F} - return CRC.rrule_via_ad(cfg, __activation_gradient_simple, Δ, out, act, x) -end - -# Reduce BLAS threads if we are going to use a native Julia implementation -function __maybe_reduce_BLAS_threads(x::AbstractArray)::Int - if ArrayInterface.fast_scalar_indexing(x) - old_threads = BLAS.get_num_threads() - BLAS.set_num_threads(1) - return old_threads - end - return -1 -end - -CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) -EnzymeRules.inactive_noinl(::typeof(__maybe_reduce_BLAS_threads), ::AbstractArray) = nothing - -function __reset_BLAS_threads(old_threads::Int) - old_threads ≥ 1 && BLAS.set_num_threads(old_threads) - return nothing -end - -CRC.@non_differentiable __reset_BLAS_threads(::Int) -EnzymeRules.inactive_noinl(::typeof(__reset_BLAS_threads), ::Int) = nothing - -__materialize_subarray(x::AbstractArray) = x -__materialize_subarray(x::SubArray) = copy(x) - -__value(x::Number) = x -__value(x::AbstractArray) = x -__value(::Type{T}) where {T <: Number} = T - -__value(x::ForwardDiff.Dual) = ForwardDiff.value(x) -__value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) -__value(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = __value(T) +# Meta Programming Utilities +__is_tracked(x) = x == :TrackedArray || x == :TrackedVector +__is_tracked(args...) = any(__is_tracked, args) -__value(::Nothing) = nothing +# This part is taken from NNlib.jl +## This just saves typing `only.(only.(` many times: +only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output(y, f, x))) -__aos_to_soa(x::AbstractArray) = x # FIXME: Upstream this to ArrayInterface.jl +## This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` +## is independent of `x`, as `_return_type` says `Union{}` when calling is an error. +struct NotaNumber <: Real end