diff --git a/ext/IntervalArithmeticForwardDiffExt.jl b/ext/IntervalArithmeticForwardDiffExt.jl index a3445d894..791b3cd5f 100644 --- a/ext/IntervalArithmeticForwardDiffExt.jl +++ b/ext/IntervalArithmeticForwardDiffExt.jl @@ -56,4 +56,64 @@ function Base.:(^)(x::ExactReal, y::Dual{<:Any, I}) where I<:Interval return convert(I, x)^y end +function Base.max(x::Dual{T,V,N}, y::AbstractFloat) where {T,V<:Interval,N} + if sup(value(x)) < y + return Dual{T,V,N}(interval(y,y), interval(0,0) * partials(x)) + elseif inf(value(x)) > y + return Dual{T,V,N}(value(x), interval(1,1) * partials(x)) + else + return Dual{T,V,N}(interval(y,sup(value(x))), interval(0,1) * partials(x)) + end +end +function Base.max(y::AbstractFloat, x::Dual{T,V,N}) where {T,V<:Interval,N} + return max(x, y) +end + +function Base.max(x::Dual{T,Dual{T2,V2,N2},N}, y::AbstractFloat) where {T,T2,V2<:Interval,N2,N} + if sup(value(value(x))) < y + return Dual{T,Dual{T2,V2,N2},N}(Dual{T2,V2,N2}(interval(y,y)), interval(0,0) * partials(x)) + elseif inf(value(value(x))) > y + return Dual{T,Dual{T2,V2,N2},N}(value(x), interval(1,1) * partials(x)) + else + return Dual{T,Dual{T2,V2,N2},N}(Dual{T2,V2,N2}(interval(y,sup(value(value(x)))), partials(value(x))), interval(0,1) * partials(x)) + end +end +function Base.max(y::AbstractFloat, x::Dual{T,Dual{T2,V2,N2},N}) where {T,T2,V2<:Interval,N2,N} + return max(x, y) +end + +function Base.min(x::Dual{T,V,N}, y::AbstractFloat) where {T,V<:Interval,N} + if inf(value(x)) > y + return Dual{T,V,N}(interval(y,y), interval(0,0) * partials(x)) + elseif sup(value(x)) < y + return Dual{T,V,N}(value(x), interval(1,1) * partials(x)) + else + return Dual{T,V,N}(interval(inf(value(x)),y), interval(0,1) * partials(x)) + end +end +function Base.min(y::AbstractFloat, x::Dual{T,V,N}) where {T,V<:Interval,N} + return min(x, y) +end + +function Base.min(x::Dual{T,Dual{T2,V2,N2},N}, y::AbstractFloat) where {T,T2,V2<:Interval,N2,N} + if inf(value(value(x))) > y + return Dual{T,Dual{T2,V2,N2},N}(Dual{T2,V2,N2}(interval(y,y)), interval(0,0) * partials(x)) + elseif sup(value(value(x))) < y + return Dual{T,Dual{T2,V2,N2},N}(value(x), interval(1,1) * partials(x)) + else + return Dual{T,Dual{T2,V2,N2},N}(Dual{T2,V2,N2}(interval(inf(value(value(x))),y), partials(value(x))), interval(0,1) * partials(x)) + end +end +function Base.min(y::AbstractFloat, x::Dual{T,Dual{T2,V2,N2},N}) where {T,T2,V2<:Interval,N2,N} + return min(x, y) +end + +function Base.clamp(i::Dual{T,V,N}, lo::AbstractFloat, hi::AbstractFloat) where {T,V<:Interval,N} + return min(max(i, lo), hi) +end + +function Base.clamp(i::Dual{T,Dual{T2,V2,N2},N}, lo::AbstractFloat, hi::AbstractFloat) where {T,T2,V2<:Interval,N2,N} + return min(max(i, lo), hi) +end + end diff --git a/src/intervals/arithmetic/absmax.jl b/src/intervals/arithmetic/absmax.jl index dcb3dfcb8..0d78138d8 100644 --- a/src/intervals/arithmetic/absmax.jl +++ b/src/intervals/arithmetic/absmax.jl @@ -50,7 +50,12 @@ for f ∈ (:min, :max) isempty_interval(y) && return y return _unsafe_bareinterval(T, $f(inf(x), inf(y)), $f(sup(x), sup(y))) end + function Base.$f(x::BareInterval{T}, y::Rational) where {T<:NumTypes} + isempty_interval(x) && return x + return _unsafe_bareinterval(T, $f(inf(x), y), $f(sup(x), y)) + end Base.$f(x::BareInterval, y::BareInterval) = $f(promote(x, y)...) + Base.$f(x::BareInterval, y::Rational) = $f(promote(x), y) function Base.$f(x::Interval, y::Interval) r = $f(bareinterval(x), bareinterval(y)) @@ -58,5 +63,14 @@ for f ∈ (:min, :max) t = isguaranteed(x) & isguaranteed(y) return _unsafe_interval(r, d, t) end + + function Base.$f(x::Interval, y::Rational) + r = $f(bareinterval(x), y) + d = decoration(x) + return _unsafe_interval(r, d, false) + end + Base.$f(x::Rational, y::Interval) = $f(y, x) end end + +Base.clamp(i::Interval, lo, hi) = min(max(i, lo), hi) diff --git a/test/interval_tests/consistency.jl b/test/interval_tests/consistency.jl index 70c5bf1e6..473abe5f9 100644 --- a/test/interval_tests/consistency.jl +++ b/test/interval_tests/consistency.jl @@ -268,7 +268,7 @@ @test radius(2.125) == 0 end - @testset "abs, min, max, sign" begin + @testset "abs, min, max, sign, clamp" begin @test isequal_interval(abs(entireinterval()), interval(0.0, Inf)) @test isequal_interval(abs(emptyinterval()), emptyinterval()) @test isequal_interval(abs(interval(-3.0,1.0)), interval(0.0, 3.0)) @@ -279,14 +279,31 @@ @test isequal_interval(min(emptyinterval(), interval(3.0,4.0)), emptyinterval()) @test isequal_interval(min(interval(-3.0,1.0), interval(3.0,4.0)), interval(-3.0, 1.0)) @test isequal_interval(min(interval(-3.0,-1.0), interval(3.0,4.0)), interval(-3.0, -1.0)) + @test isequal_interval(min(interval(1, 2), 1.5), interval(1, 1.5)) + @test isequal_interval(min(interval(1, 2), 0.5), interval(0.5, 0.5)) + @test isequal_interval(min(interval(1, 2), 2.5), interval(1, 2)) + @test !isguaranteed(min(interval(1, 2), 1.5)) + @test isequal_interval(min(1.5, interval(1,2)), interval(1, 1.5)) + @test !isguaranteed(min(1.5, interval(1,2))) @test isequal_interval(max(entireinterval(), interval(3.0,4.0)), interval(3.0, Inf)) @test isequal_interval(max(emptyinterval(), interval(3.0,4.0)), emptyinterval()) @test isequal_interval(max(interval(-3.0,1.0), interval(3.0,4.0)), interval(3.0, 4.0)) @test isequal_interval(max(interval(-3.0,-1.0), interval(3.0,4.0)), interval(3.0, 4.0)) + @test isequal_interval(max(interval(1, 2), 1.5), interval(1.5, 2)) + @test isequal_interval(max(interval(1, 2), 0.5), interval(1, 2)) + @test isequal_interval(max(interval(1, 2), 2.5), interval(2.5, 2.5)) + @test !isguaranteed(max(interval(1, 2), 1.5)) + @test isequal_interval(max(1.5, interval(1,2)), interval(1.5, 2)) + @test !isguaranteed(max(1.5, interval(1,2))) @test isequal_interval(sign(entireinterval()), interval(-1.0, 1.0)) @test isequal_interval(sign(emptyinterval()), emptyinterval()) @test isequal_interval(sign(interval(-3.0,1.0)), interval(-1.0, 1.0)) @test isequal_interval(sign(interval(-3.0,-1.0)), interval(-1.0, -1.0)) + @test isequal_interval(clamp(interval(1, 2), 1.5, 2.5), interval(1.5, 2)) + @test isequal_interval(clamp(interval(1, 2), 0.5, 1.5), interval(1, 1.5)) + @test isequal_interval(clamp(interval(1, 2), 2.5, 3.5), interval(2.5, 2.5)) + @test isequal_interval(clamp(interval(1, 2), 0.5, 2.5), interval(1, 2)) + @test !isguaranteed(clamp(interval(1, 2), 1.5, 2.5)) # Test putting functions in interval: @test issubset_interval(log(interval(-2, 5)), interval(-Inf, log(interval(5)))) diff --git a/test/interval_tests/forwarddiff.jl b/test/interval_tests/forwarddiff.jl index c417c6303..04d1a8f2d 100644 --- a/test/interval_tests/forwarddiff.jl +++ b/test/interval_tests/forwarddiff.jl @@ -96,4 +96,48 @@ end end end end + + @testset "min" begin + @test isequal_interval(ForwardDiff.derivative(x->min(x, 1.5), interval(1, 2)), interval(0, 1)) + @test isequal_interval(ForwardDiff.derivative(x->min(x, 1.5), interval(1.75, 2)), interval(0)) + @test isequal_interval(ForwardDiff.derivative(x->min(x, 1.5), interval(0.5, 0.75)), interval(1)) + @test isequal_interval(ForwardDiff.derivative(x->min(x, 1.5), interval(1, 2)), ForwardDiff.derivative(x->min(1.5, x), interval(1, 2))) + + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->min(x, 1.5), y), interval(1, 2)), interval(0)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->min(x, 1.5)^2, y), interval(1, 2)), interval(0, 2)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->min(x, 1.5)^3, y), interval(1, 2)), interval(0, 9)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->min(x, 3.0)^3, y), interval(1, 2)), interval(6, 12)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->min(x, 3.0)^3, y), interval(4, 5)), interval(0)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->min(x, 1.5)^3, y), interval(1, 2)), ForwardDiff.derivative(y->ForwardDiff.derivative(x->min(1.5, x)^3, y), interval(1, 2))) + end + + @testset "max" begin + @test isequal_interval(ForwardDiff.derivative(x->max(x, 1.5), interval(1, 2)), interval(0, 1)) + @test isequal_interval(ForwardDiff.derivative(x->max(x, 1.5), interval(1.75, 2)), interval(1)) + @test isequal_interval(ForwardDiff.derivative(x->max(x, 1.5), interval(0.5, 0.75)), interval(0)) + @test isequal_interval(ForwardDiff.derivative(x->max(x, 1.5), interval(1, 2)), ForwardDiff.derivative(x->max(1.5, x), interval(1, 2))) + + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->max(x, 1.5), y), interval(1, 2)), interval(0)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->max(x, 1.5)^2, y), interval(1, 2)), interval(0, 2)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->max(x, 1.5)^3, y), interval(1, 2)), interval(0, 12)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->max(x, 3.0)^3, y), interval(1, 2)), interval(0)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->max(x, 3.0)^3, y), interval(4, 5)), interval(24, 30)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->max(x, 1.5)^3, y), interval(1, 2)), ForwardDiff.derivative(y->ForwardDiff.derivative(x->max(1.5, x)^3, y), interval(1, 2))) + end + + @testset "clamp" begin + @test isequal_interval(ForwardDiff.derivative(x->clamp(x, 1.5, 2.5), interval(1, 2)), interval(0, 1)) + @test isequal_interval(ForwardDiff.derivative(x->clamp(x, 1.5, 2.5), interval(2, 3)), interval(0, 1)) + @test isequal_interval(ForwardDiff.derivative(x->clamp(x, 1.5, 2.5), interval(1.75, 2)), interval(1)) + @test isequal_interval(ForwardDiff.derivative(x->clamp(x, 1.5, 2.5), interval(2.75, 3)), interval(0)) + @test isequal_interval(ForwardDiff.derivative(x->clamp(x, 1.5, 2.5), interval(0.75, 1)), interval(0)) + + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->clamp(x, 1.5, 2.5), y), interval(1, 2)), interval(0)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->clamp(x, 1.5, 2.5)^2, y), interval(1, 2)), interval(0, 2)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->clamp(x, 1.5, 2.5)^3, y), interval(1, 2)), interval(0, 12)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->clamp(x, 1.5, 2.5)^3, y), interval(2, 3)), interval(0, 15)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->clamp(x, 1.5, 2.5)^3, y), interval(3, 4)), interval(0)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->clamp(x, 1.5, 2.5)^3, y), interval(1.75, 2)), interval(6*1.75, 12)) + @test isequal_interval(ForwardDiff.derivative(y->ForwardDiff.derivative(x->clamp(x, 1.5, 2.5)^3, y), interval(0, 1)), interval(0)) + end end \ No newline at end of file