Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make lars cpp code flexible #36450

Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ class LarsMomentumOp : public framework::OperatorWithKernel {
auto grad_dim = ctx->GetInputsDim("Grad");
auto param_dim = ctx->GetInputsDim("Param");
auto velocity_dim = ctx->GetInputsDim("Velocity");
auto lars_weight_decays =
ctx->Attrs().Get<std::vector<float>>("lars_weight_decay");
auto multi_precision = ctx->Attrs().Get<bool>("multi_precision");

PADDLE_ENFORCE_EQ(
Expand All @@ -61,13 +59,6 @@ class LarsMomentumOp : public framework::OperatorWithKernel {
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(), velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decays.size(), grad_dim.size(),
platform::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decays.size(), grad_dim.size()));

if (multi_precision) {
OP_INOUT_CHECK(ctx->HasInputs("MasterParam"), "Input", "MasterParam",
Expand Down
238 changes: 122 additions & 116 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ limitations under the License. */
#define LARS_BLOCK_SIZE 512
#endif

#define LARS_MAX_MERGED_OPS 60

namespace paddle {
namespace operators {

Expand All @@ -48,31 +46,20 @@ __device__ __forceinline__ double Fma(double x, double y, double z) {
template <typename T>
class LarsThreadConfig {
public:
int grid_for_norm;
int grid_for_lars;
#if CUDA_VERSION >= 11000

private:
int grid_stride;

public:
explicit LarsThreadConfig(int64_t numel, int sm_num, int num_blocks_per_sm) {
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE;
grid_for_lars =
std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE);
grid_stride = LARS_BLOCK_SIZE * grid_for_lars;
}

int GetRepeatTimes(int64_t numel) {
return (numel + grid_stride - 1) / grid_stride - 1;
}
#else
int repeat_times;
int grid_for_norm;
explicit LarsThreadConfig(const int64_t numel) {
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE;
grid_for_norm = std::min(grid, LARS_BLOCK_SIZE);
const int grid_stride = grid_for_norm * LARS_BLOCK_SIZE;
repeat_times = (numel + grid_stride - 1) / grid_stride - 1;
// Determine to read 4 fp16 or float data once, but 2 double data once.
grid_for_lars =
std::is_same<double, T>::value
Expand Down Expand Up @@ -154,10 +141,8 @@ template <typename T, typename MT>
__global__ void L2NormKernel(
#endif
const T* p_data, const T* __restrict__ g_data, MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer, const int64_t numel, const int repeat_times,
const MT rescale_grad, const int thresh = 0, MT* __restrict__ p_n = nullptr,
MT* __restrict__ g_n = nullptr) {
__shared__ MT s_buffer[2];
MT* __restrict__ g_buffer, const int64_t numel, const MT rescale_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p_bufferg_buffer命名更直观一些,p_norm_for_blocks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

准备改成buffer_for_grad_normbuffer_for_param_norm

MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉使用BlockDim.x,比使用LARS_BLOCK_SIZE安全一些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改


Expand All @@ -178,6 +163,7 @@ __global__ void L2NormKernel(
g_buffer[blockIdx.x] = g_tmp;
}
#if CUDA_VERSION >= 11000
__shared__ MT s_buffer[2];
cg->sync(); // Grid sync for writring partial result to gloabl memory
MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0;
MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0;
Expand Down Expand Up @@ -231,16 +217,34 @@ __forceinline__ __device__ void MomentumUpdate(
}

#if CUDA_VERSION >= 11000
template <typename T, typename MT>
struct LarsParamWarpper {
int64_t numel_arr[LARS_MAX_MERGED_OPS];
int repeat_arr[LARS_MAX_MERGED_OPS];
const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ lr_arr[LARS_MAX_MERGED_OPS];
T* __restrict__ p_out_arr[LARS_MAX_MERGED_OPS];
MT* __restrict__ v_out_arr[LARS_MAX_MERGED_OPS];
MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS];
MT weight_decay_arr[LARS_MAX_MERGED_OPS];
template <typename MT, int kOpNum, typename T>
struct MergedLarsMasterParam {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个结构能更通用一些吗?类名叫MasterParamHelper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

DEVICE inline MT* GetMasterParam(size_t) const { return nullptr; }
constexpr void SetMasterParam(size_t, MT*) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数不用加DEVICE描述?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetMasterParam 在host端完成,所以就没加 DEVICE描述了

};

template <typename MT, int kOpNum>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

模板中的变量名,不要叫kXxx吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,那就之间换成OpNum

struct MergedLarsMasterParam<MT, kOpNum, paddle::platform::float16> {
MT* master_params[kOpNum];

DEVICE inline MT* GetMasterParam(size_t idx) const {
return master_params[idx];
}
void SetMasterParam(size_t idx, MT* p) { master_params[idx] = p; }
};

template <typename T, typename MT,
int kOpNum =
std::is_same<T, paddle::platform::float16>::value ? 80 : 100>
struct LarsParamWarpper : public MergedLarsMasterParam<MT, kOpNum, T> {
static constexpr int kNum = kOpNum;

int numel_arr[kOpNum];
const MT* __restrict__ lr_arr[kOpNum];
const T* __restrict__ g_arr[kOpNum];
T* p_arr[kOpNum];
MT* v_arr[kOpNum];
MT weight_decay;
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接利用数据类型T 判断生成的 LarsParamWarpper 类型,也就是默认了使用fp16类型就必须使用master_param,这种修改不适用于 不依赖master_param 的纯fp16计算


template <typename T, typename MT>
Expand All @@ -258,16 +262,16 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT> lars_warpper,
int numel = lars_warpper.numel_arr[i];
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(&cg, lars_warpper.p_out_arr[i], lars_warpper.g_arr[i],
p_buffer, g_buffer, numel, lars_warpper.repeat_arr[i],
rescale_grad, 0, &param_norm, &grad_norm);
MomentumUpdate<T, MT>(
lars_warpper.p_out_arr[i], lars_warpper.g_arr[i],
lars_warpper.v_out_arr[i], lars_warpper.p_out_arr[i],
lars_warpper.v_out_arr[i], lars_warpper.master_p_out_arr[i],
lars_warpper.master_p_out_arr[i], lars_warpper.lr_arr[i], mu,
lars_warpper.weight_decay_arr[i], lars_coeff, epsilon, rescale_grad,
param_norm, grad_norm, tid, grid_stride, numel, is_amp);
L2NormKernel<T, MT>(&cg, lars_warpper.p_arr[i], lars_warpper.g_arr[i],
p_buffer, g_buffer, numel, rescale_grad, &param_norm,
&grad_norm);
MomentumUpdate<T, MT>(lars_warpper.p_arr[i], lars_warpper.g_arr[i],
lars_warpper.v_arr[i], lars_warpper.p_arr[i],
lars_warpper.v_arr[i], lars_warpper.GetMasterParam(i),
lars_warpper.GetMasterParam(i),
lars_warpper.lr_arr[i], mu, lars_warpper.weight_decay,
lars_coeff, epsilon, rescale_grad, param_norm,
grad_norm, tid, grid_stride, numel, is_amp);
}
}
#endif
Expand All @@ -279,24 +283,30 @@ __global__ void MomentumLarsKernel(
MT* master_param_out, const MT* __restrict__ learning_rate,
MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const MT mu,
const MT lars_coeff, const MT lars_weight_decay, const MT epsilon,
const MT rescale_grad, const int repeat_times, const int thresh,
const int64_t numel, const bool is_amp) {
const MT rescale_grad, const int thresh, const int64_t numel,
const bool is_amp) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int grid_stride = gridDim.x * LARS_BLOCK_SIZE;
#if CUDA_VERSION >= 11000
const cooperative_groups::grid_group cg = cooperative_groups::this_grid();
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(&cg, param, grad, p_buffer, g_buffer, numel, repeat_times,
rescale_grad, gridDim.x, &param_norm, &grad_norm);
L2NormKernel<T, MT>(&cg, param, grad, p_buffer, g_buffer, numel, rescale_grad,
&param_norm, &grad_norm);
#else
const MT rescale_grad_pow = rescale_grad * rescale_grad;
MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
__shared__ MT s_buffer[2];
MT p_part_sum = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
MT g_part_sum = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
MT tmp0 = math::blockReduceSum<MT>(p_part_sum, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_part_sum, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] = tmp0;
s_buffer[1] = tmp1;
}
__syncthreads();
MT param_norm = Sqrt(math::blockReduceSum<MT>(param_part_norm, FINAL_MASK));
MT grad_norm = Sqrt(rescale_grad_pow *
math::blockReduceSum<MT>(grad_part_norm, FINAL_MASK));
MT param_norm = Sqrt(s_buffer[0]);
MT grad_norm = rescale_grad * Sqrt(s_buffer[1]);

#endif
MomentumUpdate<T, MT>(param, grad, velocity, param_out, velocity_out,
master_param, master_param_out, learning_rate, mu,
Expand All @@ -314,18 +324,16 @@ inline void SeparatedLarsMomentumOpCUDAKernel(
MT* master_out_data, const bool is_amp) {
LarsThreadConfig<T> lars_thread_config(numel);
L2NormKernel<T, MT><<<lars_thread_config.grid_for_norm, LARS_BLOCK_SIZE, 0,
cuda_ctx.stream()>>>(
param_data, grad_data, p_buffer, g_buffer, numel,
lars_thread_config.repeat_times, rescale_grad);
cuda_ctx.stream()>>>(param_data, grad_data, p_buffer,
g_buffer, numel, rescale_grad);

MomentumLarsKernel<T, MT><<<lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE,
0, cuda_ctx.stream()>>>(
param_data, grad_data, velocity_data, param_out_data, velocity_out_data,
master_param_data, master_out_data, lr, p_buffer, g_buffer, mu,
lars_coeff, weight_decay, epsilon, rescale_grad, 0,
lars_coeff, weight_decay, epsilon, rescale_grad,
lars_thread_config.grid_for_norm, numel, is_amp);
}

template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
using MT = MultiPrecisionType<T>;
Expand All @@ -346,8 +354,9 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff"));
MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad"));

auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
MT lars_weight_decay = weight_decay_arr[0];
Copy link
Contributor Author

@JamesLim-sy JamesLim-sy Oct 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑到目前在optimizer.py 文件中已经明确了从lars_momentum 中区别出 weight_decay == 0 的特例,因此,调整Merged LarsMomentum Optimizer 计算分支共享相同的weight_decay 值。

# create the momentum optimize op
momentum_op = block.append_op(
type=self.type if _lars_weight_decay != 0.0 else 'momentum',
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True)
return momentum_op

此处的处理能够避免merged_lars 训练时,其中的每个op 都执行从global memory中取数据的问题.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了ResNet50这个场景外,不会出现weight_decay非0且不一样的场景吗?

Copy link
Contributor Author

@JamesLim-sy JamesLim-sy Oct 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否走入lars 计算,需要看Op是否在'self._exclude_from_weight_decay'名单中,resnet50 模型里传入的是exclude_from_weight_decay=['bn', 'batch_norm', '.b_0']


auto grad = ctx.MultiInput<framework::LoDTensor>("Grad");
auto param = ctx.MultiInput<framework::LoDTensor>("Param");
auto velocity = ctx.MultiInput<framework::LoDTensor>("Velocity");
Expand All @@ -362,13 +371,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
#if CUDA_VERSION >= 11000
if (op_num > 1) {
LarsParamWarpper<T, MT> lars_warpper;
PADDLE_ENFORCE_LT(
op_num, LARS_MAX_MERGED_OPS,
platform::errors::InvalidArgument(
"The maximum number of merged-ops supported is (%d), but"
"lars op required for trainning this model is (%d)\n",
LARS_MAX_MERGED_OPS, op_num));

VLOG(10) << "Num of ops merged in lars_warpper is " << lars_warpper.kNum;
/* Implementation of lars optimizer consists of following two steps:
1. Figure out the L2 norm statistic result of grad data and param data.
2. Update param and velocity with usage of L2 norm statistic result.
Expand All @@ -380,59 +383,65 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
&num_blocks_per_sm, MergedMomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE,
sizeof(MT) << 1);

size_t total_numel = 0;
for (int i = 0; i < op_num; ++i) {
size_t temp_numel = param[i]->numel();
total_numel += temp_numel;
lars_warpper.numel_arr[i] = temp_numel;
lars_warpper.g_arr[i] = grad[i]->data<T>();
lars_warpper.lr_arr[i] = learning_rate[i]->data<MT>();
lars_warpper.p_out_arr[i] =
param_out[i]->mutable_data<T>(ctx.GetPlace());
lars_warpper.v_out_arr[i] =
velocity_out[i]->mutable_data<MT>(ctx.GetPlace());
lars_warpper.weight_decay_arr[i] = static_cast<MT>(weight_decay_arr[i]);
PADDLE_ENFORCE_EQ(
param[i]->data<T>(), lars_warpper.p_out_arr[i],
platform::errors::InvalidArgument(
"Input(Param) and Output(ParamOut) must be the same Tensors."));
PADDLE_ENFORCE_EQ(velocity[i]->data<MT>(), lars_warpper.v_out_arr[i],
platform::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
}
int64_t avg_numel = total_numel / op_num;
LarsThreadConfig<float> lars_thread_config(avg_numel, sm_num,
num_blocks_per_sm);
for (int i = 0; i < op_num; ++i) {
lars_warpper.repeat_arr[i] =
lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]);
}
if (multi_precision) {
for (int i = 0; i < op_num; ++i) {
lars_warpper.master_p_out_arr[i] =
master_param_out[i]->mutable_data<MT>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(master_param[i]->data<MT>(),
lars_warpper.master_p_out_arr[i],
platform::errors::InvalidArgument(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."));
lars_warpper.weight_decay = lars_weight_decay;
int merge_times = (op_num + lars_warpper.kNum - 1) / lars_warpper.kNum;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果一个模型有160个参数,这个模型依然只会有一个optimzier op,只是这个optimizer op会启动2个CUDA Kernel计算,每个CUDA Kernel更新80个参数?

merge_times这个变量名也。。。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 还是需要在python层面区分出来 AMP LarsMomentum 和非AMP LarsMomentum,然后分先后将AMP lars 非AMP Lars 传入计算。如果单次传入的Op数量过多的话,会按照至多80个一组执行计算。
  • 变量名准备改成loop

for (int j = 0; j < merge_times; ++j) {
size_t total_numel = 0;
int start_idx = j * lars_warpper.kNum;
int loop_num = std::min(lars_warpper.kNum, op_num - start_idx);

for (int i = 0; i < loop_num; ++i) {
size_t temp_numel = param[start_idx + i]->numel();
total_numel += temp_numel;
lars_warpper.numel_arr[i] = temp_numel;
lars_warpper.g_arr[i] = grad[start_idx + i]->data<T>();
lars_warpper.p_arr[i] =
param_out[start_idx + i]->mutable_data<T>(ctx.GetPlace());
lars_warpper.v_arr[i] =
velocity_out[start_idx + i]->mutable_data<MT>(ctx.GetPlace());
lars_warpper.lr_arr[i] = learning_rate[start_idx + i]->data<MT>();
if (multi_precision) {
auto master_param_data =
master_param_out[start_idx + i]->mutable_data<MT>(
ctx.GetPlace());
lars_warpper.SetMasterParam(i, master_param_data);
PADDLE_ENFORCE_EQ(
master_param[start_idx + i]->data<MT>(), master_param_data,
platform::errors::InvalidArgument(
"Since Input(MasterParam) and Output(MasterParamOut) of "
"lars optimizer must be the same Tensors."));
}
PADDLE_ENFORCE_EQ(
param[start_idx + i]->data<T>(), lars_warpper.p_arr[i],
platform::errors::InvalidArgument(
"Since Input(Param) and Output(ParamOut) of lars optimizer "
"must be the same Tensors."));
PADDLE_ENFORCE_EQ(
velocity[start_idx + i]->data<MT>(), lars_warpper.v_arr[i],
platform::errors::InvalidArgument(
"Since Input(Velocity) and Output(VelocityOut) of "
"lars optimizer must be "
"the same Tensors."));
}
VLOG(10) << "Op number delt in loop " << j << " is : " << loop_num;
int64_t avg_numel = total_numel / loop_num;
LarsThreadConfig<float> lars_thread_config(avg_numel, sm_num,
num_blocks_per_sm);
void* cuda_param[] = {reinterpret_cast<void*>(&lars_warpper),
reinterpret_cast<void*>(&p_buffer),
reinterpret_cast<void*>(&g_buffer),
reinterpret_cast<void*>(&loop_num),
reinterpret_cast<void*>(&mu),
reinterpret_cast<void*>(&lars_coeff),
reinterpret_cast<void*>(&epsilon),
reinterpret_cast<void*>(&rescale_grad),
reinterpret_cast<void*>(&multi_precision)};
// Lanuch all sm theads,thead of each block synchronizedly cooperate.
cudaLaunchCooperativeKernel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个接口调用,确实后续可以再封装一下,可以实现在gpu_launch_config.h中,不过这个文件最好命名成gpu_launch_helper.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这种写法真的太占地方了

reinterpret_cast<void*>(MergedMomentumLarsKernel<T, MT>),
lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0,
cuda_ctx.stream());
}
void* cuda_param[] = {reinterpret_cast<void*>(&lars_warpper),
reinterpret_cast<void*>(&p_buffer),
reinterpret_cast<void*>(&g_buffer),
reinterpret_cast<void*>(&op_num),
reinterpret_cast<void*>(&mu),
reinterpret_cast<void*>(&lars_coeff),
reinterpret_cast<void*>(&epsilon),
reinterpret_cast<void*>(&rescale_grad),
reinterpret_cast<void*>(&multi_precision)};
// Lanuch all sm theads, and thead of each block synchronizedly cooperate.
cudaLaunchCooperativeKernel(
reinterpret_cast<void*>(MergedMomentumLarsKernel<T, MT>),
lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0,
cuda_ctx.stream());
} else {
auto* param_data = param[0]->data<T>();
auto* grad_data = grad[0]->data<T>();
Expand All @@ -448,15 +457,13 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
? master_param_out[0]->mutable_data<MT>(ctx.GetPlace())
: nullptr;
int64_t numel = param[0]->numel();
MT lars_weight_decay = weight_decay_arr[0];

// Figure out how many blocks can be active in each sm.
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, MomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE,
sizeof(MT) << 1);
LarsThreadConfig<float> lars_thread_config(numel, sm_num,
num_blocks_per_sm);
int repeat_times = lars_thread_config.GetRepeatTimes(numel);
int thresh = 0;
void* cuda_param[] = {
reinterpret_cast<void*>(&param_data),
Expand All @@ -474,7 +481,6 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
reinterpret_cast<void*>(&lars_weight_decay),
reinterpret_cast<void*>(&epsilon),
reinterpret_cast<void*>(&rescale_grad),
reinterpret_cast<void*>(&repeat_times),
reinterpret_cast<void*>(&thresh), // Just a placeholder
reinterpret_cast<void*>(&numel),
reinterpret_cast<void*>(&multi_precision)};
Expand All @@ -498,7 +504,7 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
velocity[i]->data<MT>(),
velocity_out[i]->mutable_data<MT>(ctx.GetPlace()), grad[i]->data<T>(),
learning_rate[i]->data<MT>(), p_buffer, g_buffer, mu, lars_coeff,
weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(),
lars_weight_decay, epsilon, rescale_grad, param[i]->numel(),
master_param_data, master_param_out_data, multi_precision);
}
#endif
Expand Down
Loading