From e06f8da2861d0989d5c9cabaa0425d8392f7fa9d Mon Sep 17 00:00:00 2001 From: Guy Van den Broeck Date: Tue, 22 Feb 2022 08:34:37 -0800 Subject: [PATCH] prettier categorical leaf plots --- src/input_distributions.jl | 6 ------ src/io/plot.jl | 8 ++++++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/input_distributions.jl b/src/input_distributions.jl index b2683778..b89a6bbf 100644 --- a/src/input_distributions.jl +++ b/src/input_distributions.jl @@ -27,9 +27,6 @@ value(d::Indicator) = d.value params(d::Indicator) = value(d) -latex(ind::Indicator) = - "\\mathbf{1}_{$(value(ind))}" - bits(d::Indicator, _ = nothing) = d unbits(d::Indicator, _ = nothing) = d @@ -81,9 +78,6 @@ num_categories(d::Categorical) = length(logps(d)) num_parameters(n::Categorical, independent) = num_categories(n) - (independent ? 1 : 0) -latex(d::Categorical) = - "Categorical("*join(params(d), ", ")*")" - init_params(d::Categorical, perturbation::Float32) = begin unnormalized_probs = map(rand(Float32, num_categories(d))) do x Float32(1.0 - perturbation + x * 2.0 * perturbation) diff --git a/src/io/plot.jl b/src/io/plot.jl index 32102d33..2c24162f 100644 --- a/src/io/plot.jl +++ b/src/io/plot.jl @@ -35,4 +35,12 @@ function plot(pc::ProbCircuit) end TikzGraphs.plot(g, node_labels; edge_labels, edge_style="font=\\tiny") +end + +latex(ind::Indicator) = + "\\mathbf{1}_{$(value(ind))}" + +function latex(d::Categorical) + p = round.(exp.(params(d)), digits=3) + "Cat(" * join(p, ", ") * ")" end \ No newline at end of file