Skip to content

Commit

Permalink
Merge pull request #6107 from JuliaLang/cb/fasterbroadcast
Browse files Browse the repository at this point in the history
Faster cache lookup in `broadcast!` via nested Dicts and get! macro
  • Loading branch information
carlobaldassi committed Mar 12, 2014
2 parents 7e94cfb + 56aa9eb commit 6b88580
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 66 deletions.
91 changes: 25 additions & 66 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Broadcast

using ..Cartesian
import Base.promote_eltype
import Base.@get!
import Base.num_bit_chunks, Base.@_msk_end, Base.getindex_unchecked
import Base.(.+), Base.(.-), Base.(.*), Base.(./), Base.(.\)
import Base.(.==), Base.(.<), Base.(.!=), Base.(.<=)
Expand Down Expand Up @@ -203,73 +204,31 @@ function gen_broadcast_function_tobitarray(genbody::Function, nd::Int, narrays::
end
end

let broadcast_cache = Dict()
global broadcast!
function broadcast!(f::Function, B, As::Union(Array,BitArray)...)
nd = ndims(B)
narrays = length(As)
key = (f, nd, narrays)
if !haskey(broadcast_cache, key)
func = gen_broadcast_function(gen_broadcast_body_iter, nd, narrays, f)
broadcast_cache[key] = func
else
func = broadcast_cache[key]
end
func(B, As...)
B
end
end # let broadcast_cache

let broadcast_cache = Dict()
global broadcast!
function broadcast!(f::Function, B::BitArray, As::Union(Array,BitArray)...)
nd = ndims(B)
narrays = length(As)
key = (f, nd, narrays)
if !haskey(broadcast_cache, key)
func = gen_broadcast_function_tobitarray(gen_broadcast_body_iter_tobitarray, nd, narrays, f)
broadcast_cache[key] = func
else
func = broadcast_cache[key]
end
func(B, As...)
B
end
end # let broadcast_cache

let broadcast_cache = Dict()
global broadcast!
function broadcast!(f::Function, B, As...)
nd = ndims(B)
narrays = length(As)
key = (f, nd, narrays)
if !haskey(broadcast_cache, key)
func = gen_broadcast_function(gen_broadcast_body_cartesian, nd, narrays, f)
broadcast_cache[key] = func
else
func = broadcast_cache[key]
end
func(B, As...)
B
end
end # let broadcast_cache

let broadcast_cache = Dict()
global broadcast!
function broadcast!(f::Function, B::BitArray, As...)
nd = ndims(B)
narrays = length(As)
key = (f, nd, narrays)
if !haskey(broadcast_cache, key)
func = gen_broadcast_function_tobitarray(gen_broadcast_body_cartesian_tobitarray, nd, narrays, f)
broadcast_cache[key] = func
else
func = broadcast_cache[key]
end
func(B, As...)
B
for (Bsig, Asig, gbf, gbb) in
((BitArray , Union(Array,BitArray) ,
:gen_broadcast_function_tobitarray, :gen_broadcast_body_iter_tobitarray ),
(Any , Union(Array,BitArray) ,
:gen_broadcast_function , :gen_broadcast_body_iter ),
(BitArray , Any ,
:gen_broadcast_function_tobitarray, :gen_broadcast_body_cartesian_tobitarray),
(Any , Any ,
:gen_broadcast_function , :gen_broadcast_body_cartesian ))

@eval let cache = Dict{Function,Dict{Int,Dict{Int,Function}}}()
global broadcast!
function broadcast!(f::Function, B::$Bsig, As::$Asig...)
nd = ndims(B)
narrays = length(As)

cache_f = @get! cache f Dict{Int,Dict{Int,Function}}()
cache_f_na = @get! cache_f narrays Dict{Int,Function}()
func = @get! cache_f_na nd $gbf($gbb, nd, narrays, f)

func(B, As...)
B
end
end # let broadcast_cache
end
end # let broadcast_cache


broadcast(f::Function, As...) = broadcast!(f, Array(promote_eltype(As...), broadcast_shape(As...)), As...)
Expand Down
20 changes: 20 additions & 0 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,26 @@ function get!{K,V}(default::Function, h::Dict{K,V}, key0)
return v
end

# NOTE: this macro is specific to Dict, not Associative, and should
# therefore not be exported as-is: it's for internal use only.
macro get!(h, key0, default)
quote
K, V = eltype($(esc(h)))
key = convert(K, $(esc(key0)))
isequal(key, $(esc(key0))) || error($(esc(key0)), " is not a valid key for type ", K)
idx = ht_keyindex2($(esc(h)), key)
if idx < 0
idx = -idx
v = convert(V, $(esc(default)))
_setindex!($(esc(h)), v, key, idx)
else
@inbounds v = $(esc(h)).vals[idx]
end
v
end
end


function getindex{K,V}(h::Dict{K,V}, key)
index = ht_keyindex(h, key)
return (index<0) ? throw(KeyError(key)) : h.vals[index]::V
Expand Down

0 comments on commit 6b88580

Please sign in to comment.