Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make lars cpp code flexible #36450

Conversation

JamesLim-sy
Copy link
Contributor

@JamesLim-sy JamesLim-sy commented Oct 14, 2021

PR types

Function optimization

PR changes

OPs

Describe

特征:

  • 结合PR36428 中针对 L2 norm计算的bug调试,对CUDA 低版本的代码也做了同步修改;
  • 参考PR36380 中采用ParamMerge 策略,对LarsParam 做出了同步修改
  • 考虑到master_parammaster_param_out, velocity_paramvelocity_param_out, paramparam_out, 这三组tensor都是互相 inplace类型,因此仅仅择取其中的master_param_outvelocity_param_outparam_out,作为计算tensor
  • C++代码内部对合并后的LarsMomentum Op进行自动的打包&分组计算
  • 由于#36409 已经明确了对于weight_decay == 0 的区分,将Merged_larsMomentum 计算情况下的全部LarsMomentum Op ,调整至共享相同的一个weight_decay
  • 添加Merged_larsMomentum的单元测试

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@JamesLim-sy JamesLim-sy changed the title first commit Make lars cpp code more flexible Oct 14, 2021
T* __restrict__ p_arr[kOpNum];
MT* __restrict__ v_arr[kOpNum];
MT weight_decay_arr[kOpNum];
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接利用数据类型T 判断生成的 LarsParamWarpper 类型,也就是默认了使用fp16类型就必须使用master_param,这种修改不适用于 不依赖master_param 的纯fp16计算

MT grad_norm = Sqrt(rescale_grad_pow *
math::blockReduceSum<MT>(grad_part_norm, FINAL_MASK));
MT param_norm = Sqrt(s_buffer[0]);
MT grad_norm = Sqrt(rescale_pow * s_buffer[1]);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于低版本的CUDA,修正L2 Norm 计算的结果获取过程

@JamesLim-sy JamesLim-sy force-pushed the Make_lars_cpp_code_more_flexible branch from d32884c to 3da1b1f Compare October 14, 2021 15:06
@JamesLim-sy JamesLim-sy force-pushed the Make_lars_cpp_code_more_flexible branch from 9ba1d75 to dc103de Compare October 16, 2021 08:43
lars_warpper.g_arr[i] = grad[start_idx + i]->data<T>();
lars_warpper.p_arr[i] = param_out[start_idx + i]->data<T>();
lars_warpper.v_arr[i] = velocity_out[start_idx + i]->data<MT>();
lars_warpper.lr_arr[i] = learning_rate[i]->data<MT>();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能够共享learning_rate的话,这里可以优化掉,减少从后面CUDA Kernel 中,访问global memory 的次数

@JamesLim-sy JamesLim-sy force-pushed the Make_lars_cpp_code_more_flexible branch from 7bb3c4b to 7be6434 Compare October 17, 2021 14:30
@JamesLim-sy JamesLim-sy changed the title Make lars cpp code more flexible Make lars cpp code flexible Oct 17, 2021
auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
MT lars_weight_decay = weight_decay_arr[0];
Copy link
Contributor Author

@JamesLim-sy JamesLim-sy Oct 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑到目前在optimizer.py 文件中已经明确了从lars_momentum 中区别出 weight_decay == 0 的特例,因此,调整Merged LarsMomentum Optimizer 计算分支共享相同的weight_decay 值。

# create the momentum optimize op
momentum_op = block.append_op(
type=self.type if _lars_weight_decay != 0.0 else 'momentum',
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True)
return momentum_op

此处的处理能够避免merged_lars 训练时,其中的每个op 都执行从global memory中取数据的问题.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了ResNet50这个场景外,不会出现weight_decay非0且不一样的场景吗?

Copy link
Contributor Author

@JamesLim-sy JamesLim-sy Oct 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否走入lars 计算,需要看Op是否在'self._exclude_from_weight_decay'名单中,resnet50 模型里传入的是exclude_from_weight_decay=['bn', 'batch_norm', '.b_0']

@@ -1961,6 +1961,7 @@ def __init__(self,
exclude_from_weight_decay=None,
epsilon=0,
multi_precision=False,
merge_option=False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

遗漏删除,下一个commit删掉

MT* __restrict__ g_n = nullptr) {
__shared__ MT s_buffer[2];
MT* __restrict__ g_buffer, const int64_t numel, const MT rescale_grad,
MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉使用BlockDim.x,比使用LARS_BLOCK_SIZE安全一些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

const MT rescale_grad, const int thresh = 0, MT* __restrict__ p_n = nullptr,
MT* __restrict__ g_n = nullptr) {
__shared__ MT s_buffer[2];
MT* __restrict__ g_buffer, const int64_t numel, const MT rescale_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p_bufferg_buffer命名更直观一些,p_norm_for_blocks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

准备改成buffer_for_grad_normbuffer_for_param_norm

template <typename MT, int kOpNum, typename T>
struct MergedLarsMasterParam {
DEVICE inline MT* GetMasterParam(size_t) const { return nullptr; }
constexpr void SetMasterParam(size_t, MT*) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数不用加DEVICE描述?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetMasterParam 在host端完成,所以就没加 DEVICE描述了

MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS];
MT weight_decay_arr[LARS_MAX_MERGED_OPS];
template <typename MT, int kOpNum, typename T>
struct MergedLarsMasterParam {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个结构能更通用一些吗?类名叫MasterParamHelper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

constexpr void SetMasterParam(size_t, MT*) {}
};

template <typename MT, int kOpNum>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

模板中的变量名,不要叫kXxx吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,那就之间换成OpNum

auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
MT lars_weight_decay = weight_decay_arr[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了ResNet50这个场景外,不会出现weight_decay非0且不一样的场景吗?

"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."));
lars_warpper.weight_decay = lars_weight_decay;
int merge_times = (op_num + lars_warpper.kNum - 1) / lars_warpper.kNum;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果一个模型有160个参数,这个模型依然只会有一个optimzier op,只是这个optimizer op会启动2个CUDA Kernel计算,每个CUDA Kernel更新80个参数?

merge_times这个变量名也。。。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 还是需要在python层面区分出来 AMP LarsMomentum 和非AMP LarsMomentum,然后分先后将AMP lars 非AMP Lars 传入计算。如果单次传入的Op数量过多的话,会按照至多80个一组执行计算。
  • 变量名准备改成loop

reinterpret_cast<void*>(&rescale_grad),
reinterpret_cast<void*>(&multi_precision)};
// Lanuch all sm theads,thead of each block synchronizedly cooperate.
cudaLaunchCooperativeKernel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个接口调用,确实后续可以再封装一下,可以实现在gpu_launch_config.h中,不过这个文件最好命名成gpu_launch_helper.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这种写法真的太占地方了

@@ -0,0 +1,210 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件名改成test_merged_optimizer.py?后续若有其他optimizer的merge,也可以基于这个单测来写?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的,我这个就是抄的锦乐大佬的单测框架写出来的

@paddle-bot-old
Copy link

Sorry to inform you that c5d06e0's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@paddle-bot
Copy link

paddle-bot bot commented Nov 1, 2022

很抱歉,经过我们的反复讨论,你的PR暂未达到合入标准,请阅读飞桨原生算子开发规范,你可以重新提交新的PR,我们先将此PR关闭,感谢你的贡献。
Sorry to inform you that through our discussion, your PR fails to meet the merging standard (Reference: Paddle Custom Operator Design Doc). You can also submit an new one. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants