Skip to content

Commit

Permalink
Remove CTask
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer committed Feb 20, 2022
1 parent c81dab8 commit 472e0de
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 100 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.6.9"
version = "0.10.0"

[deps]
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
Expand Down
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ function f()
end
end

ctask = CTask(f)
ttask = TapedTask(f)

@show consume(ctask) # 0
@show consume(ctask) # 1
@show consume(ttask) # 0
@show consume(ttask) # 1

a = copy(ctask)
a = copy(ttask)
@show consume(a) # 2
@show consume(a) # 3

@show consume(ctask) # 2
@show consume(ctask) # 3
@show consume(ttask) # 2
@show consume(ttask) # 3
```

Heap allocated objects are shallow copied:
Expand All @@ -45,17 +45,17 @@ function f()
end
end

ctask = CTask(f)
ttask = TapedTask(f)

@show consume(ctask) # 0
@show consume(ctask) # 1
@show consume(ttask) # 0
@show consume(ttask) # 1

a = copy(t)
@show consume(a) # 2
@show consume(a) # 3

@show consume(ctask) # 4
@show consume(ctask) # 5
@show consume(ttask) # 4
@show consume(ttask) # 5
```

In constrast to standard arrays, which are only shallow copied during
Expand All @@ -74,17 +74,17 @@ function f()
end
end

ctask = CTask(f)
ttask = TapedTask(f)

@show consume(ctask) # 0
@show consume(ctask) # 1
@show consume(ttask) # 0
@show consume(ttask) # 1

a = copy(ctask)
a = copy(ttask)
@show consume(a) # 2
@show consume(a) # 3

@show consume(ctask) # 2
@show consume(ctask) # 3
@show consume(ttask) # 2
@show consume(ttask) # 3
```

Note: The [Turing](https://github.com/TuringLang/Turing.jl)
Expand Down
10 changes: 4 additions & 6 deletions perf/p0.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# ]add Turing#hg/new-libtask2

using Libtask
using Turing, DynamicPPL, AdvancedPS
using BenchmarkTools
Expand All @@ -26,8 +24,8 @@ args = m.evaluator[2:end];
@btime f(args...)
# (2.0, VarInfo (2 variables (μ, σ), dimension 2; logp: -6.162))

@show "CTask construction..."
t = @btime Libtask.CTask(f, args...)
@show "TapedTask construction..."
t = @btime TapedTask(f, args...)
# schedule(t.task) # work fine!
# @show Libtask.result(t.tf)
@show "Run a tape..."
Expand All @@ -39,8 +37,8 @@ m = Turing.Core.TracedModel(gdemo(1.5, 2.), Sampler(SMC(50)), VarInfo());
@show "Directly call..."
@btime m.evaluator[1](m.evaluator[2:end]...)

@show "CTask construction..."
t = @btime Libtask.CTask(m.evaluator[1], m.evaluator[2:end]...);
@show "TapedTask construction..."
t = @btime TapedTask(m.evaluator[1], m.evaluator[2:end]...);
# schedule(t.task)
# @show Libtask.result(t.tf.tape)
@show "Run a tape..."
Expand Down
2 changes: 1 addition & 1 deletion perf/p2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo())
f = m.evaluator[1]
args = m.evaluator[2:end]

t = Libtask.CTask(f, args...)
t = TapedTask(f, args...)

t.tf(args...)

Expand Down
6 changes: 1 addition & 5 deletions src/Libtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@ using MacroTools

using LRUCache

export CTask, consume, produce
export TapedTask, consume, produce
export TArray, tzeros, tfill, TRef

export TapedTask

include("tapedfunction.jl")
include("tapedtask.jl")

include("tarray.jl")
include("tref.jl")

CTask = TapedTask

end
16 changes: 10 additions & 6 deletions src/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@ struct TapedTaskException
backtrace::Vector{Any}
end

struct TapedTask
struct TapedTask{F}
task::Task
tf::TapedFunction
tf::TapedFunction{F}
produce_ch::Channel{Any}
consume_ch::Channel{Int}
produced_val::Vector{Any}

function TapedTask(
t::Task, tf::TapedFunction, pch::Channel{Any}, cch::Channel{Int})
new(t, tf, pch, cch, Any[])
t::Task,
tf::TapedFunction{F},
produce_ch::Channel{Any},
consume_ch::Channel{Int}
) where {F}
new{F}(t, tf, produce_ch, consume_ch, Any[])
end
end

Expand Down Expand Up @@ -148,8 +152,8 @@ function Base.iterate(t::TapedTask, state=nothing)
nothing
end
end
Base.IteratorSize(::Type{TapedTask}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{TapedTask}) = Base.EltypeUnknown()
Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown()


# copy the task
Expand Down
9 changes: 4 additions & 5 deletions test/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using BenchmarkTools
using Libtask


macro rep(cnt, exp)
blk =:(begin end)
for _ in 1:eval(cnt)
Expand Down Expand Up @@ -47,10 +46,10 @@ function f()
end

@btime begin
ctask = CTask(f)
consume(ctask)
consume(ctask)
a = copy(ctask)
ttask = TapedTask(f)
consume(ttask)
consume(ttask)
a = copy(ttask)
consume(a)
consume(a)
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ using Libtask
using Test

include("tf.jl")
include("ctask.jl")
include("tapedtask.jl")
include("tarray.jl")
include("tref.jl")

if get(ENV, "BENCHMARK", nothing) != nothing
if haskey(ENV, "BENCHMARK")
include("benchmarks.jl")
end
78 changes: 39 additions & 39 deletions test/ctask.jl → test/tapedtask.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "ctask" begin
@testset "tapedtask" begin
# Test case 1: stack allocated objects are deep copied.
@testset "stack allocated objects" begin
function f()
Expand All @@ -9,14 +9,14 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 0
@test consume(ctask) == 1
a = copy(ctask)
ttask = TapedTask(f)
@test consume(ttask) == 0
@test consume(ttask) == 1
a = copy(ttask)
@test consume(a) == 2
@test consume(a) == 3
@test consume(ctask) == 2
@test consume(ctask) == 3
@test consume(ttask) == 2
@test consume(ttask) == 3

@inferred Libtask.TapedFunction(f)
end
Expand All @@ -31,16 +31,16 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 0
@test consume(ctask) == 1
a = copy(ctask)
ttask = TapedTask(f)
@test consume(ttask) == 0
@test consume(ttask) == 1
a = copy(ttask)
@test consume(a) == 2
@test consume(a) == 3
@test consume(ctask) == 4
@test consume(ctask) == 5
@test consume(ctask) == 6
@test consume(ctask) == 7
@test consume(ttask) == 4
@test consume(ttask) == 5
@test consume(ttask) == 6
@test consume(ttask) == 7
end

@testset "iteration" begin
Expand All @@ -52,20 +52,20 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)

next = iterate(ctask)
next = iterate(ttask)
@test next === (1, nothing)

val, state = next
next = iterate(ctask, state)
next = iterate(ttask, state)
@test next === (2, nothing)

val, state = next
next = iterate(ctask, state)
next = iterate(ttask, state)
@test next === (3, nothing)

a = collect(Iterators.take(ctask, 7))
a = collect(Iterators.take(ttask, 7))
@test eltype(a) === Int
@test a == 4:10
end
Expand All @@ -82,14 +82,14 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa MethodError
end
if VERSION >= v"1.5"
@test ctask.task.exception isa MethodError
@test ttask.task.exception isa MethodError
end
end

Expand All @@ -103,14 +103,14 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa ErrorException
end
if VERSION >= v"1.5"
@test ctask.task.exception isa ErrorException
@test ttask.task.exception isa ErrorException
end
end

Expand All @@ -125,14 +125,14 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa BoundsError
end
if VERSION >= v"1.5"
@test ctask.task.exception isa BoundsError
@test ttask.task.exception isa BoundsError
end
end

Expand All @@ -147,15 +147,15 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 2
ttask = TapedTask(f)
@test consume(ttask) == 2
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa BoundsError
end
if VERSION >= v"1.5"
@test ctask.task.exception isa BoundsError
@test ttask.task.exception isa BoundsError
end
end

Expand All @@ -170,17 +170,17 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 2
ctask2 = copy(ctask)
ttask = TapedTask(f)
@test consume(ttask) == 2
ttask2 = copy(ttask)
try
consume(ctask2)
consume(ttask2)
catch ex
@test ex isa BoundsError
end
@test ctask.task.exception === nothing
@test ttask.task.exception === nothing
if VERSION >= v"1.5"
@test ctask2.task.exception isa BoundsError
@test ttask2.task.exception isa BoundsError
end
end
end
Expand Down
Loading

0 comments on commit 472e0de

Please sign in to comment.