Skip to content

Commit

Permalink
[ci skip] WIP: Refactor pmap
Browse files Browse the repository at this point in the history
New functions:

 - head_and_tail  -- like take and rest but atomic
 - batchsplit     -- like split, but aware of nworkers
 - generate       -- shorthand for creating a Genertor
 - asyncgenerate  -- generate using tasks
 - asyncmap       -- map using tasks
 - pgenerate      -- generate using tasks and workers.

Reimplement pmap:

    pmap(f, c...) = collect(pgenerate(f, c...))
  • Loading branch information
samoconnor committed Mar 22, 2016
1 parent 1600981 commit 0cb8943
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 78 deletions.
2 changes: 2 additions & 0 deletions base/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ function next(g::Generator, s)
g.f(v), s2
end

generate(f, c...) = Generator(f, c...)

collect(g::Generator) = map(g.f, g.iter)
19 changes: 19 additions & 0 deletions base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,25 @@ done(i::Rest, st) = done(i.itr, st)

eltype{I}(::Type{Rest{I}}) = eltype(I)


"""
head_and_tail(c, n) -> head, tail
Returns `head`: the first `n` elements of `c`;
and `tail`: an iterator over the remaining elements.
"""
function head_and_tail(c, n)
head = Vector{eltype(c)}(n)
s = start(c)
i = 0
while i < n && !done(c, s)
i += 1
head[i], s = next(c, s)
end
return resize!(head, i), rest(c, s)
end


# Count -- infinite counting

immutable Count{S<:Number}
Expand Down
23 changes: 22 additions & 1 deletion base/mapiterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end
Apply f to each element of c using at most 100 asynchronous tasks.
For multiple collection arguments, apply f elementwise.
The iterator returns results as the become available.
Results are returned by the iterator as they become available.
Note: `collect(StreamMapIterator(f, c...; ntasks=1))` is equivalent to
`map(f, c...)`.
"""
Expand Down Expand Up @@ -144,3 +144,24 @@ function next(itr::StreamMapIterator, state::StreamMapState)

return (r, state)
end



"""
asyncgenerate(f, c...) -> iterator
Apply `@async f` to each element of `c`.
For multiple collection arguments, apply f elementwise.
Results are returned by the iterator as they become available.
"""
asyncgenerate(f, c...) = StreamMapIterator(f, c...)



"""
asyncmap(f, c...) -> collection
Transform collection `c` by applying `@async f` to each element.
For multiple collection arguments, apply f elementwise.
"""
asyncmap(f, c...) = collect(asyncgenerate(f, c...))
146 changes: 69 additions & 77 deletions base/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1522,93 +1522,85 @@ end

pmap(f) = f()

# dynamic scheduling by creating a local task to feed work to each processor
# as it finishes.
# example unbalanced workload:
# rsym(n) = (a=rand(n,n);a*a')
# L = {rsym(200),rsym(1000),rsym(200),rsym(1000),rsym(200),rsym(1000),rsym(200),rsym(1000)};
# pmap(eig, L);
function pmap(f, lsts...; err_retry=true, err_stop=false, pids = workers())
len = length(lsts)

results = Dict{Int,Any}()

busy_workers = fill(false, length(pids))
busy_workers_ntfy = Condition()

retryqueue = []
task_in_err = false
is_task_in_error() = task_in_err
set_task_in_error() = (task_in_err = true)

nextidx = 0
getnextidx() = (nextidx += 1)

states = [start(lsts[idx]) for idx in 1:len]
function getnext_tasklet()
if is_task_in_error() && err_stop
return nothing
elseif !any(idx->done(lsts[idx],states[idx]), 1:len)
nxts = [next(lsts[idx],states[idx]) for idx in 1:len]
for idx in 1:len; states[idx] = nxts[idx][2]; end
nxtvals = [x[1] for x in nxts]
return (getnextidx(), nxtvals)
elseif !isempty(retryqueue)
return shift!(retryqueue)
elseif err_retry
# Handles the condition where we have finished processing the requested lsts as well
# as any retryqueue entries, but there are still some jobs active that may result
# in an error and have to be retried.
while any(busy_workers)
wait(busy_workers_ntfy)
if !isempty(retryqueue)
return shift!(retryqueue)
end
end
return nothing
else
return nothing
end

"""
pgenerate(f, c...)
asyncgenerate(f, c...) -> iterator
Apply `f` to each element of `c` in parallel using available workers and tasks.
For multiple collection arguments, apply f elementwise.
Results are returned by the iterator as they become available.
"""
function pgenerate(f, c)
if nworkers() == 1
return asyncgenerate(f, c)
end
return flatten(asyncgenerate(remote(b -> asyncmap(f, b)), batchsplit(c)))
end

@sync begin
for (pididx, wpid) in enumerate(pids)
@async begin
tasklet = getnext_tasklet()
while (tasklet !== nothing)
(idx, fvals) = tasklet
busy_workers[pididx] = true
try
results[idx] = remotecall_fetch(f, wpid, fvals...)
catch ex
if err_retry
push!(retryqueue, (idx,fvals, ex))
else
results[idx] = ex
end
set_task_in_error()

busy_workers[pididx] = false
notify(busy_workers_ntfy; all=true)

break # remove this worker from accepting any more tasks
end
pgenerate(f, c1, c...) = pgenerate(a->f(a...), zip(c1, c...))

busy_workers[pididx] = false
notify(busy_workers_ntfy; all=true)

tasklet = getnext_tasklet()
end
end
function pmap(f, c...; err_retry=nothing, err_stop=nothing, pids=nothing)

if err_retry != nothing
depwarn("`err_retry` is deprecated, use `pmap(retry(f), c...)`.", :pmap)
if err_retry == true
f = retry(f)
end
end

if err_stop != nothing
depwarn("`err_stop` is deprecated, use `pmap(@catch(f), c...).", :pmap)
if err_stop == false
f = @catch(f)
end
end

for failure in retryqueue
results[failure[1]] = failure[3]
if pids != nothing
depwarn("`pids` is deprecated. It no longer has any effect.", :pmap)
end
[results[x] for x in 1:nextidx]

return collect(pgenerate(f, c...))
end



"""
batchsplit(c; min_batch_count=0, max_batch_size=100)
Split a collection into at least `min_batch_count` batches.
If `min_batch_count` is not specified `batchsplit` attempts
to produce enough batches to allow work to be distributed amongst
available workers.
Equivalent to `split(c, batch_size)` when `length(c) >> max_batch_size`.
"""
function batchsplit(c; min_batch_count=0, max_batch_size=100)

if min_batch_count == 0
min_batch_count = nworkers() * 3
end

@assert min_batch_count > 0
@assert max_batch_size > 1

# Split collection into batches, then peek at the first few batches...
batches = split(c, max_batch_size)
head, tail = head_and_tail(batches, min_batch_count)

# If there are not enough batches, use a smaller batch size...
if length(head) < min_batch_count
head = vcat(head...)
batch_size = max(1, div(length(head), min_batch_count))
return split(head, batch_size)
end

return flatten((head, tail))
end


# Statically split range [1,N] into equal sized chunks for np processors
function splitrange(N::Int, np::Int)
each = div(N,np)
Expand Down

0 comments on commit 0cb8943

Please sign in to comment.