From 161344e7b0f37dc4fa9bfd417bec48a594ffc5cd Mon Sep 17 00:00:00 2001 From: Weisu Yin Date: Thu, 17 Feb 2022 21:33:46 -0800 Subject: [PATCH] cpu pickle (#1731) Co-authored-by: Weisu Yin --- .../torch_image_classification.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py b/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py index 1571cb0dd..ac6181df4 100644 --- a/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py +++ b/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py @@ -824,10 +824,85 @@ def load(cls, filename, ctx='auto'): # pylint: disable=redefined-outer-name, reimported def __getstate__(self): d = self.__dict__.copy() + try: + import torch + net = d.pop('net', None) + model_ema = d.pop('_model_ema', None) + optimizer = d.pop('_optimizer', None) + loss_scaler = d.pop('_loss_scaler', None) + save_state = {} + if net is not None: + if not self._custom_net: + if isinstance(net, torch.nn.DataParallel): + save_state['state_dict'] = get_state_dict(net.module, unwrap_model) + else: + save_state['state_dict'] = get_state_dict(net, unwrap_model) + else: + net_pickle = pickle.dumps(net) + save_state['net_pickle'] = net_pickle + if optimizer is not None: + save_state['optimizer'] = optimizer.state_dict() + if loss_scaler is not None: + save_state[loss_scaler.state_dict_key] = loss_scaler.state_dict() + if model_ema is not None: + save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model) + except ImportError: + pass + d['save_state'] = save_state + d['_logger'] = None + d['_reporter'] = None return d def __setstate__(self, state): + save_state = state.pop('save_state', None) self.__dict__.update(state) + # logger + self._logger = logging.getLogger(state.get('_name', self.__class__.__name__)) + self._logger.setLevel(logging.ERROR) + try: + fh = logging.FileHandler(self._log_file) + self._logger.addHandler(fh) + #pylint: disable=bare-except + except: + pass + if not save_state: + self.net = None + self._optimizer = None + self._logger.setLevel(logging.INFO) + return + try: + import torch + self.net = None + self._optimizer = None + if self._custom_net: + if save_state.get('net_pickle', None): + self.net = pickle.loads(save_state['net_pickle']) + else: + if save_state.get('state_dict', None): + self._init_network(load_only=True) + net_state_dict = self._reconstruct_state_dict(save_state['state_dict']) + if isinstance(self.net, torch.nn.DataParallel): + self.net.module.load_state_dict(net_state_dict) + else: + self.net.load_state_dict(net_state_dict) + if save_state.get('optimizer', None): + self._init_trainer() + self._optimizer.load_state_dict(save_state['optimizer']) + if hasattr(self, '_loss_scaler') and self._loss_scaler and self._loss_scaler.state_dict_key in save_state: + loss_scaler_dict = save_state[self._loss_scaler.state_dict_key] + self._loss_scaler.load_state_dict(loss_scaler_dict) + if save_state.get('state_dict_ema', None): + self._init_model_ema() + model_ema_dict = save_state.get('state_dict_ema') + model_ema_dict = self._reconstruct_state_dict(model_ema_dict) + if isinstance(self.net, torch.nn.DataParallel): + self._model_ema.module.module.load_state_dict(model_ema_dict) + else: + self._model_ema.module.load_state_dict(model_ema_dict) + except ImportError: + pass + self._logger.setLevel(logging.INFO) + class ImageListDataset(torch.utils.data.Dataset): """An internal image list dataset for batch predict"""