Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for indicator and categorical input distributions in JPC format #111

Merged
merged 4 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/abstract_nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ DAGs.children(pc::ProbCircuit) = inputs(pc)
"Get the distribution of a PC input node"
function dist end

"Get the parameters associated with a sum node"
"Get the parameters associated with a node"
params(n::ProbCircuit) = n.params

"Count the number of parameters in the node"
Expand Down
11 changes: 10 additions & 1 deletion src/input_distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ export InputDist, Indicator, Literal, Bernoulli, Categorical, loglikelihood

abstract type InputDist end

import Base: isapprox #extend

isapprox(x::InputDist, y::InputDist) =
typeof(x) == typeof(y) && params(x) ≈ params(y)

#####################
# indicators or logical literals
#####################
Expand All @@ -18,7 +23,9 @@ const Literal = Indicator{Bool}

num_parameters(n::Indicator, independent) = 0

value(d) = d.value
value(d::Indicator) = d.value

params(d::Indicator) = value(d)

bits(d::Indicator, _ = nothing) = d

Expand Down Expand Up @@ -64,6 +71,8 @@ Bernoulli(logp) =

logps(d::Categorical) = d.logps

params(d::Categorical) = logps(d)

num_categories(d::Categorical) = length(logps(d))

num_parameters(n::Categorical, independent) =
Expand Down
31 changes: 26 additions & 5 deletions src/io/jpc_io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const jpc_grammar = raw"""
header : "jpc" _WS INT

node : "L" _WS INT _WS INT _WS SIGNED_INT -> literal_node
| "C" _WS INT _WS INT _WS INT (_WS LOGPROB)+ -> categorical_node
| "P" _WS INT _WS INT _WS INT child_nodes -> prod_node
| "S" _WS INT _WS INT _WS INT weighted_child_nodes -> sum_node

Expand Down Expand Up @@ -57,6 +58,12 @@ end
t.nodes[x[1]] = PlainInputNode(var, Literal(sign))
end

@rule categorical_node(t::PlainJpcParse, x) = begin
var = Base.parse(Int,x[3])
log_probs = Base.parse.(Float64, x[4:end])
t.nodes[x[1]] = PlainInputNode(var, Categorical(log_probs))
end

@rule prod_node(t::PlainJpcParse,x) = begin
@assert length(x[4]) == Base.parse(Int,x[3])
t.nodes[x[1]] = PlainMulNode(x[4])
Expand All @@ -72,7 +79,7 @@ function Base.parse(::Type{PlainProbCircuit}, str, ::JpcFormat)
Lerche.transform(PlainJpcParse(), ast)
end

function Base.read(io::IO, ::Type{PlainProbCircuit}, ::JpcFormat, fast=true)
function Base.read(io::IO, ::Type{PlainProbCircuit}, ::JpcFormat, fast = true)
if fast
read_fast(io, PlainProbCircuit, JpcFormat())
else
Expand All @@ -99,6 +106,10 @@ function read_fast(input, ::Type{<:ProbCircuit} = PlainProbCircuit, ::JpcFormat
var = abs(lit)
sign = lit > 0
nodes[id] = PlainInputNode(var, Literal(sign))
elseif startswith(line, "C")
var = Base.parse(Int,tokens[4])
log_probs = Base.parse.(Float64, tokens[5:end])
nodes[id] = PlainInputNode(var, Categorical(log_probs))
elseif startswith(line, "P")
child_ids = Base.parse.(Int, tokens[5:end]) .+ 1
children = nodes[child_ids]
Expand Down Expand Up @@ -127,7 +138,8 @@ c jpc nodes appear bottom-up, children before parents
c
c file syntax:
c jpc count-of-jpc-nodes
c L id-of-literal-jpc-node id-of-vtree literal
c L id-of-jpc-node id-of-vtree literal
c C id-of-jpc-node id-of-vtree variable {log-probability}+
c P id-of-sum-jpc-node id-of-vtree number-of-children {child-id}+
c S id-of-product-jpc-node id-of-vtree number-of-children {child-id log-probability}+
c"""
Expand All @@ -141,9 +153,18 @@ function Base.write(io::IO, circuit::ProbCircuit, ::JpcFormat, vtreeid::Function
println(io, "jpc $(num_nodes(circuit))")
foreach(circuit) do n
if isinput(n)
@assert dist(n) isa Literal
literal = value(dist(n)) ? randvar(n) : -randvar(n)
println(io, "L $(labeling[n]) $(vtreeid(n)) $literal")
var = randvar(n)
d = dist(n)
if d isa Literal
literal = value(d) ? var : -var
println(io, "L $(labeling[n]) $(vtreeid(n)) $literal")
elseif d isa Categorical
print(io, "C $(labeling[n]) $(vtreeid(n)) $var")
foreach(p -> print(io, " $p"), params(d))
println(io)
else
error("Input distribution type $(typeof(d)) is unknown to the JPC file format")
end
else
t = ismul(n) ? "P" : "S"
print(io, "$t $(labeling[n]) $(vtreeid(n)) $(num_inputs(n))")
Expand Down
2 changes: 2 additions & 0 deletions src/plain_nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ dist(n::PlainInputNode) = n.dist

randvars(n::PlainInputNode) = n.randvars

params(n::PlainInputNode) = params(dist(n))

num_parameters_node(n::PlainInputNode, independent) =
num_parameters(dist(n), independent)
num_parameters_node(n::PlainMulNode, _) = 0
Expand Down
2 changes: 1 addition & 1 deletion test/helper/pc_equals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function test_pc_equals(c1, c2)
@test isinput(n1)
@test isinput(n2)
# TODO: might need to fix for non-literal dists
guyvdbroeck marked this conversation as resolved.
Show resolved Hide resolved
@test dist(n1) == dist(n2)
@test dist(n1) dist(n2)
end
end
end
30 changes: 29 additions & 1 deletion test/io/jpc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ProbabilisticCircuits: JpcFormat
include("../helper/plain_dummy_circuits.jl")
include("../helper/pc_equals.jl")

@testset "Jpc IO tests" begin
@testset "Jpc IO tests Literal" begin

# Indicators
pc = little_3var()
Expand Down Expand Up @@ -32,3 +32,31 @@ include("../helper/pc_equals.jl")
end

end

@testset "Jpc IO tests categorical" begin

pc = little_3var_categorical()

mktempdir() do tmp

file = "$tmp/example.jpc"
write(file, pc)

pc2 = read(file, ProbCircuit)
test_pc_equals(pc, pc2)

pc2 = read(file, ProbCircuit, JpcFormat(), true)
test_pc_equals(pc, pc2)

pc2 = read(file, ProbCircuit, JpcFormat(), false)
test_pc_equals(pc, pc2)

file = "$tmp/example.jpc.gz"
write(file, pc)

pc2 = read(file, ProbCircuit)
test_pc_equals(pc, pc2)

end

end