Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binomial Input Nodes #119

Merged
merged 24 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@latest
with:
version: 1.6
version: 1.7

# Runs a single command using the runners shell
- name: Unit Tests
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/deploy_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@latest
with:
version: 1.6
version: 1.7

- name: Docs Build
env:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/one_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
testFile:
description: 'Relative path to test file'
required: true
default: 'test/structurelearner/learner_tests.jl'
default: 'test/io/jpc_tests.jl'

env:
DATADEPS_ALWAYS_ACCEPT: 1
Expand All @@ -20,7 +20,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@latest
with:
version: 1.5
version: 1.7

# Runs a single command using the runners shell
- name: Unit Tests
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
.ipynb_checkpoints
*Manifest.toml
docs/build/
scratch/
scratch/
samples/
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Lerche = "d42ef402-04e6-4356-9f73-091573ea58dc"
MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

[compat]
Expand All @@ -22,5 +23,6 @@ DirectedAcyclicGraphs = "0.1.3"
Graphs = "1"
Lerche = "0.5"
MetaGraphs = "0.7"
SpecialFunctions = "2.1"
TikzGraphs = "1.3"
julia = "1.6"
2 changes: 2 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ ChowLiuTrees = "be466665-d60c-4e0a-9ae9-070eb5e678a5"
DirectedAcyclicGraphs = "1e6dae5e-d6e2-422d-9af3-452e7a3785ee"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProbabilisticCircuits = "2396afbe-23d7-11ea-1e05-f1aa98e17a44"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
116 changes: 116 additions & 0 deletions examples/binomial_mnist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
using CUDA
using ProbabilisticCircuits
using ProbabilisticCircuits: BitsProbCircuit, CuBitsProbCircuit, loglikelihoods, full_batch_em, mini_batch_em
using MLDatasets
using Images
using Plots

# device!(collect(devices())[2])

function mnist_cpu()
train_cpu = collect(transpose(reshape(MNIST.traintensor(UInt8), 28*28, :)))
test_cpu = collect(transpose(reshape(MNIST.testtensor(UInt8), 28*28, :)))
train_cpu, test_cpu
end

function mnist_gpu()
cu.(mnist_cpu())
end

function truncate(data::Matrix; bits)
data .÷ 2^bits
end

function run(; batch_size = 512, num_epochs1 = 1, num_epochs2 = 1, num_epochs3 = 20,
pseudocount = 0.01, latents = 32, param_inertia1 = 0.2, param_inertia2 = 0.9, param_inertia3 = 0.95)
train, test = mnist_cpu()
train_gpu, test_gpu = mnist_gpu()
# train_gpu = train_gpu[1:1024, :]

trunc_train = cu(truncate(train; bits = 4))

println("Generating HCLT structure with $latents latents... ");
@time pc = hclt(trunc_train[1:5000,:], latents; num_cats = 256, pseudocount = 0.1, input_type = Binomial);

# println("RAT")

init_parameters(pc; perturbation = 0.4);
println("Number of free parameters: $(num_parameters(pc))")

@info "Moving circuit to GPU... "
CUDA.@time bpc = CuBitsProbCircuit(pc)

@show length(bpc.nodes)

softness = 0
@time mini_batch_em(bpc, train_gpu, num_epochs1; batch_size, pseudocount,
softness, param_inertia = param_inertia1, param_inertia_end = param_inertia2, debug = false)

ll1 = loglikelihood(bpc, test_gpu; batch_size)
println("test LL: $(ll1)")

@time mini_batch_em(bpc, train_gpu, num_epochs2; batch_size, pseudocount,
softness, param_inertia = param_inertia2, param_inertia_end = param_inertia3)

ll2 = loglikelihood(bpc, test_gpu; batch_size)
println("test LL: $(ll2)")

for iter=1:num_epochs3
@info "Iter $iter"
@time full_batch_em(bpc, train_gpu, 5; batch_size, pseudocount, softness)

ll3 = loglikelihood(bpc, test_gpu; batch_size)
println("test LL: $(ll3)")

@time do_sample(bpc, iter)
end

@info "update parameters bpc => pc"
@time ProbabilisticCircuits.update_parameters(bpc);

pc, bpc
end

function do_sample(cur_pc, iteration)
@info "Sample"
if cur_pc isa CuBitsProbCircuit
sms = sample(cur_pc, 100, 28*28,[UInt32]);
elseif cur_pc isa ProbCircuit
sms = sample(cur_pc, 100, [UInt32]);
end

do_img(i) = begin
img = Array{Float32}(sms[i,1,1:28*28]) ./ 256.0
img = transpose(reshape(img, (28, 28)))
imresize(colorview(Gray, img), ratio=4)
end

arr = [do_img(i) for i=1:size(sms, 1)]
imgs = mosaicview(arr, fillvalue=1, ncol=10, npad=4)
save("samples/samples_$iteration.png", imgs);
end

function try_map()
@info "MAP"
train_gpu, _ = mnist_gpu();
data = Array{Union{Missing, UInt32}}(train_gpu[1:10, :]);
data[:, 1:100] .= missing;
data_gpu = cu(data);

# @time MAP(pc, data; batch_size=10)
@time MAP(bpc, data_gpu; batch_size=10)
end


pc, bpc = run(; latents = 16, num_epochs1 = 0, num_epochs2 = 0, num_epochs3=2);

# arr = [dist(n).p for n in inputnodes(pc) if 300 <first(randvars(n)) <400];
# Plots.histogram(arr, normed=true, bins=50)

# do_sample(bpc, 999);
# do_sample(pc, 999);

try_map()



63 changes: 39 additions & 24 deletions examples/cat_rat_mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using MLDatasets
using CUDA
using Images

device!(collect(devices())[2])
# device!(collect(devices())[2])

function mnist_cpu()
train_int = transpose(reshape(MNIST.traintensor(UInt8), 28*28, :));
Expand Down Expand Up @@ -39,45 +39,48 @@ function generate_rat(train)
RAT(num_features; num_nodes_region, num_nodes_leaf, rg_depth, rg_replicas, input_type, balance_childs_parents)
end

function run()
function run(; batch_size = 256, num_epochs1 = 1, num_epochs2 = 1, num_epochs3 = 20,
pseudocount = 0.01, param_inertia1 = 0.2, param_inertia2 = 0.9, param_inertia3 = 0.9)

train, test = mnist_cpu();
train_gpu, test_gpu = mnist_gpu();
trunc_train = truncate(train; bits = 5);

# println("Generating HCLT structure with $latents latents... ");
# @time pc = hclt(trunc_train[1:5000,:], latents; num_cats = 256, pseudocount = 0.1, input_type = CategoricalDist);
# init_parameters(pc; perturbation = 0.4);
print("Generating RAT SPN....")
@info "Generating RAT SPN...."
@time pc = generate_rat(trunc_train);
init_parameters(pc; perturbation = 0.4);

println("Number of free parameters: $(num_parameters(pc))")

print("Moving circuit to GPU... ")
@info "Moving circuit to GPU... "
CUDA.@time bpc = CuBitsProbCircuit(BitsProbCircuit(pc));

batch_size = 2048
pseudocount = 0.01
@show length(bpc.nodes)

@info "EM"
softness = 0
epochs_1 = 5
epochs_2 = 5
epochs_3 = 10
@time mini_batch_em(bpc, train_gpu, epochs_1; batch_size, pseudocount,
softness, param_inertia = 0.2, param_inertia_end = 0.9)
@time mini_batch_em(bpc, train_gpu, num_epochs1; batch_size, pseudocount,
softness, param_inertia = param_inertia1, param_inertia_end = param_inertia2)

@time mini_batch_em(bpc, train_gpu, epochs_2; batch_size, pseudocount,
softness, param_inertia = 0.9, param_inertia_end = 0.95)
@time mini_batch_em(bpc, train_gpu, num_epochs2; batch_size, pseudocount,
softness, param_inertia = param_inertia2, param_inertia_end = param_inertia3)

@time full_batch_em(bpc, train_gpu, epochs_3; batch_size, pseudocount, softness)
for iter=1:num_epochs3
@info "Iter $iter"
@time full_batch_em(bpc, train_gpu, 5; batch_size, pseudocount, softness)

ll3 = loglikelihood(bpc, test_gpu; batch_size)
println("test LL: $(ll3)")

@time do_sample(bpc, iter)
end

print("update parameters")
@time ProbabilisticCircuits.update_parameters(bpc);
print("Save to file")
@time write("rat_cat.jpc.gz", pc);
return circuit, bpc
return pc, bpc
end

function do_sample(bpc)
function do_sample(bpc, iter=999)
CUDA.@time sms = sample(bpc, 100, 28*28, [UInt32]);

do_img(i) = begin
Expand All @@ -88,8 +91,20 @@ function do_sample(bpc)

arr = [do_img(i) for i=1:size(sms, 1)]
imgs = mosaicview(arr, fillvalue=1, ncol=10, npad=4)
save("samples.png", imgs)
save("samples/rat_samples_$(iter).png", imgs)
end

function try_map(pc, bpc)
@info "MAP"
train_gpu, _ = mnist_gpu();
data = Array{Union{Missing, UInt32}}(train_gpu[1:10, :]);
data[:, 1:400] .= missing;
data_gpu = cu(data);

# @time MAP(pc, data; batch_size=10)
MAP(bpc, data_gpu; batch_size=10)
end

# circuit, bpc = run();
#do_sample(bpc)
pc, bpc = run(; batch_size = 128, num_epochs1 = 2, num_epochs2 = 2, num_epochs3 = 2);
# do_sample(bpc)
# try_map(pc, bpc)
5 changes: 5 additions & 0 deletions src/ProbabilisticCircuits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ module ProbabilisticCircuits
export num_nodes, num_edges

include("nodes/abstract_nodes.jl")

include("nodes/input_distributions.jl")
include("nodes/indicator_dist.jl")
include("nodes/categorical_dist.jl")
include("nodes/binomial_dist.jl")

include("nodes/plain_nodes.jl")

include("bits_circuit.jl")
Expand Down
2 changes: 2 additions & 0 deletions src/bits_circuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ import Base: size, getindex #extend

size(fv::FlatVectors) = size(fv.vectors)

num_layers(fv::FlatVectors) = length(fv.ends)

getindex(fv::FlatVectors, idx) = getindex(fv.vectors, idx)

###############################################
Expand Down
17 changes: 17 additions & 0 deletions src/io/jpc_io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const jpc_grammar = raw"""

node : "L" _WS INT _WS INT _WS SIGNED_INT -> literal_node
| "I" _WS INT _WS INT _WS INT _WS INT -> indicator_node
| "B" _WS INT _WS INT _WS INT _WS INT _WS LOGPROB -> binomial_node
| "C" _WS INT _WS INT _WS INT (_WS LOGPROB)+ -> categorical_node
| "P" _WS INT _WS INT _WS INT child_nodes -> prod_node
| "S" _WS INT _WS INT _WS INT weighted_child_nodes -> sum_node
Expand Down Expand Up @@ -65,6 +66,13 @@ end
t.nodes[x[1]] = PlainInputNode(var, Indicator(value))
end

@rule binomial_node(t::PlainJpcParse, x) = begin
var = Base.parse(Int,x[3])
N = Base.parse(UInt32, x[4])
logp = Base.parse(Float64, x[5])
t.nodes[x[1]] = PlainInputNode(var, Binomial(N, exp(logp)))
end

@rule categorical_node(t::PlainJpcParse, x) = begin
var = Base.parse(Int,x[3])
log_probs = Base.parse.(Float64, x[4:end])
Expand Down Expand Up @@ -121,6 +129,11 @@ function read_fast(input, ::Type{<:ProbCircuit} = PlainProbCircuit, ::JpcFormat
var = Base.parse(Int,tokens[4])
log_probs = Base.parse.(Float64, tokens[5:end])
nodes[id] = PlainInputNode(var, Categorical(log_probs))
elseif startswith(line, "B")
var = Base.parse(Int,tokens[4])
N = Base.parse(UInt32, tokens[5])
logp = Base.parse(Float64, tokens[6])
nodes[id] = PlainInputNode(var, Binomial(N, exp(logp)))
elseif startswith(line, "P")
child_ids = Base.parse.(Int, tokens[5:end]) .+ 1
children = nodes[child_ids]
Expand Down Expand Up @@ -152,6 +165,7 @@ c jpc count-of-jpc-nodes
c L id-of-jpc-node id-of-vtree literal
c I id-of-jpc-node id-of-vtree variable indicator-value
c C id-of-jpc-node id-of-vtree variable {log-probability}+
c B id-of-jpc-node id-of-vtree variable binomial-N binomial-P
c P id-of-sum-jpc-node id-of-vtree number-of-children {child-id}+
c S id-of-product-jpc-node id-of-vtree number-of-children {child-id log-probability}+
c"""
Expand All @@ -176,6 +190,9 @@ function Base.write(io::IO, circuit::ProbCircuit, ::JpcFormat, vtreeid::Function
print(io, "C $(labeling[n]) $(vtreeid(n)) $var")
foreach(p -> print(io, " $p"), params(d))
println(io)
elseif d isa Binomial
print(io, "B $(labeling[n]) $(vtreeid(n)) $var $(d.N) $(log(d.p))")
println(io)
else
error("Input distribution type $(typeof(d)) is unknown to the JPC file format")
end
Expand Down
5 changes: 5 additions & 0 deletions src/io/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,9 @@ latex(ind::Indicator) =
function latex(d::Categorical)
p = round.(exp.(params(d)), digits=3)
"Cat(" * join(p, ", ") * ")"
end

function latex(d::Binomial)
p = round(d.p, digits=3)
"Binomial($(d.N), $(p))"
end
Loading