Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add lamb optimizer and unittest #28772

Merged
merged 10 commits into from
Nov 24, 2020
Merged

add lamb optimizer and unittest #28772

merged 10 commits into from
Nov 24, 2020

Conversation

bjjwwang
Copy link
Contributor

@bjjwwang bjjwwang commented Nov 20, 2020

PR types

Function optimization

PR changes

APIs

Describe

迁移lamb优化器到paddle 2.0,增加lamb单测

原来API paddle.fluid.optimizer.lamb

paddle.fluid.optimizer.Lamb(learning_rate=0.001, lamb_weight_decay=0.01,beta1=0.9,beta2=0.999,epsilon=1e-6,parameter_list=None,regularization=None,grad_clip=None,exclude_from_weight_decay_fn=None,name=None):

迁移到了

paddle.optimizer.Lamb(learning_rate=0.001, lamb_weight_decay=0.01,  beta1=0.9, beta2=0.999,  epsilon=1e-6, parameters=None, grad_clip=None, exclude_from_weight_decay_fn=None, name=None):

迁移原因:属于2.0 优化器迁移项目,实现了新的Optimizer基类和LRScheduler,因此需要迁移API。
具体变化:包括像parameter_list->parameters等参数命名的变化。
使用方式:使用方式没有变化,但由于lamb自带lamb_weight_decay 参数,因此基类的weight_decay(Regularization)被禁止使用,用户只需指定lamb_weight_decay即可。

是否造成代码容余? 我个人认为paddle.optimizer.Lamb最终肯定要取代fluid API,但目前应该是Python端有两处实现,C++公用一套。

image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@bjjwwang bjjwwang changed the title [WIP] add lamb optimizer and unittest add lamb optimizer and unittest Nov 22, 2020
chalsliu
chalsliu previously approved these changes Nov 23, 2020
def exclude_fn(param):
return param.name.endswith('.b_0')

optimizer = fluid.optimizer.Lamb(learning_rate=0.002,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还在用fluid api,另外,请在pr描述中说明该api与fluid api的异同点,和本次升级的主要变化等

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还在用fluid api,另外,请在pr描述中说明该api与fluid api的异同点,和本次升级的主要变化等

好的 谢谢 已做修改

import paddle

paddle.enable_static()
data = fluid.data(name='x', shape=[-1, 5], dtype='float32')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用2.0的示例代码描述方式,paddle.static.data paddle.nn.functional.fc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


optimizer = paddle.optimizer.Lamb(learning_rate=0.002,
exclude_from_weight_decay_fn=exclude_fn)
optimizer.minimize(cost)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好提供动态图的示例代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

做好修改了 谢谢

back = out.backward()
lamb.step()
lamb.clear_grad()
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR描述里还是没有说清楚本次对比fluid api的具体升级点在哪里,比如参数是否变化,使用方式是否变化,有没有注意事项等等,如果只是目录改变,为什么是新增代码而不是使用alias,是否会造成代码冗余

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好 谢谢 已作修改

some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
exclude_from_weight_decay_fn (function|None): Exclude a parameter from weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数的用法有点trick,其他Optimizer应该也有这个需求吧?
是不是通过ParamAttr的regularizer=False的属性设置来控制更好一些?类似于ParamAttr的need_clip的用法。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯 在

import paddle
import numpy as np
inp = paddle.uniform(min=-0.1, max=0.1, shape=[10, 10], dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10,
                                      weight_attr=paddle.ParamAttr(need_clip=True),
                                      bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(inp)
loss = paddle.mean(out)
beta1 = paddle.to_tensor([0.9], dtype="float32")
beta2 = paddle.to_tensor([0.85], dtype="float32")
lamb = paddle.optimizer.Lamb(learning_rate=0.002, parameters=linear.parameters(), lamb_weight_decay=0.01)
back = out.backward()
lamb.step()
lamb.clear_grad()

这个程序下验证了,修改好了。
需要注明的是,这个示例程序没有在doc中提供,只是我自己的测试。

XiaoguangHu01
XiaoguangHu01 previously approved these changes Nov 24, 2020
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Args:
learning_rate (float|Variable, optional): the learning rate used to update parameters. \
Can be a float value or a Variable with data type float32. Default 0.001.
lamb_weight_decay (float, optional): The LAMB weight decay rate. Default 0.01. Remind that weight_decay should be None.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可以将lamb_weight_decay参数名中的lamb_去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我觉得先留着,这个和optimizer基类的weight decay不一样,因此保留了原来的这个API命名。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

保留没问题,但是要明确有什么不一样呢,我理解这里的就是和AdamW里的weight_decay是一样作用的


if param_and_grad[0].need_clip:
weight_decay = 0.0
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need_clip是gradient_clip的标志,与weight_decay要区别对待

Copy link
Contributor

@guoshengCS guoshengCS Nov 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gradient_clipweight_decay是两个不同的事情,这里不能这样做的

:ref:`api_guide_Name` . Usually name is no need to set and None by default.
Examples:
.. code-block:: python
import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里(.. code 与import paddle中间)需要加一个空行 否则预览会出bug

Examples:
.. code-block:: python
import paddle
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行可以删除

learning rate, :math:`\\lambda` the LAMB weight decay rate.

Args:
learning_rate (float|Variable, optional): the learning rate used to update parameters. \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable -> Tensor,其他地方同理

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
TODO:
修改英文文档 补充中文文档

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants