From d9735c99ee2ac7558c3fa0af90a62eff798b9af2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20H=C3=A9not?= <38465572+OlivierHnt@users.noreply.github.com> Date: Sat, 19 Oct 2024 10:05:56 -0400 Subject: [PATCH] Extend conversion (#686) --- ext/IntervalArithmeticForwardDiffExt.jl | 6 +++++- src/intervals/construction.jl | 8 ++++---- src/intervals/exact_literals.jl | 6 +++++- test/interval_tests/construction.jl | 6 +++--- test/interval_tests/forwarddiff.jl | 12 ++++++------ 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/ext/IntervalArithmeticForwardDiffExt.jl b/ext/IntervalArithmeticForwardDiffExt.jl index 3ab6c7123..daf9bfbd5 100644 --- a/ext/IntervalArithmeticForwardDiffExt.jl +++ b/ext/IntervalArithmeticForwardDiffExt.jl @@ -1,7 +1,7 @@ module IntervalArithmeticForwardDiffExt using IntervalArithmetic, ForwardDiff -using ForwardDiff: Dual, ≺, value, partials +using ForwardDiff: Dual, Partials, ≺, value, partials # @@ -70,4 +70,8 @@ function Base.:(^)(x::ExactReal, y::Dual{<:Ty}) where {Ty} end end +# resolve ambiguity + +Base.convert(::Type{Dual{T,V,N}}, x::ExactReal) where {T,V,N} = Dual{T}(V(x), zero(Partials{N,V})) + end diff --git a/src/intervals/construction.jl b/src/intervals/construction.jl index 98457093a..7f19cff24 100644 --- a/src/intervals/construction.jl +++ b/src/intervals/construction.jl @@ -367,10 +367,6 @@ isguaranteed(x::Complex{<:Interval}) = isguaranteed(real(x)) & isguaranteed(imag isguaranteed(::Number) = false -Interval{T}(x::Interval) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve method ambiguity -# Interval{T}(x) where {T<:NumTypes} = convert(Interval{T}, x) -# Interval{T}(x::Interval{T}) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve method ambiguity - # """ @@ -563,6 +559,10 @@ Base.promote_rule(::Type{T}, ::Type{Interval{S}}) where {T<:AbstractIrrational,S # conversion +Interval{T}(x::Real) where {T<:NumTypes} = convert(Interval{T}, x) +Interval(x::Real) = Interval{promote_numtype(numtype(x), numtype(x))}(x) +Interval{T}(x::Interval) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve method ambiguity + Base.convert(::Type{Interval{T}}, x::Interval) where {T<:NumTypes} = interval(T, x) function Base.convert(::Type{Interval{T}}, x::Complex{<:Interval}) where {T<:NumTypes} diff --git a/src/intervals/exact_literals.jl b/src/intervals/exact_literals.jl index 08c5dea01..5f7d939db 100644 --- a/src/intervals/exact_literals.jl +++ b/src/intervals/exact_literals.jl @@ -87,6 +87,9 @@ Base.promote_rule(::Type{ExactReal{T}}, ::Type{ExactReal{S}}) where {T<:Real,S<: # to BareInterval +BareInterval{T}(x::ExactReal) where {T<:NumTypes} = convert(BareInterval{T}, x) +BareInterval(x::ExactReal) = BareInterval{promote_numtype(numtype(x.value), numtype(x.value))}(x) + Base.convert(::Type{BareInterval{T}}, x::ExactReal) where {T<:NumTypes} = bareinterval(T, x.value) Base.promote_rule(::Type{BareInterval{T}}, ::Type{ExactReal{S}}) where {T<:NumTypes,S<:Real} = @@ -105,8 +108,9 @@ Base.promote_rule(::Type{ExactReal{T}}, ::Type{Interval{S}}) where {T<:Real,S<:N # to Real -# allows Interval{<:NumTypes}(::ExactReal) (::Type{T})(x::ExactReal) where {T<:Real} = convert(T, x) +Interval{T}(x::ExactReal) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve ambiguity +Interval(x::ExactReal) = Interval{promote_numtype(numtype(x.value), numtype(x.value))}(x) # needed to resolve ambiguity Base.convert(::Type{T}, x::ExactReal) where {T<:Real} = convert(T, x.value) diff --git a/test/interval_tests/construction.jl b/test/interval_tests/construction.jl index 81e882302..e53bcc60c 100644 --- a/test/interval_tests/construction.jl +++ b/test/interval_tests/construction.jl @@ -43,8 +43,8 @@ end @test_throws MethodError BareInterval(1, 2) @test_throws MethodError BareInterval{Float64}(1, 2) - @test_throws MethodError Interval(1) - @test_throws MethodError Interval{Float64}(1) + @test !isguaranteed(Interval(1)) + @test !isguaranteed(Interval{Float64}(1)) @test_throws MethodError Interval(1, 2) @test_throws MethodError Interval{Float64}(1, 2) @@ -163,7 +163,7 @@ end i = interval(IS.Interval(0.1, 2)) @test isequal_interval(i, interval(0.1, 2.)) && !isguaranteed(i) @test interval(Float64, IS.Interval(0.1, 2)) === i - + i = interval(IS.iv"[0.1, Inf)") @test isequal_interval(i, interval(0.1, Inf)) && !isguaranteed(i) @test interval(IS.iv"[0.1, Inf]") === nai(Float64) diff --git a/test/interval_tests/forwarddiff.jl b/test/interval_tests/forwarddiff.jl index ae5eff1c4..7fd51ed3f 100644 --- a/test/interval_tests/forwarddiff.jl +++ b/test/interval_tests/forwarddiff.jl @@ -46,9 +46,9 @@ end dψ(t) = ForwardDiff.derivative(ψ, t) ddψ(t) = ForwardDiff.derivative(dψ, t) dddψ(t) = ForwardDiff.derivative(ddψ, t) - @test ψ′(0) === dψ(0) && !isguaranteed(ψ′(0)) - @test_broken ψ′′(0) === ddψ(0) && !isguaranteed(ψ′′(0)) # rely on `Interval{T}(::Real)` being defined - @test_broken ψ′′′(0) === dddψ(0) && !isguaranteed(ψ′′′(0)) # rely on `Interval{T}(::Real)` being defined + @test ψ′(0) === dψ(0) && !isguaranteed(ψ′(0)) + @test ψ′′(0) === ddψ(0) && !isguaranteed(ψ′′(0)) + @test ψ′′′(0) === dddψ(0) && !isguaranteed(ψ′′′(0)) t₀ = interval(0) @test ψ′(t₀) === dψ(t₀) && isguaranteed(ψ′(t₀)) @test ψ′′(t₀) === ddψ(t₀) && isguaranteed(ψ′′(t₀)) @@ -73,14 +73,14 @@ end @test isguaranteed(dfdy) @test isguaranteed(grad[1]) @test isguaranteed(grad[2]) - + if iszero(x) && y < 0 @test decoration(dfdx) == trv else @test in_interval(ForwardDiff.derivative(fx, x), dfdx) end - if iszero(x) && y <= 0 + if iszero(x) && y <= 0 @test decoration(dfdy) == trv else @test in_interval(ForwardDiff.derivative(fy, y), dfdy) @@ -104,4 +104,4 @@ end @exact g(x) = 2^x + 6sin(x^3) - 33 @test isguaranteed(ForwardDiff.derivative(f, interval(1))) end -end \ No newline at end of file +end