Skip to content

Commit

Permalink
Implement a basic callback system for calling a custom function after…
Browse files Browse the repository at this point in the history
… each iteration.
  • Loading branch information
ftsamis committed Aug 19, 2016
1 parent 8bc1341 commit a12700b
Showing 1 changed file with 63 additions and 1 deletion.
64 changes: 63 additions & 1 deletion tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
from astropy import units as u
from collections import OrderedDict

from tardis.montecarlo.base import MontecarloRunner
from tardis.model.base import Radial1DModel
Expand Down Expand Up @@ -39,6 +40,9 @@ def __init__(self, iterations, hold_iterations, model, plasma, runner,
'neither damped nor specific '
'- input is {0}'.format(convergence_strategy.type))

self._callbacks = OrderedDict()
self._cb_next_id = 0

def estimate_t_inner(self, input_t_inner, luminosity_requested,
t_inner_update_exponent=-0.5):
emitted_luminosity = self.runner.calculate_emitted_luminosity(
Expand Down Expand Up @@ -177,6 +181,7 @@ def run(self):
self.hold_iterations + 1))
else:
times_converged = 0
self._call_back()
self.converged = times_converged == self.hold_iterations + 1
# Last iteration
self.iterate(self.last_no_of_packets, self.no_of_virtual_packets, True)
Expand All @@ -185,6 +190,7 @@ def run(self):
logger.info("Simulation finished in {0:d} iterations "
"and took {1:.2f} s".format(
self.iterations_executed, time.time() - start_time))
self._call_back()

def log_plasma_state(self, t_rad, w, t_inner, next_t_rad, next_w,
next_t_inner, log_sampling=5):
Expand Down Expand Up @@ -235,7 +241,8 @@ def log_run_results(self, emitted_luminosity, absorbed_luminosity):
emitted_luminosity, absorbed_luminosity,
self.luminosity_requested))

def to_hdf(self, path_or_buf, path='', plasma_properties=None):
def to_hdf(self, path_or_buf, path='', plasma_properties=None,
suffix_count=True):
"""
Store the simulation to an HDF structure.
Expand All @@ -249,14 +256,69 @@ def to_hdf(self, path_or_buf, path='', plasma_properties=None):
`None` or a `PlasmaPropertyCollection` which will
be passed as the collection argument to the
plasma.to_hdf method.
suffix_count : bool
If True, the path inside the HDF will be suffixed with the
number of the iteration being stored.
Returns
-------
None
"""
if suffix_count:
path += str(self.iterations_executed)
self.runner.to_hdf(path_or_buf, path)
self.model.to_hdf(path_or_buf, path)
self.plasma.to_hdf(path_or_buf, path, plasma_properties)

def _call_back(self):
for cb, args in self._callbacks.values():
cb(self, *args)

def add_callback(self, cb_func, *args):
"""
Add a function which will be called
after every iteration.
The cb_func signature must look like:
cb_func(simulation, extra_arg1, ...)
Parameters
----------
cb_func: callable
The callback function
arg1:
The first additional arguments passed to the callable function
...
Returns
-------
: int
The callback ID
"""
cb_id = self._cb_next_id
self._callbacks[cb_id] = (cb_func, args)
self._cb_next_id += 1
return cb_id

def remove_callback(self, id):
"""
Remove the callback with a specific ID (which was returned by
add_callback)
Parameters
----------
id: int
The callback ID
Returns
-------
: True if the callback was successfully removed.
"""
try:
del self._callbacks[id]
return True
except KeyError:
return False

@classmethod
def from_config(cls, config):
model = Radial1DModel.from_config(config)
Expand Down

0 comments on commit a12700b

Please sign in to comment.