Skip to content

Commit

Permalink
update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
khosravipasha committed Mar 16, 2022
1 parent 4416c72 commit c50f1a2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
2 changes: 1 addition & 1 deletion examples/binomial_mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function truncate(data::Matrix; bits)
end

function run(; batch_size = 512, num_epochs1 = 1, num_epochs2 = 1, num_epochs3 = 20,
pseudocount = 0.1, latents = 32, param_inertia1 = 0.2, param_inertia2 = 0.9, param_inertia3 = 0.95)
pseudocount = 0.01, latents = 32, param_inertia1 = 0.2, param_inertia2 = 0.9, param_inertia3 = 0.95)
train, test = mnist_cpu()
train_gpu, test_gpu = mnist_gpu()
# train_gpu = train_gpu[1:1024, :]
Expand Down
63 changes: 39 additions & 24 deletions examples/cat_rat_mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using MLDatasets
using CUDA
using Images

device!(collect(devices())[2])
# device!(collect(devices())[2])

function mnist_cpu()
train_int = transpose(reshape(MNIST.traintensor(UInt8), 28*28, :));
Expand Down Expand Up @@ -39,45 +39,48 @@ function generate_rat(train)
RAT(num_features; num_nodes_region, num_nodes_leaf, rg_depth, rg_replicas, input_type, balance_childs_parents)
end

function run()
function run(; batch_size = 256, num_epochs1 = 1, num_epochs2 = 1, num_epochs3 = 20,
pseudocount = 0.01, param_inertia1 = 0.2, param_inertia2 = 0.9, param_inertia3 = 0.9)

train, test = mnist_cpu();
train_gpu, test_gpu = mnist_gpu();
trunc_train = truncate(train; bits = 5);

# println("Generating HCLT structure with $latents latents... ");
# @time pc = hclt(trunc_train[1:5000,:], latents; num_cats = 256, pseudocount = 0.1, input_type = CategoricalDist);
# init_parameters(pc; perturbation = 0.4);
print("Generating RAT SPN....")
@info "Generating RAT SPN...."
@time pc = generate_rat(trunc_train);
init_parameters(pc; perturbation = 0.4);

println("Number of free parameters: $(num_parameters(pc))")

print("Moving circuit to GPU... ")
@info "Moving circuit to GPU... "
CUDA.@time bpc = CuBitsProbCircuit(BitsProbCircuit(pc));

batch_size = 2048
pseudocount = 0.01
@show length(bpc.nodes)

@info "EM"
softness = 0
epochs_1 = 5
epochs_2 = 5
epochs_3 = 10
@time mini_batch_em(bpc, train_gpu, epochs_1; batch_size, pseudocount,
softness, param_inertia = 0.2, param_inertia_end = 0.9)
@time mini_batch_em(bpc, train_gpu, num_epochs1; batch_size, pseudocount,
softness, param_inertia = param_inertia1, param_inertia_end = param_inertia2)

@time mini_batch_em(bpc, train_gpu, epochs_2; batch_size, pseudocount,
softness, param_inertia = 0.9, param_inertia_end = 0.95)
@time mini_batch_em(bpc, train_gpu, num_epochs2; batch_size, pseudocount,
softness, param_inertia = param_inertia2, param_inertia_end = param_inertia3)

@time full_batch_em(bpc, train_gpu, epochs_3; batch_size, pseudocount, softness)
for iter=1:num_epochs3
@info "Iter $iter"
@time full_batch_em(bpc, train_gpu, 5; batch_size, pseudocount, softness)

ll3 = loglikelihood(bpc, test_gpu; batch_size)
println("test LL: $(ll3)")

@time do_sample(bpc, iter)
end

print("update parameters")
@time ProbabilisticCircuits.update_parameters(bpc);
print("Save to file")
@time write("rat_cat.jpc.gz", pc);
return circuit, bpc
return pc, bpc
end

function do_sample(bpc)
function do_sample(bpc, iter=999)
CUDA.@time sms = sample(bpc, 100, 28*28, [UInt32]);

do_img(i) = begin
Expand All @@ -88,8 +91,20 @@ function do_sample(bpc)

arr = [do_img(i) for i=1:size(sms, 1)]
imgs = mosaicview(arr, fillvalue=1, ncol=10, npad=4)
save("samples.png", imgs)
save("samples/rat_samples_$(iter).png", imgs)
end

function try_map(pc, bpc)
@info "MAP"
train_gpu, _ = mnist_gpu();
data = Array{Union{Missing, UInt32}}(train_gpu[1:10, :]);
data[:, 1:400] .= missing;
data_gpu = cu(data);

# @time MAP(pc, data; batch_size=10)
MAP(bpc, data_gpu; batch_size=10)
end

# circuit, bpc = run();
#do_sample(bpc)
pc, bpc = run(; batch_size = 128, num_epochs1 = 2, num_epochs2 = 2, num_epochs3 = 2);
# do_sample(bpc)
# try_map(pc, bpc)

0 comments on commit c50f1a2

Please sign in to comment.