Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explain how to obtain the model graphviz in a non-Ipython environment #7181

Merged
merged 2 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,27 @@ def broadcast_dist_samples_shape(shapes, size=None):
Examples
--------
.. code-block:: python

size = 100
shape0 = (size,)
shape1 = (size, 5)
shape2 = (size, 4, 5)
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
size=size)
assert out == (size, 4, 5)

.. code-block:: python

size = 100
shape0 = (size,)
shape1 = (5,)
shape2 = (4, 5)
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
size=size)
assert out == (size, 4, 5)

.. code-block:: python

size = 100
shape0 = (1,)
shape1 = (5,)
Expand Down
11 changes: 10 additions & 1 deletion pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,7 +1873,7 @@ def to_graphviz(
.. code-block:: python

import numpy as np
from pymc import HalfCauchy, Model, Normal, model_to_graphviz
from pymc import HalfCauchy, Model, Normal

J = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
Expand All @@ -1890,6 +1890,15 @@ def to_graphviz(
obs = Normal("obs", theta, sigma=sigma, observed=y)

schools.to_graphviz()

Note that this code automatically plots the graph if executed in a Jupyter notebook.
If executed non-interactively, such as in a script or python console, the graph
needs to be rendered explicitly:

.. code-block:: python

# creates the file `schools.pdf`
schools.to_graphviz().render("schools")
"""
return model_to_graphviz(
model=self,
Expand Down
9 changes: 9 additions & 0 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,15 @@ def model_to_graphviz(
obs = Normal("obs", theta, sigma=sigma, observed=y)

model_to_graphviz(schools)

Note that this code automatically plots the graph if executed in a Jupyter notebook.
If executed non-interactively, such as in a script or python console, the graph
needs to be rendered explicitly:

.. code-block:: python

# creates the file `schools.pdf`
model_to_graphviz(schools).render("schools")
"""
if "plain" not in formatting:
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
Expand Down
1 change: 1 addition & 0 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,7 @@ def collect_default_updates(
Examples
--------
.. code:: python

import pymc as pm
from pytensor.scan import scan
from pymc.pytensorf import collect_default_updates
Expand Down
Loading