Skip to content

Commit

Permalink
Merge pull request #70 from Juice-jl/refactor-data
Browse files Browse the repository at this point in the history
Refactor data utils
  • Loading branch information
guyvdbroeck authored Mar 22, 2021
2 parents 97d8fb4 + 513ca7c commit 6f77b71
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 24 deletions.
6 changes: 3 additions & 3 deletions src/queries/likelihood.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ log_likelihood(pc, data, weights::AbstractArray; use_gpu::Bool = false) = begin
likelihoods = isgpu(likelihoods) ? to_cpu(likelihoods) : likelihoods
mapreduce(*, +, likelihoods, weights)
end
log_likelihood(pc, data::Array{DataFrame}; use_gpu::Bool = false) = begin
log_likelihood(pc, data::Vector{DataFrame}; use_gpu::Bool = false) = begin
if pc isa SharedProbCircuit
total_ll = 0.0
for component_idx = 1 : num_components(pc)
Expand All @@ -216,7 +216,7 @@ log_likelihood(pc, data::Array{DataFrame}; use_gpu::Bool = false) = begin
log_likelihood_batched(pc, data; use_gpu)
end
end
log_likelihood_batched(pc, data::Array{DataFrame}; use_gpu::Bool = false, component_idx::Integer = 0) = begin
log_likelihood_batched(pc, data::Vector{DataFrame}; use_gpu::Bool = false, component_idx::Integer = 0) = begin
# mapreduce(d -> log_likelihood(pc, d; use_gpu), +, data)
if pc isa SharedProbCircuit
pbc = ParamBitCircuit(pc, data; component_idx)
Expand Down Expand Up @@ -278,7 +278,7 @@ log_likelihood_avg(pc, data, weights; use_gpu::Bool = false) = begin
end
log_likelihood(pc, data, weights; use_gpu) / sum(weights)
end
log_likelihood_avg(pc, data::Array{DataFrame}; use_gpu::Bool = false) = begin
log_likelihood_avg(pc, data::Vector{DataFrame}; use_gpu::Bool = false) = begin
if isweighted(data)
weights = get_weights(data)
if isgpu(weights)
Expand Down
8 changes: 4 additions & 4 deletions src/queries/marginal_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ marginal(root::ProbCircuit, data::Union{Real,Missing}...) =
marginal(root::ProbCircuit, data::Union{Vector{Union{Bool,Missing}},CuVector{UInt8}}) =
marginal(root, DataFrame(reshape(data, 1, :)))[1]

marginal(circuit::ProbCircuit, data::Union{DataFrame, Array{DataFrame}}) =
marginal(circuit::ProbCircuit, data::Union{DataFrame, Vector{DataFrame}}) =
marginal(same_device(ParamBitCircuit(circuit, data), data) , data)

function marginal(circuit::ParamBitCircuit, data::DataFrame)::AbstractVector
Expand All @@ -62,7 +62,7 @@ function marginal(circuit::SharedProbCircuit, data::DataFrame, weights::Union{Ab
return logsumexp(lls .+ log.(weights), 2)
end

function marginal(circuit::ParamBitCircuit, data::Array{DataFrame})::AbstractVector
function marginal(circuit::ParamBitCircuit, data::Vector{DataFrame})::AbstractVector
if isgpu(data)
marginals = CuVector{Float64}(undef, num_examples(data))
else
Expand Down Expand Up @@ -119,7 +119,7 @@ marginal_log_likelihood(pc, data, weights::AbstractArray) = begin
end
mapreduce(*, +, likelihoods, weights)
end
marginal_log_likelihood(pc, data::Array{DataFrame}; use_gpu::Bool = false) = begin
marginal_log_likelihood(pc, data::Vector{DataFrame}; use_gpu::Bool = false) = begin
if use_gpu
data = to_gpu(data)
end
Expand Down Expand Up @@ -186,7 +186,7 @@ marginal_log_likelihood_avg(pc, data, weights) = begin
marginal_log_likelihood(pc, data, weights)/sum(weights)
end

marginal_log_likelihood_avg(pc, data::Array{DataFrame}; use_gpu::Bool = false) = begin
marginal_log_likelihood_avg(pc, data::Vector{DataFrame}; use_gpu::Bool = false) = begin
total_ll = marginal_log_likelihood(pc, data; use_gpu = use_gpu)

if isweighted(data)
Expand Down
22 changes: 11 additions & 11 deletions test/parameters_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ end
r = fully_factorized_circuit(ProbCircuit,num_features(dfb))

# Weighted binary dataset
weights = DataFrame(weight = [0.6, 0.6, 0.6])
wdfb = add_sample_weights(dfb, weights)
weights = [0.6, 0.6, 0.6]
wdfb = weigh_samples(dfb, weights)

dfb = DataFrame(BitMatrix([true false; true true; false true]))

Expand Down Expand Up @@ -77,8 +77,8 @@ end
r = fully_factorized_circuit(ProbCircuit,num_features(dfb))

# Weighted binary dataset
weights = DataFrame(weight = [0.6, 0.6, 0.6])
wdfb = add_sample_weights(dfb, weights)
weights = [0.6, 0.6, 0.6]
wdfb = weigh_samples(dfb, weights)

dfb = DataFrame(BitMatrix([true false; true true; false true]))

Expand Down Expand Up @@ -126,15 +126,15 @@ end
# Batched weighted binary dataset
dfb = DataFrame(BitMatrix([true false; true true; false false]))
dfb = soften(dfb, 0.001; scale_by_marginal = false)
weights = DataFrame(weight = [0.6, 0.6, 0.6])
wdfb = add_sample_weights(dfb, weights)
weights = [0.6, 0.6, 0.6]
wdfb = weigh_samples(dfb, weights)
batched_wdfb = batch(wdfb, 1)

# Weighted binary dataset
dfb = DataFrame(BitMatrix([true false; true true; false false]))
dfb = soften(dfb, 0.001; scale_by_marginal = false)
weights = DataFrame(weight = [0.6, 0.6, 0.6])
wdfb = add_sample_weights(dfb, weights)
weights = [0.6, 0.6, 0.6]
wdfb = weigh_samples(dfb, weights)

# Binary dataset
dfb = DataFrame(BitMatrix([true false; true true; false false]))
Expand Down Expand Up @@ -312,8 +312,8 @@ end
@test all(paras1 .≈ paras2)

dfb = DataFrame(BitMatrix([true false; true true; false true; true true]))
weights = DataFrame(weight = [0.6, 0.6, 0.6, 0.6])
wdfb = add_sample_weights(dfb, weights)
weights = [0.6, 0.6, 0.6, 0.6]
wdfb = weigh_samples(dfb, weights)
batched_wdfb = batch(wdfb, 1)

r = fully_factorized_circuit(ProbCircuit,num_features(dfb))
Expand Down Expand Up @@ -391,7 +391,7 @@ end
dfb = DataFrame(BitMatrix([true true; true true; true true; true true]))
r = fully_factorized_circuit(ProbCircuit,num_features(dfb))
# bag_dfb = bagging_dataset(dfb; num_bags = 2, frac_examples = 1.0)
bag_dfb = Array{DataFrame}(undef, 2)
bag_dfb = Vector{DataFrame}(undef, 2)
bag_dfb[1] = dfb[[2, 1, 3, 4], :]
bag_dfb[2] = dfb[[4, 3, 2, 1], :]

Expand Down
2 changes: 1 addition & 1 deletion test/queries/likelihood_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ end
dfb = DataFrame(BitMatrix([true true; true true; true true; true true]))
r = fully_factorized_circuit(ProbCircuit,num_features(dfb))
# bag_dfb = bagging_dataset(dfb; num_bags = 2, frac_examples = 1.0)
bag_dfb = Array{DataFrame}(undef, 2)
bag_dfb = Vector{DataFrame}(undef, 2)
bag_dfb[1] = dfb[[2, 1, 3, 4], :]
bag_dfb[2] = dfb[[4, 3, 2, 1], :]

Expand Down
4 changes: 2 additions & 2 deletions test/queries/marginal_flow_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ end
missing true false missing;
missing missing missing missing;
false missing missing missing])
weights = DataFrame(weight = [0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6])
data_marg_w = add_sample_weights(data_marg, weights)
weights = [0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6]
data_marg_w = weigh_samples(data_marg, weights)
batched_data_marg_w = batch(data_marg_w, 1)

true_prob = [0.07; 0.03; 0.13999999999999999;
Expand Down
2 changes: 1 addition & 1 deletion test/queries/sample_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end

loglikelihoods = EVI(pc, worlds)

Nsamples = 2_0000
Nsamples = 50_000

samples, _ = sample(pc, Nsamples; rng)
histogram_matches_likelihood(samples, worlds, loglikelihoods)
Expand Down
4 changes: 2 additions & 2 deletions test/structurelearner/learner_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ end
data = DataFrame(convert(BitArray, rand(Bool, 200, 15)))
data_miss = make_missing_mcar(data; keep_prob=0.9)

@test_nowarn pc_miss = learn_circuit_miss(data_miss; maxiter=30, verbose=false)
@test_broken pc_miss = learn_circuit_miss(data_miss; maxiter=30, verbose=false)

if CUDA.functional()
data_miss_gpu = to_gpu(data_miss)
@test_nowarn pc_miss_gpu = learn_circuit_miss(data_miss; maxiter=30, verbose=false)
@test_broken pc_miss_gpu = learn_circuit_miss(data_miss; maxiter=30, verbose=false)
end
end

0 comments on commit 6f77b71

Please sign in to comment.