Skip to content

Commit

Permalink
Make Model.str_repr robust to variables without monkey-patch
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 11, 2023
1 parent df7b267 commit 6f4a040
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
38 changes: 23 additions & 15 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

from typing import Union

from pytensor.compile import SharedVariable
Expand Down Expand Up @@ -98,36 +96,46 @@ def str_for_dist(
def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str:
"""Make a human-readable string representation of Model, listing all random variables
and their distributions, optionally including parameter values."""
all_rv = itertools.chain(model.unobserved_RVs, model.observed_RVs, model.potentials)

rv_reprs = [rv.str_repr(formatting=formatting, include_params=include_params) for rv in all_rv]
rv_reprs = [rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr]
kwargs = dict(formatting=formatting, include_params=include_params)
free_rv_reprs = [str_for_dist(dist, **kwargs) for dist in model.free_RVs]
observed_rv_reprs = [str_for_dist(rv, **kwargs) for rv in model.observed_RVs]
det_reprs = [
str_for_potential_or_deterministic(dist, **kwargs, dist_name="Deterministic")
for dist in model.deterministics
]
potential_reprs = [
str_for_potential_or_deterministic(pot, **kwargs, dist_name="Potential")
for pot in model.potentials
]

var_reprs = free_rv_reprs + det_reprs + observed_rv_reprs + potential_reprs

if not rv_reprs:
if not var_reprs:
return ""
if "latex" in formatting:
rv_reprs = [
rv_repr.replace(r"\sim", r"&\sim &").strip("$")
for rv_repr in rv_reprs
if rv_repr is not None
var_reprs = [
var_repr.replace(r"\sim", r"&\sim &").strip("$")
for var_repr in var_reprs
if var_repr is not None
]
return r"""$$
\begin{{array}}{{rcl}}
{}
\end{{array}}
$$""".format(
"\\\\".join(rv_reprs)
"\\\\".join(var_reprs)
)
else:
# align vars on their ~
names = [s[: s.index("~") - 1] for s in rv_reprs]
distrs = [s[s.index("~") + 2 :] for s in rv_reprs]
names = [s[: s.index("~") - 1] for s in var_reprs]
distrs = [s[s.index("~") + 2 :] for s in var_reprs]
maxlen = str(max(len(x) for x in names))
rv_reprs = [
var_reprs = [
("{name:>" + maxlen + "} ~ {distr}").format(name=n, distr=d)
for n, d in zip(names, distrs)
]
return "\n".join(rv_reprs)
return "\n".join(var_reprs)


def str_for_potential_or_deterministic(
Expand Down
14 changes: 14 additions & 0 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
import numpy as np

from pytensor.tensor.random import normal

from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT
from pymc.distributions import (
Dirichlet,
Expand Down Expand Up @@ -274,3 +276,15 @@ def test_model_latex_repr_mixture_model():
"$$",
]
assert [line.strip() for line in latex_repr.split("\n")] == expected


def test_model_repr_variables_without_monkey_patched_repr():
"""Test that model repr does not rely on individual variables having the str_repr method monkey patched."""
x = normal(name="x")
assert not hasattr(x, "str_repr")

model = Model()
model.register_rv(x, "x")

str_repr = model.str_repr()
assert str_repr == "x ~ Normal(0, 1)"

0 comments on commit 6f4a040

Please sign in to comment.