Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: bypass dispatch doctor in the reverse pass
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 15, 2024
1 parent 08ef2af commit e71d55f
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 52 deletions.
12 changes: 10 additions & 2 deletions src/api/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,16 @@ fallback is used which is not highly optimized.
training by reducing internal covariate shift." International conference on machine
learning. PMLR, 2015.
"""
@stable default_mode="warn" function batchnorm(
x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector},
@stable default_mode="warn" function batchnorm(args...)
return _batchnorm(args...)
end

# Needed for Zygote type-stability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(batchnorm), args...)
return CRC.rrule_via_ad(cfg, _batchnorm, args...)
end

function _batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, running_mean::Optional{<:AbstractVector},
running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity,
momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N}
Expand Down
24 changes: 20 additions & 4 deletions src/api/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ generic implementation.
return fast_broadcast!!(σ, x)
end

## bypass a type instability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!),
σ::F, x::AbstractArray{T}) where {F, T}
return CRC.rrule_via_ad(cfg, fast_broadcast!!, σ, x)
end


"""
fast_broadcast!!(f::F, x::AbstractArray, args...) where {F}
Expand All @@ -38,16 +45,25 @@ if `x` is an immutable array, it computes `@. f(x, args...)`. Otherwise, it comp
Additionally, whether `x` is updated in-place, depends on whether this function is being
called inside a differentiated function.
"""
@stable default_mode="warn" function fast_broadcast!!(
@stable default_mode="warn" function fast_broadcast!!(args...)
return _fast_broadcast!!(args...)
end

# Needed for Zygote type-stability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!!), args...)
return CRC.rrule_via_ad(cfg, _fast_broadcast!!, args...)
end

function _fast_broadcast!!(
f::F, x::AbstractArray, args...) where {F <: Function}
return fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...)
return _fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...)
end

function fast_broadcast!!(
function _fast_broadcast!!(
::Val{true}, f::F, x::AbstractArray, args...) where {F <: Function}
return _fast_broadcast!(f, x, args...)
end
function fast_broadcast!!(
function _fast_broadcast!!(
::Val{false}, f::F, x::AbstractArray, args...) where {F <: Function}
return _fast_broadcast(f, x, args...)
end
17 changes: 14 additions & 3 deletions src/api/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,27 @@ reallocations by reusing the output buffer for multiple operations.
- For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision,
with a warning.
"""
@stable default_mode="warn" function fused_conv_bias_activation(
@stable default_mode="warn" function fused_conv_bias_activation(args...)
return _fused_conv_bias_activation(args...)
end

function _fused_conv_bias_activation(
σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N},
b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N}
return fused_conv_bias_activation(
return _fused_conv_bias_activation(
σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b, cdims)
end

# Needed for Zygote type-stability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fused_conv_bias_activation),
σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N},
b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N}
return CRC.rrule_via_ad(cfg, _fused_conv_bias_activation, σ, weight, x, b, cdims)
end

for (check, fop) in (
(false, :_fused_conv_bias_activation_impl), (true, :_generic_conv_bias_activation))
@eval function fused_conv_bias_activation(
@eval function _fused_conv_bias_activation(
σ::F, ::Val{$(check)}, weight::AbstractArray{<:Number, N},
x::AbstractArray{<:Number, N},
b::Optional{<:AbstractArray}, cdims::ConvDims) where {F, N}
Expand Down
18 changes: 14 additions & 4 deletions src/api/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,26 @@ multiple operations.
fallback to the generic implementation.
- For CUDA Arrays, this uses a special fused implementation via cuBLASLt.
"""
@stable default_mode="warn" function fused_dense_bias_activation(
σ::F, weight::AbstractMatrix, x::AbstractMatrix,
@stable default_mode="warn" function fused_dense_bias_activation(args...)
return _fused_dense_bias_activation(args...)
end

function _fused_dense_bias_activation::F, weight::AbstractMatrix, x::AbstractMatrix,
b::Optional{<:AbstractVector}) where {F}
return fused_dense_bias_activation(
return _fused_dense_bias_activation(
σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b)
end

# Needed for Zygote type-stability
function CRC.rrule(
cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_dense_bias_activation), σ::F,
weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F}
return CRC.rrule_via_ad(cfg, _fused_dense_bias_activation, σ, weight, x, b)
end

for (check, fop) in (
(false, :__fused_dense_bias_activation_impl), (true, :__generic_dense_bias_activation))
@eval function fused_dense_bias_activation(
@eval function _fused_dense_bias_activation(
σ::F, ::Val{$(check)}, weight::AbstractMatrix,
x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F}
return $(fop)(σ, weight, x, b)
Expand Down
45 changes: 27 additions & 18 deletions src/api/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,39 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see
[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from
overfitting." The journal of machine learning research 15.1 (2014): 1929-1958.
"""
@stable default_mode="warn" function dropout(
@stable default_mode="warn" function dropout(args...)
return _dropout(args...)
end

# Needed for Zygote type-stability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(dropout), args...)
return CRC.rrule_via_ad(cfg, _dropout, args...)
end

function _dropout(
rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T, dims) where {T}
rng = LuxCore.replicate(rng)
mask = _generate_dropout_mask(rng, x, p, invp; dims)
return (x .* CRC.ignore_derivatives(mask), mask, rng)
end

@stable default_mode="warn" function dropout(
function _dropout(
rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T, dims) where {T}
return (x, x, rng)
end

@stable default_mode="warn" function dropout(
rng::AbstractRNG, x::AbstractArray, ::AbstractArray,
function _dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray,
p::T, t::Val, ::Val{true}, invp::T, dims) where {T}
return dropout(rng, x, p, t, invp, dims)
end

@stable default_mode="warn" function dropout(
rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N},
function _dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N},
p::T, ::Val{true}, ::Val{false}, invp::T, dims) where {T, T1, T2, N}
size(x) != size(mask) && return dropout(rng, x, p, Val(true), invp, dims)
return x .* CRC.ignore_derivatives(mask), mask, rng
end

@stable default_mode="warn" function dropout(
rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N},
function _dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N},
p::T, ::Val{false}, ::Val{false}, invp::T, dims) where {T, T1, T2, N}
return (x, mask, rng)
end
Expand Down Expand Up @@ -89,27 +95,30 @@ for a fixed dropout probability.
[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural
information processing systems 30 (2017).
"""
@stable default_mode="warn" function alpha_dropout(
rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T}
@stable default_mode="warn" function alpha_dropout(args...)
return _alpha_dropout(args...)
end

# Needed for Zygote type-stability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(alpha_dropout), args...)
return CRC.rrule_via_ad(cfg, _alpha_dropout, args...)
end

function _alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T}
α = T(-1.7580993408473766)
A = T(inv(sqrt((1 - p) * (1 + p * α^2))))
B = T(-A * α * p)
return alpha_dropout(rng, x, p, t, α, A, B)
end

@stable default_mode="warn" function alpha_dropout(
rng::AbstractRNG, x::AbstractArray, p, t::Val{false})
function _alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false})
return alpha_dropout(rng, x, p, t, 0, 0, 0)
end

@stable default_mode="warn" function alpha_dropout(
rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B)
function _alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B)
noise, rng = _alpha_dropout_noise(rng, x)
y = _alpha_dropout_kernel(noise, p, x, α)
return broadcast(muladd, A, y, B), rng
end

@stable default_mode="warn" function alpha_dropout(
rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B)
return (x, rng)
end
_alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng)
12 changes: 10 additions & 2 deletions src/api/groupnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,16 @@ The normalized array is returned.
[1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference
on computer vision (ECCV). 2018.
"""
@stable default_mode="warn" function groupnorm(
x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector},
@stable default_mode="warn" function groupnorm(args...)
return _groupnorm(args...)
end

# Needed for Zygote type-stability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(groupnorm), args...)
return CRC.rrule_via_ad(cfg, _groupnorm, args...)
end

function _groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, groups::Int,
σ::F=identity, epsilon::Real=1.0f-5) where {F, N}
_test_valid_groupnorm_arguments(x, scale, bias, groups)
Expand Down
12 changes: 10 additions & 2 deletions src/api/instancenorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,16 @@ mean and variance.
[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The
missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).
"""
@stable default_mode="warn" function instancenorm(
x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector},
@stable default_mode="warn" function instancenorm(args...)
return _instancenorm(args...)
end

# Needed for Zygote type-stability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(instancenorm), args...)
return CRC.rrule_via_ad(cfg, _instancenorm, args...)
end

function _instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, training::Val,
σ::F=identity, epsilon::Real=1.0f-5) where {N, F}
_test_valid_instancenorm_arguments(x)
Expand Down
11 changes: 10 additions & 1 deletion src/api/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@ Normalized Array of same size as `x`.
[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv
preprint arXiv:1607.06450 (2016).
"""
@stable default_mode="warn" function layernorm(
@stable default_mode="warn" function layernorm(args...)
return _layernorm(args...)
end

# Needed for Zygote type-stability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(layernorm), args...)
return CRC.rrule_via_ad(cfg, _layernorm, args...)
end

function _layernorm(
x::AbstractArray{<:Number, N}, scale::Optional{<:AbstractArray{<:Number, N}},
bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity,
dims=Colon(), epsilon::Real=1.0f-5) where {N, F}
Expand Down
6 changes: 0 additions & 6 deletions src/impl/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,3 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!)

return CRC.rrule_via_ad(cfg, broadcast, f, x)
end

## bypass a type instability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!),
σ::F, x::AbstractArray{T}) where {F, T}
return CRC.rrule_via_ad(cfg, fast_broadcast!!, σ, x)
end
7 changes: 2 additions & 5 deletions test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@
# CI timings under check
# Most of the actual tests happen upstream in Lux
@testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for (Tw, Tx) in [
(Float16, Float16),
# (Float32, Float16),
(Float32, Float32),
# (Float32, Float64),
(Float64, Float64)],
(Float16, Float16), (Float32, Float16), (Float32, Float32),
(Float32, Float64), (Float64, Float64)],
hasbias in (true, false),
activation in (identity, tanh, tanh_fast, sigmoid,
sigmoid_fast, relu, gelu, anonact, swish),
Expand Down
7 changes: 2 additions & 5 deletions test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
# These are not all possible combinations but rather a representative set to keep
# CI timings under check
@testset "$(Tw) x $(Tx)" for (Tw, Tx) in [
(Float16, Float16),
# (Float32, Float16),
(Float32, Float32),
# (Float32, Float64),
(Float64, Float64)]
(Float16, Float16), (Float32, Float16), (Float32, Float32),
(Float32, Float64), (Float64, Float64)]
for M in (4, 8),
N in (4, 8),
hasbias in (true, false),
Expand Down

0 comments on commit e71d55f

Please sign in to comment.