diff --git a/base/locks.jl b/base/locks.jl index 94acb61b73e9f..425785b910fda 100644 --- a/base/locks.jl +++ b/base/locks.jl @@ -1,9 +1,9 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license import .Base: _uv_hook_close, unsafe_convert, - lock, trylock, unlock, islocked + lock, trylock, unlock, islocked, wait, notify -export SpinLock, RecursiveSpinLock, Mutex +export SpinLock, RecursiveSpinLock, Mutex, Event ########################################## @@ -238,3 +238,57 @@ end function islocked(m::Mutex) return m.ownertid != 0 end + +""" + Event() + +Create a level-triggered event source. Tasks that call [`wait`](@ref) on an +`Event` are suspended and queued until `notify` is called on the `Event`. +After `notify` is called, the `Event` remains in a signaled state and +tasks will no longer block when waiting for it. +""" +mutable struct Event + lock::Mutex + q::Vector{Task} + set::Bool + # TODO: use a Condition with its paired lock + Event() = new(Mutex(), Task[], false) +end + +function wait(e::Event) + e.set && return + lock(e.lock) + while !e.set + ct = current_task() + push!(e.q, ct) + unlock(e.lock) + try + wait() + catch + filter!(x->x!==ct, e.q) + rethrow() + end + lock(e.lock) + end + unlock(e.lock) + return nothing +end + +function notify(e::Event) + lock(e.lock) + if !e.set + e.set = true + for t in e.q + schedule(t) + end + empty!(e.q) + end + unlock(e.lock) + return nothing +end + +# TODO: decide what to call this +#function clear(e::Event) +# e.set = false +# return nothing +#end diff --git a/stdlib/Distributed/src/cluster.jl b/stdlib/Distributed/src/cluster.jl index e1d4bae511d91..b57f1bb8f19dd 100644 --- a/stdlib/Distributed/src/cluster.jl +++ b/stdlib/Distributed/src/cluster.jl @@ -67,6 +67,7 @@ mutable struct Worker manager::ClusterManager config::WorkerConfig version::Union{VersionNumber, Nothing} # Julia version of the remote process + initialized::Threads.Event function Worker(id::Int, r_stream::IO, w_stream::IO, manager::ClusterManager; version::Union{VersionNumber, Nothing}=nothing, @@ -90,6 +91,7 @@ mutable struct Worker return map_pid_wrkr[id] end w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func) + w.initialized = Threads.Event() register_worker(w) w end diff --git a/stdlib/Distributed/src/messages.jl b/stdlib/Distributed/src/messages.jl index 35b090f703868..bbfb13f276fa5 100644 --- a/stdlib/Distributed/src/messages.jl +++ b/stdlib/Distributed/src/messages.jl @@ -174,6 +174,9 @@ end function send_msg_(w::Worker, header, msg, now::Bool) check_worker_state(w) + if myid() != 1 && !isa(msg, IdentifySocketMsg) && !isa(msg, IdentifySocketAckMsg) + wait(w.initialized) + end io = w.w_stream lock(io.lock) try diff --git a/stdlib/Distributed/src/process_messages.jl b/stdlib/Distributed/src/process_messages.jl index 06a10c4081d73..1a1f2f5f8a770 100644 --- a/stdlib/Distributed/src/process_messages.jl +++ b/stdlib/Distributed/src/process_messages.jl @@ -288,9 +288,10 @@ end function handle_msg(msg::IdentifySocketMsg, header, r_stream, w_stream, version) # register a new peer worker connection - w=Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version) + w = Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version) send_connection_hdr(w, false) send_msg_now(w, MsgHeader(), IdentifySocketAckMsg()) + notify(w.initialized) end function handle_msg(msg::IdentifySocketAckMsg, header, r_stream, w_stream, version) @@ -301,6 +302,7 @@ end function handle_msg(msg::JoinPGRPMsg, header, r_stream, w_stream, version) LPROC.id = msg.self_pid controller = Worker(1, r_stream, w_stream, cluster_manager; version=version) + notify(controller.initialized) register_worker(LPROC) topology(msg.topology) @@ -340,6 +342,7 @@ function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConf process_messages(w.r_stream, w.w_stream, false) send_connection_hdr(w, true) send_msg_now(w, MsgHeader(), IdentifySocketMsg(myid())) + notify(w.initialized) catch e @error "Error on $(myid()) while connecting to peer $rpid, exiting" exception=e,catch_backtrace() exit(1) diff --git a/test/threads.jl b/test/threads.jl index 960715761de8b..7b79b141b660b 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -503,3 +503,16 @@ function test_thread_too_few_iters() @test !(true in found[nthreads():end]) end test_thread_too_few_iters() + +let e = Event() + done = false + t = @async (wait(e); done = true) + sleep(0.1) + @test done == false + notify(e) + wait(t) + @test done == true + blocked = true + wait(@async (wait(e); blocked = false)) + @test !blocked +end