diff --git a/src/intervals/arithmetic.jl b/src/intervals/arithmetic.jl index 00dc2d092..bdacb7f0b 100644 --- a/src/intervals/arithmetic.jl +++ b/src/intervals/arithmetic.jl @@ -162,8 +162,8 @@ function inv(a::Interval{T}) where T<:Real isempty(a) && return emptyinterval(a) if zero(T) ∈ a - a.lo < zero(T) == a.hi && return @round(-Inf, inv(a.lo)) - a.lo == zero(T) < a.hi && return @round(inv(a.hi), Inf) + a.lo < zero(T) == a.hi && return @round(T(-Inf), inv(a.lo)) + a.lo == zero(T) < a.hi && return @round(inv(a.hi), T(Inf)) a.lo < zero(T) < a.hi && return entireinterval(T) a == zero(a) && return emptyinterval(T) end @@ -195,14 +195,14 @@ function /(a::Interval{T}, b::Interval{T}) where T<:Real if iszero(b.lo) - a.lo >= zero(T) && return @round(a.lo/b.hi, Inf) - a.hi <= zero(T) && return @round(-Inf, a.hi/b.hi) + a.lo >= zero(T) && return @round(a.lo/b.hi, T(Inf)) + a.hi <= zero(T) && return @round(T(-Inf), a.hi/b.hi) return entireinterval(S) elseif iszero(b.hi) - a.lo >= zero(T) && return @round(-Inf, a.lo/b.lo) - a.hi <= zero(T) && return @round(a.hi/b.lo, Inf) + a.lo >= zero(T) && return @round(T(-Inf), a.lo/b.lo) + a.hi <= zero(T) && return @round(a.hi/b.lo, T(Inf)) return entireinterval(S) else @@ -218,10 +218,10 @@ function extended_div(a::Interval{T}, b::Interval{T}) where T<:Real S = typeof(a.lo / b.lo) if 0 < b.hi && 0 > b.lo && 0 ∉ a if a.hi < 0 - return (Interval(-Inf, a.hi / b.hi), Interval(a.hi / b.lo, Inf)) + return (Interval(T(-Inf), a.hi / b.hi), Interval(a.hi / b.lo, T(Inf))) elseif a.lo > 0 - return (Interval(-Inf, a.lo / b.lo), Interval(a.lo / b.hi, Inf)) + return (Interval(T(-Inf), a.lo / b.lo), Interval(a.lo / b.hi, T(Inf))) end elseif 0 ∈ a && 0 ∈ b @@ -394,7 +394,7 @@ The default is the true midpoint at `α = 0.5`. Assumes 0 ≤ α ≤ 1. -Warning: if the parameter `α = 0.5` is explicitely set, the behavior differs +Warning: if the parameter `α = 0.5` is explicitly set, the behavior differs from the default case if the provided `Interval` is not finite, since when `α` is provided `mid` simply replaces `+∞` (respectively `-∞`) by `prevfloat(+∞)` (respecively `nextfloat(-∞)`) for the computation of the intermediate point. @@ -403,16 +403,18 @@ function mid(a::Interval{T}, α) where T isempty(a) && return convert(T, NaN) - lo = (a.lo == -∞ ? nextfloat(-∞) : a.lo) - hi = (a.hi == +∞ ? prevfloat(+∞) : a.hi) + lo = (a.lo == -∞ ? nextfloat(T(-∞)) : a.lo) + hi = (a.hi == +∞ ? prevfloat(T(+∞)) : a.hi) - midpoint = α * (hi - lo) + lo + β = convert(T, α) + + midpoint = β * (hi - lo) + lo isfinite(midpoint) && return midpoint #= Fallback in case of overflow: hi - lo == +∞. This case can not be the default one as it does not pass several IEEE1788-2015 tests for small floats. =# - return (1-α) * lo + α * hi + return (1 - β) * lo + β * hi end """ @@ -432,13 +434,13 @@ function mid(a::Interval{T}) where T a.lo == -∞ && return nextfloat(a.lo) a.hi == +∞ && return prevfloat(a.hi) - midpoint = 0.5 * (a.lo + a.hi) + midpoint = (a.lo + a.hi) / 2 isfinite(midpoint) && return midpoint #= Fallback in case of overflow: a.hi + a.lo == +∞ or a.hi + a.lo == -∞. This case can not be the default one as it does not pass several IEEE1788-2015 tests for small floats. =# - return 0.5 * a.lo + 0.5 * a.hi + return a.lo / 2 + a.hi / 2 end mid(a::Interval{Rational{T}}) where T = (1//2) * (a.lo + a.hi) diff --git a/src/intervals/rounding_macros.jl b/src/intervals/rounding_macros.jl index 751a934de..ddd079740 100644 --- a/src/intervals/rounding_macros.jl +++ b/src/intervals/rounding_macros.jl @@ -22,7 +22,7 @@ function round_expr(ex::Expr, rounding_mode::RoundingMode) return :( $(esc(op))( $(esc(ex.args[2])), $(esc(ex.args[3])), $rounding_mode) ) else # unary operator - return :( $op($(esc(ex.args[2])), $rounding_mode ) ) + return :( $(esc(op))($(esc(ex.args[2])), $rounding_mode ) ) end else return ex diff --git a/test/interval_tests/consistency.jl b/test/interval_tests/consistency.jl index 1052e7586..15d3c7280 100644 --- a/test/interval_tests/consistency.jl +++ b/test/interval_tests/consistency.jl @@ -388,4 +388,26 @@ setprecision(Interval, Float64) end + @testset "Type stability" begin + for T in (Float32, Float64, BigFloat) + + xs = [3..4, 0..4, 0..0, -4..0, -4..4, -Inf..4, 4..Inf, -Inf..Inf] + + for x in xs + for y in xs + xx = Interval{T}(x) + yy = Interval{T}(y) + + for op in (+, -, *, /, atan) + @inferred op(x, y) + end + end + + for op in (sin, cos, exp, log, tan, abs, mid, diam) + @inferred op(x) + end + end + end + end + end