Skip to content
This repository has been archived by the owner on Apr 17, 2023. It is now read-only.

Commit

Permalink
move ema model to hook (#73)
Browse files Browse the repository at this point in the history
* move ema model to hook

* delete cls_runner
  • Loading branch information
kprokofi authored Oct 27, 2022
1 parent 738239e commit 12861a9
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 17 deletions.
22 changes: 11 additions & 11 deletions mpa/modules/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ def __init__(self,
def after_train_epoch(self, runner):
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
if hasattr(runner, 'save_ckpt'):
if runner.save_ckpt:
if runner.save_ema_model:
backup_model = runner.model
runner.model = runner.ema_model
runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
if runner.save_ema_model:
runner.model = backup_model
if hasattr(runner, 'save_ckpt') and runner.save_ckpt:
if hasattr(runner, 'save_ema_model') and runner.save_ema_model:
backup_model = runner.model
runner.model = runner.ema_model
runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
if hasattr(runner, 'save_ema_model') and runner.save_ema_model:
runner.model = backup_model
runner.save_ema_model = False
runner.save_ckpt = False

@master_only
Expand Down
4 changes: 0 additions & 4 deletions recipes/stages/_base_/runners/cls_runner.py

This file was deleted.

2 changes: 1 addition & 1 deletion recipes/stages/classification/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ _base_: [
'../_base_/default.py',
'../_base_/logs/tensorboard_logger.py',
'../_base_/optimizers/sgd.py',
'../_base_/runners/cls_runner.py',
'../_base_/runners/epoch_runner_cancel.py',
'../_base_/schedules/cos_anneal.py',
]

Expand Down
2 changes: 1 addition & 1 deletion recipes/stages/classification/train_multilabel.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
_base_: [
'../_base_/logs/tensorboard_logger.py',
'../_base_/optimizers/sgd.py',
'../_base_/runners/cls_runner.py',
'../_base_/runners/epoch_runner_cancel.py',
'../_base_/schedules/1cycle.py',
]

Expand Down

0 comments on commit 12861a9

Please sign in to comment.