-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merge momentum ops/kernels #36380
Merge momentum ops/kernels #36380
Conversation
Thanks for your contribution! |
template <typename T, typename MT, bool kHasMasterParams, | ||
uint32_t kParamNum = kHasMasterParams ? 55 : 110> | ||
struct MergedMomentumKernelParam | ||
: public MergedMomentumMasterParams<MT, kParamNum, kHasMasterParams> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方继承的目的是利用C++编译器的一个特性:继承空类不会影响该类的sizeof
的大小。因此当不需要MasterParams时,sizeof(MergedMomentumKernelParam)
会更小,可以容纳更多的kParamNum
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
学到了~
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam"); | ||
auto master_params_out = | ||
ctx.MultiOutput<framework::Tensor>("MasterParamOut"); | ||
auto multi_precision = ctx.Attr<bool>("multi_precision"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的判断是不是表明了会在python端区分性的收集好AMP 和 非AMP optimizer,再分别传入执行optimizer计算
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的。把AMP和非AMP分开有以下几个原因的考虑:
OperatorWithKernel::GetExpectedKernelType
这个方法更好写。不然就得遍历所有Param
的类型,如果有混合类型(FP16、FP32)得返回调用FP16的kernel,没有混合类型就调用单一类型的kernel。写起来会非常麻烦。MergedMomentumKernelParam
这个类的sizeof
会更小一些,不用装载bool multi_precision[N]
,里面的实现也不用到处if-else判断是不是混合类型。- 如果有混合类型,
MergedMomentumKernelParam
这个类的params
和grads
只能写成void *params[N]
和void *grads[N]
了,代码可读性比较差。
static constexpr auto N = kParamNum; | ||
size_t sizes[N]; | ||
T *params[N]; | ||
const T *grads[N]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const T *
这类只读数据加入 __restrict__
关键字会有一点性能提升,但是提升幅度可能不会明显,https://developer.nvidia.com/blog/cuda-pro-tip-optimize-pointer-aliasing/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
template <typename T, typename MT, bool kHasMasterParams, | ||
uint32_t kParamNum = kHasMasterParams ? 55 : 110> | ||
struct MergedMomentumKernelParam | ||
: public MergedMomentumMasterParams<MT, kParamNum, kHasMasterParams> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
学到了~
bd25e1b
to
ca2ecae
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM COOL!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
OPs
Describe
Merge multiple momentum ops/kernels to be one momentum ops/kernels.