Skip to content

Commit

Permalink
Fixing the monitor callback of the bucketing module. (apache#8696)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdomhan authored and piiswrong committed Nov 19, 2017
1 parent 2341420 commit cbce299
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/mxnet/module/bucketing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
self._curr_module = None
self._curr_bucket_key = None
self._params_dirty = False
self._monitor = None

def _reset_bind(self):
"""Internal utility function to reset binding."""
Expand Down Expand Up @@ -367,6 +368,8 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
if self._monitor is not None:
module.install_monitor(self._monitor)
self._buckets[bucket_key] = module

self._curr_module = self._buckets[bucket_key]
Expand Down Expand Up @@ -510,5 +513,6 @@ def symbol(self):
def install_monitor(self, mon):
"""Installs monitor on all executors """
assert self.binded
self._monitor = mon
for mod in self._buckets.values():
mod.install_monitor(mon)

0 comments on commit cbce299

Please sign in to comment.