Skip to content

Commit

Permalink
CTask ==> TapedTask.
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai committed Feb 20, 2022
1 parent 7891a3a commit 26c65b6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
26 changes: 13 additions & 13 deletions src/container.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
struct Trace{F}
f::F
ctask::Libtask.CTask
task::Libtask.TapedTask
end

const Particle = Trace

function Trace(f)
if hasfield(typeof(f), :evaluator) # Test whether f is a Turing.TracedModel
# println(f.evaluator)
ctask = Libtask.CTask(f.evaluator[1], f.evaluator[2:end]...)
task = Libtask.TapedTask(f.evaluator[1], f.evaluator[2:end]...)
else # f is a Function, or AdavncedPS.Model
ctask = Libtask.CTask(f)
task = Libtask.TapedTask(f)
end

# add backward reference
newtrace = Trace(f, ctask)
addreference!(ctask.task, newtrace)
newtrace = Trace(f, task)
addreference!(task.task, newtrace)

return newtrace
end

Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask))
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.task))

# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
advance!(t::Trace) = Libtask.consume(t.ctask)
advance!(t::Trace) = Libtask.consume(t.task)

# reset log probability
reset_logprob!(t::Trace) = nothing
Expand All @@ -38,7 +38,7 @@ function fork(trace::Trace, isref::Bool=false)
isref && delete_retained!(newtrace.f)

# add backward reference
addreference!(newtrace.ctask.task, newtrace)
addreference!(newtrace.task.task, newtrace)

return newtrace
end
Expand All @@ -47,16 +47,16 @@ end
# Create new task and copy randomness
function forkr(trace::Trace)
newf = reset_model(trace.f)
# ctask = Libtask.CTask(trace.ctask)
# task = Libtask.TapedTask(trace.task)
if hasfield(typeof(newf), :evaluator) # Test whether f is a Turing.TracedModel
ctask = Libtask.CTask(newf.evaluator[1], newf.evaluator[2:end]...)
task = Libtask.TapedTask(newf.evaluator[1], newf.evaluator[2:end]...)
else # f is a Function, or AdavncedPS.Model
ctask = Libtask.CTask(newf)
task = Libtask.TapedTask(newf)
end

# add backward reference
newtrace = Trace(newf, ctask)
addreference!(ctask.task, newtrace)
newtrace = Trace(newf, task)
addreference!(task.task, newtrace)

return newtrace
end
Expand Down
12 changes: 6 additions & 6 deletions test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@
# Test task copy version of trace
tr = AdvancedPS.Trace(f2)

consume(tr.ctask)
consume(tr.ctask)
consume(tr.task)
consume(tr.task)

a = AdvancedPS.fork(tr)
consume(a.ctask)
consume(a.ctask)
consume(a.task)
consume(a.task)

@test consume(tr.ctask) == 2
@test consume(a.ctask) == 4
@test consume(tr.task) == 2
@test consume(a.task) == 4
end
end

0 comments on commit 26c65b6

Please sign in to comment.