Skip to content

Commit

Permalink
cpu pickle (#1731)
Browse files Browse the repository at this point in the history
Co-authored-by: Weisu Yin <[email protected]>
  • Loading branch information
yinweisu and yinweisu authored Feb 18, 2022
1 parent aeca782 commit 161344e
Showing 1 changed file with 75 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -824,10 +824,85 @@ def load(cls, filename, ctx='auto'):
# pylint: disable=redefined-outer-name, reimported
def __getstate__(self):
d = self.__dict__.copy()
try:
import torch
net = d.pop('net', None)
model_ema = d.pop('_model_ema', None)
optimizer = d.pop('_optimizer', None)
loss_scaler = d.pop('_loss_scaler', None)
save_state = {}
if net is not None:
if not self._custom_net:
if isinstance(net, torch.nn.DataParallel):
save_state['state_dict'] = get_state_dict(net.module, unwrap_model)
else:
save_state['state_dict'] = get_state_dict(net, unwrap_model)
else:
net_pickle = pickle.dumps(net)
save_state['net_pickle'] = net_pickle
if optimizer is not None:
save_state['optimizer'] = optimizer.state_dict()
if loss_scaler is not None:
save_state[loss_scaler.state_dict_key] = loss_scaler.state_dict()
if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model)
except ImportError:
pass
d['save_state'] = save_state
d['_logger'] = None
d['_reporter'] = None
return d

def __setstate__(self, state):
save_state = state.pop('save_state', None)
self.__dict__.update(state)
# logger
self._logger = logging.getLogger(state.get('_name', self.__class__.__name__))
self._logger.setLevel(logging.ERROR)
try:
fh = logging.FileHandler(self._log_file)
self._logger.addHandler(fh)
#pylint: disable=bare-except
except:
pass
if not save_state:
self.net = None
self._optimizer = None
self._logger.setLevel(logging.INFO)
return
try:
import torch
self.net = None
self._optimizer = None
if self._custom_net:
if save_state.get('net_pickle', None):
self.net = pickle.loads(save_state['net_pickle'])
else:
if save_state.get('state_dict', None):
self._init_network(load_only=True)
net_state_dict = self._reconstruct_state_dict(save_state['state_dict'])
if isinstance(self.net, torch.nn.DataParallel):
self.net.module.load_state_dict(net_state_dict)
else:
self.net.load_state_dict(net_state_dict)
if save_state.get('optimizer', None):
self._init_trainer()
self._optimizer.load_state_dict(save_state['optimizer'])
if hasattr(self, '_loss_scaler') and self._loss_scaler and self._loss_scaler.state_dict_key in save_state:
loss_scaler_dict = save_state[self._loss_scaler.state_dict_key]
self._loss_scaler.load_state_dict(loss_scaler_dict)
if save_state.get('state_dict_ema', None):
self._init_model_ema()
model_ema_dict = save_state.get('state_dict_ema')
model_ema_dict = self._reconstruct_state_dict(model_ema_dict)
if isinstance(self.net, torch.nn.DataParallel):
self._model_ema.module.module.load_state_dict(model_ema_dict)
else:
self._model_ema.module.load_state_dict(model_ema_dict)
except ImportError:
pass
self._logger.setLevel(logging.INFO)


class ImageListDataset(torch.utils.data.Dataset):
"""An internal image list dataset for batch predict"""
Expand Down

0 comments on commit 161344e

Please sign in to comment.