Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
chore: format suggestion
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
avik-pal and github-actions[bot] committed Jul 15, 2024
1 parent e71d55f commit 5c287fd
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 20 deletions.
6 changes: 2 additions & 4 deletions src/api/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!
return CRC.rrule_via_ad(cfg, fast_broadcast!!, σ, x)
end


"""
fast_broadcast!!(f::F, x::AbstractArray, args...) where {F}
Expand All @@ -50,12 +49,11 @@ called inside a differentiated function.
end

# Needed for Zygote type-stability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!!), args...)
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!), args...)
return CRC.rrule_via_ad(cfg, _fast_broadcast!!, args...)
end

function _fast_broadcast!!(
f::F, x::AbstractArray, args...) where {F <: Function}
function _fast_broadcast!!(f::F, x::AbstractArray, args...) where {F <: Function}
return _fast_broadcast!!(Val(ArrayInterface.can_setindex(typeof(x))), f, x, args...)
end

Expand Down
13 changes: 6 additions & 7 deletions src/api/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,19 @@ reallocations by reusing the output buffer for multiple operations.
return _fused_conv_bias_activation(args...)
end

# Needed for Zygote type-stability
function CRC.rrule(
cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv_bias_activation), args...)
return CRC.rrule_via_ad(cfg, _fused_conv_bias_activation, args...)
end

function _fused_conv_bias_activation(
σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N},
b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N}
return _fused_conv_bias_activation(
σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b, cdims)
end

# Needed for Zygote type-stability
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fused_conv_bias_activation),
σ::F, weight::AbstractArray{<:Number, N}, x::AbstractArray{<:Number, N},
b::Optional{<:AbstractArray{<:Number, N}}, cdims::ConvDims) where {F, N}
return CRC.rrule_via_ad(cfg, _fused_conv_bias_activation, σ, weight, x, b, cdims)
end

for (check, fop) in (
(false, :_fused_conv_bias_activation_impl), (true, :_generic_conv_bias_activation))
@eval function _fused_conv_bias_activation(
Expand Down
13 changes: 6 additions & 7 deletions src/api/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,18 @@ multiple operations.
return _fused_dense_bias_activation(args...)
end

# Needed for Zygote type-stability
function CRC.rrule(
cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_dense_bias_activation), args...)
return CRC.rrule_via_ad(cfg, _fused_dense_bias_activation, args...)
end

function _fused_dense_bias_activation::F, weight::AbstractMatrix, x::AbstractMatrix,
b::Optional{<:AbstractVector}) where {F}
return _fused_dense_bias_activation(
σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b)
end

# Needed for Zygote type-stability
function CRC.rrule(
cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_dense_bias_activation), σ::F,
weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F}
return CRC.rrule_via_ad(cfg, _fused_dense_bias_activation, σ, weight, x, b)
end

for (check, fop) in (
(false, :__fused_dense_bias_activation_impl), (true, :__generic_dense_bias_activation))
@eval function _fused_dense_bias_activation(
Expand Down
2 changes: 1 addition & 1 deletion src/impl/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end

# Special Cases where we don't need to go down the generic path
## rrule for activation functions -- we need to define this on `fast_broadcast!!`
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_broadcast!!),
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!!),
f::F, x::AbstractArray{T}) where {F, T}
σ === identity && return x, @closure->(NoTangent(), NoTangent(), Δ))

Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ _reshape_into_proper_shape(x, y) = reshape(x, _get_reshape_dims(size(y), length(

## Maybe typecast the array
_ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x
_ofeltype_array(::Type{T}, x::AbstractArray) where {T} = T.(x)
_ofeltype_array(::Type{T}, x::AbstractArray) where {T} = convert(AbstractArray{T}, x)
_ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing

__materialize_subarray(x::AbstractArray) = x
Expand Down

0 comments on commit 5c287fd

Please sign in to comment.