Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove CTask #123

Merged
merged 2 commits into from
Feb 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.6.10"
version = "0.7.0"

[deps]
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
Expand Down
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ function f()
end
end

ctask = CTask(f)
ttask = TapedTask(f)

@show consume(ctask) # 0
@show consume(ctask) # 1
@show consume(ttask) # 0
@show consume(ttask) # 1

a = copy(ctask)
a = copy(ttask)
@show consume(a) # 2
@show consume(a) # 3

@show consume(ctask) # 2
@show consume(ctask) # 3
@show consume(ttask) # 2
@show consume(ttask) # 3
```

Heap allocated objects are shallow copied:
Expand All @@ -45,17 +45,17 @@ function f()
end
end

ctask = CTask(f)
ttask = TapedTask(f)

@show consume(ctask) # 0
@show consume(ctask) # 1
@show consume(ttask) # 0
@show consume(ttask) # 1

a = copy(t)
@show consume(a) # 2
@show consume(a) # 3

@show consume(ctask) # 4
@show consume(ctask) # 5
@show consume(ttask) # 4
@show consume(ttask) # 5
```

In constrast to standard arrays, which are only shallow copied during
Expand All @@ -74,17 +74,17 @@ function f()
end
end

ctask = CTask(f)
ttask = TapedTask(f)

@show consume(ctask) # 0
@show consume(ctask) # 1
@show consume(ttask) # 0
@show consume(ttask) # 1

a = copy(ctask)
a = copy(ttask)
@show consume(a) # 2
@show consume(a) # 3

@show consume(ctask) # 2
@show consume(ctask) # 3
@show consume(ttask) # 2
@show consume(ttask) # 3
```

Note: The [Turing](https://github.com/TuringLang/Turing.jl)
Expand Down
10 changes: 4 additions & 6 deletions perf/p0.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# ]add Turing#hg/new-libtask2

using Libtask
using Turing, DynamicPPL, AdvancedPS
using BenchmarkTools
Expand All @@ -26,8 +24,8 @@ args = m.evaluator[2:end];
@btime f(args...)
# (2.0, VarInfo (2 variables (μ, σ), dimension 2; logp: -6.162))

@show "CTask construction..."
t = @btime Libtask.CTask(f, args...)
@show "TapedTask construction..."
t = @btime TapedTask(f, args...)
# schedule(t.task) # work fine!
# @show Libtask.result(t.tf)
@show "Run a tape..."
Expand All @@ -39,8 +37,8 @@ m = Turing.Core.TracedModel(gdemo(1.5, 2.), Sampler(SMC(50)), VarInfo());
@show "Directly call..."
@btime m.evaluator[1](m.evaluator[2:end]...)

@show "CTask construction..."
t = @btime Libtask.CTask(m.evaluator[1], m.evaluator[2:end]...);
@show "TapedTask construction..."
t = @btime TapedTask(m.evaluator[1], m.evaluator[2:end]...);
# schedule(t.task)
# @show Libtask.result(t.tf.tape)
@show "Run a tape..."
Expand Down
2 changes: 1 addition & 1 deletion perf/p2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo())
f = m.evaluator[1]
args = m.evaluator[2:end]

t = Libtask.CTask(f, args...)
t = TapedTask(f, args...)

t.tf(args...)

Expand Down
6 changes: 1 addition & 5 deletions src/Libtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@ using MacroTools

using LRUCache

export CTask, consume, produce
export TapedTask, consume, produce
export TArray, tzeros, tfill, TRef

export TapedTask

include("tapedfunction.jl")
include("tapedtask.jl")

include("tarray.jl")
include("tref.jl")

const CTask = TapedTask

end
16 changes: 10 additions & 6 deletions src/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@ struct TapedTaskException
backtrace::Vector{Any}
end

struct TapedTask
struct TapedTask{F}
task::Task
tf::TapedFunction
tf::TapedFunction{F}
produce_ch::Channel{Any}
consume_ch::Channel{Int}
produced_val::Vector{Any}

function TapedTask(
t::Task, tf::TapedFunction, pch::Channel{Any}, cch::Channel{Int})
new(t, tf, pch, cch, Any[])
t::Task,
tf::TapedFunction{F},
produce_ch::Channel{Any},
consume_ch::Channel{Int}
) where {F}
new{F}(t, tf, produce_ch, consume_ch, Any[])
end
end

Expand Down Expand Up @@ -148,8 +152,8 @@ function Base.iterate(t::TapedTask, state=nothing)
nothing
end
end
Base.IteratorSize(::Type{TapedTask}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{TapedTask}) = Base.EltypeUnknown()
Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown()


# copy the task
Expand Down
9 changes: 4 additions & 5 deletions test/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using BenchmarkTools
using Libtask


macro rep(cnt, exp)
blk =:(begin end)
for _ in 1:eval(cnt)
Expand Down Expand Up @@ -47,10 +46,10 @@ function f()
end

@btime begin
ctask = CTask(f)
consume(ctask)
consume(ctask)
a = copy(ctask)
ttask = TapedTask(f)
consume(ttask)
consume(ttask)
a = copy(ttask)
consume(a)
consume(a)
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ using Libtask
using Test

include("tf.jl")
include("ctask.jl")
include("tapedtask.jl")
include("tarray.jl")
include("tref.jl")

if get(ENV, "BENCHMARK", nothing) != nothing
if haskey(ENV, "BENCHMARK")
include("benchmarks.jl")
end
78 changes: 39 additions & 39 deletions test/ctask.jl → test/tapedtask.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "ctask" begin
@testset "tapedtask" begin
# Test case 1: stack allocated objects are deep copied.
@testset "stack allocated objects" begin
function f()
Expand All @@ -9,14 +9,14 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 0
@test consume(ctask) == 1
a = copy(ctask)
ttask = TapedTask(f)
@test consume(ttask) == 0
@test consume(ttask) == 1
a = copy(ttask)
@test consume(a) == 2
@test consume(a) == 3
@test consume(ctask) == 2
@test consume(ctask) == 3
@test consume(ttask) == 2
@test consume(ttask) == 3

@inferred Libtask.TapedFunction(f)
end
Expand All @@ -31,16 +31,16 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 0
@test consume(ctask) == 1
a = copy(ctask)
ttask = TapedTask(f)
@test consume(ttask) == 0
@test consume(ttask) == 1
a = copy(ttask)
@test consume(a) == 2
@test consume(a) == 3
@test consume(ctask) == 4
@test consume(ctask) == 5
@test consume(ctask) == 6
@test consume(ctask) == 7
@test consume(ttask) == 4
@test consume(ttask) == 5
@test consume(ttask) == 6
@test consume(ttask) == 7
end

@testset "iteration" begin
Expand All @@ -52,20 +52,20 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)

next = iterate(ctask)
next = iterate(ttask)
@test next === (1, nothing)

val, state = next
next = iterate(ctask, state)
next = iterate(ttask, state)
@test next === (2, nothing)

val, state = next
next = iterate(ctask, state)
next = iterate(ttask, state)
@test next === (3, nothing)

a = collect(Iterators.take(ctask, 7))
a = collect(Iterators.take(ttask, 7))
@test eltype(a) === Int
@test a == 4:10
end
Expand All @@ -82,14 +82,14 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa MethodError
end
if VERSION >= v"1.5"
@test ctask.task.exception isa MethodError
@test ttask.task.exception isa MethodError
end
end

Expand All @@ -103,14 +103,14 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa ErrorException
end
if VERSION >= v"1.5"
@test ctask.task.exception isa ErrorException
@test ttask.task.exception isa ErrorException
end
end

Expand All @@ -125,14 +125,14 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa BoundsError
end
if VERSION >= v"1.5"
@test ctask.task.exception isa BoundsError
@test ttask.task.exception isa BoundsError
end
end

Expand All @@ -147,15 +147,15 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 2
ttask = TapedTask(f)
@test consume(ttask) == 2
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa BoundsError
end
if VERSION >= v"1.5"
@test ctask.task.exception isa BoundsError
@test ttask.task.exception isa BoundsError
end
end

Expand All @@ -170,17 +170,17 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 2
ctask2 = copy(ctask)
ttask = TapedTask(f)
@test consume(ttask) == 2
ttask2 = copy(ttask)
try
consume(ctask2)
consume(ttask2)
catch ex
@test ex isa BoundsError
end
@test ctask.task.exception === nothing
@test ttask.task.exception === nothing
if VERSION >= v"1.5"
@test ctask2.task.exception isa BoundsError
@test ttask2.task.exception isa BoundsError
end
end
end
Expand Down
Loading