From f8f236b4bc63ffdd0dff7d4475810570965252fd Mon Sep 17 00:00:00 2001 From: matthias314 Date: Sat, 20 May 2023 21:45:56 -0400 Subject: [PATCH] `@fastmath` support for `sum`, `prod`, `extrema` and `extrema!` --- base/fastmath.jl | 43 ++++++++++++++++----- test/fastmath.jl | 99 ++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 114 insertions(+), 28 deletions(-) diff --git a/base/fastmath.jl b/base/fastmath.jl index 7865736f8a776..d63ca06f1c15d 100644 --- a/base/fastmath.jl +++ b/base/fastmath.jl @@ -86,10 +86,14 @@ const fast_op = :tan => :tan_fast, :tanh => :tanh_fast, # reductions + :sum => :sum_fast, + :prod => :prod_fast, :maximum => :maximum_fast, :minimum => :minimum_fast, :maximum! => :maximum!_fast, - :minimum! => :minimum!_fast) + :minimum! => :minimum!_fast, + :extrema => :extrema_fast, + :extrema! => :extrema!_fast) const rewrite_op = Dict(:+= => :+, @@ -377,16 +381,23 @@ end # Reductions -maximum_fast(a; kw...) = Base.reduce(max_fast, a; kw...) -minimum_fast(a; kw...) = Base.reduce(min_fast, a; kw...) +add_sum_fast(x, y) = Base.add_sum(x, y) +add_sum_fast(x::T, y::T) where {T<:FloatTypes} = @fastmath x+y -maximum_fast(f, a; kw...) = Base.mapreduce(f, max_fast, a; kw...) -minimum_fast(f, a; kw...) = Base.mapreduce(f, min_fast, a; kw...) +mul_prod_fast(x, y) = Base.mul_prod(x, y) +mul_prod_fast(x::T, y::T) where {T<:FloatTypes} = @fastmath x*y -Base.reducedim_init(f, ::typeof(max_fast), A::AbstractArray, region) = - Base.reducedim_init(f, max, A::AbstractArray, region) -Base.reducedim_init(f, ::typeof(min_fast), A::AbstractArray, region) = - Base.reducedim_init(f, min, A::AbstractArray, region) +for (fred_fast, f, f_fast) in ((:sum_fast, :add_sum, :add_sum_fast), + (:prod_fast, :mul_prod, :mul_prod_fast), + (:maximum_fast, :max, :max_fast), + (:minimum_fast, :min, :min_fast)) + @eval begin + Base.reducedim_init(f, ::typeof($f_fast), A::AbstractArray, region) = + Base.reducedim_init(f, Base.$f, A, region) + $fred_fast(a; kw...) = Base.reduce($f_fast, a; kw...) + $fred_fast(f, a; kw...) = Base.mapreduce(f, $f_fast, a; kw...) + end +end maximum!_fast(r::AbstractArray, A::AbstractArray; kw...) = maximum!_fast(identity, r, A; kw...) @@ -398,4 +409,18 @@ maximum!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) minimum!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) = Base.mapreducedim!(f, min_fast, Base.initarray!(r, f, min, init, A), A) +_extrema_rf_fast((min1, max1), (min2, max2)) = (min_fast(min1, min2), max_fast(max1, max2)) + +Base.reducedim_init(f, ::typeof(_extrema_rf_fast), A::AbstractArray, region) = + Base.reducedim_init(f, Base._extrema_rf, A, region) + +extrema_fast(a; kw...) = extrema_fast(identity, a; kw...) +extrema_fast(f, a; kw...) = Base.mapreduce(Base.ExtremaMap(f), _extrema_rf_fast, a; kw...) + +extrema!_fast(r::AbstractArray, A::AbstractArray; kw...) = + extrema!_fast(identity, r, A; kw...) +extrema!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) = + Base.mapreducedim!(Base.ExtremaMap(f), _extrema_rf_fast, + Base.initarray!(r, Base.ExtremaMap(f), Base._extrema_rf, init, A), A) + end diff --git a/test/fastmath.jl b/test/fastmath.jl index 8755e727db092..dd86728aba0a7 100644 --- a/test/fastmath.jl +++ b/test/fastmath.jl @@ -209,27 +209,88 @@ end end @testset "reductions" begin - @test @fastmath(maximum([1,2,3])) == 3 - @test @fastmath(minimum([1,2,3])) == 1 - @test @fastmath(maximum(abs2, [1,2,3+0im])) == 9 - @test @fastmath(minimum(sqrt, [1,2,3])) == 1 - @test @fastmath(maximum(Float32[4 5 6; 7 8 9])) == 9.0f0 - @test @fastmath(minimum(Float32[4 5 6; 7 8 9])) == 4.0f0 + for T in (Int, Float16, Float32, Float64) + S = T == Int ? Float64 : T + x = @fastmath(sum(T[1,2,3,4])) + @test x isa T && x == 10 + x = @fastmath(prod(T[1,2,3,4])) + @test x isa T && x == 24 + x = @fastmath(sum(abs2, T[1,2,3])) + @test x isa T && x == 14 + x = @fastmath(prod(sqrt, T[1,4,9])) + @test x isa S && x == 6 + x = @fastmath(sum(T[1 2 3; 4 5 6])) + @test x isa T && x == 21 + x = @fastmath(prod(T[1 2 3; 4 5 6])) + @test x isa T && x == 720 - @test @fastmath(maximum(Float32[4 5 6; 7 8 9]; dims=1)) == Float32[7.0 8.0 9.0] - @test @fastmath(minimum(Float32[4 5 6; 7 8 9]; dims=2)) == Float32[4.0; 7.0;;] - @test @fastmath(maximum(abs, [4+im -5 6-im; -7 8 -9]; dims=1)) == [7.0 8.0 9.0] - @test @fastmath(minimum(cbrt, [4 -5 6; -7 8 -9]; dims=2)) == cbrt.([-5; -9;;]) + x = @fastmath(sum(T[4 5 6; 7 8 9]; dims=1)) + @test x isa Matrix{T} && x == [11 13 15] + x = @fastmath(prod(T[4 5 6; 7 8 9]; dims=2)) + @test x isa Matrix{T} && x == [120; 504;;] + x = @fastmath(sum(abs, Complex{T}[2im -2im 2im; -1 2 -3]; dims=1)) + @test x isa Matrix{S} && x == [3 4 5] + x = @fastmath(prod(cbrt, T[1 -8 1; 8 1 -8]; dims=2)) + @test x isa Matrix{S} && x == [-2; -4;;] + end + + for T in (Int, Float16, Float32, Float64) + if T == Float16 ; continue end # necessary until #49907 is fixed + S = T == Int ? Float64 : T + y = @fastmath(maximum(T[1,2,3])) + @test y isa T && y == 3 + y = @fastmath(minimum(T[1,2,3])) + @test y isa T && y == 1 + y = @fastmath(extrema(T[1,2,3])) + @test y isa Tuple{T,T} && y == (1,3) + y = @fastmath(maximum(abs2, Complex{T}[1,2,3im])) + @test y isa T && y == 9 + y = @fastmath(minimum(sqrt, T[1,2,3])) + @test y isa S && y == 1 + y = @fastmath(extrema(abs, T[-1,-2,3])) + @test y isa Tuple{T,T} && y == (1,3) + y = @fastmath(maximum(T[4 5 6; 7 8 9])) + @test y isa T && y == 9 + y = @fastmath(minimum(T[4 5 6; 7 8 9])) + @test y isa T && y == 4 + y = @fastmath(extrema(T[4 5 6; 7 8 9])) + @test y isa Tuple{T,T} && y == (4,9) - x = randn(3,4,5) - x1 = sum(x; dims=1) - x23 = sum(x; dims=(2,3)) - @test @fastmath(maximum!(x1, x)) ≈ maximum(x; dims=1) - @test x1 ≈ maximum(x; dims=1) - @test @fastmath(minimum!(x23, x)) ≈ minimum(x; dims=(2,3)) - @test x23 ≈ minimum(x; dims=(2,3)) - @test @fastmath(maximum!(abs, x23, x .+ im)) ≈ maximum(abs, x .+ im; dims=(2,3)) - @test @fastmath(minimum!(abs2, x1, x .+ im)) ≈ minimum(abs2, x .+ im; dims=1) + y = @fastmath(maximum(T[4 5 6; 7 8 9]; dims=1)) + @test y isa Matrix{T} && y == [7 8 9] + y = @fastmath(minimum(T[4 5 6; 7 8 9]; dims=2)) + @test y isa Matrix{T} && y == [4; 7;;] + y = @fastmath(extrema(T[4 5 6; 7 8 9]; dims=2)) + @test y isa Matrix{Tuple{T,T}} && y == [(4,6);(7,9);;] + y = @fastmath(maximum(abs, Complex{T}[4+im -5 6-im; -7 8 -9]; dims=1)) + @test y isa Matrix{S} && y == T[7 8 9] + y = @fastmath(minimum(cbrt, T[4 -5 6; -7 8 -9]; dims=2)) + @test y isa Matrix{S} && y == cbrt.(T[-5; -9;;]) + y = @fastmath(extrema(abs, T[-4 5 6; 7 -8 -9]; dims=1)) + @test y isa Matrix{Tuple{T,T}} && y == [(4,7) (5,8) (6,9)] + + x = rand(T, 3,4,5) + x1 = sum(x; dims=1) + x23 = sum(x; dims=(2,3)) + y1 = map(z -> (z,z), x1) + y23 = map(z -> (z,z), x23) + y = @fastmath(maximum!(x1, x)) + @test y ≈ maximum(x; dims=1) + @test y === x1 + y = @fastmath(minimum!(x23, x)) + @test y ≈ minimum(x; dims=(2,3)) + @test y === x23 + @test @fastmath(maximum!(abs, x23, x .+ im)) ≈ maximum(abs, x .+ im; dims=(2,3)) + @test @fastmath(minimum!(abs2, x1, x .+ im)) ≈ minimum(abs2, x .+ im; dims=1) + y = @fastmath(extrema!(y1, x)) + @test map(z -> z[1], y) ≈ minimum(x; dims=1) + @test map(z -> z[2], y) ≈ maximum(x; dims=1) + @test y === y1 + y = @fastmath(extrema!(y23, x)) + @test map(z -> z[1], y) ≈ minimum(x; dims=(2,3)) + @test map(z -> z[2], y) ≈ maximum(x; dims=(2,3)) + @test y === y23 + end end @testset "issue #10544" begin