diff --git a/python/nano/src/bigdl/nano/pytorch/utils.py b/python/nano/src/bigdl/nano/pytorch/utils.py index a86fc97b8e9..5e9a229d8c8 100644 --- a/python/nano/src/bigdl/nano/pytorch/utils.py +++ b/python/nano/src/bigdl/nano/pytorch/utils.py @@ -57,6 +57,7 @@ class ChannelsLastCallback(pl.Callback): def setup(self, trainer, pl_module, stage: Optional[str] = None) -> None: """Override hook setup to convert model to channels_last and wrap DataHook.""" + # TODO: add check for module_states fn_old = getattr(pl_module, "on_before_batch_transfer") fn = batch_call(fn_old) setattr(pl_module, "on_before_batch_transfer_origin", fn_old)