Skip to content

Commit

Permalink
fix lars_weight_decay from a scalar into a vector.
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed Oct 20, 2021
1 parent 82dd12a commit c5d06e0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 22 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
ctx.MultiOutput<framework::LoDTensor>("MasterParamOut");

int op_num = grad.size();
PADDLE_ENFORCE_EQ(weight_decay_arr.size(), op_num,
platform::errors::InvalidArgument(
"Since Input(lars_weight_decay) and Iutput(grad) of "
"lars optimizer must be the same size."));
#if CUDA_VERSION >= 11000
if (op_num > 1) {
LarsParamWarpper<T, MT> lars_warpper;
Expand Down
49 changes: 27 additions & 22 deletions python/paddle/fluid/tests/unittests/test_merged_lars_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@
from collections import OrderedDict


def run_momentum_op(params,
grads,
velocitys,
master_params,
learning_rates,
place,
multi_precision,
weight_decay,
mu=0.9,
rescale_grad=0.01,
use_merged=False):
def run_lars_momentum_op(params,
grads,
velocitys,
master_params,
learning_rates,
place,
multi_precision,
weight_decays,
mu=0.9,
rescale_grad=0.01,
use_merged=False):
assert len(params) == len(grads)
assert len(params) == len(velocitys)
if multi_precision:
Expand All @@ -44,9 +44,7 @@ def run_momentum_op(params,
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
'lars_weight_decay': weight_decay.tolist()
}

param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype) for p in params
Expand Down Expand Up @@ -111,19 +109,25 @@ def run_momentum_op(params,
'Velocity': v,
'LearningRate': lr,
}
attrs['lars_weight_decay'] = [float(weight_decays[i])]
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
inputs['MasterParam'] = master_param_vars[i]
outputs['MasterParamOut'] = master_param_vars[i]
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
else:
lars_weight_decay = []
for decay in weight_decays:
lars_weight_decay.append(float(decay))

inputs = {
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_vars,
}
attrs['lars_weight_decay'] = lars_weight_decay
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
inputs['MasterParam'] = master_param_vars
Expand All @@ -146,9 +150,11 @@ def setUp(self):
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]

def gen_rand_lr(self, shapes, dtype):
lr = np.random.random(1).astype(dtype)
return [np.ones(1).astype(dtype) * lr for s in range(len(shapes))]
def gen_lr_and_decay(self, shapes, dtype):
data = np.random.random(1).astype(dtype)
lr_rates = [np.ones(1).astype(dtype) * data for s in range(len(shapes))]
weight_decays = data * np.ones(len(shapes), dtype=np.float32)
return lr_rates, weight_decays

def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
Expand All @@ -158,30 +164,29 @@ def prepare_data(self, shapes, multi_precision, seed, place):
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
velocitys = self.gen_rand_data(shapes, mp_dtype)
weight_decay = self.gen_rand_data([[1]], mp_dtype)[0]
learning_rates = self.gen_rand_lr(shapes, mp_dtype)
learning_rates, weight_decays = self.gen_lr_and_decay(shapes, mp_dtype)
if multi_precision:
master_params = [p.astype(mp_dtype) for p in params]
else:
master_params = None
return params, grads, velocitys, master_params, learning_rates, weight_decay
return params, grads, velocitys, master_params, learning_rates, weight_decays

def check_with_place(self, place, multi_precision):
params, grads, velocitys, master_params, learning_rates, weight_decay = self.prepare_data(
params, grads, velocitys, master_params, learning_rates, weight_decays = self.prepare_data(
self.shapes, multi_precision, self.seed, place)

def run_op(merge_option):
# CPU Momentum Op does not support rescale_grad
rescale_grad = 1.0 if isinstance(place, paddle.CPUPlace) else 0.01
return run_momentum_op(
return run_lars_momentum_op(
params,
grads,
velocitys,
master_params,
learning_rates,
place,
multi_precision,
weight_decay,
weight_decays,
rescale_grad=rescale_grad,
use_merged=merge_option)

Expand Down

0 comments on commit c5d06e0

Please sign in to comment.