Skip to content

Commit

Permalink
a type-stability fix via parent & re-wrap
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Jan 4, 2021
1 parent df02563 commit 8ab56db
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 11 deletions.
32 changes: 21 additions & 11 deletions src/rulesets/LinearAlgebra/norm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,22 @@ end

function _normp_back_x(x, p, y, Δy)
c = real(Δy) / y
T = promote_type(eltype(x), typeof(c))
∂x = similar(x, T) # same comment as _norm1_back about allocation and type-stability.
map!(∂x, x) do xi
∂x = map(x) do xi
a = norm(xi)
∂xi = xi * ((a / y)^(p - 2) * c)
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
end
return ∂x
end
function _normp_back_x(x::WithSomeZeros, p, y, Δy) # Diagonal, UpperTriangular, etc.
c = real(Δy) / y
∂x_data = map(parent(x)) do xi
a = norm(xi)
∂xi = xi * ((a / y)^(p - 2) * c)
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
end
return withsomezeros_rewrap(x, ∂x_data)
end

function _normp_back_p(x, p, y, Δy)
y > 0 && isfinite(y) && !iszero(p) || return zero(real(Δy)) * zero(y) / one(p)
Expand Down Expand Up @@ -175,13 +182,13 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray)
end

function _norm1_back(x, y, Δy)
T = promote_type(eltype(x), real(eltype(Δy)))
∂x = similar(x, T)
# The reason not to let broadcast allocate ∂x is that NaN .* Diagonal(ones(3)) isa Matrix,
# while pi .* Diagonal(ones(3)) isa Diagonal, hence this would be type-unstable.
∂x .= sign.(x) .* real(Δy)
∂x = sign.(x) .* real(Δy)
return ∂x
end
function _norm1_back(x::WithSomeZeros, y, Δy)
∂x_data = sign.(parent(x)) .* real(Δy)
return withsomezeros_rewrap(x, ∂x_data)
end
function _norm1_back!(∂x, x, y, Δy)
∂x .+= sign.(x) .* real(Δy)
return ∂x
Expand Down Expand Up @@ -211,11 +218,14 @@ function _norm2_forward(x, Δx, y)
return ∂y
end
function _norm2_back(x, y, Δy)
T = typeof(one(eltype(x)) / one(real(eltype(Δy))))
∂x = similar(x, T) # same comment as _norm1_back about allocation and type-stability.
∂x .= x .* (real(Δy) * pinv(y))
∂x = x .* (real(Δy) * pinv(y))
return ∂x
end
function _norm2_back(x::WithSomeZeros, y, Δy)
T = typeof(one(eltype(x)) / one(real(eltype(Δy))))
∂x_data = parent(x) .* (real(Δy) * pinv(y))
return withsomezeros_rewrap(x, ∂x_data)
end
function _norm2_back!(∂x, x, y, Δy)
∂x .+= x .* (real(Δy) * pinv(y))
return ∂x # must return after mutating
Expand Down
33 changes: 33 additions & 0 deletions src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,36 @@ Symmetric
````
"""
_unionall_wrapper(::Type{T}) where {T} = T.name.wrapper

"""
WithSomeZeros{T}
This is a union of LinearAlgebra types, all of which are partly structral zeros,
with a simple backing array given by `parent(x)`. All have methods of `_rewrap`
to re-create.
This exists to solve a type instability, as broadcasting for instance
`λ .* Diagonal(rand(3))` gives a dense matrix when `x==Inf`.
But `withsomezeros_rewrap(x, λ .* parent(x))` is type-stable.
"""
WithSomeZeros{T} = Union{
Diagonal{T},
UpperTriangular{T},
UnitUpperTriangular{T},
UpperHessenberg{T},
LowerTriangular{T},
UnitLowerTriangular{T},
}
for S in [
:Diagonal,
:UpperTriangular,
:UnitUpperTriangular,
:UpperHessenberg,
:LowerTriangular,
:UnitLowerTriangular,
]
@eval withsomezeros_rewrap(::$S, x) = $S(x)
end

# Bidiagonal, Tridiagonal have more complicated storage.
# AdjOrTransUpperOrUnitUpperTriangular would need adjoint(parent(parent()))

0 comments on commit 8ab56db

Please sign in to comment.