Skip to content

Commit

Permalink
Add time remaining column to progress bars (#7273)
Browse files Browse the repository at this point in the history
* Add time remaining column to progress bars

* Consistent order remaining/elapsed

* Disable sample_posterior_predictive taskbar when progressbar=False

* Formatting

* More formatting

* More formatting (why doesnt pre-commit fix this?)

* Disable progress bar when progress=False

* Set refresh flag in progress bar updates

* Typo
  • Loading branch information
fonnesbeck authored Apr 26, 2024
1 parent 6761c0c commit 60a6314
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def apply_function_over_dataset(
out_dict = _DefaultTrace(n_pts)
indices = range(n_pts)

with Progress(console=Console(theme=progressbar_theme)) as progress:
with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress:
task = progress.add_task("Computing ...", total=n_pts, visible=progressbar)
for idx in indices:
out = fn(posterior_pts[idx])
Expand Down
4 changes: 3 additions & 1 deletion pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,9 @@ def sample_posterior_predictive(
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)
try:
with Progress(console=Console(theme=progressbar_theme)) as progress:
with Progress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
task = progress.add_task("Sampling ...", total=samples, visible=progressbar)
for idx in np.arange(samples):
if nchain > 1:
Expand Down
4 changes: 2 additions & 2 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,8 +1041,8 @@ def _sample(
for it, diverging in enumerate(sampling_gen):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
progress.update(task, advance=1)
progress.update(task, advance=1, completed=True)
progress.update(task, refresh=True, advance=1)
progress.update(task, refresh=True, advance=1, completed=True)
except KeyboardInterrupt:
pass

Expand Down
6 changes: 5 additions & 1 deletion pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

from rich.console import Console
from rich.progress import BarColumn, Progress, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme

from pymc.blocking import DictToArrayBijection
Expand Down Expand Up @@ -428,7 +428,10 @@ def __init__(
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)
self._show_progress = progressbar
self._divergences = 0
Expand Down Expand Up @@ -465,6 +468,7 @@ def __iter__(self):
self._divergences += 1
progress.update(
task,
refresh=True,
completed=self._completed_draws,
total=self._total_draws,
description=self._desc.format(self),
Expand Down
6 changes: 4 additions & 2 deletions pymc/sampling/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import cloudpickle
import numpy as np

from rich.progress import BarColumn, Progress, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn

from pymc.backends.base import BaseTrace
from pymc.initial_point import PointType
Expand Down Expand Up @@ -104,7 +104,7 @@ def _sample_population(
task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar)

for _ in sampling:
progress.update(task, advance=1)
progress.update(task, advance=1, refresh=True)

return

Expand Down Expand Up @@ -180,6 +180,8 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True):
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
) as self._progress:
for c, stepper in enumerate(steppers):
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
Expand Down
14 changes: 12 additions & 2 deletions pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
import numpy as np

from arviz import InferenceData
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)

import pymc

Expand Down Expand Up @@ -366,6 +372,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
with Progress(
TextColumn("{task.description}"),
SpinnerColumn(),
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
TextColumn("{task.fields[status]}"),
) as progress:
Expand Down Expand Up @@ -403,6 +411,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
stage = update_data["stage"]
beta = update_data["beta"]
# update the progress bar for this task:
progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id)
progress.update(
status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id, refresh=True
)

return tuple(cloudpickle.loads(r.result()) for r in futures)
3 changes: 2 additions & 1 deletion pymc/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def find_MAP(
if isinstance(e, StopIteration):
pm._log.info(e)
finally:
cost_func.progress.update(cost_func.task, completed=cost_func.n_eval)
cost_func.progress.update(cost_func.task, completed=cost_func.n_eval, refresh=True)
print(file=sys.stdout)

mx0 = RaveledVars(mx0, x0.point_map_info)
Expand Down Expand Up @@ -223,6 +223,7 @@ def __init__(
*Progress.get_default_columns(),
TextColumn("{task.fields[loss]}"),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)
self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="")

Expand Down
4 changes: 3 additions & 1 deletion pymc/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def fit(
def _iterate_without_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks):
i = 0
try:
with Progress(console=Console(theme=progressbar_theme)) as progress:
with Progress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
task = progress.add_task("Fitting", total=n, visible=progressbar)
for i in range(n):
step_func()
Expand Down

0 comments on commit 60a6314

Please sign in to comment.