diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 910e37e1f..abd422387 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -17,7 +17,7 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real) return y, ∂y end -function rrule(::typeof(norm), x::AbstractArray, p::Real) +function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) y = LinearAlgebra.norm(x, p) function norm_pullback_p(Δy) ∂x = if isempty(x) || p == 0 @@ -48,7 +48,7 @@ function rrule(::typeof(norm), x::AbstractArray, p::Real) norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero()) return y, norm_pullback_p end -function rrule(::typeof(norm), x::AbstractArray) +function rrule(::typeof(norm), x::AbstractArray{<:Number}) y = LinearAlgebra.norm(x) function norm_pullback_2(Δy) ∂x = if isempty(x) @@ -64,7 +64,7 @@ function rrule(::typeof(norm), x::AbstractArray) norm_pullback_2(::Zero) = (NO_FIELDS, Zero()) return y, norm_pullback_2 end -function rrule(::typeof(norm), x::Union{LinearAlgebra.AdjOrTransAbsVec}, p::Real) +function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::Real) y, inner_pullback = rrule(norm, parent(x), p) function norm_pullback(Δy) (∂self, ∂x′, ∂p) = inner_pullback(Δy) @@ -93,7 +93,7 @@ end ##### `normp` ##### -function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray, p) +function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p) y = LinearAlgebra.normp(x, p) function normp_pullback(Δy) ∂x = @thunk _normp_back_x(x, p, y, Δy) @@ -138,14 +138,14 @@ end ##### `normMinusInf`/`normInf` ##### -function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray) +function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normMinusInf(x) normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero()) return y, normMinusInf_pullback end -function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray) +function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number}) y = LinearAlgebra.normInf(x) normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy)) normInf_pullback(::Zero) = (NO_FIELDS, Zero()) @@ -169,7 +169,7 @@ end ##### `norm1` ##### -function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray) +function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number}) y = LinearAlgebra.norm1(x) norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk( @thunk(_norm1_back(x, y, Δy)), @@ -201,7 +201,7 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x) return y, _norm2_forward(x, Δx, y) end -function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray) +function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number}) y = LinearAlgebra.norm2(x) norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk( @thunk(_norm2_back(x, y, Δy)), @@ -233,7 +233,7 @@ end ##### `normalize` ##### -function rrule(::typeof(normalize), x::AbstractVector, p::Real) +function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) nrm, inner_pullback = rrule(norm, x, p) Ty = typeof(first(x) / nrm) y = copyto!(similar(x, Ty), x) @@ -248,7 +248,7 @@ function rrule(::typeof(normalize), x::AbstractVector, p::Real) normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero()) return y, normalize_pullback end -function rrule(::typeof(normalize), x::AbstractVector) +function rrule(::typeof(normalize), x::AbstractVector{<:Number}) nrm = LinearAlgebra.norm2(x) Ty = typeof(first(x) / nrm) y = copyto!(similar(x, Ty), x)