Skip to content

Commit

Permalink
multi pricison for lars op and lars optimizer (#33280)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Jun 3, 2021
1 parent fc5b3a9 commit 4d805e6
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 55 deletions.
14 changes: 14 additions & 0 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("LearningRate",
"(LoDTensor, default LoDTensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();

AddOutput("ParamOut",
"(LoDTensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(LoDTensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();

AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
Expand All @@ -51,6 +56,15 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("epsilon",
"(float, default 0.0) epsilon to avoid Division by Zero.")
.SetDefault(0.0);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);

AddComment(R"DOC(
Lars Momentum Optimizer.
Expand Down
119 changes: 89 additions & 30 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,105 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void MomentumLarsKernel(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu,
const int64_t num, const T lars_coeff,
const T lars_weight_decay, const T* p_norm,
const T* g_norm, T* p_out, T* v_out,
const T epsilon) {
T lr = learning_rate[0];
T local_lr = learning_rate[0];
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;

template <typename T, typename MT>
__global__ void MomentumLarsKernel(
const T* p, const T* g, const MT* v,
const MultiPrecisionType<T>* learning_rate, const MT mu, const int64_t num,
const MT lars_coeff, const MT lars_weight_decay,
const MultiPrecisionType<T>* p_norm, const MultiPrecisionType<T>* g_norm,
T* p_out, MT* v_out, const MT epsilon, const MT* master_p, MT* master_p_out,
const MultiPrecisionType<T> rescale_grad) {
const MT lr = static_cast<MT>(learning_rate[0]);
MT local_lr = lr;
const MT p_n = static_cast<MT>(p_norm[0]);
const MT g_n = static_cast<MT>(g_norm[0]);

if (lars_weight_decay > static_cast<MT>(0) && p_n > static_cast<MT>(0) &&
g_n > static_cast<MT>(0)) {
local_lr =
lr * lars_coeff * p_n / (g_n + lars_weight_decay * p_n + epsilon);
}
CUDA_KERNEL_LOOP(i, num) {
if (lars_weight_decay > 0 && p_norm[0] > 0 && g_norm[0] > 0) {
local_lr = lr * lars_coeff * p_norm[0] /
(g_norm[0] + lars_weight_decay * p_norm[0] + epsilon);
}
MT grad = static_cast<MT>(g[i]) * static_cast<MT>(rescale_grad);
MT param = master_p ? master_p[i] : static_cast<MT>(p[i]);

MT v_new = v[i] * mu + local_lr * (grad + lars_weight_decay * param);
MT p_new = param - v_new;

T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]);
v_out[i] = v_new;
p_out[i] = p[i] - v_new;
p_out[i] = static_cast<T>(p_new);
if (master_p_out) master_p_out[i] = p_new;
}
}

template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
using MPDType = MultiPrecisionType<T>;

public:
void Compute(const framework::ExecutionContext& ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
InnerCompute<MPDType>(ctx, multi_precision);
} else {
InnerCompute<T>(ctx, multi_precision);
}
}

private:
template <typename MT>
void InnerCompute(const framework::ExecutionContext& ctx,
const bool multi_precision) const {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto grad = ctx.Input<framework::LoDTensor>("Grad");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");

const framework::Tensor* master_param = nullptr;
framework::Tensor* master_param_out = nullptr;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
}

const MT* master_p = multi_precision ? master_param->data<MT>() : nullptr;
MT* master_p_out = multi_precision
? master_param_out->mutable_data<MT>(ctx.GetPlace())
: nullptr;

T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
MT* v_out = velocity_out->mutable_data<MT>(ctx.GetPlace());

T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
T epsilon = ctx.Attr<float>("epsilon");
MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff"));
MT lars_weight_decay =
static_cast<MT>(ctx.Attr<float>("lars_weight_decay"));
MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
MPDType rescale_grad =
static_cast<MPDType>(ctx.Attr<float>("rescale_grad"));

auto* p = param->data<T>();
auto* v = velocity->data<T>();
auto* g = grad->data<T>();
auto* lr = learning_rate->data<T>();
auto* v = velocity->data<MT>();
auto* lr = learning_rate->data<MPDType>();

int block = 512;
int grid = (param->numel() + block - 1) / block;
Expand All @@ -72,17 +122,24 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
g_norm_t.Resize({1});
auto* p_norm_data = p_norm_t.mutable_data<T>(ctx.GetPlace());
auto* g_norm_data = g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
auto* p_norm_data = p_norm_t.mutable_data<MPDType>(ctx.GetPlace());
auto* g_norm_data = g_norm_t.mutable_data<MPDType>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<MPDType>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<MPDType>::From(g_norm_t);

auto* place = ctx.template device_context<DeviceContext>().eigen_device();
ep_norm.device(*place) = eigen_p.square().sum().sqrt();
eg_norm.device(*place) = eigen_g.square().sum().sqrt();
MomentumLarsKernel<<<grid, block, 0, ctx.cuda_device_context().stream()>>>(

// eigen unsupport fp16 l2-norm
ep_norm.device(*place) =
eigen_p.template cast<MPDType>().square().sum().sqrt();
eg_norm.device(*place) =
(eigen_g.template cast<MPDType>() * rescale_grad).square().sum().sqrt();

MomentumLarsKernel<
T, MT><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay,
p_norm_data, g_norm_data, p_out, v_out, epsilon);
p_norm_data, g_norm_data, p_out, v_out, epsilon, master_p, master_p_out,
rescale_grad);
}
};

Expand All @@ -93,4 +150,6 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lars_momentum,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>);
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
3 changes: 3 additions & 0 deletions paddle/fluid/operators/optimizers/momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ class MomentumOp : public framework::OperatorWithKernel {

ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
if (ctx->HasOutput("MasterParamOut")) {
ctx->SetOutputDim("MasterParamOut", param_dim);
}
}

framework::OpKernelType GetExpectedKernelType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def layer_warp(block_func, input, ch_in, ch_out, count, stride):
return pool


def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
classdim = 10
data_shape = [3, 32, 32]
BATCH_SIZE = 32
Expand All @@ -96,12 +96,17 @@ def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
# Test program
test_program = train_program.clone(for_test=True)

if use_adam:
if optimizer == "Adam":
optimizer = paddle.optimizer.AdamW(
learning_rate=0.001,
epsilon=1e-8,
weight_decay=0.0,
multi_precision=True)
elif optimizer == "Lars":
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
learning_rate=0.001,
momentum=0.9,
multi_precision=use_pure_fp16)
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
Expand Down Expand Up @@ -169,9 +174,11 @@ def test_resnet_pure_fp16(self):
if not fluid.core.is_compiled_with_cuda():
return

def do_test(use_nesterov=False, use_adam=False):
if use_adam:
def do_test(use_nesterov=False, optimizer=""):
if optimizer == "Adam":
suffix = "use Adam"
elif optimizer == "Lars":
suffix = "use Lars"
else:
suffix = "with Nesterov" if use_nesterov else "without Nesterov"
with self.scope_prog_guard():
Expand All @@ -180,14 +187,14 @@ def do_test(use_nesterov=False, use_adam=False):
train_loss_fp16, test_loss_fp16 = train(
use_pure_fp16=True,
use_nesterov=use_nesterov,
use_adam=use_adam)
optimizer=optimizer)
with self.scope_prog_guard():
print("-----------------FP32 Train {}-----------------".format(
suffix))
train_loss_fp32, test_loss_fp32 = train(
use_pure_fp16=False,
use_nesterov=use_nesterov,
use_adam=use_adam)
optimizer=optimizer)

self.assertTrue(
np.allclose(
Expand All @@ -208,7 +215,8 @@ def do_test(use_nesterov=False, use_adam=False):

do_test(use_nesterov=False)
do_test(use_nesterov=True)
do_test(use_adam=True)
do_test(optimizer="Adam")
do_test(optimizer="Lars")

@contextlib.contextmanager
def scope_prog_guard(self):
Expand Down
Loading

0 comments on commit 4d805e6

Please sign in to comment.