Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge momentum ops/kernels #36380

Merged
merged 10 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions paddle/fluid/operators/merged_momentum_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/merged_momentum_op.h"

namespace paddle {
namespace operators {

class MergedMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext *ctx) const override {}

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto param_dtype =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(param_dtype, ctx.GetPlace());
}
};

class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated")
.AsDuplicable();
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter")
.AsDuplicable();
AddInput("Velocity",
"(Tensor, default Tensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated")
.AsDuplicable();
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDispensable()
.AsDuplicable();
AddOutput("ParamOut",
"(Tensor) This output is updated parameter. "
"It shared memory with Input(Param).")
.AsDuplicable();
AddOutput("VelocityOut",
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity).")
.AsDuplicable();
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable()
.AsDuplicable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);
AddComment(R"DOC(Merged Momentum Optimizer.)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, ops::MergedMomentumOp,
ops::MergedMomentumOpMaker);

REGISTER_OP_CPU_KERNEL(
merged_momentum, ops::MergedMomentumOpKernel<plat::CPUDeviceContext, float>,
ops::MergedMomentumOpKernel<plat::CPUDeviceContext, double>);
24 changes: 24 additions & 0 deletions paddle/fluid/operators/merged_momentum_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/merged_momentum_op.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
merged_momentum,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, plat::float16>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, float>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, double>);
196 changes: 196 additions & 0 deletions paddle/fluid/operators/merged_momentum_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/for_range.h"

namespace paddle {
namespace operators {

template <typename MT, uint32_t kParamNum, bool kHasMasterParams>
struct MergedMomentumMasterParams {
MT *master_params[kParamNum];

HOSTDEVICE MT *MasterParam(size_t idx) const { return master_params[idx]; }
HOSTDEVICE void SetMasterParam(size_t idx, MT *p) { master_params[idx] = p; }
};

template <typename MT, uint32_t kParamNum>
struct MergedMomentumMasterParams<MT, kParamNum, false> {
HOSTDEVICE constexpr MT *MasterParam(size_t) const { return nullptr; }
HOSTDEVICE constexpr void SetMasterParam(size_t, MT *) {}
};

template <typename T, typename MT, bool kHasMasterParams,
uint32_t kParamNum = kHasMasterParams ? 55 : 110>
struct MergedMomentumKernelParam
: public MergedMomentumMasterParams<MT, kParamNum, kHasMasterParams> {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方继承的目的是利用C++编译器的一个特性:继承空类不会影响该类的sizeof的大小。因此当不需要MasterParams时,sizeof(MergedMomentumKernelParam)会更小,可以容纳更多的kParamNum

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

学到了~

static constexpr auto N = kParamNum;
size_t sizes[N];
T *params[N];
const T *grads[N];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const T * 这类只读数据加入 __restrict__ 关键字会有一点性能提升,但是提升幅度可能不会明显,https://developer.nvidia.com/blog/cuda-pro-tip-optimize-pointer-aliasing/

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

MT *velocitys[N];
const MT *lr;
MT mu;
MT rescale_grad;
uint32_t param_num;

HOSTDEVICE void operator()(size_t i) const {
const auto lr_val = *lr;
for (uint32_t idx = 0; idx < param_num; ++idx) {
auto size = sizes[idx];
if (i >= size) continue;

auto param_p = params[idx];
auto grad_p = grads[idx];
auto velocity_p = velocitys[idx];
auto master_param_p = this->MasterParam(idx);

const MT param =
master_param_p ? master_param_p[i] : static_cast<MT>(param_p[i]);
const MT grad = static_cast<MT>(grad_p[i]) * rescale_grad;
const MT velocity = velocity_p[i];
const MT velocity_out = velocity * mu + grad;
const MT param_out = param - lr_val * velocity_out;
velocity_p[i] = velocity_out;
param_p[i] = static_cast<T>(param_out);
if (master_param_p) {
master_param_p[i] = param_out;
}
}
}
};

template <typename DeviceContext, typename T>
class MergedMomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
PADDLE_ENFORCE_EQ(
n, params_out.size(),
platform::errors::InvalidArgument(
"Output(ParamOut) number must be equal to Input(Param) number."));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(
params[i], params_out[i],
platform::errors::InvalidArgument(
"Input(Param) and Output(ParamOut) must be the same Tensors."));
}

auto grads = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ(
n, grads.size(),
platform::errors::InvalidArgument(
"Input(Grad) number must be equal to Input(Param) number."));

auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
PADDLE_ENFORCE_EQ(n, velocitys.size(),
platform::errors::InvalidArgument(
"Input(Velocity) number and Input(Param) number."));

auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
PADDLE_ENFORCE_EQ(
n, velocitys_out.size(),
platform::errors::InvalidArgument("Output(VelocityOut) number must be "
"equal to Input(Param) number."));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i],
platform::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
}

auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_params_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
auto multi_precision = ctx.Attr<bool>("multi_precision");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的判断是不是表明了会在python端区分性的收集好AMP 和 非AMP optimizer,再分别传入执行optimizer计算

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的。把AMP和非AMP分开有以下几个原因的考虑:

  • OperatorWithKernel::GetExpectedKernelType这个方法更好写。不然就得遍历所有Param的类型,如果有混合类型(FP16、FP32)得返回调用FP16的kernel,没有混合类型就调用单一类型的kernel。写起来会非常麻烦。
  • MergedMomentumKernelParam这个类的sizeof会更小一些,不用装载bool multi_precision[N],里面的实现也不用到处if-else判断是不是混合类型。
  • 如果有混合类型,MergedMomentumKernelParam这个类的paramsgrads只能写成void *params[N]void *grads[N]了,代码可读性比较差。

if (multi_precision) {
PADDLE_ENFORCE_EQ(
n, master_params.size(),
platform::errors::InvalidArgument("Input(MasterParam) number must be "
"equal to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, master_params_out.size(),
platform::errors::InvalidArgument(
"Output(MasterParamOut) number must be equal to "
"Input(MasterParam) number."));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i],
platform::errors::InvalidArgument(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."));
PADDLE_ENFORCE_NOT_NULL(master_params[i],
platform::errors::InvalidArgument(
"Input(MasterParam) must be provided when "
"multi_precision=True."));
}
} else {
master_params.clear();
master_params_out.clear();
}

auto lr = ctx.Input<framework::Tensor>("LearningRate");
auto mu = ctx.Attr<float>("mu");
auto rescale_grad = ctx.Attr<float>("rescale_grad");
using MPType = typename operators::details::MPTypeTrait<T>::Type;

auto &dev_ctx = ctx.template device_context<DeviceContext>();

#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lr->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MPType>(); \
kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MPType>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}

if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}

#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
}
};

} // namespace operators
} // namespace paddle
Loading