Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance issues with fusing 12+ broadcasts #22255

Closed
ChrisRackauckas opened this issue Jun 6, 2017 · 13 comments · Fixed by #26891
Closed

Performance issues with fusing 12+ broadcasts #22255

ChrisRackauckas opened this issue Jun 6, 2017 · 13 comments · Fixed by #26891
Labels
broadcast Applying a function over a collection performance Must go faster

Comments

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Jun 6, 2017

Updated OP

While working through the issue, @jebej identified that the problem is fusing 12+ broadcasts in this comment (#22255 (comment)) which contains an MWE.

Original OP

In OrdinaryDiffEq.jl, I see a 10x performance regression due to using broadcast. With the testing code:

const ζ = 0.5
const ω₀ = 10.0

using OrdinaryDiffEq, DiffEqBase

const y₀ = Float64[0., sqrt(1-ζ^2)*ω₀]
const A = 1
const ϕ = 0


function f(t::Float64)
    α = sqrt(1-ζ^2)*ω₀
    x = A*exp(-ζ*ω₀*t)*sin*t + ϕ)
    p = A*exp(-ζ*ω₀*t)*(-ζ*ω₀*sin*t + ϕ) + α*cos*t + ϕ))
    return [x,p]
end

function df(t::Float64, y::Vector{Float64}, dy::Vector{Float64})
    dy[1] = y[2]
    dy[2] = -2*ζ*ω₀*y[2] - ω₀^2*y[1]
    return nothing
end

const T = [0.,10.]
using BenchmarkTools
prob = ODEProblem(df,y₀,(T[1],T[2]))
@benchmark init($prob,Tsit5(),dense=false,dt=1/10)
@benchmark solve($prob,Tsit5(),dense=false,dt=1/10)

using ProfileView

@profile for i in 1:1000; solve(prob,Tsit5(),dense=false,dt=1/10); end
ProfileView.view()

I get a 10x regression by changing the inner loop from:

function perform_step!(integrator,cache::Tsit5Cache,f=integrator.f)
  @unpack t,dt,uprev,u,k = integrator
  uidx = eachindex(integrator.uprev)
  @unpack c1,c2,c3,c4,c5,c6,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a62,a63,a64,a65,a71,a72,a73,a74,a75,a76,b1,b2,b3,b4,b5,b6,b7 = cache.tab
  @unpack k1,k2,k3,k4,k5,k6,k7,utilde,tmp,atmp = cache
  a = dt*a21
  for i in uidx
    tmp[i] = @muladd uprev[i]+a*k1[i]
  end
  f(@muladd(t+c1*dt),tmp,k2)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a31*k1[i]+a32*k2[i])
  end
  f(@muladd(t+c2*dt),tmp,k3)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a41*k1[i]+a42*k2[i]+a43*k3[i])
  end
  f(@muladd(t+c3*dt),tmp,k4)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a51*k1[i]+a52*k2[i]+a53*k3[i]+a54*k4[i])
  end
  f(@muladd(t+c4*dt),tmp,k5)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a61*k1[i]+a62*k2[i]+a63*k3[i]+a64*k4[i]+a65*k5[i])
  end
  f(t+dt,tmp,k6)
  for i in uidx
    u[i] = @muladd uprev[i]+dt*(a71*k1[i]+a72*k2[i]+a73*k3[i]+a74*k4[i]+a75*k5[i]+a76*k6[i])
  end
  f(t+dt,u,k7)
  if integrator.opts.adaptive
    for i in uidx
      utilde[i] = @muladd uprev[i] + dt*(b1*k1[i] + b2*k2[i] + b3*k3[i] + b4*k4[i] + b5*k5[i] + b6*k6[i] + b7*k7[i])
      atmp[i] = ((utilde[i]-u[i])./@muladd(integrator.opts.abstol+max(abs(uprev[i]),abs(u[i])).*integrator.opts.reltol))
    end
    integrator.EEst = integrator.opts.internalnorm(atmp)
  end
  @pack integrator = t,dt,u,k
end

to:

function perform_step!(integrator,cache::Tsit5Cache,f=integrator.f)
  @unpack t,dt,uprev,u,k = integrator
  @unpack c1,c2,c3,c4,c5,c6,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a62,a63,a64,a65,a71,a72,a73,a74,a75,a76,b1,b2,b3,b4,b5,b6,b7 = cache.tab
  @unpack k1,k2,k3,k4,k5,k6,k7,utilde,tmp,atmp = cache
  a = dt*a21
  tmp .= @muladd uprev+a*k1
  f(@muladd(t+c1*dt),tmp,k2)
  tmp .= @muladd uprev+dt*(a31*k1+a32*k2)
  f(@muladd(t+c2*dt),tmp,k3)
  tmp .= @muladd uprev+dt*(a41*k1+a42*k2+a43*k3)
  f(@muladd(t+c3*dt),tmp,k4)
  tmp .= @muladd uprev+dt*(a51*k1+a52*k2+a53*k3+a54*k4)
  f(@muladd(t+c4*dt),tmp,k5)
  tmp .= @muladd uprev+dt*(a61*k1+a62*k2+a63*k3+a64*k4+a65*k5)
  f(t+dt,tmp,k6)
  u .= @muladd uprev+dt*(a71*k1+a72*k2+a73*k3+a74*k4+a75*k5+a76*k6)
  f(t+dt,u,k7)
  if integrator.opts.adaptive
    utilde .= @muladd uprev + dt*(b1*k1 + b2*k2 + b3*k3 + b4*k4 + b5*k5 + b6*k6 + b7*k7)
    atmp .= ((utilde.-u)./@muladd(integrator.opts.abstol+max.(abs.(uprev),abs.(u)).*integrator.opts.reltol))
    integrator.EEst = integrator.opts.internalnorm(atmp)
  end
  @pack integrator = t,dt,u,k
end

I.e. all that's changed are loops to broadcast. The input array is y0 which is length 2. For reference, the @muladd macro acts like:

println(macroexpand(:(u .= @muladd uprev+dt*(a71*k1+a72*k2+a73*k3+a74*k4+a75*k5+a76*k6))))
#u .= (muladd).(dt, (muladd).(a71, k1, (muladd).(a72, k2, (muladd).(a73, k3, (muladd).(a74, k4, (muladd).(a75, k5, a76 .* k6))))), uprev)

println(macroexpand(:(atmp .= ((utilde.-u)./@muladd(integrator.opts.abstol+max.(abs.(uprev),abs.(u)).*integrator.opts.reltol)))))
#atmp .= (utilde .- u) ./ (muladd).(max.(abs.(uprev), abs.(u)), integrator.opts.reltol, integrator.opts.abstol)

The profile is here: https://ufile.io/2lu0f

The benchmark results are using loops:

BenchmarkTools.Trial:
  memory estimate:  87.43 KiB
  allocs estimate:  3757
  --------------
  minimum time:     281.032 μs (0.00% GC)
  median time:      526.934 μs (0.00% GC)
  mean time:        475.180 μs (2.88% GC)
  maximum time:     5.601 ms (87.13% GC)
  --------------
  samples:          10000
  evals/sample:     1

and using broadcast:

BenchmarkTools.Trial:
  memory estimate:  854.34 KiB
  allocs estimate:  37976
  --------------
  minimum time:     3.246 ms (0.00% GC)
  median time:      6.185 ms (0.00% GC)
  mean time:        5.762 ms (2.40% GC)
  maximum time:     13.919 ms (26.41% GC)
  --------------
  samples:          867
  evals/sample:     1

Am I hitting some broadcasting splatting penalty or something?

@nalimilan
Copy link
Member

Any chance you could make the example simpler, e.g. by removing uses of @muladd and keeping only one computation line?

@yuyichao
Copy link
Contributor

yuyichao commented Jun 6, 2017

AFAICT you are not doing braodcasting for the slower version. Also, f uses non-const global variables.

@yuyichao
Copy link
Contributor

yuyichao commented Jun 6, 2017

Ah, the @muladd does the transformation. That is extremely confusing................ It is indeed possible to hit tuple length limit with braodcast. You should check code_warntype and also reduce your example.

@ChrisRackauckas
Copy link
Member Author

Give me a little bit to get it simpler. I have to restart Atom and delete my precompilation cache every time I do something due to #21969, so this will take awhile.

@ChrisRackauckas
Copy link
Member Author

ChrisRackauckas commented Jun 6, 2017

Here's the form that expands all of the broadcasts:

function perform_step!(integrator,cache::Tsit5Cache,f=integrator.f)
  @unpack t,dt,uprev,u,k = integrator
  @unpack c1,c2,c3,c4,c5,c6,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a62,a63,a64,a65,a71,a72,a73,a74,a75,a76,b1,b2,b3,b4,b5,b6,b7 = cache.tab
  @unpack k1,k2,k3,k4,k5,k6,k7,utilde,tmp,atmp = cache
  a = dt*a21
  tmp .= (muladd).(a, k1, uprev)
  f(muladd(c1, dt, t),tmp,k2)
  tmp .= (muladd).(dt, (muladd).(a31, k1, a32 .* k2), uprev)
  f(muladd(c2, dt, t),tmp,k3)
  tmp .= (muladd).(dt, (muladd).(a41, k1, (muladd).(a42, k2, a43 .* k3)), uprev)
  f(muladd(c3, dt, t),tmp,k4)
  tmp .= (muladd).(dt, (muladd).(a51, k1, (muladd).(a52, k2, (muladd).(a53, k3, a54 .* k4))), uprev)
  f(muladd(c4, dt, t),tmp,k5)
  tmp .= (muladd).(dt, (muladd).(a61, k1, (muladd).(a62, k2, (muladd).(a63, k3, (muladd).(a64, k4, a65 .* k5)))), uprev)
  f(t+dt,tmp,k6)
  u .= (muladd).(dt, (muladd).(a71, k1, (muladd).(a72, k2, (muladd).(a73, k3, (muladd).(a74, k4, (muladd).(a75, k5, a76 .* k6))))), uprev)
  f(t+dt,u,k7)
  if integrator.opts.adaptive
    utilde .= (muladd).(dt, (muladd).(b1, k1, (muladd).(b2, k2, (muladd).(b3, k3, (muladd).(b4, k4, (muladd).(b5, k5, (muladd).(b6, k6, b7 .* k7)))))), uprev)
    atmp .= ((utilde.-u)./(muladd).(max.(abs.(uprev), abs.(u)), integrator.opts.reltol, integrator.opts.abstol))
    integrator.EEst = integrator.opts.internalnorm(atmp)
  end
  @pack integrator = t,dt,u,k
end

adding consts around doesn't really do anything, which is just because most of the time is not spent in the function calls f but rather in the broadcasts. The @code_warntype produces a gigantic output:

integrator = init(prob,Tsit5(),dense=false,dt=1/10)
OrdinaryDiffEq.loopheader!(integrator)
@code_warntype OrdinaryDiffEq.perform_step!(integrator,integrator.cache,integrator.f)

https://gist.github.com/ChrisRackauckas/11ba482f80147dd744f5c2b87e288067

@ararslan ararslan added performance Must go faster broadcast Applying a function over a collection labels Jun 6, 2017
@ChrisRackauckas
Copy link
Member Author

This might be an MWE:

function f(a,b,c)
  a.= b.*c
end

function g(a,b,c)
  @inbounds for ii in eachindex(a)
    a[ii] = b[ii].*c[ii]
  end
end

function f(a,b,c,d,e,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z)
  a.= b.*c .+ d .* e .+ h .* i .+ j .*k .+ l .* m .+ n .* o .+ p .* q .+ r .* s .+ t .* u .+ v .* w .+ x .* y .+ z
end

function g(a,b,c,d,e,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z)
  @inbounds for ii in eachindex(a)
    a[ii] = b[ii].*c[ii] .+ d[ii] .* e[ii] .+ h[ii] .* i[ii] .+ j[ii] .*k[ii] .+ l[ii] .* m[ii] .+ n[ii] .* o[ii] .+ p[ii] .* q[ii] .+ r[ii] .* s[ii] .+ t[ii] .* u[ii] .+ v[ii] .* w[ii] .+ x[ii] .* y[ii] .+ z[ii]
  end
end
a = rand(10)
b = rand(10)
c = rand(10)
d = rand(10)
e = rand(10)
h = rand(10)
i = rand(10)
j = rand(10)
k = rand(10)
l = rand(10)
m = rand(10)
n = rand(10)
o = rand(10)
p = rand(10)
q = rand(10)
r = rand(10)
s = rand(10)
t = rand(10)
u = rand(10)
v = rand(10)
w = rand(10)
x = rand(10)
y = rand(10)
z = rand(10)

@benchmark f($a,$b,$c)
@benchmark g($a,$b,$c)

@benchmark f($a,$b,$c,$d,$e,$h,$i,$j,$k,$l,$m,$n,$o,$p,$q,$r,$s,$t,$u,$v,$w,$x,$y,$z)
@benchmark g($a,$b,$c,$d,$e,$h,$i,$j,$k,$l,$m,$n,$o,$p,$q,$r,$s,$t,$u,$v,$w,$x,$y,$z)

Benchmarks on the small broadcasts:

@benchmark f($a,$b,$c)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     16.393 ns (0.00% GC)
  median time:      32.494 ns (0.00% GC)
  mean time:        28.710 ns (0.00% GC)
  maximum time:     225.996 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

@benchmark g($a,$b,$c)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     7.611 ns (0.00% GC)
  median time:      14.051 ns (0.00% GC)
  mean time:        12.511 ns (0.00% GC)
  maximum time:     192.038 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

and benchmarks on the long broadcasts:

@benchmark f($a,$b,$c,$d,$e,$h,$i,$j,$k,$l,$m,$n,$o,$p,$q,$r,$s,$t,$u,$v,$w,$x,$y,$z)
BenchmarkTools.Trial: 
  memory estimate:  17.77 KiB
  allocs estimate:  650
  --------------
  minimum time:     51.229 μs (0.00% GC)
  median time:      101.874 μs (0.00% GC)
  mean time:        122.778 μs (3.60% GC)
  maximum time:     31.670 ms (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

@benchmark g($a,$b,$c,$d,$e,$h,$i,$j,$k,$l,$m,$n,$o,$p,$q,$r,$s,$t,$u,$v,$w,$x,$y,$z)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     222.590 ns (0.00% GC)
  median time:      413.482 ns (0.00% GC)
  mean time:        363.430 ns (0.00% GC)
  maximum time:     2.199 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     434

When it gets sufficiently long, it begins to allocate quite a bit!

@ChrisRackauckas
Copy link
Member Author

Indeed get a the same kinds of differences when I use some of the broadcast expressions from the test code:

function f(a,b,c,d,e,h,i,j,k,l,m,n,o,p,q)
  a .= (muladd).(b, (muladd).(c, d, (muladd).(e, h, (muladd).(i, j, (muladd).(k, l, (muladd).(m, n, o .* p))))), q)
end

function g(a,b,c,d,e,h,i,j,k,l,m,n,o,p,q)
  @inbounds for ii in eachindex(a)
    a[ii] = (muladd)(b[ii], (muladd)(c[ii], d[ii], (muladd)(e[ii], h[ii], (muladd)(i[ii], j[ii], (muladd)(k[ii], l[ii], (muladd)(m[ii], n[ii], o[ii] .* p[ii]))))), q[ii])
  end
end

@benchmark f($a,$b,$c,$d,$e,$h,$i,$j,$k,$l,$m,$n,$o,$p,$q)
@benchmark g($a,$b,$c,$d,$e,$h,$i,$j,$k,$l,$m,$n,$o,$p,$q)
@benchmark f($a,$b,$c,$d,$e,$h,$i,$j,$k,$l,$m,$n,$o,$p,$q)

BenchmarkTools.Trial: 
  memory estimate:  1.58 KiB
  allocs estimate:  44
  --------------
  minimum time:     6.382 μs (0.00% GC)
  median time:      12.471 μs (0.00% GC)
  mean time:        11.167 μs (2.64% GC)
  maximum time:     1.154 ms (94.98% GC)
  --------------
  samples:          10000
  evals/sample:     5
@benchmark g($a,$b,$c,$d,$e,$h,$i,$j,$k,$l,$m,$n,$o,$p,$q)

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     48.009 ns (0.00% GC)
  median time:      83.139 ns (0.00% GC)
  mean time:        74.315 ns (0.00% GC)
  maximum time:     1.335 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

Interestingly, that's only with 8 dots and 15 numbers... I would've though the magic 16 would be a limit here.

@nalimilan
Copy link
Member

Does it still happen when reducing the number of operands?

@ChrisRackauckas
Copy link
Member Author

Does it still happen when reducing the number of operands?

No. See the "benchmarks on small broadcasts". For "sufficiently few" operands, it works just fine.

@nalimilan
Copy link
Member

OK. So where's the cutoff?

@jebej
Copy link
Contributor

jebej commented Jun 7, 2017

With the earlier example Chris had, it happens here with 12 variables, but not with 11:

function fun1!(a,b,c,d,e,f,g,h,i,j,k)
  a .= b.*c .+ d.*e .+ f.*g .+ h.*i .+ j.*k
end
function fun2!(a,b,c,d,e,f,g,h,i,j,k,l)
  a .= b.*c .+ d.*e .+ f.*g .+ h.*i .+ j.*k .+ l
end
a = rand(10); b = rand(10); c = rand(10); d = rand(10); e = rand(10);
f = rand(10); g = rand(10); h = rand(10); i = rand(10); j = rand(10);
k = rand(10); l = rand(10);
using BenchmarkTools
@benchmark fun1!($a,$b,$c,$d,$e,$f,$g,$h,$i,$j,$k)
@benchmark fun2!($a,$b,$c,$d,$e,$f,$g,$h,$i,$j,$k,$l)
julia> @benchmark fun1!($a,$b,$c,$d,$e,$f,$g,$h,$i,$j,$k)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     149.380 ns (0.00% GC)
  median time:      150.616 ns (0.00% GC)
  mean time:        153.680 ns (0.00% GC)
  maximum time:     382.402 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     947

julia> @benchmark fun2!($a,$b,$c,$d,$e,$f,$g,$h,$i,$j,$k,$l)
BenchmarkTools.Trial:
  memory estimate:  256 bytes
  allocs estimate:  4
  --------------
  minimum time:     509.969 ns (0.00% GC)
  median time:      514.536 ns (0.00% GC)
  mean time:        540.944 ns (3.26% GC)
  maximum time:     9.924 μs (91.70% GC)
  --------------
  samples:          10000
  evals/sample:     192

@ChrisRackauckas ChrisRackauckas changed the title 10x performance regression due to broadcast 10x performance regression due to fusing 12+ broadcasts Jun 7, 2017
@ChrisRackauckas ChrisRackauckas changed the title 10x performance regression due to fusing 12+ broadcasts Performance issues with fusing 12+ broadcasts Jun 7, 2017
@ChrisRackauckas
Copy link
Member Author

Updated the title to reflect the new status of the issue. Thanks @jebej for finding that. Updating the OP with the new information.

@ChrisRackauckas
Copy link
Member Author

Does anyone know which inference option would be involved here? I could 10 .s as fine, but 11 as too much, but none of the inference options have a cutoff at 10. I want to make sure this is fixed with #22545 though.

mbauman added a commit that referenced this issue Apr 23, 2018
This patch represents the combined efforts of four individuals, over 60
commits, and an iterated design over (at least) three pull requests that
spanned nearly an entire year (closes #22063, #23692, #25377 by superceding
them).

This introduces a pure Julia data structure that represents a fused broadcast
expression.  For example, the expression `2 .* (x .+ 1)` lowers to:

```julia
julia> Meta.@lower 2 .* (x .+ 1)
:($(Expr(:thunk, CodeInfo(:(begin
      Core.SSAValue(0) = (Base.getproperty)(Base.Broadcast, :materialize)
      Core.SSAValue(1) = (Base.getproperty)(Base.Broadcast, :make)
      Core.SSAValue(2) = (Base.getproperty)(Base.Broadcast, :make)
      Core.SSAValue(3) = (Core.SSAValue(2))(+, x, 1)
      Core.SSAValue(4) = (Core.SSAValue(1))(*, 2, Core.SSAValue(3))
      Core.SSAValue(5) = (Core.SSAValue(0))(Core.SSAValue(4))
      return Core.SSAValue(5)
  end)))))
```

Or, slightly more readably as:

```julia
using .Broadcast: materialize, make
materialize(make(*, 2, make(+, x, 1)))
```

The `Broadcast.make` function serves two purposes. Its primary purpose is to
construct the `Broadcast.Broadcasted` objects that hold onto the function, the
tuple of arguments (potentially including nested `Broadcasted` arguments), and
sometimes a set of `axes` to include knowledge of the outer shape. The
secondary purpose, however, is to allow an "out" for objects that _don't_ want
to participate in fusion. For example, if `x` is a range in the above `2 .* (x
.+ 1)` expression, it needn't allocate an array and operate elementwise — it
can just compute and return a new range. Thus custom structures are able to
specialize `Broadcast.make(f, args...)` just as they'd specialize on `f`
normally to return an immediate result.

`Broadcast.materialize` is identity for everything _except_ `Broadcasted`
objects for which it allocates an appropriate result and computes the
broadcast. It does two things: it `initialize`s the outermost `Broadcasted`
object to compute its axes and then `copy`s it.

Similarly, an in-place fused broadcast like `y .= 2 .* (x .+ 1)` uses the exact
same expression tree to compute the right-hand side of the expression as above,
and then uses `materialize!(y, make(*, 2, make(+, x, 1)))` to `instantiate` the
`Broadcasted` expression tree and then `copyto!` it into the given destination.

All-together, this forms a complete API for custom types to extend and
customize the behavior of broadcast (fixes #22060). It uses the existing
`BroadcastStyle`s throughout to simplify dispatch on many arguments:

* Custom types can opt-out of broadcast fusion by specializing
  `Broadcast.make(f, args...)` or `Broadcast.make(::BroadcastStyle, f, args...)`.

* The `Broadcasted` object computes and stores the type of the combined
  `BroadcastStyle` of its arguments as its first type parameter, allowing for
  easy dispatch and specialization.

* Custom Broadcast storage is still allocated via `broadcast_similar`, however
  instead of passing just a function as a first argument, the entire
  `Broadcasted` object is passed as a final argument. This potentially allows
  for much more runtime specialization dependent upon the exact expression
  given.

* Custom broadcast implmentations for a `CustomStyle` are defined by
  specializing `copy(bc::Broadcasted{CustomStyle})` or
  `copyto!(dest::AbstractArray, bc::Broadcasted{CustomStyle})`.

* Fallback broadcast specializations for a given output object of type `Dest`
  (for the `DefaultArrayStyle` or another such style that hasn't implemented
  assignments into such an object) are defined by specializing
  `copyto(dest::Dest, bc::Broadcasted{Nothing})`.

As it fully supports range broadcasting, this now deprecates `(1:5) + 2` to
`.+`, just as had been done for all `AbstractArray`s in general.

As a first-mover proof of concept, LinearAlgebra uses this new system to
improve broadcasting over structured arrays. Before, broadcasting over a
structured matrix would result in a sparse array. Now, broadcasting over a
structured matrix will _either_ return an appropriately structured matrix _or_
a dense array. This does incur a type instability (in the form of a
discriminated union) in some situations, but thanks to type-based introspection
of the `Broadcasted` wrapper commonly used functions can be special cased to be
type stable.  For example:

```julia
julia> f(d) = round.(Int, d)
f (generic function with 1 method)

julia> @inferred f(Diagonal(rand(3)))
3×3 Diagonal{Int64,Array{Int64,1}}:
 0  ⋅  ⋅
 ⋅  0  ⋅
 ⋅  ⋅  1

julia> @inferred Diagonal(rand(3)) .* 3
ERROR: return type Diagonal{Float64,Array{Float64,1}} does not match inferred return type Union{Array{Float64,2}, Diagonal{Float64,Array{Float64,1}}}
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope

julia> @inferred Diagonal(1:4) .+ Bidiagonal(rand(4), rand(3), 'U') .* Tridiagonal(1:3, 1:4, 1:3)
4×4 Tridiagonal{Float64,Array{Float64,1}}:
 1.30771  0.838589   ⋅          ⋅
 0.0      3.89109   0.0459757   ⋅
  ⋅       0.0       4.48033    2.51508
  ⋅        ⋅        0.0        6.23739
```

In addition to the issues referenced above, it fixes:

* Fixes #19313, #22053, #23445, and #24586: Literals are no longer treated
  specially in a fused broadcast; they're just arguments in a `Broadcasted`
  object like everything else.

* Fixes #21094: Since broadcasting is now represented by a pure Julia
  datastructure it can be created within `@generated` functions and serialized.

* Fixes #26097: The fallback destination-array specialization method of
  `copyto!` is specifically implemented as `Broadcasted{Nothing}` and will not
  be confused by `nothing` arguments.

* Fixes the broadcast-specific element of #25499: The default base broadcast
  implementation no longer depends upon `Base._return_type` to allocate its
  array (except in the empty or concretely-type cases). Note that the sparse
  implementation (#19595) is still dependent upon inference and is _not_ fixed.

* Fixes #25340: Functions are treated like normal values just like arguments
  and only evaluated once.

* Fixes #22255, and is performant with 12+ fused broadcasts. Okay, that one was
  fixed on master already, but this fixes it now, too.

* Fixes #25521.

* The performance of this patch has been thoroughly tested through its
  iterative development process in #25377. There remain [two classes of
  performance regressions](#25377) that Nanosoldier flagged.

* #25691: Propagation of constant literals sill lose their constant-ness upon
  going through the broadcast machinery. I believe quite a large number of
  functions would need to be marked as `@pure` to support this -- including
  functions that are intended to be specialized.

(For bookkeeping, this is the squashed version of the [teh-jn/lazydotfuse](#25377)
branch as of a1d4e7e. Squashed and separated
out to make it easier to review and commit)

Co-authored-by: Tim Holy <[email protected]>
Co-authored-by: Jameson Nash <[email protected]>
Co-authored-by: Andrew Keller <[email protected]>
Keno pushed a commit that referenced this issue Apr 27, 2018
This patch represents the combined efforts of four individuals, over 60
commits, and an iterated design over (at least) three pull requests that
spanned nearly an entire year (closes #22063, #23692, #25377 by superceding
them).

This introduces a pure Julia data structure that represents a fused broadcast
expression.  For example, the expression `2 .* (x .+ 1)` lowers to:

```julia
julia> Meta.@lower 2 .* (x .+ 1)
:($(Expr(:thunk, CodeInfo(:(begin
      Core.SSAValue(0) = (Base.getproperty)(Base.Broadcast, :materialize)
      Core.SSAValue(1) = (Base.getproperty)(Base.Broadcast, :make)
      Core.SSAValue(2) = (Base.getproperty)(Base.Broadcast, :make)
      Core.SSAValue(3) = (Core.SSAValue(2))(+, x, 1)
      Core.SSAValue(4) = (Core.SSAValue(1))(*, 2, Core.SSAValue(3))
      Core.SSAValue(5) = (Core.SSAValue(0))(Core.SSAValue(4))
      return Core.SSAValue(5)
  end)))))
```

Or, slightly more readably as:

```julia
using .Broadcast: materialize, make
materialize(make(*, 2, make(+, x, 1)))
```

The `Broadcast.make` function serves two purposes. Its primary purpose is to
construct the `Broadcast.Broadcasted` objects that hold onto the function, the
tuple of arguments (potentially including nested `Broadcasted` arguments), and
sometimes a set of `axes` to include knowledge of the outer shape. The
secondary purpose, however, is to allow an "out" for objects that _don't_ want
to participate in fusion. For example, if `x` is a range in the above `2 .* (x
.+ 1)` expression, it needn't allocate an array and operate elementwise — it
can just compute and return a new range. Thus custom structures are able to
specialize `Broadcast.make(f, args...)` just as they'd specialize on `f`
normally to return an immediate result.

`Broadcast.materialize` is identity for everything _except_ `Broadcasted`
objects for which it allocates an appropriate result and computes the
broadcast. It does two things: it `initialize`s the outermost `Broadcasted`
object to compute its axes and then `copy`s it.

Similarly, an in-place fused broadcast like `y .= 2 .* (x .+ 1)` uses the exact
same expression tree to compute the right-hand side of the expression as above,
and then uses `materialize!(y, make(*, 2, make(+, x, 1)))` to `instantiate` the
`Broadcasted` expression tree and then `copyto!` it into the given destination.

All-together, this forms a complete API for custom types to extend and
customize the behavior of broadcast (fixes #22060). It uses the existing
`BroadcastStyle`s throughout to simplify dispatch on many arguments:

* Custom types can opt-out of broadcast fusion by specializing
  `Broadcast.make(f, args...)` or `Broadcast.make(::BroadcastStyle, f, args...)`.

* The `Broadcasted` object computes and stores the type of the combined
  `BroadcastStyle` of its arguments as its first type parameter, allowing for
  easy dispatch and specialization.

* Custom Broadcast storage is still allocated via `broadcast_similar`, however
  instead of passing just a function as a first argument, the entire
  `Broadcasted` object is passed as a final argument. This potentially allows
  for much more runtime specialization dependent upon the exact expression
  given.

* Custom broadcast implmentations for a `CustomStyle` are defined by
  specializing `copy(bc::Broadcasted{CustomStyle})` or
  `copyto!(dest::AbstractArray, bc::Broadcasted{CustomStyle})`.

* Fallback broadcast specializations for a given output object of type `Dest`
  (for the `DefaultArrayStyle` or another such style that hasn't implemented
  assignments into such an object) are defined by specializing
  `copyto(dest::Dest, bc::Broadcasted{Nothing})`.

As it fully supports range broadcasting, this now deprecates `(1:5) + 2` to
`.+`, just as had been done for all `AbstractArray`s in general.

As a first-mover proof of concept, LinearAlgebra uses this new system to
improve broadcasting over structured arrays. Before, broadcasting over a
structured matrix would result in a sparse array. Now, broadcasting over a
structured matrix will _either_ return an appropriately structured matrix _or_
a dense array. This does incur a type instability (in the form of a
discriminated union) in some situations, but thanks to type-based introspection
of the `Broadcasted` wrapper commonly used functions can be special cased to be
type stable.  For example:

```julia
julia> f(d) = round.(Int, d)
f (generic function with 1 method)

julia> @inferred f(Diagonal(rand(3)))
3×3 Diagonal{Int64,Array{Int64,1}}:
 0  ⋅  ⋅
 ⋅  0  ⋅
 ⋅  ⋅  1

julia> @inferred Diagonal(rand(3)) .* 3
ERROR: return type Diagonal{Float64,Array{Float64,1}} does not match inferred return type Union{Array{Float64,2}, Diagonal{Float64,Array{Float64,1}}}
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope

julia> @inferred Diagonal(1:4) .+ Bidiagonal(rand(4), rand(3), 'U') .* Tridiagonal(1:3, 1:4, 1:3)
4×4 Tridiagonal{Float64,Array{Float64,1}}:
 1.30771  0.838589   ⋅          ⋅
 0.0      3.89109   0.0459757   ⋅
  ⋅       0.0       4.48033    2.51508
  ⋅        ⋅        0.0        6.23739
```

In addition to the issues referenced above, it fixes:

* Fixes #19313, #22053, #23445, and #24586: Literals are no longer treated
  specially in a fused broadcast; they're just arguments in a `Broadcasted`
  object like everything else.

* Fixes #21094: Since broadcasting is now represented by a pure Julia
  datastructure it can be created within `@generated` functions and serialized.

* Fixes #26097: The fallback destination-array specialization method of
  `copyto!` is specifically implemented as `Broadcasted{Nothing}` and will not
  be confused by `nothing` arguments.

* Fixes the broadcast-specific element of #25499: The default base broadcast
  implementation no longer depends upon `Base._return_type` to allocate its
  array (except in the empty or concretely-type cases). Note that the sparse
  implementation (#19595) is still dependent upon inference and is _not_ fixed.

* Fixes #25340: Functions are treated like normal values just like arguments
  and only evaluated once.

* Fixes #22255, and is performant with 12+ fused broadcasts. Okay, that one was
  fixed on master already, but this fixes it now, too.

* Fixes #25521.

* The performance of this patch has been thoroughly tested through its
  iterative development process in #25377. There remain [two classes of
  performance regressions](#25377) that Nanosoldier flagged.

* #25691: Propagation of constant literals sill lose their constant-ness upon
  going through the broadcast machinery. I believe quite a large number of
  functions would need to be marked as `@pure` to support this -- including
  functions that are intended to be specialized.

(For bookkeeping, this is the squashed version of the [teh-jn/lazydotfuse](#25377)
branch as of a1d4e7e. Squashed and separated
out to make it easier to review and commit)

Co-authored-by: Tim Holy <[email protected]>
Co-authored-by: Jameson Nash <[email protected]>
Co-authored-by: Andrew Keller <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
broadcast Applying a function over a collection performance Must go faster
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants