Skip to content

Commit

Permalink
Add option to enable/disble the progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
atharva-2001 committed Jul 29, 2021
1 parent c8be515 commit 69556f2
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 7 deletions.
4 changes: 4 additions & 0 deletions tardis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def run_tardis(
show_cplots=True,
log_level=None,
specific=None,
show_progress_bar=True,
**kwargs,
):
"""
Expand Down Expand Up @@ -52,6 +53,8 @@ def run_tardis(
The default value None means that the `specific` specified in the configuration file will be used.
show_cplots : bool, default: True, optional
Option to enable tardis convergence plots.
show_progress_bar : bool, default: True, optional
Option to enable the progress bar.
**kwargs : dict, optional
Optional keyword arguments including those
supported by :obj:`tardis.visualization.tools.convergence_plot.ConvergencePlots`.
Expand Down Expand Up @@ -101,6 +104,7 @@ def run_tardis(
atom_data=atom_data,
virtual_packet_logging=virtual_packet_logging,
show_cplots=show_cplots,
show_progress_bar=show_progress_bar,
**kwargs,
)
for cb in simulation_callbacks:
Expand Down
2 changes: 2 additions & 0 deletions tardis/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def run(
last_run=False,
iteration=0,
total_iterations=0,
show_progress_bar=True,
):
"""
Run the montecarlo calculation
Expand Down Expand Up @@ -310,6 +311,7 @@ def run(
iteration,
total_packets,
total_iterations,
show_progress_bar,
self,
)
self._integrator = FormalIntegrator(model, plasma, self)
Expand Down
28 changes: 21 additions & 7 deletions tardis/montecarlo/montecarlo_numba/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
dynamic_ncols=True,
bar_format="{bar}{percentage:3.0f}% of packets propagated, iteration 0/?",
)
packet_pbar.container.close()


def update_packet_pbar(i, current_iteration, total_iterations, total_packets):
Expand All @@ -62,7 +63,16 @@ def update_packet_pbar(i, current_iteration, total_iterations, total_packets):

# set bar total when first called
if packet_pbar.total == None:
packet_pbar.ncols = "100%"
packet_pbar.container = packet_pbar.status_printer(
packet_pbar.fp,
packet_pbar.total,
packet_pbar.desc,
packet_pbar.ncols,
)
display(packet_pbar.container)
packet_pbar.reset(total=total_packets)
packet_pbar.display()

# display and reset progress bar when run_tardis is called again
if bar_iteration > current_iteration:
Expand Down Expand Up @@ -112,6 +122,7 @@ def montecarlo_radial1d(
iteration,
total_packets,
total_iterations,
show_progress_bar,
runner,
):
packet_collection = PacketCollection(
Expand Down Expand Up @@ -164,6 +175,7 @@ def montecarlo_radial1d(
packet_seeds,
iteration=iteration,
total_iterations=total_iterations,
show_progress_bar=show_progress_bar,
)

runner._montecarlo_virtual_luminosity.value[:] = v_packets_energy_hist
Expand Down Expand Up @@ -211,6 +223,7 @@ def montecarlo_main_loop(
packet_seeds,
iteration,
total_iterations,
show_progress_bar,
):
"""
This is the main loop of the MonteCarlo routine that generates packets
Expand Down Expand Up @@ -269,13 +282,14 @@ def montecarlo_main_loop(
virt_packet_last_line_interaction_out_id = []

for i in prange(len(output_nus)):
with objmode:
update_packet_pbar(
1,
current_iteration=iteration,
total_iterations=total_iterations,
total_packets=total_packets,
)
if show_progress_bar:
with objmode:
update_packet_pbar(
1,
current_iteration=iteration,
total_iterations=total_iterations,
total_packets=total_packets,
)

if montecarlo_configuration.single_packet_seed != -1:
seed = packet_seeds[montecarlo_configuration.single_packet_seed]
Expand Down
5 changes: 5 additions & 0 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
convergence_strategy,
nthreads,
show_cplots,
show_progress_bar,
cplots_kwargs,
):

Expand All @@ -153,6 +154,7 @@ def __init__(
self.luminosity_nu_end = luminosity_nu_end
self.luminosity_requested = luminosity_requested
self.nthreads = nthreads
self.show_progress_bar = show_progress_bar

if convergence_strategy.type in ("damped"):
self.convergence_strategy = convergence_strategy
Expand Down Expand Up @@ -370,6 +372,7 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0, last_run=False):
last_run=last_run,
iteration=self.iterations_executed,
total_iterations=self.iterations,
show_progress_bar=self.show_progress_bar,
)
output_energy = self.runner.output_energy
if np.sum(output_energy < 0) == len(output_energy):
Expand Down Expand Up @@ -587,6 +590,7 @@ def from_config(
packet_source=None,
virtual_packet_logging=False,
show_cplots=True,
show_progress_bar=True,
**kwargs,
):
"""
Expand Down Expand Up @@ -683,4 +687,5 @@ def from_config(
convergence_strategy=config.montecarlo.convergence_strategy,
nthreads=config.montecarlo.nthreads,
cplots_kwargs=cplots_kwargs,
show_progress_bar=show_progress_bar,
)

0 comments on commit 69556f2

Please sign in to comment.