diff --git a/base/fastmath.jl b/base/fastmath.jl index 5f905b86554f4..a969bcaaa6ae0 100644 --- a/base/fastmath.jl +++ b/base/fastmath.jl @@ -84,7 +84,12 @@ const fast_op = :sinh => :sinh_fast, :sqrt => :sqrt_fast, :tan => :tan_fast, - :tanh => :tanh_fast) + :tanh => :tanh_fast, + # reductions + :maximum => :maximum_fast, + :minimum => :minimum_fast, + :maximum! => :maximum!_fast, + :minimum! => :minimum!_fast) const rewrite_op = Dict(:+= => :+, @@ -366,4 +371,27 @@ for f in (:^, :atan, :hypot, :log) end end +# Reductions + +maximum_fast(a; kw...) = Base.reduce(max_fast, a; kw...) +minimum_fast(a; kw...) = Base.reduce(min_fast, a; kw...) + +maximum_fast(f, a; kw...) = Base.mapreduce(f, max_fast, a; kw...) +minimum_fast(f, a; kw...) = Base.mapreduce(f, min_fast, a; kw...) + +Base.reducedim_init(f, ::typeof(max_fast), A::AbstractArray, region) = + Base.reducedim_init(f, max, A::AbstractArray, region) +Base.reducedim_init(f, ::typeof(min_fast), A::AbstractArray, region) = + Base.reducedim_init(f, min, A::AbstractArray, region) + +maximum!_fast(r::AbstractArray, A::AbstractArray; kw...) = + maximum!_fast(identity, r, A; kw...) +minimum!_fast(r::AbstractArray, A::AbstractArray; kw...) = + minimum!_fast(identity, r, A; kw...) + +maximum!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) = + Base.mapreducedim!(f, max_fast, Base.initarray!(r, f, max, init, A), A) +minimum!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) = + Base.mapreducedim!(f, min_fast, Base.initarray!(r, f, min, init, A), A) + end diff --git a/test/fastmath.jl b/test/fastmath.jl index e93fb93330b4f..8755e727db092 100644 --- a/test/fastmath.jl +++ b/test/fastmath.jl @@ -207,6 +207,31 @@ end @test @fastmath(cis(third)) ≈ cis(third) end end + +@testset "reductions" begin + @test @fastmath(maximum([1,2,3])) == 3 + @test @fastmath(minimum([1,2,3])) == 1 + @test @fastmath(maximum(abs2, [1,2,3+0im])) == 9 + @test @fastmath(minimum(sqrt, [1,2,3])) == 1 + @test @fastmath(maximum(Float32[4 5 6; 7 8 9])) == 9.0f0 + @test @fastmath(minimum(Float32[4 5 6; 7 8 9])) == 4.0f0 + + @test @fastmath(maximum(Float32[4 5 6; 7 8 9]; dims=1)) == Float32[7.0 8.0 9.0] + @test @fastmath(minimum(Float32[4 5 6; 7 8 9]; dims=2)) == Float32[4.0; 7.0;;] + @test @fastmath(maximum(abs, [4+im -5 6-im; -7 8 -9]; dims=1)) == [7.0 8.0 9.0] + @test @fastmath(minimum(cbrt, [4 -5 6; -7 8 -9]; dims=2)) == cbrt.([-5; -9;;]) + + x = randn(3,4,5) + x1 = sum(x; dims=1) + x23 = sum(x; dims=(2,3)) + @test @fastmath(maximum!(x1, x)) ≈ maximum(x; dims=1) + @test x1 ≈ maximum(x; dims=1) + @test @fastmath(minimum!(x23, x)) ≈ minimum(x; dims=(2,3)) + @test x23 ≈ minimum(x; dims=(2,3)) + @test @fastmath(maximum!(abs, x23, x .+ im)) ≈ maximum(abs, x .+ im; dims=(2,3)) + @test @fastmath(minimum!(abs2, x1, x .+ im)) ≈ minimum(abs2, x .+ im; dims=1) +end + @testset "issue #10544" begin a = fill(1.,2,2) b = fill(1.,2,2)