Skip to content

Commit

Permalink
Add print_name to Truncated and CustomDists
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 21, 2024
1 parent 30d00fe commit aa679f3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ def rv_op(
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
_print_name=(class_name, f"\\operatorname{{{class_name}}}"),
# Specific to CustomDist
_random_fn=random,
),
Expand Down Expand Up @@ -802,6 +803,7 @@ def rv_op(
# If logp is not provided, we try to infer it from the dist graph
dict(
inline_logprob=logp is None,
_print_name=(class_name, f"\\operatorname{{{class_name}}}"),
),
)

Expand Down
4 changes: 4 additions & 0 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class TruncatedRV(SymbolicRandomVariable):
def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
self.base_rv_op = base_rv_op
self.max_n_steps = max_n_steps
self._print_name = (
f"Truncated{self.base_rv_op._print_name[0]}",
f"\\operatorname{{{self.base_rv_op._print_name[1]}}}",
)
super().__init__(*args, **kwargs)

def update(self, node: Node):
Expand Down
26 changes: 25 additions & 1 deletion tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pytensor.tensor.random import normal

from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT
from pymc import Bernoulli, Censored, CustomDist, Gamma, HalfCauchy, Mixture, StudentT, Truncated
from pymc.distributions import (
Dirichlet,
DirichletMultinomial,
Expand Down Expand Up @@ -285,3 +285,27 @@ def test_model_repr_variables_without_monkey_patched_repr():

str_repr = model.str_repr()
assert str_repr == "x ~ Normal(0, 1)"


def test_truncated_repr():
with Model() as model:
x = Truncated("x", Gamma.dist(1, 1), lower=0, upper=20)

str_repr = model.str_repr(include_params=False)
assert str_repr == "x ~ TruncatedGamma"


def test_custom_dist_repr():
with Model() as model:

def dist(mu, size):
return Normal.dist(mu, 1, size=size)

def random(rng, mu, size):
return rng.normal(mu, size=size)

x = CustomDist("x", 0, dist=dist, class_name="CustomDistNormal")
x = CustomDist("y", 0, random=random, class_name="CustomRandomNormal")

str_repr = model.str_repr(include_params=False)
assert str_repr == "\n".join(["x ~ CustomDistNormal", "y ~ CustomRandomNormal"])

0 comments on commit aa679f3

Please sign in to comment.