Skip to content

Commit

Permalink
keep a single message buffer per node, not per network edge
Browse files Browse the repository at this point in the history
  • Loading branch information
Krastanov committed Nov 26, 2023
1 parent 835b14d commit 08d152b
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 69 deletions.
36 changes: 22 additions & 14 deletions src/messagebuffer.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
struct MessageBuffer{T}
queue::DelayQueue{T}
buffer::Vector{T}
signalreception::Resource
env::Simulation
buffer::Vector{NamedTuple{(:src,:tag), Tuple{Int,T}}}
waiters::IdDict{Resource,Resource}
end

@resumable function take_loop_mb(env, q, mb)
@resumable function take_loop_mb(env, channel, src, mb)
while true
@yield lock(mb.signalreception)
msg = @yield take!(q)
push!(mb.buffer, msg)
unlock(mb.signalreception)
tag = @yield take!(channel)
push!(mb.buffer, (;src,tag))
for waiter in keys(mb.waiters)
unlock(waiter)
end
end
end

function MessageBuffer(q::DelayQueue{T}) where {T}
mb = MessageBuffer{T}(q, T[], ConcurrentSim.Resource(q.store.env))
@process take_loop_mb(q.store.env, q, mb)
function MessageBuffer(qs::Vector{NamedTuple{(:src,:channel), Tuple{Int, DelayQueue{T}}}}) where {T}
env = qs[1].channel.store.env
signal = IdDict{Resource,Resource}()
mb = MessageBuffer{T}(env, Tuple{Int,T}[], signal)
for (;src, channel) in qs
@process take_loop_mb(env, channel, src, mb)
end
mb
end

@resumable function wait_process(env, mb::MessageBuffer)
@yield lock(mb.signalreception)
unlock(mb.signalreception)
waitresource = Resource(env)
lock(waitresource)
mb.waiters[waitresource] = waitresource
@yield lock(waitresource)
pop!(mb.waiters, waitresource)
end

function Base.wait(mb::MessageBuffer)
@process wait_process(mb.queue.store.env, mb)
@process wait_process(mb.env, mb)
end
42 changes: 17 additions & 25 deletions src/networks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ struct RegisterNet
vertex_metadata::Vector{Dict{Symbol,Any}}
edge_metadata::Dict{Tuple{Int,Int},Dict{Symbol,Any}}
directed_edge_metadata::Dict{Pair{Int,Int},Dict{Symbol,Any}}
cchannels::Dict{Pair{Int,Int},Any}
cbuffers::Dict{Pair{Int,Int},Any}
qchannels::Dict{Pair{Int,Int},Any}
cchannels::Dict{Pair{Int,Int},DelayQueue{Tag}} # Dict{src=>dst, DelayQueue}
cbuffers::Dict{Int,MessageBuffer{Tag}} # Dict{dst, MessageBuffer}
qchannels::Dict{Pair{Int,Int},Any} # Dict{src=>dst, QuantumChannel}
reverse_lookup::IdDict{Register,Int}
end

Expand All @@ -30,17 +30,19 @@ function RegisterNet(graph::SimpleGraph, registers, vertex_metadata, edge_metada
end
end

cchannels = Dict{Pair{Int,Int},Any}()
cbuffers = Dict{Pair{Int,Int},Any}()
cchannels = Dict{Pair{Int,Int},DelayQueue{Tag}}()
qchannels = Dict{Pair{Int,Int},Any}()
for (;src,dst) in edges(graph)
cchannels[src=>dst] = DelayQueue{Tag}(env, 0)
cbuffers[src=>dst] = MessageBuffer(cchannels[src=>dst])
qchannels[src=>dst] = QuantumChannel(env, 0)
cchannels[dst=>src] = DelayQueue{Tag}(env, 0)
cbuffers[dst=>src] = MessageBuffer(cchannels[dst=>src])
qchannels[dst=>src] = QuantumChannel(env, 0)
end
cbuffers = Dict{Int,MessageBuffer{Tag}}()
for (v,r) in zip(vertices(graph), registers)
channels = [(;src=w, channel=cchannels[w=>v]) for w in neighbors(graph, v)]
cbuffers[v] = MessageBuffer(channels)
end
reverse_lookup = IdDict{Register,Int}()
for (v,r) in zip(vertices(graph), registers)
reverse_lookup[r] = v
Expand Down Expand Up @@ -119,6 +121,10 @@ end

"""Get a handle to a classical channel between two registers.
Usually used for sending classical messages between registers.
It can be used for receiving as well, but a more convenient choice is [`messagebuffer`](@ref),
which is a message buffer listening to **all** channels sending to a given destination register.
```jldoctest
julia> net = RegisterNet([Register(2), Register(2), Register(2)]) # defaults to a chain topology
A network of 3 registers in a graph of 2 edges
Expand All @@ -133,7 +139,7 @@ julia> channel(net, 1=>2) === channel(net, net[1]=>net[2])
true
```
See also: [`qchannel`](@ref)
See also: [`qchannel`](@ref), [`messagebuffer`](@ref)
"""
function channel(net::RegisterNet, args...)
return achannel(net, args..., Val{:C}())
Expand All @@ -158,33 +164,19 @@ function qchannel(net::RegisterNet, args...)
return achannel(net, args..., Val{:Q}())
end

"""Get a handle to a classical message buffer corresponding to a channel between two registers.
```jldoctest
julia> net = RegisterNet([Register(2), Register(2), Register(2)]) # defaults to a chain topology
A network of 3 registers in a graph of 2 edges
julia> qchannel(net, 1=>2)
QuantumChannel{Qubit}(Qubit(), ConcurrentSim.DelayQueue{Register}(ConcurrentSim.QueueStore{Register, Int64}, 0.0), nothing)
julia> qchannel(net, 1=>2) === qchannel(net, net[1]=>net[2])
true
```
"""Get a handle to a classical message buffer corresponding to all channels sending to a given destination register.
See also: [`channel`](@ref)
"""
function messagebuffer(net::RegisterNet, args...)
return achannel(net, args..., Val{:B}())
function messagebuffer(net::RegisterNet, dst::Int)
return net.cbuffers[dst]
end


function achannel(net::RegisterNet, src::Int, dst::Int, ::Val{Q}) where {Q}
if Q==:Q
return net.qchannels[src=>dst]
elseif Q==:C
return net.cchannels[src=>dst]
elseif Q==:B
return net.cbuffers[src=>dst]
end
end

Expand Down
55 changes: 26 additions & 29 deletions src/queries.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ end

"""A [`query`](@ref) for classical message buffers.
You are advised actually use [`querypop!`](@ref), not `query` when working with classical message buffers."""
You are advised to actually use [`querypop!`](@ref), not `query` when working with classical message buffers."""
function query(mb::MessageBuffer, tag::Tag)
i = findfirst(==(tag), mb.buffer)
return isnothing(i) ? nothing : (;slot=i, tag=mb.buffer[i])
i = findfirst(t->t.tag==tag, mb.buffer)
return isnothing(i) ? nothing : (;depth=i, src=mb.buffer[i][1], tag=mb.buffer[i][2])
end

raw"""A [`query`](@ref) for classical message buffers that also pops the message out of the buffer.
Expand All @@ -118,28 +118,21 @@ raw"""A [`query`](@ref) for classical message buffers that also pops the message
julia> net = RegisterNet([Register(3), Register(2)])
A network of 2 registers in a graph of 1 edges
julia> put!(channel(net, 1=>2), Tag(:my_tag))
ConcurrentSim.Process 5
julia> net = RegisterNet([Register(3), Register(2)])
A network of 2 registers in a graph of 1 edges
julia> put!(channel(net, 1=>2), Tag(:my_tag));
julia> put!(channel(net, 1=>2), Tag(:another_tag, 123, 456));
julia> query(messagebuffer(net, 1=>2), :my_tag)
julia> query(messagebuffer(net, 2), :my_tag)
julia> run(get_time_tracker(net))
julia> query(messagebuffer(net, 1=>2), :my_tag)
(slot = 1, tag = Symbol(:my_tag)::Tag)
julia> query(messagebuffer(net, 2), :my_tag)
(depth = 1, src = 1, tag = Symbol(:my_tag)::Tag)
julia> querypop!(messagebuffer(net, 1=>2), :my_tag)
Symbol(:my_tag)::Tag
julia> querypop!(messagebuffer(net, 2), :my_tag)
(src = 1, tag = Symbol(:my_tag)::Tag)
julia> querypop!(messagebuffer(net, 1=>2), :my_tag) === nothing
julia> querypop!(messagebuffer(net, 2), :my_tag) === nothing
true
julia> querypop!(messagebuffer(net, 1=>2), :another_tag, ❓, ❓)
Expand All @@ -149,30 +142,34 @@ julia> querypop!(messagebuffer(net, 1=>2), :another_tag, ❓, ❓) === nothing
true
```
You can also wait on a message buffer for a message to arrive before runnign a query:
You can also wait on a message buffer for a message to arrive before running a query:
```jldoctes
julia> net = RegisterNet([Register(3), Register(2)])
A network of 2 registers in a graph of 1 edges
julia> net = RegisterNet([Register(3), Register(2), Register(3)])
A network of 3 registers in a graph of 2 edges
julia> env = get_time_tracker(net);
julia> @resumable function receive_tags(env)
while true
mb = messagebuffer(net, 1=>2)
mb = messagebuffer(net, 2)
@yield wait(mb)
msg = querypop!(mb, :second_tag, ❓, ❓)
println("t=$(now(env)): query returns $msg")
print("t=$(now(env)): query returns ")
if isnothing(msg)
println("nothing")
else
println("$(msg.tag) received from node $(msg.src)")
end
end
end
receive_tags (generic function with 1 method)
julia> @resumable function send_tags(env)
@yield timeout(env, 1)
@yield timeout(env, 1.0)
put!(channel(net, 1=>2), Tag(:my_tag))
@yield timeout(env, 2)
put!(channel(net, 1=>2), Tag(:second_tag, 123, 456))
@yield timeout(env, 2.0)
put!(channel(net, 3=>2), Tag(:second_tag, 123, 456))
end
send_tags (generic function with 1 method)
Expand All @@ -182,12 +179,12 @@ julia> @process receive_tags(env);
julia> run(env, 10)
t=1.0: query returns nothing
t=3.0: query returns SymbolIntInt(:second_tag, 123, 456)::Tag
t=3.0: query returns SymbolIntInt(:second_tag, 123, 456)::Tag received from node 3
```
"""
function querypop!(mb::MessageBuffer, args...)
r = query(mb, args...)
return isnothing(r) ? nothing : popat!(mb.buffer, r.slot)
return isnothing(r) ? nothing : popat!(mb.buffer, r.depth)
end

_nothingor(l,r) = isnothing(l) || l==r
Expand Down Expand Up @@ -240,10 +237,10 @@ for (tagsymbol, tagvariant) in pairs(tag_types)
allB ? res : nothing
end end
newmethod_mb = quote function query(mb::MessageBuffer, $(argssig_wild...))
for (slot, tag) in enumerate(mb.buffer)
for (depth, (src, tag)) in pairs(mb.buffer)
if isvariant(tag, ($(tagsymbol,))[1]) # a weird workaround for interpolating a symbol as a symbol
if _all($(nonwild_checks...)) && _all($(wild_checks...))
return (;slot, tag)
return (;depth, src, tag)
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ println("Starting tests with $(Threads.nthreads()) threads out of `Sys.CPU_THREA
@doset "observable"
@doset "noninstant_and_backgrounds_qubit"
@doset "noninstant_and_backgrounds_qumode"

@doset "messagebuffer"
@doset "tags_and_queries"

@doset "circuitzoo_api"
Expand Down
32 changes: 32 additions & 0 deletions test/test_messagebuffer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using QuantumSavory
using QuantumSavory: tag_types
using ResumableFunctions, ConcurrentSim
using Test

net = RegisterNet([Register(3), Register(2), Register(3)])
env = get_time_tracker(net);
@resumable function receive_tags(env)
while true
mb = messagebuffer(net, 2)
@yield wait(mb)
msg = querypop!(mb, :second_tag, ❓, ❓)
print("t=$(now(env)): query returns ")
if isnothing(msg)
#println("nothing")
else
#println("$(msg.tag) received from node $(msg.src)")
end
end
end
@resumable function send_tags(env)
@yield timeout(env, 1.0)
put!(channel(net, 1=>2), Tag(:my_tag))
@yield timeout(env, 2.0)
put!(channel(net, 3=>2), Tag(:second_tag, 123, 456))
end
@process send_tags(env);
@process receive_tags(env);
run(env, 10)

@test query(messagebuffer(net, 2), :second_tag, ❓, ❓) === nothing
@test query(messagebuffer(net, 2), :my_tag).tag == Tag(:my_tag)

0 comments on commit 08d152b

Please sign in to comment.