Skip to content

Commit

Permalink
update minimal
Browse files Browse the repository at this point in the history
  • Loading branch information
khosravipasha committed Mar 8, 2022
1 parent 733f037 commit 0e28de4
Showing 1 changed file with 106 additions and 32 deletions.
138 changes: 106 additions & 32 deletions examples/minimal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,60 +5,134 @@ abstract type InputDist end
struct Indicator{T} <: InputDist
value::T
end
const Literal = Indicator{Bool}

struct Categorical <: InputDist
categories::UInt32
end
struct IndicatorBool <:InputDist end
struct IndicatorUInt32 <:InputDist end

struct Categorical <: InputDist end
struct Binomial <: InputDist end
######
abstract type BitsNode end
struct BitsSum <: BitsNode end
struct BitsInput{D <: InputDist} <: BitsNode
variable::Int
dist::D
end
dist(n::BitsInput) = n.dist
dist(n::BitsNode) = nothing
########
ans_idx(n) = 1
ans_idx(n::Indicator{Bool}) = 2
ans_idx(n::Indicator{UInt32}) = 3
ans_idx(n::Indicator{Int64}) = 4
ans_idx(n::Categorical) = 5
ans_idx(n::Binomial) = 6
ans_idx(n::IndicatorBool) = 7
ans_idx(n::IndicatorUInt32) = 8

########
function kernel(ans, vec, ids)
for idx = 1:2
for idx = 1:size(ids, 1)
cur_id = ids[idx]
# @cuprintln(cur_id)
item = vec[cur_id]
d = dist(item)::InputDist
var = item.variable

CUDA.@atomic ans[ans_idx(d)] += 1
item = vec[cur_id]::BitsInput
d = dist(item)
i = ans_idx(d)::Int64
ans[i] += 1
end
return nothing
end

ans_idx(n::InputDist) = nothing
ans_idx(n::Indicator{Bool}) = 1
ans_idx(n::Indicator{UInt32}) = 2
ans_idx(n::Categorical) = 3

########

function run()
vec = Vector{Union{BitsSum,
BitsInput{Indicator{Bool}},
BitsInput{Indicator{UInt32}},
BitsInput{Categorical}
}}()
input_node_ids = [1, 2]
push!(vec, BitsInput(1, Indicator{Bool}(false)))
push!(vec, BitsInput(2, Indicator{UInt32}(1)))
push!(vec, BitsInput(2, Categorical(10)))
function vec_A()
vec = Vector{Union{
BitsInput{Indicator{Bool}},
BitsInput{Indicator{UInt32}},
BitsInput{Categorical},
# BitsSum
}}()

push!(vec, BitsInput(Indicator{Bool}(false)))
push!(vec, BitsInput(Indicator{UInt32}(1)))
push!(vec, BitsInput(Categorical()))
# push!(vec, BitsSum())
return vec
end

function vec_B()
vec = Vector{Union{
BitsInput{Categorical},
BitsInput{Binomial},
BitsSum
}}()
push!(vec, BitsInput(Categorical()))
push!(vec, BitsInput(Binomial()))
push!(vec, BitsSum())
return vec
end

function vec_C()
vec = Vector{Union{
BitsInput{Indicator{Bool}},
BitsInput{Indicator{UInt32}},
BitsInput{Indicator{Int64}},
# BitsInput{Categorical},
# BitsInput{Binomial},
# BitsSum
}}()

push!(vec, BitsInput(Indicator{Bool}(false)))
push!(vec, BitsInput(Indicator{UInt32}(UInt32(1))))
# push!(vec, BitsInput(Indicator{Int64}(1)))
# push!(vec, BitsSum())
return vec
end

function vec_D()
vec = Vector{Union{
BitsInput{Indicator{UInt32}},
BitsInput{Categorical},
BitsInput{Binomial},
BitsSum
}}()

push!(vec, BitsInput(Categorical()))
push!(vec, BitsInput(Binomial()))
push!(vec, BitsSum())
return vec
end

cuvec = cu(vec)
ids = cu(input_node_ids)
ans = CUDA.zeros(Int64, 3)
function vec_E()
vec = Vector{Union{
BitsInput{Indicator{UInt32}},
BitsInput{Categorical},
BitsInput{Binomial},
BitsSum
}}()

@device_code_warntype @cuda kernel(ans, cuvec, ids)
push!(vec, BitsInput(Categorical()))
push!(vec, BitsInput(Binomial()))
push!(vec, BitsSum())
return vec
end

function run(get_vec)
vec = get_vec()
input_node_ids = Vector{UInt32}()
for i = 1:length(vec)
if vec[i] isa BitsInput
append!(input_node_ids, UInt32(i))
end
end
ans = CUDA.zeros(Int64, 10)
@cuda kernel(ans, cu(vec), cu(input_node_ids))
println(ans)
end

run()


run(vec_A) # Runs fine
run(vec_B) # Runs fine
run(vec_C) # Runs fine
# run(vec_D) # Fails
run(vec_E) # Fails

0 comments on commit 0e28de4

Please sign in to comment.