Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th PPSCI No.12】Adam、AdamW 优化器支持 amsgrad #68079

Merged
merged 52 commits into from
Dec 4, 2024

Conversation

megemini
Copy link
Contributor

@megemini megemini commented Sep 8, 2024

PR Category

User Experience

PR Types

New features

Description

【Hackathon 7th No.12】Adam、AdamW 优化器支持 amsgrad

关联:

本地对比 pytorch 的结果,两者一致:

比对代码
import numpy as np

import torch
import paddle


def func(t, x):
    if t % 101 == 1:
        return 1010 * x
    else:
        return -10 * x


np.random.seed(2024)
data = np.array(0).astype("float64")
epoch = 500
lr = 0.1

for amsgrad in [True, False]:
    for opt_name, opt_torch, opt_paddle in [
        ["Adam", torch.optim.Adam, paddle.optimizer.Adam],
        ["AdamW", torch.optim.AdamW, paddle.optimizer.AdamW],
    ]:
        for torch_device, paddle_device in [["cpu", "cpu"], ["cuda", "gpu"]]:
            print(f"------ optimizer is : {opt_name} ; compare : {paddle_device}------")
            print(f"------ pytorch ------")
            x = torch.tensor(data, device=torch.device(torch_device))
            x.requires_grad = True

            optimizer = opt_torch([x], lr=lr, amsgrad=amsgrad)
            for i in range(epoch):
                y = func(i, x)
                optimizer.zero_grad()
                y.backward()
                optimizer.step()

            if torch_device == "cuda":
                x_torch = x.cpu().detach().numpy()
                y_torch = y.cpu().detach().numpy()
            else:
                x_torch = x.detach().numpy()
                y_torch = y.detach().numpy()

            print(f"------ paddle ------")
            paddle.set_device(paddle_device)
            x = paddle.to_tensor(data)
            x.stop_gradient = False

            optimizer = opt_paddle(parameters=[x], learning_rate=lr, amsgrad=amsgrad)
            for i in range(epoch):
                y = func(i, x)
                optimizer.clear_grad()
                y.backward()
                optimizer.step()

            x_paddle = x.numpy()
            y_paddle = y.numpy()

            np.testing.assert_allclose(x_torch, x_paddle, atol=1e-06, rtol=1e-06)
            print(x_torch, x_paddle)
            print(y_torch, y_paddle)
            print(f"------- compare finish ---------")

输出结果:

------ optimizer is : Adam ; compare : cpu------
------ pytorch ------
------ paddle ------
0.382819332566745 0.3828193325667452
-3.7319234136114865 -3.7319234136114887
------- compare finish ---------
------ optimizer is : Adam ; compare : gpu------
------ pytorch ------
------ paddle ------
0.3828193325667449 0.38281933256674533
-3.7319234136114856 -3.73192341361149
------- compare finish ---------
------ optimizer is : AdamW ; compare : cpu------
------ pytorch ------
------ paddle ------
0.38940724227589385 0.389407242265435
-3.801604114817793 -3.8016041146280424
------- compare finish ---------
------ optimizer is : AdamW ; compare : gpu------
------ pytorch ------
------ paddle ------
0.38940724227589385 0.3894072422654346
-3.801604114817793 -3.801604114628038
------- compare finish ---------
------ optimizer is : Adam ; compare : cpu------
------ pytorch ------
------ paddle ------
0.47233193956960806 0.47233193956960845
-4.62253146676283 -4.622531466762833
------- compare finish ---------
------ optimizer is : Adam ; compare : gpu------
------ pytorch ------
------ paddle ------
0.472331939569608 0.4723319395696082
-4.62253146676283 -4.6225314667628306
------- compare finish ---------
------ optimizer is : AdamW ; compare : cpu------
------ pytorch ------
------ paddle ------
0.462192080569021 0.46219208087997216
-4.525658535292251 -4.525658538303618
------- compare finish ---------
------ optimizer is : AdamW ; compare : gpu------
------ pytorch ------
------ paddle ------
0.46219208056902106 0.46219208087997266
-4.525658535292251 -4.525658538303623
------- compare finish ---------

Update 20240908

  • 已在本地完成:

    • test_adam_op.py
    • test_adamw_op.py
    • test_merged_adam_op.py
    • test_fused_adam_op.py

    相关测试。

  • 需要在 CI 环境中验证分布式的测试项目

  • 需要在 CI 环境中验证其他测试项目

另外,xpu 的 amsgrad 变体,由于 xpu 底层接口暂不支持,因此,此处只修改了相关的输入输出参数列表。

Copy link

paddle-bot bot commented Sep 8, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Sep 8, 2024
@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Sep 9, 2024
@PaddlePaddle PaddlePaddle unlocked this conversation Sep 9, 2024
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加的ams_grad是否会影响原有的代码执行逻辑和存储空间占用情况?PR的代码起来无论是否开启ams_grad,都会比原先没有amsgrad的代码多申请一段mom2_max的空间,以及有一些多余的变量产生。


inline HOSTDEVICE void operator()(size_t i) const {
// Merge all memory access together.
T g = grad_[i];
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T mom2_max = moment2_max_[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是必须要记录的吗?

Comment on lines 236 to 248
T mom2_max_;
if (amsgrad_) {
mom2_max_ = std::max(mom2, mom2_max);
p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow)));
} else {
mom2_max_ = mom2_max;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow)));
}

// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
moment2_max_out_[i] = mom2_max_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理,如果amsgrad没有开启,建议不要添加任何多余的变量和相关计算逻辑,保持原样即可

Comment on lines 326 to 327
Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> moment2_max_out{
moment2_max_out_, static_cast<Eigen::Index>(numel)};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,如果没有开启amsgrad,是否会有mom2_max相关的冗余运算和存储占用?


inline HOSTDEVICE void adam_update(size_t i, T g) const {
// The following code is the same as dense
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T mom2_max = moment2_max_[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -14,6 +14,7 @@

#pragma once

#include <stdio.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件是什么有代码依赖吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调试之后忘记删掉了,抱歉 ~

@@ -117,6 +117,7 @@ class Adam(Optimizer):
The default value is False.
multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false.
amsgrad (bool, optional): Whether to use the AMSGrad of this algorithm. Default is false.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -104,6 +104,7 @@ class AdamW(Optimizer):
different semantics with the original Adam algorithm and may lead to different result.
The default value is False.
multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
amsgrad (bool, optional): Whether to use the AMSGrad of this algorithm. Default is false.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@megemini
Copy link
Contributor Author

megemini commented Sep 9, 2024

添加的ams_grad是否会影响原有的代码执行逻辑和存储空间占用情况?PR的代码起来无论是否开启ams_grad,都会比原先没有amsgrad的代码多申请一段mom2_max的空间,以及有一些多余的变量产生。

这个之前考虑过,主要是因为,目前涉及到 amsgrad 的地方太多了,所以优化相关的事情想先往后放一下 ~

那我现在改一下试试吧 ~

@HydrogenSulfate
Copy link
Contributor

HydrogenSulfate commented Sep 9, 2024

添加的ams_grad是否会影响原有的代码执行逻辑和存储空间占用情况?PR的代码起来无论是否开启ams_grad,都会比原先没有amsgrad的代码多申请一段mom2_max的空间,以及有一些多余的变量产生。

这个之前考虑过,主要是因为,目前涉及到 amsgrad 的地方太多了,所以优化相关的事情想先往后放一下 ~

那我现在改一下试试吧 ~

  1. 这一点影响是比较大的。因为一般情况下优化器是逐元素跟踪参数状态,所以优化器每一个统计量需要记录的数量都等于模型参数数量,adam(w)这种带动量的优化器则更会多。因此模型训练过程中显存占比前三就是中间状态、优化器参数、模型参数,如果没有优化,很可能原先在16G上能训的下的CV、NLP模型就会OOM了,更不用说B级别参数量的大模型

  2. 代码本身的计算逻辑应该没太大问题,目前没有优化的版本可以用于快速验证正确性,但最终版本一定要考虑到这种基本但必要的优化

@HydrogenSulfate
Copy link
Contributor

另外可以在修改完成后,用ResNet50或者其他模型,以fake data为输入做一个对比,确认下amsgrad关闭时,显存无变化,开启时显存增加量与参数量基本相同。

@megemini megemini dismissed stale reviews from phlrain, heavengate, and zyfncg via af27337 November 8, 2024 05:39
@HydrogenSulfate
Copy link
Contributor

@megemini hello大佬,我们内部测试了最新的这个PR,应该是没问题了,还麻烦解决一下冲突

@megemini
Copy link
Contributor Author

@megemini hello大佬,我们内部测试了最新的这个PR,应该是没问题了,还麻烦解决一下冲突

非常感谢!!!听说过程非常曲折 😂😂😂 感谢 ~~~

冲突已经解决 ~ PR-CI-NPU-910B-Paddle 这个 CI 的错误,看上去是 npu 那边没有正确的处理 paddle::optional ,传了个空指针过去?~

@phlrain phlrain self-requested a review December 2, 2024 07:34
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate
Copy link
Contributor

@megemini 能麻烦再合一下develop分支吗,windows和Hygon-DCU-Test这两个挂了,应该不是PR的原因

@luotao1
Copy link
Contributor

luotao1 commented Dec 3, 2024

windows和Hygon-DCU-Test这两个挂了,应该不是PR的原因

@HydrogenSulfate DCU已豁免,是单侧随机挂。windows我重跑了,根据其他开发者反馈,最近windows流水线网络不好。

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit d774b83 into PaddlePaddle:develop Dec 4, 2024
28 checks passed
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 changed the title 【Hackathon 7th PPSCI No.12】Adam、AdamW 优化器支持 amsgrad -part 【Hackathon 7th PPSCI No.12】Adam、AdamW 优化器支持 amsgrad Dec 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants