diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index 4a5330ea2c5a..0bea260cd3d9 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -92,6 +92,7 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging, self._curr_module = None self._curr_bucket_key = None self._params_dirty = False + self._monitor = None def _reset_bind(self): """Internal utility function to reset binding.""" @@ -367,6 +368,8 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): module.bind(data_shapes, label_shapes, self._curr_module.for_training, self._curr_module.inputs_need_grad, force_rebind=False, shared_module=self._buckets[self._default_bucket_key]) + if self._monitor is not None: + module.install_monitor(self._monitor) self._buckets[bucket_key] = module self._curr_module = self._buckets[bucket_key] @@ -510,5 +513,6 @@ def symbol(self): def install_monitor(self, mon): """Installs monitor on all executors """ assert self.binded + self._monitor = mon for mod in self._buckets.values(): mod.install_monitor(mon)