From 6141c47aa73318d7734abc478fd846d92b89c2f7 Mon Sep 17 00:00:00 2001 From: OlivierHnt <38465572+OlivierHnt@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:00:39 -0500 Subject: [PATCH] Improve complex power --- src/intervals/arithmetic/power.jl | 209 +++++++++++++----------------- 1 file changed, 92 insertions(+), 117 deletions(-) diff --git a/src/intervals/arithmetic/power.jl b/src/intervals/arithmetic/power.jl index 8a1addb4..cb719ad4 100644 --- a/src/intervals/arithmetic/power.jl +++ b/src/intervals/arithmetic/power.jl @@ -26,11 +26,12 @@ power_mode() = PowerMode{:fast}() ^(x::BareInterval, y::BareInterval) ^(x::Interval, y::Interval) -Compute the power of the positive real part of `x` by `y`. This function is not -in the IEEE Standard 1788-2015. Its behaviour depend on the current -[`PowerMode`](@ref). +Compute the power of `x` by `y`. Unless `y` is an integer, the positive real +part of `x^y` is returned. This function is not in the IEEE Standard 1788-2015. +Its behaviour depends on the current [`PowerMode`](@ref). -See also: [`pow`](@ref) and [`pown`](@ref). +See also: [`pow`](@ref), [`pown`](@ref), [`fastpow`](@ref) and +[`fastpown`](@ref). # Examples @@ -60,40 +61,43 @@ function Base.:^(x::Interval, y::Interval) return _unsafe_interval(bareinterval(r), d, t) end -Base.:^(n::Integer, y::Interval) = ^(n//one(n), y) Base.:^(x::Interval, n::Integer) = ^(x, n//one(n)) -Base.:^(x::Rational, y::Interval) = ^(convert(Interval{typeof(x)}, x), y) Base.:^(x::Interval, y::Rational) = ^(x, convert(Interval{typeof(y)}, y)) -function Base.:^(x::Complex{Interval{T}}, y::Complex{Interval{T}}) where {T<:NumTypes} - !isthinzero(x) && return exp(y * log(x)) - d = min(decoration(x), decoration(y)) - t = isguaranteed(x) & isguaranteed(y) - isthinzero(y) && return complex(_unsafe_interval(one(BareInterval{T}), d, t), _unsafe_interval(zero(BareInterval{T}), d, t)) - (inf(real(y)) > 0) & !isempty_interval(bareinterval(real(y))) && return complex(_unsafe_interval(zero(BareInterval{T}), d, t), _unsafe_interval(zero(BareInterval{T}), d, t)) - d = min(d, trv) - return complex(_unsafe_interval(emptyinterval(BareInterval{T}), d, t), _unsafe_interval(emptyinterval(BareInterval{T}), d, t)) +function Base.:^(x::Complex{<:Interval}, y::Complex{<:Interval}) + if isreal(x) && isthininteger(y) + r = real(x) ^ real(y) + d = min(decoration(x), decoration(y), decoration(r)) + t = isguaranteed(x) & isguaranteed(y) + return complex(_unsafe_interval(bareinterval(real(r)), d, t), _unsafe_interval(bareinterval(imag(r)), d, t)) + else + isthininteger(y) && return exp(y * _log_no_branch_cut(x)) + return exp(y * log(x)) + end end -Base.:^(x::Complex{<:Interval}, y::Complex{<:Interval}) = ^(promote(x, y)...) -Base.:^(x::Complex{<:Interval}, y::Real) = ^(promote(x, y)...) -Base.:^(x::Real, y::Complex{<:Interval}) = ^(promote(x, y)...) -# needed to avoid method ambiguities -Base.:^(x::Complex{<:Interval}, n::Bool) = ^(promote(x, n)...) -Base.:^(x::Complex{<:Interval}, n::Integer) = ^(promote(x, n)...) -Base.:^(x::Complex{<:Interval}, n::Rational) = ^(promote(x, n)...) + +function _log_no_branch_cut(z::Complex{<:Interval}) + x, y = reim(z) + by = bareinterval(y) + bx = bareinterval(x) + r = atan(by, bx) + d = min(decoration(y), decoration(x), decoration(r)) + d = min(d, + ifelse(in_interval(0, by), + ifelse(in_interval(0, bx), trv, d), + d)) + t = isguaranteed(y) & isguaranteed(x) + angle = _unsafe_interval(r, d, t) + return complex(log(abs(z)), angle) +end + +# needed to avoid method errors +Base.:^(x::Complex{<:Interval}, y::Interval) = ^(promote(x, y)...) +Base.:^(x::Interval, y::Complex{<:Interval}) = ^(promote(x, y)...) # overwrite behaviour for small integer powers from https://github.com/JuliaLang/julia/pull/24240 -# Base.literal_pow(::typeof(^), x::Interval, ::Val{n}) where {n} = x^n Base.literal_pow(::typeof(^), x::Interval, ::Val{n}) where {n} = _select_pown(power_mode(), x, n) -function Base.literal_pow(::typeof(^), x::Complex{Interval{T}}, ::Val{n}) where {T<:NumTypes,n} - !isthinzero(x) && return exp(interval(T, n) * log(x)) - d = decoration(x) - t = isguaranteed(x) - n == 0 && return complex(_unsafe_interval(one(BareInterval{T}), d, t), _unsafe_interval(zero(BareInterval{T}), d, t)) - n > 0 && return complex(_unsafe_interval(zero(BareInterval{T}), d, t), _unsafe_interval(zero(BareInterval{T}), d, t)) - d = min(d, trv) - return complex(_unsafe_interval(emptyinterval(BareInterval{T}), d, t), _unsafe_interval(emptyinterval(BareInterval{T}), d, t)) -end +Base.literal_pow(::typeof(^), x::Complex{<:Interval}, ::Val{n}) where {n} = ^(x, interval(n)) # helper functions for power @@ -110,7 +114,7 @@ if `y` is a thin integer, this is not equivalent to `pown(x, sup(y))`. Implement the `pow` function of the IEEE Standard 1788-2015 (Table 9.1). -See also: [`pown`](@ref). +See also: [`fastpow`](@ref), [`pown`](@ref) and [`fastpown`](@ref). # Examples @@ -127,25 +131,15 @@ julia> pow(interval(-1, 1), interval(-3)) Interval{Float64}(1.0, Inf, trv) ``` """ -function pow(x::BareInterval{T}, y::BareInterval{T}) where {T<:AbstractFloat} +function pow(x::BareInterval{T}, y::BareInterval{T}) where {T<:NumTypes} isempty_interval(y) && return y domain = _unsafe_bareinterval(T, zero(T), typemax(T)) x = intersect_interval(x, domain) isempty_interval(x) && return x - isthin(y) && return _pow(x, sup(y)) - return hull(_pow(x, inf(y)), _pow(x, sup(y))) + isthin(y) && return _thin_pow(x, sup(y)) + return hull(_thin_pow(x, inf(y)), _thin_pow(x, sup(y))) end -function pow(x::BareInterval{T}, y::BareInterval{T}) where {T<:Rational} - isempty_interval(y) && return y - domain = _unsafe_bareinterval(T, zero(T), typemax(T)) - x = intersect_interval(x, domain) - isempty_interval(x) && return x - isthin(y) && return _pow(x, sup(y)) - return hull(_pow(x, inf(y)), _pow(x, sup(y))) -end -pow(x::BareInterval{<:AbstractFloat}, y::BareInterval{<:AbstractFloat}) = pow(promote(x, y)...) -pow(x::BareInterval{<:Rational}, y::BareInterval{<:Rational}) = pow(promote(x, y)...) -pow(x::BareInterval{<:Rational}, y::BareInterval{<:AbstractFloat}) = pow(promote(x, y)...) +pow(x::BareInterval, y::BareInterval) = pow(promote(x, y)...) # specialize on rational to improve exactness function pow(x::BareInterval{T}, y::BareInterval{S}) where {T<:NumTypes,S<:Rational} R = promote_numtype(T, S) @@ -153,14 +147,12 @@ function pow(x::BareInterval{T}, y::BareInterval{S}) where {T<:NumTypes,S<:Ratio domain = _unsafe_bareinterval(T, zero(T), typemax(T)) x = intersect_interval(x, domain) isempty_interval(x) && return emptyinterval(BareInterval{R}) - isthin(y) && return BareInterval{R}(_pow(x, sup(y))) - return BareInterval{R}(hull(_pow(x, inf(y)), _pow(x, sup(y)))) + isthin(y) && return BareInterval{R}(_thin_pow(x, sup(y))) + return BareInterval{R}(hull(_thin_pow(x, inf(y)), _thin_pow(x, sup(y)))) end -pow(n::Integer, y::BareInterval) = pow(n//one(n), y) -pow(x::BareInterval, n::Integer) = pow(x, n//one(n)) -pow(x::Real, y::BareInterval) = pow(bareinterval(x), y) pow(x::BareInterval, y::Real) = pow(x, bareinterval(y)) +pow(x::BareInterval, n::Integer) = pow(x, n//one(n)) function pow(x::Interval, y::Interval) bx = bareinterval(x) @@ -172,17 +164,15 @@ function pow(x::Interval, y::Interval) return _unsafe_interval(r, d, t) end -pow(n::Integer, y::Interval) = pow(n//one(n), y) -pow(x::Interval, n::Integer) = pow(x, n//one(n)) -pow(x::Real, y::Interval) = pow(interval(x), y) pow(x::Interval, y::Real) = pow(x, interval(y)) +pow(x::Interval, n::Integer) = pow(x, n//one(n)) -# helper functions for power +# helper function for `pow` -function _pow(x::BareInterval{T}, y::T) where {T<:NumTypes} +function _thin_pow(x::BareInterval{T}, y::T) where {T<:NumTypes} # assume `inf(x) ≥ 0` and `!isempty_interval(x)` - if sup(x) == 0 - y > 0 && return x # zero(x) + if sup(x) == 0 # isthinzero(x) + y > 0 && return x return emptyinterval(BareInterval{T}) else isinteger(y) && return pown(x, Integer(y)) @@ -193,15 +183,14 @@ function _pow(x::BareInterval{T}, y::T) where {T<:NumTypes} end end -function _pow(x::BareInterval{T}, y::Rational{S}) where {T<:NumTypes,S<:Integer} +function _thin_pow(x::BareInterval{T}, y::Rational{S}) where {T<:NumTypes,S<:Integer} # assume `inf(x) ≥ 0` and `!isempty_interval(x)` - if sup(x) == 0 - y > 0 && return x # zero(x) + if sup(x) == 0 # isthinzero(x) + y > 0 && return x return emptyinterval(BareInterval{T}) else isinteger(y) && return pown(x, S(y)) - y == (1//2) && return sqrt(x) - return pown(rootn(x, y.den), y.num) + return pown(rootn(x, denominator(y)), numerator(y)) end end @@ -210,6 +199,8 @@ end Implement the `pown` function of the IEEE Standard 1788-2015 (Table 9.1). +See also: [`fastpown`](@ref), [`pow`](@ref) and [`fastpow`](@ref). + # Examples ```jldoctest @@ -321,37 +312,29 @@ Base.hypot(x::Interval, y::Interval) = sqrt(_select_pown(power_mode(), x, 2) + _ """ fastpow(x, y) -A faster implementation of `pow(x, y)`, at the cost of maybe returning a -slightly larger interval. +A faster implementation of `pow(x, y)`, at the cost of maybe returning a larger +interval. + +See also: [`pow`](@ref), [`pown`](@ref) and [`fastpown`](@ref). """ function fastpow(x::BareInterval{T}, y::BareInterval{T}) where {T<:NumTypes} isempty_interval(y) && return y - isthininteger(y) && return fastpow(x, Integer(sup(y))) domain = _unsafe_bareinterval(T, zero(T), typemax(T)) x = intersect_interval(x, domain) isempty_interval(x) && return x - if sup(x) == 0 - sup(y) > 0 && return x # zero(x) + if sup(x) == 0 # isthinzero(x) + sup(y) > 0 && return x return emptyinterval(BareInterval{T}) + elseif isthininteger(y) + n = Integer(sup(y)) + n < 0 && return inv(_positive_power_by_squaring(x, -n)) + return _positive_power_by_squaring(x, n) else return exp(y * log(x)) end end - -function fastpow(x::BareInterval{T}, n::Integer) where {T<:NumTypes} - n < 0 && return inv(fastpow(x, -n)) - domain = _unsafe_bareinterval(T, zero(T), typemax(T)) - x = intersect_interval(x, domain) - isempty_interval(x) && return x - if sup(x) == 0 - n > 0 && return x # zero(x) - return emptyinterval(BareInterval{T}) # n == 0 - else - return _positive_power_by_squaring(x, n) - end -end - fastpow(x::BareInterval, y::BareInterval) = fastpow(promote(x, y)...) + fastpow(x::BareInterval, y::Real) = fastpow(x, bareinterval(y)) function fastpow(x::Interval, y::Interval) @@ -366,17 +349,35 @@ end fastpow(x::Interval, y::Real) = fastpow(x, interval(y)) -# helper function for fast power +""" + fastpown(x, n) + +A faster implementation of `pown(x, n)`, at the cost of maybe returning a larger +interval. + +See also: [`pown`](@ref), [`pow`](@ref) and [`fastpow`](@ref). +""" +function fastpown(x::BareInterval{T}, n::Integer) where {T<:NumTypes} + isempty_interval(x) && return x + n < 0 && return inv(fastpown(x, -n)) + range = _unsafe_bareinterval(T, ifelse(iseven(n), zero(T), typemin(T)), typemax(T)) + return intersect_interval(_positive_power_by_squaring(x, n), range) +end + +function fastpown(x::Interval, n::Integer) + r = fastpown(bareinterval(x), n) + d = min(decoration(x), decoration(r)) + d = min(d, ifelse((n < 0) & in_interval(0, x), trv, d)) + return _unsafe_interval(r, d, isguaranteed(x)) +end + +# helper function for `fastpow` and `fastpown` # code inspired by `power_by_squaring(::Any, ::Integer)` in base/intfuncs.jl -Base.@assume_effects :terminates_locally function _positive_power_by_squaring(x::BareInterval, n::Integer) - if n == 1 - return x - elseif n == 0 - return one(x) - elseif n == 2 - return x*x - end +Base.@assume_effects :terminates_locally function _positive_power_by_squaring(x, n) + n == 0 && return one(x) + n == 1 && return x + n == 2 && return x*x t = trailing_zeros(n) + 1 n >>= t while (t -= 1) > 0 @@ -394,32 +395,6 @@ Base.@assume_effects :terminates_locally function _positive_power_by_squaring(x: return y end -""" - fastpown(x, n) - -A faster implementation of `pown(x, y)`, at the cost of maybe returning a -slightly larger interval. -""" -function fastpown(x::BareInterval{T}, n::Integer) where {T<:NumTypes} - isempty_interval(x) && return x - n == 0 && return one(BareInterval{T}) - n == 1 && return x - if n < 0 - isthinzero(x) && return emptyinterval(BareInterval{T}) - return inv(fastpown(x, -n)) - else - range = _unsafe_bareinterval(T, ifelse(iseven(n), zero(T), typemin(T)), typemax(T)) - return intersect_interval(_positive_power_by_squaring(x, n), range) - end -end - -function fastpown(x::Interval, n::Integer) - r = fastpown(bareinterval(x), n) - d = min(decoration(x), decoration(r)) - d = min(d, ifelse((n < 0) & in_interval(0, x), trv, d)) - return _unsafe_interval(r, d, isguaranteed(x)) -end - # for f ∈ (:cbrt, :exp, :exp2, :exp10, :expm1)