Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to rules for norm #337

Merged
merged 22 commits into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.63"
version = "0.7.64"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
149 changes: 90 additions & 59 deletions src/rulesets/LinearAlgebra/norm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ function frule((_, Δx), ::typeof(norm), x)
y = norm(x)
return y, _norm2_forward(x, Δx, norm(x))
end

function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
y = norm(x, p)
∂y = if iszero(Δx) || iszero(p)
Expand All @@ -17,15 +18,12 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
return y, ∂y
end

function rrule(
::typeof(norm),
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
p::Real,
)
function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
y = LinearAlgebra.norm(x, p)
function norm_pullback(Δy)
∂x = Thunk() do
return if isempty(x) || p == 0
function norm_pullback_p(Δy)
∂x = InplaceableThunk(
# out-of-place versions
@thunk(if isempty(x) || p == 0
zero.(x) .* (zero(y) * zero(real(Δy)))
elseif p == 2
_norm2_back(x, y, Δy)
Expand All @@ -37,35 +35,52 @@ function rrule(
_normInf_back(x, y, Δy)
else
_normp_back_x(x, p, y, Δy)
end)
, # in-place versions -- can be fixed when actually useful?
dx -> if isempty(x) || p == 0
dx
elseif p == 2
_norm2_back!(dx, x, y, Δy)
elseif p == 1
_norm1_back!(dx, x, y, Δy)
elseif p == Inf
dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved
elseif p == -Inf
dx .+= _normInf_back(x, y, Δy)
else
dx .+= _normp_back_x(x, p, y, Δy)
end
end
)
∂p = @thunk _normp_back_p(x, p, y, Δy)
return (NO_FIELDS, ∂x, ∂p)
end
norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
return y, norm_pullback
norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero())
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
return y, norm_pullback_p
end
function rrule(
::typeof(norm),
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
)

function rrule(::typeof(norm), x::AbstractArray{<:Number})
y = LinearAlgebra.norm(x)
function norm_pullback(Δy)
∂x = if isempty(x)
zero.(x) .* (zero(y) * zero(real(Δy)))
else
_norm2_back(x, y, Δy)
end
function norm_pullback_2(Δy)
∂x = InplaceableThunk(
@thunk(if isempty(x)
zero.(x) .* (zero(y) * zero(real(Δy)))
else
_norm2_back(x, y, Δy)
end)
,
dx -> if isempty(x)
dx
else
_norm2_back!(dx, x, y, Δy)
end
)
return (NO_FIELDS, ∂x)
end
norm_pullback(::Zero) = (NO_FIELDS, Zero())
return y, norm_pullback
norm_pullback_2(::Zero) = (NO_FIELDS, Zero())
return y, norm_pullback_2
end
function rrule(
::typeof(norm),
x::Union{LinearAlgebra.TransposeAbsVec, LinearAlgebra.AdjointAbsVec},
p::Real,
)

function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::Real)
y, inner_pullback = rrule(norm, parent(x), p)
function norm_pullback(Δy)
(∂self, ∂x′, ∂p) = inner_pullback(Δy)
Expand All @@ -75,6 +90,7 @@ function rrule(
end
return y, norm_pullback
end

function rrule(::typeof(norm), x::Number, p::Real)
y = norm(x, p)
function norm_pullback(Δy)
Expand All @@ -94,11 +110,7 @@ end
##### `normp`
#####

function rrule(
::typeof(LinearAlgebra.normp),
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
p,
)
function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p)
y = LinearAlgebra.normp(x, p)
function normp_pullback(Δy)
∂x = @thunk _normp_back_x(x, p, y, Δy)
Expand All @@ -111,15 +123,24 @@ end

function _normp_back_x(x, p, y, Δy)
c = real(Δy) / y
∂x = similar(x)
broadcast!(∂x, x) do xi
∂x = map(x) do xi
a = norm(xi)
∂xi = xi * ((a / y)^(p - 2) * c)
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
end
return ∂x
end

function _normp_back_x(x::WithSomeZeros, p, y, Δy) # Diagonal, UpperTriangular, etc.
c = real(Δy) / y
∂x_data = map(parent(x)) do xi
a = norm(xi)
∂xi = xi * ((a / y)^(p - 2) * c)
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
end
return withsomezeros_rewrap(x, ∂x_data)
end

function _normp_back_p(x, p, y, Δy)
y > 0 && isfinite(y) && !iszero(p) || return zero(real(Δy)) * zero(y) / one(p)
s = sum(x) do xi
Expand All @@ -135,20 +156,14 @@ end
##### `normMinusInf`/`normInf`
#####

function rrule(
::typeof(LinearAlgebra.normMinusInf),
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
)
function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number})
y = LinearAlgebra.normMinusInf(x)
normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero())
return y, normMinusInf_pullback
end

function rrule(
::typeof(LinearAlgebra.normInf),
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
)
function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number})
y = LinearAlgebra.normInf(x)
normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
normInf_pullback(::Zero) = (NO_FIELDS, Zero())
Expand All @@ -172,19 +187,26 @@ end
##### `norm1`
#####

function rrule(
::typeof(LinearAlgebra.norm1),
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
)
function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number})
y = LinearAlgebra.norm1(x)
norm1_pullback(Δy) = (NO_FIELDS, _norm1_back(x, y, Δy))
norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk(
@thunk(_norm1_back(x, y, Δy)),
dx -> _norm1_back!(dx, x, y, Δy),
))
norm1_pullback(::Zero) = (NO_FIELDS, Zero())
return y, norm1_pullback
end

function _norm1_back(x, y, Δy)
∂x = similar(x)
∂x .= sign.(x) .* real(Δy)
∂x = sign.(x) .* real(Δy)
return ∂x
end
function _norm1_back(x::WithSomeZeros, y, Δy)
∂x_data = sign.(parent(x)) .* real(Δy)
return withsomezeros_rewrap(x, ∂x_data)
end
function _norm1_back!(∂x, x, y, Δy)
∂x .+= sign.(x) .* real(Δy)
return ∂x
end

Expand All @@ -197,12 +219,12 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x)
return y, _norm2_forward(x, Δx, y)
end

function rrule(
::typeof(LinearAlgebra.norm2),
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
)
function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number})
y = LinearAlgebra.norm2(x)
norm2_pullback(Δy) = (NO_FIELDS, _norm2_back(x, y, Δy))
norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk(
@thunk(_norm2_back(x, y, Δy)),
dx -> _norm2_back!(dx, x, y, Δy),
))
norm2_pullback(::Zero) = (NO_FIELDS, Zero())
return y, norm2_pullback
end
Expand All @@ -212,16 +234,24 @@ function _norm2_forward(x, Δx, y)
return ∂y
end
function _norm2_back(x, y, Δy)
∂x = similar(x)
∂x .= x .* (real(Δy) * pinv(y))
∂x = x .* (real(Δy) * pinv(y))
return ∂x
end
function _norm2_back(x::WithSomeZeros, y, Δy)
T = typeof(one(eltype(x)) / one(real(eltype(Δy))))
∂x_data = parent(x) .* (real(Δy) * pinv(y))
return withsomezeros_rewrap(x, ∂x_data)
end
function _norm2_back!(∂x, x, y, Δy)
∂x .+= x .* (real(Δy) * pinv(y))
return ∂x # must return after mutating
end

#####
##### `normalize`
#####

function rrule(::typeof(normalize), x::AbstractVector, p::Real)
function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real)
nrm, inner_pullback = rrule(norm, x, p)
Ty = typeof(first(x) / nrm)
y = copyto!(similar(x, Ty), x)
Expand All @@ -236,7 +266,8 @@ function rrule(::typeof(normalize), x::AbstractVector, p::Real)
normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
return y, normalize_pullback
end
function rrule(::typeof(normalize), x::AbstractVector)

function rrule(::typeof(normalize), x::AbstractVector{<:Number})
nrm = LinearAlgebra.norm2(x)
Ty = typeof(first(x) / nrm)
y = copyto!(similar(x, Ty), x)
Expand Down
33 changes: 33 additions & 0 deletions src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,36 @@ Symmetric
````
"""
_unionall_wrapper(::Type{T}) where {T} = T.name.wrapper

"""
WithSomeZeros{T}

This is a union of LinearAlgebra types, all of which are partly structral zeros,
with a simple backing array given by `parent(x)`. All have methods of `_rewrap`
to re-create.

This exists to solve a type instability, as broadcasting for instance
`λ .* Diagonal(rand(3))` gives a dense matrix when `x==Inf`.
But `withsomezeros_rewrap(x, λ .* parent(x))` is type-stable.
"""
WithSomeZeros{T} = Union{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call these StructuredSparseArray

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You approve of the mechanism, #337 (comment)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am willing to give it a shot.
We can always change it later.
It's not going to lead to wrong behavour AFAICT.

It seems unfortunate not to take advantage of the fact that we know where the zeros are,
and we know that the pullback is going to map zeros to zeros, since linear.
So we should be able to skip some.
But idk that that is a generic API for our structurally sparse matrixes to know if an index will be zero.

Copy link
Member Author

@mcabbott mcabbott May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misunderstand you, but both λ .* Diagonal(rand(3)) and this function do know where the zeros are, and do O(N) work. That's the only really sparse one.

For UpperTriangular, I haven't tried to time this against broadcasting... there could be trade-offs, maybe broadcasting skips half, but if so it needs lots of if statements. Frankly I doubt that anyone has ever called norm(::UpperTriangular) outside a test, though. So perhaps thinking about that can wait until this finds wider use where someone does need to care.

Copy link
Member Author

@mcabbott mcabbott May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be good to fix this instability upstream. Can't we argue that the off-diagonal elements are a strong zero like false, and make NaN .* Diagonal(rand(3)) just work? Is there an issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like all structural zeros should be strong yes.
I was sure I had seem julia displaying that behavour on SparseCSC matrixes, but I can't reproduce it right now.

Diagonal{T},
UpperTriangular{T},
UnitUpperTriangular{T},
# UpperHessenberg{T}, # doesn't exist in Julia 1.0
LowerTriangular{T},
UnitLowerTriangular{T},
}
for S in [
:Diagonal,
:UpperTriangular,
:UnitUpperTriangular,
# :UpperHessenberg,
:LowerTriangular,
:UnitLowerTriangular,
]
@eval withsomezeros_rewrap(::$S, x) = $S(x)
end

# Bidiagonal, Tridiagonal have more complicated storage.
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
# AdjOrTransUpperOrUnitUpperTriangular would need adjoint(parent(parent()))
Loading