Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Event; use it to fix race in Distributed setup #29623

Merged
merged 2 commits into from
Oct 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions base/locks.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

import .Base: _uv_hook_close, unsafe_convert,
lock, trylock, unlock, islocked
lock, trylock, unlock, islocked, wait, notify

export SpinLock, RecursiveSpinLock, Mutex
export SpinLock, RecursiveSpinLock, Mutex, Event


##########################################
Expand Down Expand Up @@ -238,3 +238,57 @@ end
function islocked(m::Mutex)
return m.ownertid != 0
end

"""
Event()

Create a level-triggered event source. Tasks that call [`wait`](@ref) on an
`Event` are suspended and queued until `notify` is called on the `Event`.
After `notify` is called, the `Event` remains in a signaled state and
tasks will no longer block when waiting for it.
"""
mutable struct Event
lock::Mutex
q::Vector{Task}
set::Bool
# TODO: use a Condition with its paired lock
Event() = new(Mutex(), Task[], false)
end

function wait(e::Event)
e.set && return
lock(e.lock)
while !e.set
ct = current_task()
push!(e.q, ct)
unlock(e.lock)
try
wait()
catch
filter!(x->x!==ct, e.q)
rethrow()
end
lock(e.lock)
end
unlock(e.lock)
return nothing
end

function notify(e::Event)
lock(e.lock)
if !e.set
e.set = true
for t in e.q
schedule(t)
end
empty!(e.q)
end
unlock(e.lock)
return nothing
end

# TODO: decide what to call this
#function clear(e::Event)
# e.set = false
# return nothing
#end
2 changes: 2 additions & 0 deletions stdlib/Distributed/src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ mutable struct Worker
manager::ClusterManager
config::WorkerConfig
version::Union{VersionNumber, Nothing} # Julia version of the remote process
initialized::Threads.Event

function Worker(id::Int, r_stream::IO, w_stream::IO, manager::ClusterManager;
version::Union{VersionNumber, Nothing}=nothing,
Expand All @@ -90,6 +91,7 @@ mutable struct Worker
return map_pid_wrkr[id]
end
w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func)
w.initialized = Threads.Event()
register_worker(w)
w
end
Expand Down
3 changes: 3 additions & 0 deletions stdlib/Distributed/src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ end

function send_msg_(w::Worker, header, msg, now::Bool)
check_worker_state(w)
if myid() != 1 && !isa(msg, IdentifySocketMsg) && !isa(msg, IdentifySocketAckMsg)
wait(w.initialized)
end
io = w.w_stream
lock(io.lock)
try
Expand Down
5 changes: 4 additions & 1 deletion stdlib/Distributed/src/process_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,10 @@ end

function handle_msg(msg::IdentifySocketMsg, header, r_stream, w_stream, version)
# register a new peer worker connection
w=Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version)
w = Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version)
send_connection_hdr(w, false)
send_msg_now(w, MsgHeader(), IdentifySocketAckMsg())
notify(w.initialized)
end

function handle_msg(msg::IdentifySocketAckMsg, header, r_stream, w_stream, version)
Expand All @@ -301,6 +302,7 @@ end
function handle_msg(msg::JoinPGRPMsg, header, r_stream, w_stream, version)
LPROC.id = msg.self_pid
controller = Worker(1, r_stream, w_stream, cluster_manager; version=version)
notify(controller.initialized)
register_worker(LPROC)
topology(msg.topology)

Expand Down Expand Up @@ -340,6 +342,7 @@ function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConf
process_messages(w.r_stream, w.w_stream, false)
send_connection_hdr(w, true)
send_msg_now(w, MsgHeader(), IdentifySocketMsg(myid()))
notify(w.initialized)
catch e
@error "Error on $(myid()) while connecting to peer $rpid, exiting" exception=e,catch_backtrace()
exit(1)
Expand Down
13 changes: 13 additions & 0 deletions test/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,16 @@ function test_thread_too_few_iters()
@test !(true in found[nthreads():end])
end
test_thread_too_few_iters()

let e = Event()
done = false
t = @async (wait(e); done = true)
sleep(0.1)
@test done == false
notify(e)
wait(t)
@test done == true
blocked = true
wait(@async (wait(e); blocked = false))
@test !blocked
end