Skip to content

Commit

Permalink
Summarize model as rich table
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 15, 2024
1 parent 2accca9 commit 283cc10
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,12 @@ Model Transforms

autoreparam.vip_reparametrize
autoreparam.VIP


Printing
========
.. currentmodule:: pymc_experimental.printing
.. autosummary::
:toctree: generated/

model_table
182 changes: 182 additions & 0 deletions pymc_experimental/printing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import numpy as np

from pymc import Model
from pymc.printing import str_for_dist, str_for_potential_or_deterministic
from pytensor import Mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.type import Constant, Variable
from rich.box import SIMPLE_HEAD
from rich.table import Table


def variable_expression(
model: Model,
var: Variable,
truncate_deterministic: int | None,
) -> str:
"""Get the expression of a variable in a human-readable format."""
if var in model.data_vars:
var_expr = "Data"
elif var in model.deterministics:
str_repr = str_for_potential_or_deterministic(var, dist_name="")
_, var_expr = str_repr.split(" ~ ")
var_expr = var_expr[1:-1] # Remove outer parentheses (f(...))
if truncate_deterministic is not None and len(var_expr) > truncate_deterministic:
contents = var_expr[2:-1].split(", ")
str_len = 0
for show_n, content in enumerate(contents):
str_len += len(content) + 2
if str_len > truncate_deterministic:
break
var_expr = f"f({', '.join(contents[:show_n])}, ...)"
elif var in model.potentials:
var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split(" ~ ")[1]
else: # basic_RVs
var_expr = str_for_dist(var).split(" ~ ")[1]
return var_expr


def _extract_dim_value(var: SharedVariable | Constant) -> np.ndarray:
if isinstance(var, SharedVariable):
return var.get_value(borrow=True)
else:
return var.data


def dims_expression(model: Model, var: Variable) -> str:
"""Get the dimensions of a variable in a human-readable format."""
if (dims := model.named_vars_to_dims.get(var.name)) is not None:
dim_sizes = {dim: _extract_dim_value(model.dim_lengths[dim]) for dim in dims}
return " × ".join(f"{dim}[{dim_size}]" for dim, dim_size in dim_sizes.items())
else:
dim_sizes = list(var.shape.eval(mode=Mode(linker="py", optimizer=None)))
return f"[{', '.join(map(str, dim_sizes))}]" if dim_sizes else ""


def model_parameter_count(model: Model) -> int:
"""Count the number of parameters in the model."""
rv_shapes = model.eval_rv_shapes() # Includes transformed variables
return np.sum([np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs])


def model_table(
model: Model,
*,
split_groups: bool = True,
truncate_deterministic: int | None = None,
parameter_count: bool = True,
) -> Table:
"""Create a rich table with a summary of the model's variables and their expressions.
Parameters
----------
model : Model
The PyMC model to summarize.
split_groups : bool
If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs)
will be separated by a section.
truncate_deterministic : int | None
If not None, truncate the expression of deterministic variables that go beyond this length.
empty_dims : bool
If True, show the dimensions of scalar variables as an empty list.
parameter_count : bool
If True, add a row with the total number of parameters in the model.
Returns
-------
Table
A rich table with the model's variables, their expressions and dims.
Examples
--------
.. code-block:: python
import numpy as np
import pymc as pm
from pymc_experimental.printing import model_table
coords = {"subject": range(20), "param": ["a", "b"]}
with pm.Model(coords=coords) as m:
x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param"))
y = pm.Data("y", np.random.normal(size=(20,)), dims="subject")
beta = pm.Normal("beta", mu=0, sigma=1, dims="param")
mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject")
sigma = pm.HalfNormal("sigma", sigma=1)
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject")
table = model_table(m)
table # Displays the following table in an interactive environment
'''
Variable Expression Dimensions
─────────────────────────────────────────────────────
x = Data subject[20] × param[2]
y = Data subject[20]
beta ~ Normal(0, 1) param[2]
sigma ~ HalfNormal(0, 1)
Parameter count = 3
mu = f(beta) subject[20]
y_obs ~ Normal(mu, sigma) subject[20]
'''
Output can be explicitly rendered in a rich console or exported to text, html or svg.
.. code-block:: python
from rich.console import Console
console = Console(record=True)
console.print(table)
text_export = console.export_text()
html_export = console.export_html()
svg_export = console.export_svg()
"""
table = Table(
show_header=True,
show_edge=False,
box=SIMPLE_HEAD,
highlight=False,
collapse_padding=True,
)
table.add_column("Variable", justify="right")
table.add_column("Expression", justify="left")
table.add_column("Dimensions")

if split_groups:
groups = (
model.data_vars,
model.free_RVs,
model.deterministics,
model.potentials,
model.observed_RVs,
)
else:
# Show variables in the order they were defined
groups = (model.named_vars.values(),)

for group in groups:
if not group:
continue

for var in group:
var_name = var.name
sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]'
var_expr = variable_expression(model, var, truncate_deterministic)
dims_expr = dims_expression(model, var)
if dims_expr == "[]":
dims_expr = ""
table.add_row(var_name + sep, var_expr, dims_expr)

if parameter_count and (not split_groups or group == model.free_RVs):
n_parameters = model_parameter_count(model)
table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]")

table.add_section()

return table
98 changes: 98 additions & 0 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np
import pymc as pm

from rich.console import Console

from pymc_experimental.printing import model_table


def get_text(table) -> str:
console = Console(width=80)
with console.capture() as capture:
console.print(table)
return capture.get()


def test_model_table():
with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model:
x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))

mu = pm.Normal("mu", mu=0, sigma=1)
sigma = pm.HalfNormal("sigma", sigma=1)
global_intercept = pm.Normal("global_intercept", mu=0, sigma=1)
intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, shape=(20, 1))
beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject")

mu_trial = pm.Deterministic(
"mu_trial",
global_intercept.squeeze() + intercept_subject + beta_subject * x_data,
dims=["trial", "subject"],
)
noise = pm.Exponential("noise", lam=1)
y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject"))

pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject")

table_txt = get_text(model_table(model))
expected = """ Variable Expression Dimensions
────────────────────────────────────────────────────────────────────────────────
x_data = Data trial[6] × subject[20]
y_data = Data trial[6] × subject[20]
mu ~ Normal(0, 1)
sigma ~ HalfNormal(0, 1)
global_intercept ~ Normal(0, 1)
intercept_subject ~ Normal(0, 1) [20, 1]
beta_subject ~ Normal(mu, sigma) subject[20]
noise ~ Exponential(f())
Parameter count = 44
mu_trial = f(intercept_subject, trial[6] × subject[20]
beta_subject,
global_intercept)
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
"""
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]

table_txt = get_text(model_table(model, split_groups=False))
expected = """ Variable Expression Dimensions
────────────────────────────────────────────────────────────────────────────────
x_data = Data trial[6] × subject[20]
y_data = Data trial[6] × subject[20]
mu ~ Normal(0, 1)
sigma ~ HalfNormal(0, 1)
global_intercept ~ Normal(0, 1)
intercept_subject ~ Normal(0, 1) [20, 1]
beta_subject ~ Normal(mu, sigma) subject[20]
mu_trial = f(intercept_subject, trial[6] × subject[20]
beta_subject,
global_intercept)
noise ~ Exponential(f())
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
Parameter count = 44
"""
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]

table_txt = get_text(
model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False)
)
expected = """ Variable Expression Dimensions
────────────────────────────────────────────────────────────────────────────
x_data = Data trial[6] × subject[20]
y_data = Data trial[6] × subject[20]
mu ~ Normal(0, 1)
sigma ~ HalfNormal(0, 1)
global_intercept ~ Normal(0, 1)
intercept_subject ~ Normal(0, 1) [20, 1]
beta_subject ~ Normal(mu, sigma) subject[20]
mu_trial = f(intercept_subject, ...) trial[6] × subject[20]
noise ~ Exponential(f())
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
"""
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]

0 comments on commit 283cc10

Please sign in to comment.