Skip to content

Commit

Permalink
restrict all arrays to eltype Number
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Apr 29, 2021
1 parent b0da01f commit be61a4e
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/rulesets/LinearAlgebra/norm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
return y, ∂y
end

function rrule(::typeof(norm), x::AbstractArray, p::Real)
function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
y = LinearAlgebra.norm(x, p)
function norm_pullback_p(Δy)
∂x = if isempty(x) || p == 0
Expand Down Expand Up @@ -48,7 +48,7 @@ function rrule(::typeof(norm), x::AbstractArray, p::Real)
norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero())
return y, norm_pullback_p
end
function rrule(::typeof(norm), x::AbstractArray)
function rrule(::typeof(norm), x::AbstractArray{<:Number})
y = LinearAlgebra.norm(x)
function norm_pullback_2(Δy)
∂x = if isempty(x)
Expand All @@ -64,7 +64,7 @@ function rrule(::typeof(norm), x::AbstractArray)
norm_pullback_2(::Zero) = (NO_FIELDS, Zero())
return y, norm_pullback_2
end
function rrule(::typeof(norm), x::Union{LinearAlgebra.AdjOrTransAbsVec}, p::Real)
function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::Real)
y, inner_pullback = rrule(norm, parent(x), p)
function norm_pullback(Δy)
(∂self, ∂x′, ∂p) = inner_pullback(Δy)
Expand Down Expand Up @@ -93,7 +93,7 @@ end
##### `normp`
#####

function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray, p)
function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p)
y = LinearAlgebra.normp(x, p)
function normp_pullback(Δy)
∂x = @thunk _normp_back_x(x, p, y, Δy)
Expand Down Expand Up @@ -138,14 +138,14 @@ end
##### `normMinusInf`/`normInf`
#####

function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray)
function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number})
y = LinearAlgebra.normMinusInf(x)
normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero())
return y, normMinusInf_pullback
end

function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray)
function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number})
y = LinearAlgebra.normInf(x)
normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
normInf_pullback(::Zero) = (NO_FIELDS, Zero())
Expand All @@ -169,7 +169,7 @@ end
##### `norm1`
#####

function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray)
function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number})
y = LinearAlgebra.norm1(x)
norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk(
@thunk(_norm1_back(x, y, Δy)),
Expand Down Expand Up @@ -201,7 +201,7 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x)
return y, _norm2_forward(x, Δx, y)
end

function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray)
function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number})
y = LinearAlgebra.norm2(x)
norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk(
@thunk(_norm2_back(x, y, Δy)),
Expand Down Expand Up @@ -233,7 +233,7 @@ end
##### `normalize`
#####

function rrule(::typeof(normalize), x::AbstractVector, p::Real)
function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real)
nrm, inner_pullback = rrule(norm, x, p)
Ty = typeof(first(x) / nrm)
y = copyto!(similar(x, Ty), x)
Expand All @@ -248,7 +248,7 @@ function rrule(::typeof(normalize), x::AbstractVector, p::Real)
normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
return y, normalize_pullback
end
function rrule(::typeof(normalize), x::AbstractVector)
function rrule(::typeof(normalize), x::AbstractVector{<:Number})
nrm = LinearAlgebra.norm2(x)
Ty = typeof(first(x) / nrm)
y = copyto!(similar(x, Ty), x)
Expand Down

0 comments on commit be61a4e

Please sign in to comment.