Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make normalize work for Numbers #49342

Merged
merged 3 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1804,21 +1804,18 @@ function normalize!(a::AbstractArray, p::Real=2)
__normalize!(a, nrm)
end

@inline function __normalize!(a::AbstractArray, nrm::Real)
@inline function __normalize!(a::AbstractArray, nrm)
# The largest positive floating point number whose inverse is less than infinity
δ = inv(prevfloat(typemax(nrm)))

if nrm ≥ δ # Safe to multiply with inverse
invnrm = inv(nrm)
rmul!(a, invnrm)

else # scale elements to avoid overflow
εδ = eps(one(nrm))/δ
rmul!(a, εδ)
rmul!(a, inv(nrm*εδ))
end

a
return a
end

"""
Expand Down
28 changes: 4 additions & 24 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ using .Main.Quaternions
isdefined(Main, :OffsetArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "OffsetArrays.jl"))
using .Main.OffsetArrays

isdefined(Main, :DualNumbers) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "DualNumbers.jl"))
using .Main.DualNumbers

Random.seed!(123)

Expand Down Expand Up @@ -78,30 +80,7 @@ n = 5 # should be odd
end

@testset "det with nonstandard Number type" begin
struct MyDual{T<:Real} <: Real
val::T
eps::T
end
Base.:+(x::MyDual, y::MyDual) = MyDual(x.val + y.val, x.eps + y.eps)
Base.:*(x::MyDual, y::MyDual) = MyDual(x.val * y.val, x.eps * y.val + y.eps * x.val)
Base.:/(x::MyDual, y::MyDual) = x.val / y.val
Base.:(==)(x::MyDual, y::MyDual) = x.val == y.val && x.eps == y.eps
Base.zero(::MyDual{T}) where {T} = MyDual(zero(T), zero(T))
Base.zero(::Type{MyDual{T}}) where {T} = MyDual(zero(T), zero(T))
Base.one(::MyDual{T}) where {T} = MyDual(one(T), zero(T))
Base.one(::Type{MyDual{T}}) where {T} = MyDual(one(T), zero(T))
# the following line is required for BigFloat, IDK why it doesn't work via
# promote_rule like for all other types
Base.promote_type(::Type{MyDual{BigFloat}}, ::Type{BigFloat}) = MyDual{BigFloat}
Base.promote_rule(::Type{MyDual{T}}, ::Type{S}) where {T,S<:Real} =
MyDual{promote_type(T, S)}
Base.promote_rule(::Type{MyDual{T}}, ::Type{MyDual{S}}) where {T,S} =
MyDual{promote_type(T, S)}
Base.convert(::Type{MyDual{T}}, x::MyDual) where {T} =
MyDual(convert(T, x.val), convert(T, x.eps))
if elty <: Real
@test det(triu(MyDual.(A, zero(A)))) isa MyDual
end
elty <: Real && @test det(Dual.(triu(A), zero(A))) isa Dual
end
end

Expand Down Expand Up @@ -390,6 +369,7 @@ end
[1.0 2.0 3.0; 4.0 5.0 6.0], # 2-dim
rand(1,2,3), # higher dims
rand(1,2,3,4),
Dual.(randn(2,3), randn(2,3)),
OffsetArray([-1,0], (-2,)) # no index 1
)
@test normalize(arr) == normalize!(copy(arr))
Expand Down
46 changes: 46 additions & 0 deletions test/testhelpers/DualNumbers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module DualNumbers

export Dual

# Dual numbers type with minimal interface
# example of a (real) number type that subtypes Number, but not Real.
# Can be used to test generic linear algebra functions.

struct Dual{T<:Real} <: Number
val::T
eps::T
end
Base.:+(x::Dual, y::Dual) = Dual(x.val + y.val, x.eps + y.eps)
Base.:-(x::Dual, y::Dual) = Dual(x.val - y.val, x.eps - y.eps)
Base.:*(x::Dual, y::Dual) = Dual(x.val * y.val, x.eps * y.val + y.eps * x.val)
Base.:*(x::Number, y::Dual) = Dual(x*y.val, x*y.eps)
Base.:*(x::Dual, y::Number) = Dual(x.val*y, x.eps*y)
Base.:/(x::Dual, y::Dual) = Dual(x.val / y.val, (x.eps*y.val - x.val*y.eps)/(y.val*y.val))

Base.:(==)(x::Dual, y::Dual) = x.val == y.val && x.eps == y.eps

Base.promote_rule(::Type{Dual{T}}, ::Type{T}) where {T} = Dual{T}
Base.promote_rule(::Type{Dual{T}}, ::Type{S}) where {T,S<:Real} = Dual{promote_type(T, S)}
Base.promote_rule(::Type{Dual{T}}, ::Type{Dual{S}}) where {T,S} = Dual{promote_type(T, S)}

Base.convert(::Type{Dual{T}}, x::Dual{T}) where {T} = x
Base.convert(::Type{Dual{T}}, x::Dual) where {T} = Dual(convert(T, x.val), convert(T, x.eps))
Base.convert(::Type{Dual{T}}, x::Real) where {T} = Dual(convert(T, x), zero(T))

Base.float(x::Dual) = Dual(float(x.val), float(x.eps))
# the following two methods are needed for normalize (to check for potential overflow)
Base.typemax(x::Dual) = Dual(typemax(x.val), zero(x.eps))
Base.prevfloat(x::Dual{<:AbstractFloat}) = prevfloat(x.val)
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved

Base.abs2(x::Dual) = x*x
Base.abs(x::Dual) = sqrt(abs2(x))
Base.sqrt(x::Dual) = Dual(sqrt(x.val), x.eps/(2sqrt(x.val)))

Base.isless(x::Dual, y::Dual) = x.val < y.val
Base.isless(x::Real, y::Dual) = x < y.val
Base.isinf(x::Dual) = isinf(x.val) & isfinite(x.eps)
Base.real(x::Dual) = x # since we curently only consider Dual{<:Real}

end # module