From d9687a88510c8bf3b5fa2fa8963aaf000117045d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 21 Mar 2024 12:06:20 +0100 Subject: [PATCH] Add `print_name` to Truncated and CustomDists --- pymc/distributions/distribution.py | 2 ++ pymc/distributions/truncated.py | 4 ++++ tests/test_printing.py | 26 +++++++++++++++++++++++++- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 7035e4e027..0d3aee8862 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -663,6 +663,7 @@ def rv_op( ndim_supp=ndim_supp, ndims_params=ndims_params, dtype=dtype, + _print_name=(class_name, f"\\operatorname{{{class_name}}}"), # Specific to CustomDist _random_fn=random, ), @@ -802,6 +803,7 @@ def rv_op( # If logp is not provided, we try to infer it from the dist graph dict( inline_logprob=logp is None, + _print_name=(class_name, f"\\operatorname{{{class_name}}}"), ), ) diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index d7c8181f2d..2a1618348a 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -56,6 +56,10 @@ class TruncatedRV(SymbolicRandomVariable): def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs): self.base_rv_op = base_rv_op self.max_n_steps = max_n_steps + self._print_name = ( + f"Truncated{self.base_rv_op._print_name[0]}", + f"\\operatorname{{{self.base_rv_op._print_name[1]}}}", + ) super().__init__(*args, **kwargs) def update(self, node: Node): diff --git a/tests/test_printing.py b/tests/test_printing.py index b2577768f1..26692ca569 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -15,7 +15,7 @@ from pytensor.tensor.random import normal -from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT +from pymc import Bernoulli, Censored, CustomDist, Gamma, HalfCauchy, Mixture, StudentT, Truncated from pymc.distributions import ( Dirichlet, DirichletMultinomial, @@ -285,3 +285,27 @@ def test_model_repr_variables_without_monkey_patched_repr(): str_repr = model.str_repr() assert str_repr == "x ~ Normal(0, 1)" + + +def test_truncated_repr(): + with Model() as model: + x = Truncated("x", Gamma.dist(1, 1), lower=0, upper=20) + + str_repr = model.str_repr(include_params=False) + assert str_repr == "x ~ TruncatedGamma" + + +def test_custom_dist_repr(): + with Model() as model: + + def dist(mu, size): + return Normal.dist(mu, 1, size=size) + + def random(rng, mu, size): + return rng.normal(mu, size=size) + + x = CustomDist("x", 0, dist=dist, class_name="CustomDistNormal") + x = CustomDist("y", 0, random=random, class_name="CustomRandomNormal") + + str_repr = model.str_repr(include_params=False) + assert str_repr == "\n".join(["x ~ CustomDistNormal", "y ~ CustomRandomNormal"])