Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transform argument to plots #1036

Merged
merged 17 commits into from
Feb 6, 2020
6 changes: 6 additions & 0 deletions arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def plot_trace(
var_names=None,
coords=None,
divergences="bottom",
transform=None,
figsize=None,
rug=False,
lines=None,
Expand Down Expand Up @@ -46,6 +47,8 @@ def plot_trace(
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
divergences : {"bottom", "top", None, False}
Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y.
transform : callable
Function to transform data (defaults to identity)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a better wording will be "Function to transform data (defaults to None i.e. the identity function)" or probably just "Function to transform data (defaults to None)"

figsize : figure size tuple
If None, size is (12, variables * 2)
rug : bool
Expand Down Expand Up @@ -137,6 +140,9 @@ def plot_trace(
divergence_data = False

data = get_coords(convert_to_dataset(data, group="posterior"), coords)

data = transform(data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By doing this you are basically writing data = None(data), which is not legal Python. Instead you should do something like:

if transform is not None:
    data = transform(data)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aloctavodia Sorry for this elementary mistake! I fixed it.


var_names = _var_names(var_names, data)

if lines is None:
Expand Down