diff --git a/src/container.jl b/src/container.jl index 8631938e..2b27a4fa 100644 --- a/src/container.jl +++ b/src/container.jl @@ -1,6 +1,6 @@ struct Trace{F} f::F - ctask::Libtask.CTask + task::Libtask.TapedTask end const Particle = Trace @@ -8,23 +8,23 @@ const Particle = Trace function Trace(f) if hasfield(typeof(f), :evaluator) # Test whether f is a Turing.TracedModel # println(f.evaluator) - ctask = Libtask.CTask(f.evaluator[1], f.evaluator[2:end]...) + task = Libtask.TapedTask(f.evaluator[1], f.evaluator[2:end]...) else # f is a Function, or AdavncedPS.Model - ctask = Libtask.CTask(f) + task = Libtask.TapedTask(f) end # add backward reference - newtrace = Trace(f, ctask) - addreference!(ctask.task, newtrace) + newtrace = Trace(f, task) + addreference!(task.task, newtrace) return newtrace end -Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask)) +Base.copy(trace::Trace) = Trace(trace.f, copy(trace.task)) # step to the next observe statement and # return the log probability of the transition (or nothing if done) -advance!(t::Trace) = Libtask.consume(t.ctask) +advance!(t::Trace) = Libtask.consume(t.task) # reset log probability reset_logprob!(t::Trace) = nothing @@ -38,7 +38,7 @@ function fork(trace::Trace, isref::Bool=false) isref && delete_retained!(newtrace.f) # add backward reference - addreference!(newtrace.ctask.task, newtrace) + addreference!(newtrace.task.task, newtrace) return newtrace end @@ -47,16 +47,16 @@ end # Create new task and copy randomness function forkr(trace::Trace) newf = reset_model(trace.f) - # ctask = Libtask.CTask(trace.ctask) + # task = Libtask.TapedTask(trace.task) if hasfield(typeof(newf), :evaluator) # Test whether f is a Turing.TracedModel - ctask = Libtask.CTask(newf.evaluator[1], newf.evaluator[2:end]...) + task = Libtask.TapedTask(newf.evaluator[1], newf.evaluator[2:end]...) else # f is a Function, or AdavncedPS.Model - ctask = Libtask.CTask(newf) + task = Libtask.TapedTask(newf) end # add backward reference - newtrace = Trace(newf, ctask) - addreference!(ctask.task, newtrace) + newtrace = Trace(newf, task) + addreference!(task.task, newtrace) return newtrace end diff --git a/test/container.jl b/test/container.jl index ac0b2bf2..1e9e2d13 100644 --- a/test/container.jl +++ b/test/container.jl @@ -109,14 +109,14 @@ # Test task copy version of trace tr = AdvancedPS.Trace(f2) - consume(tr.ctask) - consume(tr.ctask) + consume(tr.task) + consume(tr.task) a = AdvancedPS.fork(tr) - consume(a.ctask) - consume(a.ctask) + consume(a.task) + consume(a.task) - @test consume(tr.ctask) == 2 - @test consume(a.ctask) == 4 + @test consume(tr.task) == 2 + @test consume(a.task) == 4 end end