From 766a0065be8c36354147125f0bb2588a150af607 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 22 Oct 2020 13:31:27 -0400 Subject: [PATCH] Make Distributed.Worker threadsafe (#38134) Co-authored-by: Jonas Schulze --- stdlib/Distributed/src/cluster.jl | 8 +-- stdlib/Distributed/test/distributed_exec.jl | 1 + stdlib/Distributed/test/threads.jl | 61 +++++++++++++++++++++ 3 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 stdlib/Distributed/test/threads.jl diff --git a/stdlib/Distributed/src/cluster.jl b/stdlib/Distributed/src/cluster.jl index 0522ea060492e..323771a6c2684 100644 --- a/stdlib/Distributed/src/cluster.jl +++ b/stdlib/Distributed/src/cluster.jl @@ -99,7 +99,7 @@ mutable struct Worker add_msgs::Array{Any,1} gcflag::Bool state::WorkerState - c_state::Condition # wait for state changes + c_state::Event # wait for state changes ct_time::Float64 # creation time conn_func::Any # used to setup connections lazily @@ -133,7 +133,7 @@ mutable struct Worker if haskey(map_pid_wrkr, id) return map_pid_wrkr[id] end - w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func) + w=new(id, [], [], false, W_CREATED, Event(), time(), conn_func) w.initialized = Event() register_worker(w) w @@ -144,7 +144,7 @@ end function set_worker_state(w, state) w.state = state - notify(w.c_state; all=true) + notify(w.c_state) end function check_worker_state(w::Worker) @@ -189,7 +189,7 @@ function wait_for_conn(w) timeout = worker_timeout() - (time() - w.ct_time) timeout <= 0 && error("peer $(w.id) has not connected to $(myid())") - @async (sleep(timeout); notify(w.c_state; all=true)) + @async (sleep(timeout); notify(w.c_state)) wait(w.c_state) w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds") end diff --git a/stdlib/Distributed/test/distributed_exec.jl b/stdlib/Distributed/test/distributed_exec.jl index 367ab3f3e4930..066f48e8c369c 100644 --- a/stdlib/Distributed/test/distributed_exec.jl +++ b/stdlib/Distributed/test/distributed_exec.jl @@ -1711,4 +1711,5 @@ include("splitrange.jl") # Run topology tests last after removing all workers, since a given # cluster at any time only supports a single topology. rmprocs(workers()) +include("threads.jl") include("topology.jl") diff --git a/stdlib/Distributed/test/threads.jl b/stdlib/Distributed/test/threads.jl new file mode 100644 index 0000000000000..2f4127be384ba --- /dev/null +++ b/stdlib/Distributed/test/threads.jl @@ -0,0 +1,61 @@ +using Test +using Distributed, Base.Threads +using Base.Iterators: product + +exeflags = ("--startup-file=no", + "--check-bounds=yes", + "--depwarn=error", + "--threads=2") + +function call_on(f, wid, tid) + remotecall(wid) do + t = Task(f) + ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1) + schedule(t) + @assert threadid(t) == tid + t + end +end + +# Run function on process holding the data to only serialize the result of f. +# This becomes useful for things that cannot be serialized (e.g. running tasks) +# or that would be unnecessarily big if serialized. +fetch_from_owner(f, rr) = remotecall_fetch(f∘fetch, rr.where, rr) + +isdone(rr) = fetch_from_owner(istaskdone, rr) +isfailed(rr) = fetch_from_owner(istaskfailed, rr) + +@testset "RemoteChannel allows put!/take! from thread other than 1" begin + ws = ts = product(1:2, 1:2) + @testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws + @testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts + # We want (the default) lazyness, so that we wait for `Worker.c_state`! + procs_added = addprocs(2; exeflags, lazy=true) + @everywhere procs_added using Base.Threads + p1 = procs_added[w1] + p2 = procs_added[w2] + chan_id = first(procs_added) + chan = RemoteChannel(chan_id) + send = call_on(p1, t1) do + put!(chan, nothing) + end + recv = call_on(p2, t2) do + take!(chan) + end + + # Wait on the spawned tasks on the owner + @sync begin + @async fetch_from_owner(wait, recv) + @async fetch_from_owner(wait, send) + end + + # Check the tasks + @test isdone(send) + @test isdone(recv) + + @test !isfailed(send) + @test !isfailed(recv) + rmprocs(procs_added) + end + end +end