Skip to content

Commit

Permalink
Extend conversion (JuliaIntervals#686)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierHnt authored Oct 19, 2024
1 parent 429ad16 commit d9735c9
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
6 changes: 5 additions & 1 deletion ext/IntervalArithmeticForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module IntervalArithmeticForwardDiffExt

using IntervalArithmetic, ForwardDiff
using ForwardDiff: Dual, , value, partials
using ForwardDiff: Dual, Partials, , value, partials

#

Expand Down Expand Up @@ -70,4 +70,8 @@ function Base.:(^)(x::ExactReal, y::Dual{<:Ty}) where {Ty}
end
end

# resolve ambiguity

Base.convert(::Type{Dual{T,V,N}}, x::ExactReal) where {T,V,N} = Dual{T}(V(x), zero(Partials{N,V}))

end
8 changes: 4 additions & 4 deletions src/intervals/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,6 @@ isguaranteed(x::Complex{<:Interval}) = isguaranteed(real(x)) & isguaranteed(imag

isguaranteed(::Number) = false

Interval{T}(x::Interval) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve method ambiguity
# Interval{T}(x) where {T<:NumTypes} = convert(Interval{T}, x)
# Interval{T}(x::Interval{T}) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve method ambiguity

#

"""
Expand Down Expand Up @@ -563,6 +559,10 @@ Base.promote_rule(::Type{T}, ::Type{Interval{S}}) where {T<:AbstractIrrational,S

# conversion

Interval{T}(x::Real) where {T<:NumTypes} = convert(Interval{T}, x)
Interval(x::Real) = Interval{promote_numtype(numtype(x), numtype(x))}(x)
Interval{T}(x::Interval) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve method ambiguity

Base.convert(::Type{Interval{T}}, x::Interval) where {T<:NumTypes} = interval(T, x)

function Base.convert(::Type{Interval{T}}, x::Complex{<:Interval}) where {T<:NumTypes}
Expand Down
6 changes: 5 additions & 1 deletion src/intervals/exact_literals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ Base.promote_rule(::Type{ExactReal{T}}, ::Type{ExactReal{S}}) where {T<:Real,S<:

# to BareInterval

BareInterval{T}(x::ExactReal) where {T<:NumTypes} = convert(BareInterval{T}, x)
BareInterval(x::ExactReal) = BareInterval{promote_numtype(numtype(x.value), numtype(x.value))}(x)

Base.convert(::Type{BareInterval{T}}, x::ExactReal) where {T<:NumTypes} = bareinterval(T, x.value)

Base.promote_rule(::Type{BareInterval{T}}, ::Type{ExactReal{S}}) where {T<:NumTypes,S<:Real} =
Expand All @@ -105,8 +108,9 @@ Base.promote_rule(::Type{ExactReal{T}}, ::Type{Interval{S}}) where {T<:Real,S<:N

# to Real

# allows Interval{<:NumTypes}(::ExactReal)
(::Type{T})(x::ExactReal) where {T<:Real} = convert(T, x)
Interval{T}(x::ExactReal) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve ambiguity
Interval(x::ExactReal) = Interval{promote_numtype(numtype(x.value), numtype(x.value))}(x) # needed to resolve ambiguity

Base.convert(::Type{T}, x::ExactReal) where {T<:Real} = convert(T, x.value)

Expand Down
6 changes: 3 additions & 3 deletions test/interval_tests/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ end
@test_throws MethodError BareInterval(1, 2)
@test_throws MethodError BareInterval{Float64}(1, 2)

@test_throws MethodError Interval(1)
@test_throws MethodError Interval{Float64}(1)
@test !isguaranteed(Interval(1))
@test !isguaranteed(Interval{Float64}(1))
@test_throws MethodError Interval(1, 2)
@test_throws MethodError Interval{Float64}(1, 2)

Expand Down Expand Up @@ -163,7 +163,7 @@ end
i = interval(IS.Interval(0.1, 2))
@test isequal_interval(i, interval(0.1, 2.)) && !isguaranteed(i)
@test interval(Float64, IS.Interval(0.1, 2)) === i

i = interval(IS.iv"[0.1, Inf)")
@test isequal_interval(i, interval(0.1, Inf)) && !isguaranteed(i)
@test interval(IS.iv"[0.1, Inf]") === nai(Float64)
Expand Down
12 changes: 6 additions & 6 deletions test/interval_tests/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ end
(t) = ForwardDiff.derivative(ψ, t)
ddψ(t) = ForwardDiff.derivative(dψ, t)
dddψ(t) = ForwardDiff.derivative(ddψ, t)
@test ψ′(0) === (0) && !isguaranteed(ψ′(0))
@test_broken ψ′′(0) === ddψ(0) && !isguaranteed(ψ′′(0)) # rely on `Interval{T}(::Real)` being defined
@test_broken ψ′′′(0) === dddψ(0) && !isguaranteed(ψ′′′(0)) # rely on `Interval{T}(::Real)` being defined
@test ψ′(0) === (0) && !isguaranteed(ψ′(0))
@test ψ′′(0) === ddψ(0) && !isguaranteed(ψ′′(0))
@test ψ′′′(0) === dddψ(0) && !isguaranteed(ψ′′′(0))
t₀ = interval(0)
@test ψ′(t₀) === (t₀) && isguaranteed(ψ′(t₀))
@test ψ′′(t₀) === ddψ(t₀) && isguaranteed(ψ′′(t₀))
Expand All @@ -73,14 +73,14 @@ end
@test isguaranteed(dfdy)
@test isguaranteed(grad[1])
@test isguaranteed(grad[2])

if iszero(x) && y < 0
@test decoration(dfdx) == trv
else
@test in_interval(ForwardDiff.derivative(fx, x), dfdx)
end

if iszero(x) && y <= 0
if iszero(x) && y <= 0
@test decoration(dfdy) == trv
else
@test in_interval(ForwardDiff.derivative(fy, y), dfdy)
Expand All @@ -104,4 +104,4 @@ end
@exact g(x) = 2^x + 6sin(x^3) - 33
@test isguaranteed(ForwardDiff.derivative(f, interval(1)))
end
end
end

0 comments on commit d9735c9

Please sign in to comment.