You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I have tried to save the checkpoint and resume training. It seems that the parameters have been loaded, but the result is worse than training from scratch.
Here is the code I modified.
if resume:
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()
Thanks for your interest in our library. Could you share which loss function you are using? As some loss function Class also involve optimization parameters that are updated in each iteration(e.g., moving average estimator self.u_pos in AveragePrecisionLoss() ), the degraded performance might be caused by the re-initialized loss function for each resuming. So, to resume training exactly, you also need to load the previous optimization parameters in the loss function Class. Currently, you can try the naive solution below to save the previous loss function Class. We'll incorporate this feature to support resuming training in our further development.
Hi, I have tried to save the checkpoint and resume training. It seems that the parameters have been loaded, but the result is worse than training from scratch.
Here is the code I modified.
if resume:
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, checkpoint_path)
The text was updated successfully, but these errors were encountered: