Skip to content

Commit

Permalink
prettier categorical leaf plots
Browse files Browse the repository at this point in the history
  • Loading branch information
guyvdbroeck committed Feb 22, 2022
1 parent afae957 commit e06f8da
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 0 additions & 6 deletions src/input_distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ value(d::Indicator) = d.value

params(d::Indicator) = value(d)

latex(ind::Indicator) =
"\\mathbf{1}_{$(value(ind))}"

bits(d::Indicator, _ = nothing) = d

unbits(d::Indicator, _ = nothing) = d
Expand Down Expand Up @@ -81,9 +78,6 @@ num_categories(d::Categorical) = length(logps(d))
num_parameters(n::Categorical, independent) =
num_categories(n) - (independent ? 1 : 0)

latex(d::Categorical) =
"Categorical("*join(params(d), ", ")*")"

init_params(d::Categorical, perturbation::Float32) = begin
unnormalized_probs = map(rand(Float32, num_categories(d))) do x
Float32(1.0 - perturbation + x * 2.0 * perturbation)
Expand Down
8 changes: 8 additions & 0 deletions src/io/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,12 @@ function plot(pc::ProbCircuit)
end

TikzGraphs.plot(g, node_labels; edge_labels, edge_style="font=\\tiny")
end

latex(ind::Indicator) =
"\\mathbf{1}_{$(value(ind))}"

function latex(d::Categorical)
p = round.(exp.(params(d)), digits=3)
"Cat(" * join(p, ", ") * ")"
end

0 comments on commit e06f8da

Please sign in to comment.