Skip to content

Commit

Permalink
fix the bug of resume and adaptive set function for lr scheduler (Pad…
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Dec 9, 2022
1 parent 19c36c9 commit 46334fd
Show file tree
Hide file tree
Showing 12 changed files with 79 additions and 72 deletions.
6 changes: 3 additions & 3 deletions plsc/engine/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def defualt_train_one_epoch(engine, epoch_id):
# clear gradients
engine.optimizer.clear_grad()

if engine.lr_scheduler is not None and engine.lr_decay_unit == 'step':
engine.lr_scheduler.step()
if engine.lr_decay_unit == 'step':
engine.optimizer.lr_step(engine.global_step)

# below code just for logging
# update metric_for_logger
Expand Down Expand Up @@ -101,7 +101,7 @@ def defualt_train_one_epoch(engine, epoch_id):
io.save_checkpoint(
engine.model,
engine.optimizer,
engine.lr_scheduler,
engine.scaler,
engine.best_metric,
engine.output_dir,
model_name=engine.config["Model"]["name"],
Expand Down
14 changes: 6 additions & 8 deletions plsc/engine/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def update_loss(trainer, loss_dict, batch_size):

def log_info(trainer, batch_size, epoch_id, iter_id):
lr_msg = "lr: none"
if trainer.lr_scheduler is not None:
lr_msg = "lr: {:.6f}".format(trainer.lr_scheduler.get_lr())
lr_msg = "lr: {:.6f}".format(trainer.optimizer.get_lr())

metric_msg = ", ".join([
"{}: {:.5f}".format(key, trainer.output_info[key].avg)
Expand All @@ -65,12 +64,11 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
len(trainer.train_dataloader), lr_msg, metric_msg, time_msg, ips_msg,
eta_msg))

if trainer.lr_scheduler is not None:
logger.scaler(
name="lr",
value=trainer.lr_scheduler.get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
logger.scaler(
name="lr",
value=trainer.optimizer.get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
for key in trainer.output_info:
logger.scaler(
name="train_{}".format(key),
Expand Down
12 changes: 6 additions & 6 deletions plsc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def train(self):
if self.config["Global"]["checkpoint"] is not None:
metric_info = io.load_checkpoint(
self.config["Global"]["checkpoint"], self.model,
self.optimizer, self.lr_scheduler)
self.optimizer, self.scaler)
if metric_info is not None:
self.best_metric.update(metric_info)
if "global_step" in metric_info:
Expand All @@ -319,8 +319,8 @@ def train(self):
# for one epoch train
self.train_epoch_func(self, epoch_id)

if self.lr_scheduler is not None and self.lr_decay_unit == 'epoch':
self.lr_scheduler.step()
if self.lr_decay_unit == 'epoch':
self.optimizer.lr_step(epoch_id)

if self.use_dali:
self.train_dataloader.reset()
Expand Down Expand Up @@ -348,7 +348,7 @@ def train(self):
io.save_checkpoint(
self.model,
self.optimizer,
self.lr_scheduler,
self.scaler,
self.best_metric,
self.output_dir,
model_name=self.config["Model"]["name"],
Expand All @@ -370,7 +370,7 @@ def train(self):
io.save_checkpoint(
self.model,
self.optimizer,
self.lr_scheduler,
self.scaler,
eval_metric_info,
self.output_dir,
model_name=self.config["Model"]["name"],
Expand All @@ -381,7 +381,7 @@ def train(self):
io.save_checkpoint(
self.model,
self.optimizer,
self.lr_scheduler,
self.scaler,
eval_metric_info,
self.output_dir,
model_name=self.config["Model"]["name"],
Expand Down
6 changes: 3 additions & 3 deletions plsc/engine/recognition/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def defualt_train_one_epoch(engine, epoch_id):
# clear gradients
engine.optimizer.clear_grad()

if engine.lr_scheduler is not None and engine.lr_decay_unit == 'step':
engine.lr_scheduler.step()
if engine.lr_decay_unit == 'step':
engine.optimizer.lr_step(engine.global_step)

# below code just for logging
# update metric_for_logger
Expand All @@ -98,7 +98,7 @@ def defualt_train_one_epoch(engine, epoch_id):
io.save_checkpoint(
engine.model,
engine.optimizer,
engine.lr_scheduler,
engine.scaler,
engine.best_metric,
engine.output_dir,
model_name=engine.config["Model"]["name"],
Expand Down
14 changes: 6 additions & 8 deletions plsc/engine/recognition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def update_loss(trainer, loss_dict, batch_size):

def log_info(trainer, batch_size, epoch_id, iter_id):
lr_msg = "lr: none"
if trainer.lr_scheduler is not None:
lr_msg = "lr: {:.6f}".format(trainer.lr_scheduler.get_lr())
lr_msg = "lr: {:.6f}".format(trainer.optimizer.get_lr())

metric_msg = ", ".join([
"{}: {:.5f}".format(key, trainer.output_info[key].avg)
Expand All @@ -65,12 +64,11 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
len(trainer.train_dataloader), lr_msg, metric_msg, time_msg, ips_msg,
eta_msg))

if trainer.lr_scheduler is not None:
logger.scaler(
name="lr",
value=trainer.lr_scheduler.get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
logger.scaler(
name="lr",
value=trainer.optimizer.get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
for key in trainer.output_info:
logger.scaler(
name="train_{}".format(key),
Expand Down
28 changes: 6 additions & 22 deletions plsc/models/iresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import paddle
import paddle.nn as nn

from plsc.nn import init

from .layers import PartialFC
from .layers import Model

Expand All @@ -30,24 +32,6 @@
__all__ = ["IResNet18", "IResNet34", "IResNet50", "IResNet100", "IResNet200"]


@paddle.no_grad()
def constant_(x, val):
temp_value = paddle.full(x.shape, val, x.dtype)
if temp_value.dtype != x.dtype:
temp_value = temp_value.astype(x.dtype)
x.copy_(temp_value, False)
return x


@paddle.no_grad()
def normal_(x, mean=0., std=1.):
temp_value = paddle.normal(mean, std, shape=x.shape)
if temp_value.dtype != x.dtype:
temp_value = temp_value.astype(x.dtype)
x.copy_(temp_value, False)
return x


def conv3x3(in_planes,
out_planes,
stride=1,
Expand Down Expand Up @@ -202,14 +186,14 @@ def __init__(self,

for m in self.sublayers():
if isinstance(m, paddle.nn.Conv2D):
normal_(m.weight, 0, 0.1)
init.normal_(m.weight, 0, 0.1)
elif isinstance(m, (paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)):
constant_(m.weight, 1)
constant_(m.bias, 0)
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.sublayers():
if isinstance(m, IBasicBlock):
constant_(m.bn2.weight, 0)
init.constant_(m.bn2.weight, 0)

pfc_config.update({
'num_classes': class_num,
Expand Down
5 changes: 3 additions & 2 deletions plsc/nn/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def constant_(x, value):

@paddle.no_grad()
def normal_(x, mean=0., std=1.):
temp_value = paddle.normal(mean, std, shape=x.shape)
x.set_value(temp_value)
temp_value = paddle.tensor.random.gaussian(
shape=x.shape, mean=mean, std=std, dtype=x.dtype)
x.copy_(temp_value, False)
return x


Expand Down
9 changes: 7 additions & 2 deletions plsc/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Adafactor(Optimizer):
def __init__(self,
params,
lr=None,
lr_func=None,
eps=1e-30,
eps_scale=1e-3,
clip_threshold=1.0,
Expand All @@ -32,7 +33,9 @@ def __init__(self,
use_master_param=True,
no_weight_decay_name=[],
one_dim_param_no_weight_decay=False,
grad_clip=None):
grad_clip=None,
**args):

relative_step = not lr
if warmup_init and not relative_step:
raise ValueError('warmup_init requires relative_step=True')
Expand All @@ -41,6 +44,7 @@ def __init__(self,
0] # make it compat with standard betas arg
defaults = dict(
lr=lr,
lr_func=lr_func,
eps=eps,
eps_scale=eps_scale,
clip_threshold=clip_threshold,
Expand All @@ -53,7 +57,8 @@ def __init__(self,
use_master_param=use_master_param,
no_weight_decay_name=no_weight_decay_name,
one_dim_param_no_weight_decay=one_dim_param_no_weight_decay,
grad_clip=grad_clip)
grad_clip=grad_clip,
**args)
super(Adafactor, self).__init__(params, defaults)

@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion plsc/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class AdamW(Optimizer):
def __init__(self,
params,
lr=0.001,
lr_func=None,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.0,
Expand All @@ -37,12 +38,14 @@ def __init__(self,

defaults = dict(
lr=lr,
lr_func=lr_func,
betas=betas,
eps=eps,
weight_decay=weight_decay,
use_master_param=use_master_param,
exp_avg_force_fp32=exp_avg_force_fp32,
grad_clip=grad_clip, )
grad_clip=grad_clip,
**args)
super(AdamW, self).__init__(params, defaults)

@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion plsc/optimizer/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Momentum(Optimizer):
def __init__(self,
params,
lr=0.001,
lr_func=None,
momentum=0.9,
weight_decay=0.0,
use_master_param=True,
Expand All @@ -35,10 +36,12 @@ def __init__(self,

defaults = dict(
lr=lr,
lr_func=lr_func,
momentum=momentum,
weight_decay=weight_decay,
use_master_param=use_master_param,
grad_clip=grad_clip, )
grad_clip=grad_clip,
**args)
super(Momentum, self).__init__(params, defaults)

@staticmethod
Expand Down
21 changes: 20 additions & 1 deletion plsc/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def add_param_group(self, param_group):
"parameter group didn't specify a value of required optimization parameter "
+ name)
else:
param_group.setdefault(name, default)
if name == 'lr':
param_group.setdefault(name, deepcopy(default))
else:
param_group.setdefault(name, default)

params = param_group['params']
if len(params) != len(set(params)):
Expand Down Expand Up @@ -201,6 +204,22 @@ def clear_grad(self, set_to_zero=True):
if p.grad is not None:
p.clear_gradient(set_to_zero)

@paddle.no_grad()
def lr_step(self, step=None):
for group in self.param_groups:
lr = group['lr']
if isinstance(lr, paddle.optimizer.lr.LRScheduler):
lr.step()
elif 'lr_func' in group and callable(group['lr_func']):
group['lr_func'](group, step)

@paddle.no_grad()
def get_lr(self, group_id=0):
lr = self.param_groups[group_id]['lr']
if isinstance(lr, paddle.optimizer.lr.LRScheduler):
lr = lr.get_lr()
return lr

@paddle.no_grad()
def step(self):
raise NotImplementedError
26 changes: 11 additions & 15 deletions plsc/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _remove_if_exist(path):
pass


def load_checkpoint(checkpoint_path, net, optimizer, lr_scheduler):
def load_checkpoint(checkpoint_path, net, optimizer, loss_scaler):
"""
load model from checkpoint
"""
Expand All @@ -71,6 +71,8 @@ def load_checkpoint(checkpoint_path, net, optimizer, lr_scheduler):
"Optimizer checkpoint path {} does not exists.".format(opt_path)
opt_dict = paddle.load(opt_path)

scaler_dict = opt_dict.pop('scaler_state', {})

dist_opt_path = checkpoint_path + "_rank{}.pdopt".format(rank)
if os.path.exists(dist_opt_path):
dist_opt_dict = paddle.load(dist_opt_path)
Expand All @@ -83,13 +85,9 @@ def load_checkpoint(checkpoint_path, net, optimizer, lr_scheduler):

optimizer.set_state_dict(opt_dict)

# load lr scheduler
if lr_scheduler is not None:
lr_path = checkpoint_path + '.pdlr'
assert os.path.exists(lr_path), \
"Learning rate scheduler checkpoint path {} does not exists.".format(lr_path)
lr_dict = paddle.load(lr_path)
lr_scheduler.set_state_dict(lr_dict)
# load loss scaler
if len(scaler_dict) > 0 and loss_scaler is not None:
loss_scaler.load_state_dict(scaler_dict)

# load metric state
metric_path = checkpoint_path + '.pdstates'
Expand All @@ -116,7 +114,7 @@ def _optimizer_state_dict_split(state_dict):

def save_checkpoint(net,
optimizer,
lr_scheduler,
loss_scaler,
metric_info,
model_path,
model_name="",
Expand Down Expand Up @@ -145,16 +143,14 @@ def save_checkpoint(net,
opt_state_dict)

if local_rank == 0:
if loss_scaler is not None:
opt_state_dict['scaler_state'] = loss_scaler.state_dict()
paddle.save(opt_state_dict, model_prefix + ".pdopt")
paddle.save(metric_info, model_prefix + ".pdstates")
if len(dist_opt_state_dict['state']) > 0:
paddle.save(dist_opt_state_dict,
model_prefix + "_rank{}.pdopt".format(rank))

if local_rank == 0:
if lr_scheduler is not None:
paddle.save(lr_scheduler.state_dict(), model_prefix + ".pdlr")
paddle.save(metric_info, model_prefix + ".pdstates")

logger.info("Already save {} model in {}".format(prefix, model_dir))

keep_prefixs = ['best', 'latest']
Expand All @@ -181,7 +177,7 @@ def save_checkpoint(net,
to_remove = timestamps
for timestamp in to_remove:
model_prefix = timestamp_to_path[timestamp]
for ext in ['.pdparams', '.pdopt', '.pdlr', '.pdstates']:
for ext in ['.pdparams', '.pdopt', '.pdstates']:
path = model_prefix + ext
_remove_if_exist(path)

Expand Down

0 comments on commit 46334fd

Please sign in to comment.