Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: group norm kernel implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 20, 2024
1 parent ed5b6d7 commit a85fd65
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
40 changes: 20 additions & 20 deletions src/impl/affine_normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,19 @@ function _affine_normalize_gn_impl(opmode::AbstractInternalArrayOpMode, f::F,
return y
end

function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F,
x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F}
function __affine_normalize_gn_impl!(
::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F, x::AbstractArray{<:Number, 4},
μ, σ², scale::Optional{<:AbstractArray{<:Number, 4}},
bias::Optional{<:AbstractArray{<:Number, 4}}, ϵ::Real) where {F}
@fastmath @inbounds @simd ivdep for J in axes(y, 2)
for K in axes(y, 3), L in axes(y, 4)
_sc = inv(sqrt(σ²[1, 1, K, L] + ϵ))
_bc = -μ[1, 1, K, L] * _sc
for I in axes(y, 1)
y[I, J, K, L] = f(x[I, J, K, L] * _sc + _bc)
if scale !== nothing
_sc = scale[1, J, K, 1] / sqrt(σ²[1, 1, K, L] + ϵ)
_bc = bias[1, J, K, 1] - μ[1, 1, K, L] * _sc
else
_sc = inv(sqrt(σ²[1, 1, K, L] + ϵ))
_bc = -μ[1, 1, K, L] * _sc
end
end
end
end

function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F,
x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4},
bias::AbstractArray{<:Number, 4}, ϵ::Real) where {F}
@fastmath @inbounds @simd ivdep for J in axes(y, 2)
for K in axes(y, 3), L in axes(y, 4)
_sc = scale[1, J, K, 1] / sqrt(σ²[1, 1, K, L] + ϵ)
_bc = bias[1, J, K, 1] - μ[1, 1, K, L] * _sc
for I in axes(y, 1)
y[I, J, K, L] = f(x[I, J, K, L] * _sc + _bc)
end
Expand Down Expand Up @@ -167,11 +160,11 @@ end

@inbounds ∂x[i, j, k, l] = ∂y[i, j, k, l] * _sc
@inbounds ∂μ[i, j, k, l] = -∂x[i, j, k, l]
@inbounds ∂σ²[i, j, k, l] -= ∂x[i, j, k, l] */ (2 * denom²)
@inbounds ∂σ²[i, j, k, l] = -∂x[i, j, k, l] */ (2 * denom²)

if scale !== nothing
@inbounds ∂sc[i, j, k, l] += ∂y[i, j, k, l] */ denom
@inbounds ∂b[i, j, k, l] += ∂y[i, j, k, l]
@inbounds ∂sc[i, j, k, l] = ∂y[i, j, k, l] */ denom
@inbounds ∂b[i, j, k, l] = ∂y[i, j, k, l]
end
end

Expand All @@ -182,6 +175,13 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale,
∂sc = scale === nothing ? ∂∅ : similar(scale)
∂b = bias === nothing ? ∂∅ : similar(bias)

fill!(∂μ, false)
fill!(∂σ², false)
if scale !== nothing
fill!(∂sc, false)
fill!(∂b, false)
end

@fastmath @inbounds @simd ivdep for J in axes(∂y, 2)
for K in axes(∂y, 3), L in axes(∂y, 4)
denom = sqrt(σ²[1, 1, K, L] + ϵ)
Expand Down
1 change: 1 addition & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ function __added_bias_gradient(b::AbstractVector{<:Number}, Δ::AbstractArray{<:
end

# Operations that most AD won't be able to differentiate
__reduce_sum(::Nothing, ::NoTangent) = ∂∅
function __reduce_sum(x::AbstractArray, y::AbstractArray)
z = similar(x, promote_type(eltype(x), eltype(y)))
sum!(z, y)
Expand Down

0 comments on commit a85fd65

Please sign in to comment.