diff --git a/Project.toml b/Project.toml index a1db18f3..77943aaa 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.6.9" +version = "0.10.0" [deps] IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" diff --git a/README.md b/README.md index ca819f1d..661476c4 100644 --- a/README.md +++ b/README.md @@ -19,17 +19,17 @@ function f() end end -ctask = CTask(f) +ttask = TapedTask(f) -@show consume(ctask) # 0 -@show consume(ctask) # 1 +@show consume(ttask) # 0 +@show consume(ttask) # 1 -a = copy(ctask) +a = copy(ttask) @show consume(a) # 2 @show consume(a) # 3 -@show consume(ctask) # 2 -@show consume(ctask) # 3 +@show consume(ttask) # 2 +@show consume(ttask) # 3 ``` Heap allocated objects are shallow copied: @@ -45,17 +45,17 @@ function f() end end -ctask = CTask(f) +ttask = TapedTask(f) -@show consume(ctask) # 0 -@show consume(ctask) # 1 +@show consume(ttask) # 0 +@show consume(ttask) # 1 a = copy(t) @show consume(a) # 2 @show consume(a) # 3 -@show consume(ctask) # 4 -@show consume(ctask) # 5 +@show consume(ttask) # 4 +@show consume(ttask) # 5 ``` In constrast to standard arrays, which are only shallow copied during @@ -74,17 +74,17 @@ function f() end end -ctask = CTask(f) +ttask = TapedTask(f) -@show consume(ctask) # 0 -@show consume(ctask) # 1 +@show consume(ttask) # 0 +@show consume(ttask) # 1 -a = copy(ctask) +a = copy(ttask) @show consume(a) # 2 @show consume(a) # 3 -@show consume(ctask) # 2 -@show consume(ctask) # 3 +@show consume(ttask) # 2 +@show consume(ttask) # 3 ``` Note: The [Turing](https://github.com/TuringLang/Turing.jl) diff --git a/perf/p0.jl b/perf/p0.jl index c1f1ef00..14f629ab 100644 --- a/perf/p0.jl +++ b/perf/p0.jl @@ -1,5 +1,3 @@ -# ]add Turing#hg/new-libtask2 - using Libtask using Turing, DynamicPPL, AdvancedPS using BenchmarkTools @@ -26,8 +24,8 @@ args = m.evaluator[2:end]; @btime f(args...) # (2.0, VarInfo (2 variables (μ, σ), dimension 2; logp: -6.162)) -@show "CTask construction..." -t = @btime Libtask.CTask(f, args...) +@show "TapedTask construction..." +t = @btime TapedTask(f, args...) # schedule(t.task) # work fine! # @show Libtask.result(t.tf) @show "Run a tape..." @@ -39,8 +37,8 @@ m = Turing.Core.TracedModel(gdemo(1.5, 2.), Sampler(SMC(50)), VarInfo()); @show "Directly call..." @btime m.evaluator[1](m.evaluator[2:end]...) -@show "CTask construction..." -t = @btime Libtask.CTask(m.evaluator[1], m.evaluator[2:end]...); +@show "TapedTask construction..." +t = @btime TapedTask(m.evaluator[1], m.evaluator[2:end]...); # schedule(t.task) # @show Libtask.result(t.tf.tape) @show "Run a tape..." diff --git a/perf/p2.jl b/perf/p2.jl index a5a4357d..a95b3d3c 100644 --- a/perf/p2.jl +++ b/perf/p2.jl @@ -56,7 +56,7 @@ m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo()) f = m.evaluator[1] args = m.evaluator[2:end] -t = Libtask.CTask(f, args...) +t = TapedTask(f, args...) t.tf(args...) diff --git a/src/Libtask.jl b/src/Libtask.jl index 59795b98..75421d3c 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -5,17 +5,13 @@ using MacroTools using LRUCache -export CTask, consume, produce +export TapedTask, consume, produce export TArray, tzeros, tfill, TRef -export TapedTask - include("tapedfunction.jl") include("tapedtask.jl") include("tarray.jl") include("tref.jl") -CTask = TapedTask - end diff --git a/src/tapedtask.jl b/src/tapedtask.jl index 2c8dee4c..013542b6 100644 --- a/src/tapedtask.jl +++ b/src/tapedtask.jl @@ -3,16 +3,20 @@ struct TapedTaskException backtrace::Vector{Any} end -struct TapedTask +struct TapedTask{F} task::Task - tf::TapedFunction + tf::TapedFunction{F} produce_ch::Channel{Any} consume_ch::Channel{Int} produced_val::Vector{Any} function TapedTask( - t::Task, tf::TapedFunction, pch::Channel{Any}, cch::Channel{Int}) - new(t, tf, pch, cch, Any[]) + t::Task, + tf::TapedFunction{F}, + produce_ch::Channel{Any}, + consume_ch::Channel{Int} + ) where {F} + new{F}(t, tf, produce_ch, consume_ch, Any[]) end end @@ -148,8 +152,8 @@ function Base.iterate(t::TapedTask, state=nothing) nothing end end -Base.IteratorSize(::Type{TapedTask}) = Base.SizeUnknown() -Base.IteratorEltype(::Type{TapedTask}) = Base.EltypeUnknown() +Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown() +Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown() # copy the task diff --git a/test/benchmarks.jl b/test/benchmarks.jl index f71b271a..f742ae9b 100644 --- a/test/benchmarks.jl +++ b/test/benchmarks.jl @@ -1,7 +1,6 @@ using BenchmarkTools using Libtask - macro rep(cnt, exp) blk =:(begin end) for _ in 1:eval(cnt) @@ -47,10 +46,10 @@ function f() end @btime begin - ctask = CTask(f) - consume(ctask) - consume(ctask) - a = copy(ctask) + ttask = TapedTask(f) + consume(ttask) + consume(ttask) + a = copy(ttask) consume(a) consume(a) end diff --git a/test/runtests.jl b/test/runtests.jl index 24c28f6e..8cb0fef1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,10 +2,10 @@ using Libtask using Test include("tf.jl") -include("ctask.jl") +include("tapedtask.jl") include("tarray.jl") include("tref.jl") -if get(ENV, "BENCHMARK", nothing) != nothing +if haskey(ENV, "BENCHMARK") include("benchmarks.jl") end diff --git a/test/ctask.jl b/test/tapedtask.jl similarity index 69% rename from test/ctask.jl rename to test/tapedtask.jl index cb85908c..041431ab 100644 --- a/test/ctask.jl +++ b/test/tapedtask.jl @@ -1,4 +1,4 @@ -@testset "ctask" begin +@testset "tapedtask" begin # Test case 1: stack allocated objects are deep copied. @testset "stack allocated objects" begin function f() @@ -9,14 +9,14 @@ end end - ctask = CTask(f) - @test consume(ctask) == 0 - @test consume(ctask) == 1 - a = copy(ctask) + ttask = TapedTask(f) + @test consume(ttask) == 0 + @test consume(ttask) == 1 + a = copy(ttask) @test consume(a) == 2 @test consume(a) == 3 - @test consume(ctask) == 2 - @test consume(ctask) == 3 + @test consume(ttask) == 2 + @test consume(ttask) == 3 @inferred Libtask.TapedFunction(f) end @@ -31,16 +31,16 @@ end end - ctask = CTask(f) - @test consume(ctask) == 0 - @test consume(ctask) == 1 - a = copy(ctask) + ttask = TapedTask(f) + @test consume(ttask) == 0 + @test consume(ttask) == 1 + a = copy(ttask) @test consume(a) == 2 @test consume(a) == 3 - @test consume(ctask) == 4 - @test consume(ctask) == 5 - @test consume(ctask) == 6 - @test consume(ctask) == 7 + @test consume(ttask) == 4 + @test consume(ttask) == 5 + @test consume(ttask) == 6 + @test consume(ttask) == 7 end @testset "iteration" begin @@ -52,20 +52,20 @@ end end - ctask = CTask(f) + ttask = TapedTask(f) - next = iterate(ctask) + next = iterate(ttask) @test next === (1, nothing) val, state = next - next = iterate(ctask, state) + next = iterate(ttask, state) @test next === (2, nothing) val, state = next - next = iterate(ctask, state) + next = iterate(ttask, state) @test next === (3, nothing) - a = collect(Iterators.take(ctask, 7)) + a = collect(Iterators.take(ttask, 7)) @test eltype(a) === Int @test a == 4:10 end @@ -82,14 +82,14 @@ end end - ctask = CTask(f) + ttask = TapedTask(f) try - consume(ctask) + consume(ttask) catch ex @test ex isa MethodError end if VERSION >= v"1.5" - @test ctask.task.exception isa MethodError + @test ttask.task.exception isa MethodError end end @@ -103,14 +103,14 @@ end end - ctask = CTask(f) + ttask = TapedTask(f) try - consume(ctask) + consume(ttask) catch ex @test ex isa ErrorException end if VERSION >= v"1.5" - @test ctask.task.exception isa ErrorException + @test ttask.task.exception isa ErrorException end end @@ -125,14 +125,14 @@ end end - ctask = CTask(f) + ttask = TapedTask(f) try - consume(ctask) + consume(ttask) catch ex @test ex isa BoundsError end if VERSION >= v"1.5" - @test ctask.task.exception isa BoundsError + @test ttask.task.exception isa BoundsError end end @@ -147,15 +147,15 @@ end end - ctask = CTask(f) - @test consume(ctask) == 2 + ttask = TapedTask(f) + @test consume(ttask) == 2 try - consume(ctask) + consume(ttask) catch ex @test ex isa BoundsError end if VERSION >= v"1.5" - @test ctask.task.exception isa BoundsError + @test ttask.task.exception isa BoundsError end end @@ -170,17 +170,17 @@ end end - ctask = CTask(f) - @test consume(ctask) == 2 - ctask2 = copy(ctask) + ttask = TapedTask(f) + @test consume(ttask) == 2 + ttask2 = copy(ttask) try - consume(ctask2) + consume(ttask2) catch ex @test ex isa BoundsError end - @test ctask.task.exception === nothing + @test ttask.task.exception === nothing if VERSION >= v"1.5" - @test ctask2.task.exception isa BoundsError + @test ttask2.task.exception isa BoundsError end end end diff --git a/test/tarray.jl b/test/tarray.jl index 7f3eca30..ca99f46d 100644 --- a/test/tarray.jl +++ b/test/tarray.jl @@ -130,15 +130,15 @@ end end - ctask = CTask(f) + ttask = TapedTask(f) - consume(ctask) - consume(ctask) - a = copy(ctask) + consume(ttask) + consume(ttask) + a = copy(ttask) consume(a) consume(a) - @test consume(ctask) == 2 + @test consume(ttask) == 2 @test consume(a) == 4 DATA = Dict{Task, Array}() @@ -151,18 +151,18 @@ end end - ctask = CTask(g) - @test consume(ctask) == hash(ctask.task) # index = 1 - @test consume(ctask) == hash(ctask.task) # index = 2 + ttask = TapedTask(g) + @test consume(ttask) == hash(ttask.task) # index = 1 + @test consume(ttask) == hash(ttask.task) # index = 2 - a = copy(ctask) + a = copy(ttask) @test consume(a) == hash(a.task) # index = 3 @test consume(a) == hash(a.task) # index = 4 - @test consume(ctask) == hash(ctask.task) # index = 3 + @test consume(ttask) == hash(ttask.task) # index = 3 - @test DATA[ctask.task] == [hash(ctask.task), hash(ctask.task), hash(ctask.task), 0] - @test DATA[a.task] == [hash(ctask.task), hash(ctask.task), hash(a.task), hash(a.task)] + @test DATA[ttask.task] == [hash(ttask.task), hash(ttask.task), hash(ttask.task), 0] + @test DATA[a.task] == [hash(ttask.task), hash(ttask.task), hash(a.task), hash(a.task)] end @testset "Issue: PR-86 (DynamicPPL.jl/pull/261)" begin @@ -177,13 +177,13 @@ end - ctask = CTask(f) + ttask = TapedTask(f) ex = try for _ in 1:999 - consume(ctask) - consume(ctask) - a = copy(ctask) + consume(ttask) + consume(ttask) + a = copy(ttask) consume(a) consume(a) end diff --git a/test/tref.jl b/test/tref.jl index e080bef4..b4f8fb1a 100644 --- a/test/tref.jl +++ b/test/tref.jl @@ -11,7 +11,7 @@ end end - ctask = CTask(f) + ctask = TapedTask(f) consume(ctask) consume(ctask) @@ -36,7 +36,7 @@ end end - ctask = CTask(f) + ctask = TapedTask(f) consume(ctask) consume(ctask)