Skip to content

Commit

Permalink
api support for learning from missing data
Browse files Browse the repository at this point in the history
  • Loading branch information
khosravipasha committed Mar 3, 2021
1 parent eeadf82 commit 5b227c2
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 24 deletions.
3 changes: 2 additions & 1 deletion src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ end
function estimate_parameters_em(pbc::ParamBitCircuit, data; pseudocount::Float64, entropy_reg::Float64 = 0.0,
use_sample_weights::Bool = true, use_gpu::Bool = false,
reuse_v = nothing, reuse_f = nothing, reuse_counts = nothing,
exp_update_factor = 0.0)
exp_update_factor = 0.0
)
if isweighted(data)
# `data' is weighted according to its `weight' column
data, weights = split_sample_weights(data)
Expand Down
23 changes: 16 additions & 7 deletions src/queries/marginal_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ const MAR = marginal
Compute the marginal likelihood of the PC given the data
"""
marginal_log_likelihood(pc, data) = begin
marginal_log_likelihood(pc, data; use_gpu::Bool = false) = begin
if use_gpu
data = to_gpu(data)
end

if isweighted(data)
# `data' is weighted according to its `weight' column
data, weights = split_sample_weights(data)
Expand All @@ -102,7 +106,9 @@ marginal_log_likelihood(pc, data) = begin
sum(marginal(pc, data))
end
end
marginal_log_likelihood(pc, data, weights::DataFrame) = marginal_log_likelihood(pc, data, weights[:, 1])
marginal_log_likelihood(pc, data, weights::DataFrame; use_gpu::Bool = false) =
marginal_log_likelihood(pc, data, weights[:, 1]; use_gpu)

marginal_log_likelihood(pc, data, weights::AbstractArray) = begin
if isgpu(weights)
weights = to_cpu(weights)
Expand Down Expand Up @@ -160,23 +166,26 @@ end
"""
Compute the marginal likelihood of the PC given the data, averaged over all instances in the data
"""
marginal_log_likelihood_avg(pc, data) = begin
marginal_log_likelihood_avg(pc, data; use_gpu::Bool = false) = begin
if isweighted(data)
# `data' is weighted according to its `weight' column
data, weights = split_sample_weights(data)

marginal_log_likelihood_avg(pc, data, weights)
marginal_log_likelihood_avg(pc, data, weights; use_gpu = use_gpu)
else
marginal_log_likelihood(pc, data)/num_examples(data)
marginal_log_likelihood(pc, data; use_gpu = use_gpu) / num_examples(data)
end
end
marginal_log_likelihood_avg(pc, data, weights::DataFrame) = marginal_log_likelihood_avg(pc, data, weights[:, 1])

marginal_log_likelihood_avg(pc, data, weights::DataFrame; use_gpu::Bool = false) =
marginal_log_likelihood_avg(pc, data, weights[:, 1]; use_gpu = use_gpu)

marginal_log_likelihood_avg(pc, data, weights) = begin
if isgpu(weights)
weights = to_cpu(weights)
end
marginal_log_likelihood(pc, data, weights)/sum(weights)
end

marginal_log_likelihood_avg(pc, data::Array{DataFrame}; use_gpu::Bool = false) = begin
total_ll = marginal_log_likelihood(pc, data; use_gpu = use_gpu)

Expand Down
19 changes: 17 additions & 2 deletions src/structurelearner/heuristics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,31 @@ function vRand(vars::Vector{Var})
return Var(rand(vars))
end

function heuristic_loss(circuit::LogicCircuit, train_x; pick_edge="eFlow", pick_var="vMI")
function heuristic_loss(circuit::LogicCircuit, train_x;
pick_edge = "eFlow",
pick_var = "vMI",
miss_data::Bool = false
)
if isweighted(train_x)
train_x, weights = split_sample_weights(train_x)
else
weights = nothing
end


@assert !(miss_data && pick_var=="vMI") "Cannot use vMI for picking vars for missing data. Use vRand instead."

candidates, variable_scope = split_candidates(circuit)
if isempty(candidates) return nothing end
values, flows = satisfies_flows(circuit, train_x; weights = nothing) # Do not use samples weights here

if miss_data
# Have to use marginal flows when have missing data
values, flows = marginal_flows(circuit, train_x; weights = nothing) # Do not use samples weights here
else
# Satisfies Flows much faster than marginal flows
values, flows = satisfies_flows(circuit, train_x; weights = nothing) # Do not use samples weights here
end

if pick_edge == "eFlow"
edge, flow = eFlow(values, flows, candidates)
elseif pick_edge == "eRand"
Expand Down
90 changes: 76 additions & 14 deletions src/structurelearner/learner.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
export learn_circuit
export learn_circuit,
learn_circuit_miss

using LogicCircuits: split_step, struct_learn
using Statistics: mean
using Random

"""
Learn structure of a single structured decomposable circuit
"""
Expand All @@ -13,14 +16,49 @@ function learn_circuit(train_x;
maxiter=100,
seed=nothing,
return_vtree=false,
verbose=true)
verbose=true,
max_circuit_nodes=nothing)

# Initial Structure
pc, vtree = learn_chow_liu_tree_circuit(train_x)

learn_circuit(train_x, pc, vtree; pick_edge, pick_var, depth, pseudocount, sanity_check,
maxiter, seed, return_vtree, entropy_reg, verbose)
learn_circuit(train_x, pc, vtree; pick_edge, pick_var,
depth, pseudocount, sanity_check,
maxiter, seed, return_vtree, entropy_reg, verbose, max_circuit_nodes)
end


"""
Learn structure of a single structured decomposable circuit from missing data
"""
function learn_circuit_miss(train_x;
impute_method::Symbol=:median,
pick_edge="eFlow", depth=1,
pseudocount=1.0,
entropy_reg=0.0,
sanity_check=true,
maxiter=100,
seed=nothing,
return_vtree=false,
verbose=true,
max_circuit_nodes=nothing)

# Initial Structure
train_x_impute = impute(train_x; method=impute_method)
pc, vtree = learn_chow_liu_tree_circuit(train_x_impute)

# Only vRand supported for missing data
pick_var="vRand"

learn_circuit(train_x, pc, vtree; pick_edge, pick_var,
depth, pseudocount, sanity_check,
maxiter, seed, return_vtree, entropy_reg,
max_circuit_nodes, verbose,
has_missing=true)
end



function learn_circuit(train_x, pc, vtree;
pick_edge="eFlow", pick_var="vMI", depth=1,
pseudocount=1.0,
Expand All @@ -32,40 +70,65 @@ function learn_circuit(train_x, pc, vtree;
splitting_data=nothing,
use_gpu=false,
entropy_reg=0.0,
verbose=true)
verbose=true,
max_circuit_nodes=nothing,
has_missing::Bool=false
)

if seed !== nothing
Random.seed!(seed)
end

if has_missing
estimate_parameters_func = estimate_parameters_em
likelihood_avg_func = marginal_log_likelihood_avg
else
estimate_parameters_func = estimate_parameters
likelihood_avg_func = log_likelihood_avg
end



# structure_update
loss(circuit) = heuristic_loss(circuit, splitting_data == nothing ? train_x : splitting_data;
pick_edge=pick_edge, pick_var=pick_var)
loss(circuit) = heuristic_loss(circuit,
splitting_data == nothing ? train_x : splitting_data;
pick_edge=pick_edge,
pick_var=pick_var,
miss_data=has_missing)


pc_split_step(circuit) = begin
r = split_step(circuit; loss=loss, depth=depth, sanity_check=sanity_check)
if isnothing(r) return nothing end
c, = r
if batch_size > 0
estimate_parameters(c, batch(train_x, batch_size); pseudocount, use_gpu, entropy_reg)
estimate_parameters_func(c, batch(train_x, batch_size); pseudocount, use_gpu, entropy_reg)
else
estimate_parameters(c, train_x; pseudocount, use_gpu, entropy_reg)
estimate_parameters_func(c, train_x; pseudocount, use_gpu, entropy_reg)
end
return c, missing
end
iter = 0
log_per_iter(circuit) = begin
# ll = EVI(circuit, train_x);
if batch_size > 0
ll = log_likelihood_avg(circuit, batch(train_x, batch_size); use_gpu)
ll = likelihood_avg_func(circuit, batch(train_x, batch_size); use_gpu)
else
ll = log_likelihood_avg(circuit, train_x; use_gpu)
ll = likelihood_avg_func(circuit, train_x; use_gpu)
end
verbose && println("Iteration $iter/$maxiter. LogLikelihood = $(ll); nodes = $(num_nodes(circuit)); edges = $(num_edges(circuit)); params = $(num_parameters(circuit))")
iter += 1

if !isnothing(max_circuit_nodes) && num_nodes(circuit) > max_circuit_nodes
epoch_printer("Stopping early, circuit node count ($(num_nodes(circuit))) is above max threshold $(max_circuit_nodes).");
return true; # stop
end

false
end

# Log Before Learning
log_per_iter(pc)

pc = struct_learn(pc;
primitives=[pc_split_step], kwargs=Dict(pc_split_step=>()),
maxiter=maxiter, stop=log_per_iter, verbose=verbose)
Expand All @@ -75,5 +138,4 @@ function learn_circuit(train_x, pc, vtree;
else
pc
end
end

end

0 comments on commit 5b227c2

Please sign in to comment.