-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make lars cpp code flexible #36450
Make lars cpp code flexible #36450
Changes from 10 commits
4e1fc95
2ffbbba
3da1b1f
fb89aef
eec1fc6
dc103de
7be6434
57b952a
c310c8d
3f62752
c6ef005
47ba53b
82dd12a
c5d06e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -28,8 +28,6 @@ limitations under the License. */ | |||||||||||||||||||
#define LARS_BLOCK_SIZE 512 | ||||||||||||||||||||
#endif | ||||||||||||||||||||
|
||||||||||||||||||||
#define LARS_MAX_MERGED_OPS 60 | ||||||||||||||||||||
|
||||||||||||||||||||
namespace paddle { | ||||||||||||||||||||
namespace operators { | ||||||||||||||||||||
|
||||||||||||||||||||
|
@@ -48,31 +46,20 @@ __device__ __forceinline__ double Fma(double x, double y, double z) { | |||||||||||||||||||
template <typename T> | ||||||||||||||||||||
class LarsThreadConfig { | ||||||||||||||||||||
public: | ||||||||||||||||||||
int grid_for_norm; | ||||||||||||||||||||
int grid_for_lars; | ||||||||||||||||||||
#if CUDA_VERSION >= 11000 | ||||||||||||||||||||
|
||||||||||||||||||||
private: | ||||||||||||||||||||
int grid_stride; | ||||||||||||||||||||
|
||||||||||||||||||||
public: | ||||||||||||||||||||
explicit LarsThreadConfig(int64_t numel, int sm_num, int num_blocks_per_sm) { | ||||||||||||||||||||
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; | ||||||||||||||||||||
grid_for_lars = | ||||||||||||||||||||
std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE); | ||||||||||||||||||||
grid_stride = LARS_BLOCK_SIZE * grid_for_lars; | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
int GetRepeatTimes(int64_t numel) { | ||||||||||||||||||||
return (numel + grid_stride - 1) / grid_stride - 1; | ||||||||||||||||||||
} | ||||||||||||||||||||
#else | ||||||||||||||||||||
int repeat_times; | ||||||||||||||||||||
int grid_for_norm; | ||||||||||||||||||||
explicit LarsThreadConfig(const int64_t numel) { | ||||||||||||||||||||
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; | ||||||||||||||||||||
grid_for_norm = std::min(grid, LARS_BLOCK_SIZE); | ||||||||||||||||||||
const int grid_stride = grid_for_norm * LARS_BLOCK_SIZE; | ||||||||||||||||||||
repeat_times = (numel + grid_stride - 1) / grid_stride - 1; | ||||||||||||||||||||
// Determine to read 4 fp16 or float data once, but 2 double data once. | ||||||||||||||||||||
grid_for_lars = | ||||||||||||||||||||
std::is_same<double, T>::value | ||||||||||||||||||||
|
@@ -154,10 +141,8 @@ template <typename T, typename MT> | |||||||||||||||||||
__global__ void L2NormKernel( | ||||||||||||||||||||
#endif | ||||||||||||||||||||
const T* p_data, const T* __restrict__ g_data, MT* __restrict__ p_buffer, | ||||||||||||||||||||
MT* __restrict__ g_buffer, const int64_t numel, const int repeat_times, | ||||||||||||||||||||
const MT rescale_grad, const int thresh = 0, MT* __restrict__ p_n = nullptr, | ||||||||||||||||||||
MT* __restrict__ g_n = nullptr) { | ||||||||||||||||||||
__shared__ MT s_buffer[2]; | ||||||||||||||||||||
MT* __restrict__ g_buffer, const int64_t numel, const MT rescale_grad, | ||||||||||||||||||||
MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { | ||||||||||||||||||||
int tid = threadIdx.x + blockDim.x * blockIdx.x; | ||||||||||||||||||||
int grid_stride = LARS_BLOCK_SIZE * gridDim.x; | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里感觉使用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据建议修改 |
||||||||||||||||||||
|
||||||||||||||||||||
|
@@ -178,6 +163,7 @@ __global__ void L2NormKernel( | |||||||||||||||||||
g_buffer[blockIdx.x] = g_tmp; | ||||||||||||||||||||
} | ||||||||||||||||||||
#if CUDA_VERSION >= 11000 | ||||||||||||||||||||
__shared__ MT s_buffer[2]; | ||||||||||||||||||||
cg->sync(); // Grid sync for writring partial result to gloabl memory | ||||||||||||||||||||
MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0; | ||||||||||||||||||||
MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0; | ||||||||||||||||||||
|
@@ -231,16 +217,34 @@ __forceinline__ __device__ void MomentumUpdate( | |||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
#if CUDA_VERSION >= 11000 | ||||||||||||||||||||
template <typename T, typename MT> | ||||||||||||||||||||
struct LarsParamWarpper { | ||||||||||||||||||||
int64_t numel_arr[LARS_MAX_MERGED_OPS]; | ||||||||||||||||||||
int repeat_arr[LARS_MAX_MERGED_OPS]; | ||||||||||||||||||||
const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS]; | ||||||||||||||||||||
const MT* __restrict__ lr_arr[LARS_MAX_MERGED_OPS]; | ||||||||||||||||||||
T* __restrict__ p_out_arr[LARS_MAX_MERGED_OPS]; | ||||||||||||||||||||
MT* __restrict__ v_out_arr[LARS_MAX_MERGED_OPS]; | ||||||||||||||||||||
MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS]; | ||||||||||||||||||||
MT weight_decay_arr[LARS_MAX_MERGED_OPS]; | ||||||||||||||||||||
template <typename MT, int kOpNum, typename T> | ||||||||||||||||||||
struct MergedLarsMasterParam { | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个结构能更通用一些吗?类名叫 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据建议修改 |
||||||||||||||||||||
DEVICE inline MT* GetMasterParam(size_t) const { return nullptr; } | ||||||||||||||||||||
constexpr void SetMasterParam(size_t, MT*) {} | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个函数不用加 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||
}; | ||||||||||||||||||||
|
||||||||||||||||||||
template <typename MT, int kOpNum> | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 模板中的变量名,不要叫 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 嗯,那就之间换成 |
||||||||||||||||||||
struct MergedLarsMasterParam<MT, kOpNum, paddle::platform::float16> { | ||||||||||||||||||||
MT* master_params[kOpNum]; | ||||||||||||||||||||
|
||||||||||||||||||||
DEVICE inline MT* GetMasterParam(size_t idx) const { | ||||||||||||||||||||
return master_params[idx]; | ||||||||||||||||||||
} | ||||||||||||||||||||
void SetMasterParam(size_t idx, MT* p) { master_params[idx] = p; } | ||||||||||||||||||||
}; | ||||||||||||||||||||
|
||||||||||||||||||||
template <typename T, typename MT, | ||||||||||||||||||||
int kOpNum = | ||||||||||||||||||||
std::is_same<T, paddle::platform::float16>::value ? 80 : 100> | ||||||||||||||||||||
struct LarsParamWarpper : public MergedLarsMasterParam<MT, kOpNum, T> { | ||||||||||||||||||||
static constexpr int kNum = kOpNum; | ||||||||||||||||||||
|
||||||||||||||||||||
int numel_arr[kOpNum]; | ||||||||||||||||||||
const MT* __restrict__ lr_arr[kOpNum]; | ||||||||||||||||||||
const T* __restrict__ g_arr[kOpNum]; | ||||||||||||||||||||
T* p_arr[kOpNum]; | ||||||||||||||||||||
MT* v_arr[kOpNum]; | ||||||||||||||||||||
MT weight_decay; | ||||||||||||||||||||
}; | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里直接利用数据类型 |
||||||||||||||||||||
|
||||||||||||||||||||
template <typename T, typename MT> | ||||||||||||||||||||
|
@@ -258,16 +262,16 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT> lars_warpper, | |||||||||||||||||||
int numel = lars_warpper.numel_arr[i]; | ||||||||||||||||||||
MT param_norm = static_cast<MT>(0); | ||||||||||||||||||||
MT grad_norm = static_cast<MT>(0); | ||||||||||||||||||||
L2NormKernel<T, MT>(&cg, lars_warpper.p_out_arr[i], lars_warpper.g_arr[i], | ||||||||||||||||||||
p_buffer, g_buffer, numel, lars_warpper.repeat_arr[i], | ||||||||||||||||||||
rescale_grad, 0, ¶m_norm, &grad_norm); | ||||||||||||||||||||
MomentumUpdate<T, MT>( | ||||||||||||||||||||
lars_warpper.p_out_arr[i], lars_warpper.g_arr[i], | ||||||||||||||||||||
lars_warpper.v_out_arr[i], lars_warpper.p_out_arr[i], | ||||||||||||||||||||
lars_warpper.v_out_arr[i], lars_warpper.master_p_out_arr[i], | ||||||||||||||||||||
lars_warpper.master_p_out_arr[i], lars_warpper.lr_arr[i], mu, | ||||||||||||||||||||
lars_warpper.weight_decay_arr[i], lars_coeff, epsilon, rescale_grad, | ||||||||||||||||||||
param_norm, grad_norm, tid, grid_stride, numel, is_amp); | ||||||||||||||||||||
L2NormKernel<T, MT>(&cg, lars_warpper.p_arr[i], lars_warpper.g_arr[i], | ||||||||||||||||||||
p_buffer, g_buffer, numel, rescale_grad, ¶m_norm, | ||||||||||||||||||||
&grad_norm); | ||||||||||||||||||||
MomentumUpdate<T, MT>(lars_warpper.p_arr[i], lars_warpper.g_arr[i], | ||||||||||||||||||||
lars_warpper.v_arr[i], lars_warpper.p_arr[i], | ||||||||||||||||||||
lars_warpper.v_arr[i], lars_warpper.GetMasterParam(i), | ||||||||||||||||||||
lars_warpper.GetMasterParam(i), | ||||||||||||||||||||
lars_warpper.lr_arr[i], mu, lars_warpper.weight_decay, | ||||||||||||||||||||
lars_coeff, epsilon, rescale_grad, param_norm, | ||||||||||||||||||||
grad_norm, tid, grid_stride, numel, is_amp); | ||||||||||||||||||||
} | ||||||||||||||||||||
} | ||||||||||||||||||||
#endif | ||||||||||||||||||||
|
@@ -279,24 +283,30 @@ __global__ void MomentumLarsKernel( | |||||||||||||||||||
MT* master_param_out, const MT* __restrict__ learning_rate, | ||||||||||||||||||||
MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const MT mu, | ||||||||||||||||||||
const MT lars_coeff, const MT lars_weight_decay, const MT epsilon, | ||||||||||||||||||||
const MT rescale_grad, const int repeat_times, const int thresh, | ||||||||||||||||||||
const int64_t numel, const bool is_amp) { | ||||||||||||||||||||
const MT rescale_grad, const int thresh, const int64_t numel, | ||||||||||||||||||||
const bool is_amp) { | ||||||||||||||||||||
int tid = threadIdx.x + blockIdx.x * blockDim.x; | ||||||||||||||||||||
int grid_stride = gridDim.x * LARS_BLOCK_SIZE; | ||||||||||||||||||||
#if CUDA_VERSION >= 11000 | ||||||||||||||||||||
const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); | ||||||||||||||||||||
MT param_norm = static_cast<MT>(0); | ||||||||||||||||||||
MT grad_norm = static_cast<MT>(0); | ||||||||||||||||||||
L2NormKernel<T, MT>(&cg, param, grad, p_buffer, g_buffer, numel, repeat_times, | ||||||||||||||||||||
rescale_grad, gridDim.x, ¶m_norm, &grad_norm); | ||||||||||||||||||||
L2NormKernel<T, MT>(&cg, param, grad, p_buffer, g_buffer, numel, rescale_grad, | ||||||||||||||||||||
¶m_norm, &grad_norm); | ||||||||||||||||||||
#else | ||||||||||||||||||||
const MT rescale_grad_pow = rescale_grad * rescale_grad; | ||||||||||||||||||||
MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; | ||||||||||||||||||||
MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; | ||||||||||||||||||||
__shared__ MT s_buffer[2]; | ||||||||||||||||||||
MT p_part_sum = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; | ||||||||||||||||||||
MT g_part_sum = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; | ||||||||||||||||||||
MT tmp0 = math::blockReduceSum<MT>(p_part_sum, FINAL_MASK); | ||||||||||||||||||||
MT tmp1 = math::blockReduceSum<MT>(g_part_sum, FINAL_MASK); | ||||||||||||||||||||
if (threadIdx.x == 0) { | ||||||||||||||||||||
s_buffer[0] = tmp0; | ||||||||||||||||||||
s_buffer[1] = tmp1; | ||||||||||||||||||||
} | ||||||||||||||||||||
__syncthreads(); | ||||||||||||||||||||
MT param_norm = Sqrt(math::blockReduceSum<MT>(param_part_norm, FINAL_MASK)); | ||||||||||||||||||||
MT grad_norm = Sqrt(rescale_grad_pow * | ||||||||||||||||||||
math::blockReduceSum<MT>(grad_part_norm, FINAL_MASK)); | ||||||||||||||||||||
MT param_norm = Sqrt(s_buffer[0]); | ||||||||||||||||||||
MT grad_norm = rescale_grad * Sqrt(s_buffer[1]); | ||||||||||||||||||||
|
||||||||||||||||||||
#endif | ||||||||||||||||||||
MomentumUpdate<T, MT>(param, grad, velocity, param_out, velocity_out, | ||||||||||||||||||||
master_param, master_param_out, learning_rate, mu, | ||||||||||||||||||||
|
@@ -314,18 +324,16 @@ inline void SeparatedLarsMomentumOpCUDAKernel( | |||||||||||||||||||
MT* master_out_data, const bool is_amp) { | ||||||||||||||||||||
LarsThreadConfig<T> lars_thread_config(numel); | ||||||||||||||||||||
L2NormKernel<T, MT><<<lars_thread_config.grid_for_norm, LARS_BLOCK_SIZE, 0, | ||||||||||||||||||||
cuda_ctx.stream()>>>( | ||||||||||||||||||||
param_data, grad_data, p_buffer, g_buffer, numel, | ||||||||||||||||||||
lars_thread_config.repeat_times, rescale_grad); | ||||||||||||||||||||
cuda_ctx.stream()>>>(param_data, grad_data, p_buffer, | ||||||||||||||||||||
g_buffer, numel, rescale_grad); | ||||||||||||||||||||
|
||||||||||||||||||||
MomentumLarsKernel<T, MT><<<lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, | ||||||||||||||||||||
0, cuda_ctx.stream()>>>( | ||||||||||||||||||||
param_data, grad_data, velocity_data, param_out_data, velocity_out_data, | ||||||||||||||||||||
master_param_data, master_out_data, lr, p_buffer, g_buffer, mu, | ||||||||||||||||||||
lars_coeff, weight_decay, epsilon, rescale_grad, 0, | ||||||||||||||||||||
lars_coeff, weight_decay, epsilon, rescale_grad, | ||||||||||||||||||||
lars_thread_config.grid_for_norm, numel, is_amp); | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
template <typename DeviceContext, typename T> | ||||||||||||||||||||
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { | ||||||||||||||||||||
using MT = MultiPrecisionType<T>; | ||||||||||||||||||||
|
@@ -346,8 +354,9 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { | |||||||||||||||||||
MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff")); | ||||||||||||||||||||
MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon")); | ||||||||||||||||||||
MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad")); | ||||||||||||||||||||
|
||||||||||||||||||||
auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay"); | ||||||||||||||||||||
MT lars_weight_decay = weight_decay_arr[0]; | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 考虑到目前在 Paddle/python/paddle/fluid/optimizer.py Lines 2087 to 2095 in 4e036fa
此处的处理能够避免 merged_lars 训练时,其中的每个op 都执行从global memory 中取数据的问题.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 除了ResNet50这个场景外,不会出现 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是否走入 |
||||||||||||||||||||
|
||||||||||||||||||||
auto grad = ctx.MultiInput<framework::LoDTensor>("Grad"); | ||||||||||||||||||||
auto param = ctx.MultiInput<framework::LoDTensor>("Param"); | ||||||||||||||||||||
auto velocity = ctx.MultiInput<framework::LoDTensor>("Velocity"); | ||||||||||||||||||||
|
@@ -362,13 +371,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { | |||||||||||||||||||
#if CUDA_VERSION >= 11000 | ||||||||||||||||||||
if (op_num > 1) { | ||||||||||||||||||||
LarsParamWarpper<T, MT> lars_warpper; | ||||||||||||||||||||
PADDLE_ENFORCE_LT( | ||||||||||||||||||||
op_num, LARS_MAX_MERGED_OPS, | ||||||||||||||||||||
platform::errors::InvalidArgument( | ||||||||||||||||||||
"The maximum number of merged-ops supported is (%d), but" | ||||||||||||||||||||
"lars op required for trainning this model is (%d)\n", | ||||||||||||||||||||
LARS_MAX_MERGED_OPS, op_num)); | ||||||||||||||||||||
|
||||||||||||||||||||
VLOG(10) << "Num of ops merged in lars_warpper is " << lars_warpper.kNum; | ||||||||||||||||||||
/* Implementation of lars optimizer consists of following two steps: | ||||||||||||||||||||
1. Figure out the L2 norm statistic result of grad data and param data. | ||||||||||||||||||||
2. Update param and velocity with usage of L2 norm statistic result. | ||||||||||||||||||||
|
@@ -380,59 +383,65 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { | |||||||||||||||||||
&num_blocks_per_sm, MergedMomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE, | ||||||||||||||||||||
sizeof(MT) << 1); | ||||||||||||||||||||
|
||||||||||||||||||||
size_t total_numel = 0; | ||||||||||||||||||||
for (int i = 0; i < op_num; ++i) { | ||||||||||||||||||||
size_t temp_numel = param[i]->numel(); | ||||||||||||||||||||
total_numel += temp_numel; | ||||||||||||||||||||
lars_warpper.numel_arr[i] = temp_numel; | ||||||||||||||||||||
lars_warpper.g_arr[i] = grad[i]->data<T>(); | ||||||||||||||||||||
lars_warpper.lr_arr[i] = learning_rate[i]->data<MT>(); | ||||||||||||||||||||
lars_warpper.p_out_arr[i] = | ||||||||||||||||||||
param_out[i]->mutable_data<T>(ctx.GetPlace()); | ||||||||||||||||||||
lars_warpper.v_out_arr[i] = | ||||||||||||||||||||
velocity_out[i]->mutable_data<MT>(ctx.GetPlace()); | ||||||||||||||||||||
lars_warpper.weight_decay_arr[i] = static_cast<MT>(weight_decay_arr[i]); | ||||||||||||||||||||
PADDLE_ENFORCE_EQ( | ||||||||||||||||||||
param[i]->data<T>(), lars_warpper.p_out_arr[i], | ||||||||||||||||||||
platform::errors::InvalidArgument( | ||||||||||||||||||||
"Input(Param) and Output(ParamOut) must be the same Tensors.")); | ||||||||||||||||||||
PADDLE_ENFORCE_EQ(velocity[i]->data<MT>(), lars_warpper.v_out_arr[i], | ||||||||||||||||||||
platform::errors::InvalidArgument( | ||||||||||||||||||||
"Input(Velocity) and Output(VelocityOut) must be " | ||||||||||||||||||||
"the same Tensors.")); | ||||||||||||||||||||
} | ||||||||||||||||||||
int64_t avg_numel = total_numel / op_num; | ||||||||||||||||||||
LarsThreadConfig<float> lars_thread_config(avg_numel, sm_num, | ||||||||||||||||||||
num_blocks_per_sm); | ||||||||||||||||||||
for (int i = 0; i < op_num; ++i) { | ||||||||||||||||||||
lars_warpper.repeat_arr[i] = | ||||||||||||||||||||
lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]); | ||||||||||||||||||||
} | ||||||||||||||||||||
if (multi_precision) { | ||||||||||||||||||||
for (int i = 0; i < op_num; ++i) { | ||||||||||||||||||||
lars_warpper.master_p_out_arr[i] = | ||||||||||||||||||||
master_param_out[i]->mutable_data<MT>(ctx.GetPlace()); | ||||||||||||||||||||
PADDLE_ENFORCE_EQ(master_param[i]->data<MT>(), | ||||||||||||||||||||
lars_warpper.master_p_out_arr[i], | ||||||||||||||||||||
platform::errors::InvalidArgument( | ||||||||||||||||||||
"Input(MasterParam) and Output(MasterParamOut) " | ||||||||||||||||||||
"must be the same Tensors.")); | ||||||||||||||||||||
lars_warpper.weight_decay = lars_weight_decay; | ||||||||||||||||||||
int merge_times = (op_num + lars_warpper.kNum - 1) / lars_warpper.kNum; | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果一个模型有160个参数,这个模型依然只会有一个optimzier op,只是这个optimizer op会启动2个CUDA Kernel计算,每个CUDA Kernel更新80个参数?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||
for (int j = 0; j < merge_times; ++j) { | ||||||||||||||||||||
size_t total_numel = 0; | ||||||||||||||||||||
int start_idx = j * lars_warpper.kNum; | ||||||||||||||||||||
int loop_num = std::min(lars_warpper.kNum, op_num - start_idx); | ||||||||||||||||||||
|
||||||||||||||||||||
for (int i = 0; i < loop_num; ++i) { | ||||||||||||||||||||
size_t temp_numel = param[start_idx + i]->numel(); | ||||||||||||||||||||
total_numel += temp_numel; | ||||||||||||||||||||
lars_warpper.numel_arr[i] = temp_numel; | ||||||||||||||||||||
lars_warpper.g_arr[i] = grad[start_idx + i]->data<T>(); | ||||||||||||||||||||
lars_warpper.p_arr[i] = | ||||||||||||||||||||
param_out[start_idx + i]->mutable_data<T>(ctx.GetPlace()); | ||||||||||||||||||||
lars_warpper.v_arr[i] = | ||||||||||||||||||||
velocity_out[start_idx + i]->mutable_data<MT>(ctx.GetPlace()); | ||||||||||||||||||||
lars_warpper.lr_arr[i] = learning_rate[start_idx + i]->data<MT>(); | ||||||||||||||||||||
if (multi_precision) { | ||||||||||||||||||||
auto master_param_data = | ||||||||||||||||||||
master_param_out[start_idx + i]->mutable_data<MT>( | ||||||||||||||||||||
ctx.GetPlace()); | ||||||||||||||||||||
lars_warpper.SetMasterParam(i, master_param_data); | ||||||||||||||||||||
PADDLE_ENFORCE_EQ( | ||||||||||||||||||||
master_param[start_idx + i]->data<MT>(), master_param_data, | ||||||||||||||||||||
platform::errors::InvalidArgument( | ||||||||||||||||||||
"Since Input(MasterParam) and Output(MasterParamOut) of " | ||||||||||||||||||||
"lars optimizer must be the same Tensors.")); | ||||||||||||||||||||
} | ||||||||||||||||||||
PADDLE_ENFORCE_EQ( | ||||||||||||||||||||
param[start_idx + i]->data<T>(), lars_warpper.p_arr[i], | ||||||||||||||||||||
platform::errors::InvalidArgument( | ||||||||||||||||||||
"Since Input(Param) and Output(ParamOut) of lars optimizer " | ||||||||||||||||||||
"must be the same Tensors.")); | ||||||||||||||||||||
PADDLE_ENFORCE_EQ( | ||||||||||||||||||||
velocity[start_idx + i]->data<MT>(), lars_warpper.v_arr[i], | ||||||||||||||||||||
platform::errors::InvalidArgument( | ||||||||||||||||||||
"Since Input(Velocity) and Output(VelocityOut) of " | ||||||||||||||||||||
"lars optimizer must be " | ||||||||||||||||||||
"the same Tensors.")); | ||||||||||||||||||||
} | ||||||||||||||||||||
VLOG(10) << "Op number delt in loop " << j << " is : " << loop_num; | ||||||||||||||||||||
int64_t avg_numel = total_numel / loop_num; | ||||||||||||||||||||
LarsThreadConfig<float> lars_thread_config(avg_numel, sm_num, | ||||||||||||||||||||
num_blocks_per_sm); | ||||||||||||||||||||
void* cuda_param[] = {reinterpret_cast<void*>(&lars_warpper), | ||||||||||||||||||||
reinterpret_cast<void*>(&p_buffer), | ||||||||||||||||||||
reinterpret_cast<void*>(&g_buffer), | ||||||||||||||||||||
reinterpret_cast<void*>(&loop_num), | ||||||||||||||||||||
reinterpret_cast<void*>(&mu), | ||||||||||||||||||||
reinterpret_cast<void*>(&lars_coeff), | ||||||||||||||||||||
reinterpret_cast<void*>(&epsilon), | ||||||||||||||||||||
reinterpret_cast<void*>(&rescale_grad), | ||||||||||||||||||||
reinterpret_cast<void*>(&multi_precision)}; | ||||||||||||||||||||
// Lanuch all sm theads,thead of each block synchronizedly cooperate. | ||||||||||||||||||||
cudaLaunchCooperativeKernel( | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个接口调用,确实后续可以再封装一下,可以实现在 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,这种写法真的太占地方了 |
||||||||||||||||||||
reinterpret_cast<void*>(MergedMomentumLarsKernel<T, MT>), | ||||||||||||||||||||
lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0, | ||||||||||||||||||||
cuda_ctx.stream()); | ||||||||||||||||||||
} | ||||||||||||||||||||
void* cuda_param[] = {reinterpret_cast<void*>(&lars_warpper), | ||||||||||||||||||||
reinterpret_cast<void*>(&p_buffer), | ||||||||||||||||||||
reinterpret_cast<void*>(&g_buffer), | ||||||||||||||||||||
reinterpret_cast<void*>(&op_num), | ||||||||||||||||||||
reinterpret_cast<void*>(&mu), | ||||||||||||||||||||
reinterpret_cast<void*>(&lars_coeff), | ||||||||||||||||||||
reinterpret_cast<void*>(&epsilon), | ||||||||||||||||||||
reinterpret_cast<void*>(&rescale_grad), | ||||||||||||||||||||
reinterpret_cast<void*>(&multi_precision)}; | ||||||||||||||||||||
// Lanuch all sm theads, and thead of each block synchronizedly cooperate. | ||||||||||||||||||||
cudaLaunchCooperativeKernel( | ||||||||||||||||||||
reinterpret_cast<void*>(MergedMomentumLarsKernel<T, MT>), | ||||||||||||||||||||
lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0, | ||||||||||||||||||||
cuda_ctx.stream()); | ||||||||||||||||||||
} else { | ||||||||||||||||||||
auto* param_data = param[0]->data<T>(); | ||||||||||||||||||||
auto* grad_data = grad[0]->data<T>(); | ||||||||||||||||||||
|
@@ -448,15 +457,13 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { | |||||||||||||||||||
? master_param_out[0]->mutable_data<MT>(ctx.GetPlace()) | ||||||||||||||||||||
: nullptr; | ||||||||||||||||||||
int64_t numel = param[0]->numel(); | ||||||||||||||||||||
MT lars_weight_decay = weight_decay_arr[0]; | ||||||||||||||||||||
|
||||||||||||||||||||
// Figure out how many blocks can be active in each sm. | ||||||||||||||||||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor( | ||||||||||||||||||||
&num_blocks_per_sm, MomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE, | ||||||||||||||||||||
sizeof(MT) << 1); | ||||||||||||||||||||
LarsThreadConfig<float> lars_thread_config(numel, sm_num, | ||||||||||||||||||||
num_blocks_per_sm); | ||||||||||||||||||||
int repeat_times = lars_thread_config.GetRepeatTimes(numel); | ||||||||||||||||||||
int thresh = 0; | ||||||||||||||||||||
void* cuda_param[] = { | ||||||||||||||||||||
reinterpret_cast<void*>(¶m_data), | ||||||||||||||||||||
|
@@ -474,7 +481,6 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { | |||||||||||||||||||
reinterpret_cast<void*>(&lars_weight_decay), | ||||||||||||||||||||
reinterpret_cast<void*>(&epsilon), | ||||||||||||||||||||
reinterpret_cast<void*>(&rescale_grad), | ||||||||||||||||||||
reinterpret_cast<void*>(&repeat_times), | ||||||||||||||||||||
reinterpret_cast<void*>(&thresh), // Just a placeholder | ||||||||||||||||||||
reinterpret_cast<void*>(&numel), | ||||||||||||||||||||
reinterpret_cast<void*>(&multi_precision)}; | ||||||||||||||||||||
|
@@ -498,7 +504,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { | |||||||||||||||||||
velocity[i]->data<MT>(), | ||||||||||||||||||||
velocity_out[i]->mutable_data<MT>(ctx.GetPlace()), grad[i]->data<T>(), | ||||||||||||||||||||
learning_rate[i]->data<MT>(), p_buffer, g_buffer, mu, lars_coeff, | ||||||||||||||||||||
weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(), | ||||||||||||||||||||
lars_weight_decay, epsilon, rescale_grad, param[i]->numel(), | ||||||||||||||||||||
master_param_data, master_param_out_data, multi_precision); | ||||||||||||||||||||
} | ||||||||||||||||||||
#endif | ||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
p_buffer
、g_buffer
命名更直观一些,p_norm_for_blocks
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
准备改成
buffer_for_grad_norm
和buffer_for_param_norm