Skip to content

Commit

Permalink
xpu support amp (#33809)
Browse files Browse the repository at this point in the history
  • Loading branch information
taixiurong authored Jun 29, 2021
1 parent 0d3de8d commit 4d4fb66
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 99 deletions.
12 changes: 5 additions & 7 deletions cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,17 @@ ELSEIF(WITH_CENTOS)
SET(XPU_XRE_DIR_NAME "xre-centos7_x86_64")
SET(XPU_XDNN_DIR_NAME "xdnn-centos7_x86_64")
SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64")

ELSE ()
SET(XPU_XRE_DIR_NAME "xre-ubuntu_x86_64")
SET(XPU_XDNN_DIR_NAME "xdnn-ubuntu_x86_64")
SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64")
ENDIF()

IF(NOT XPU_BASE_URL)
SET(XPU_BASE_URL "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev/20210527")
ENDIF()

SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL}/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_BASE_URL "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_XRE_URL "${XPU_BASE_URL}/20210625/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/20210625/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_PACK_DEPENCE_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/pack_paddle_depence.sh" CACHE STRING "" FORCE)

SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/imperative/amp_auto_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ AmpOperators::AmpOperators()
for (auto it = all_kernels.begin(); it != all_kernels.end(); it++) {
bool supported = false;
for (auto& kernel_type : it->second) {
if (platform::is_gpu_place(kernel_type.first.place_) &&
if ((platform::is_gpu_place(kernel_type.first.place_) ||
platform::is_xpu_place(kernel_type.first.place_)) &&
kernel_type.first.data_type_ == fp16_dtype) {
supported = true;
}
Expand Down Expand Up @@ -91,7 +92,8 @@ inline std::string GetDtypeStr(

inline bool NeedCast(const std::shared_ptr<VarBase>& var) {
if (platform::is_gpu_place(var->Place()) ||
platform::is_cuda_pinned_place(var->Place())) {
platform::is_cuda_pinned_place(var->Place()) ||
platform::is_xpu_place(var->Place())) {
// CudaPinndePlace is added for varbase created by dataloader
if (var->DataType() == framework::proto::VarType::FP32 ||
var->DataType() == framework::proto::VarType::FP16) {
Expand Down
15 changes: 1 addition & 14 deletions paddle/fluid/operators/cast_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,9 @@ limitations under the License. */
namespace paddle {
namespace operators {

template <typename T>
class XPUFPTypeTrait {
public:
using Type = T;
};

template <>
class XPUFPTypeTrait<platform::float16> {
public:
using Type = float16;
};

template <typename DeviceContext, typename InT>
class CastXPUKernel : public framework::OpKernel<InT> {
using XPUInTDType = typename XPUFPTypeTrait<InT>::Type;
using XPUInTDType = typename XPUTypeTrait<InT>::Type;

public:
void Compute(const framework::ExecutionContext& context) const override {
Expand All @@ -49,7 +37,6 @@ class CastXPUKernel : public framework::OpKernel<InT> {
context.Attr<int>("out_dtype"));
auto* in_data = in->data<InT>();

// using XPUOutTDType = typename XPUFPTypeTrait<InT>::Type;
auto numel = in->numel();
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = -1;
Expand Down
81 changes: 49 additions & 32 deletions paddle/fluid/operators/matmul_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ template <typename T, typename FCT>
static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out,
bool trans_x, bool trans_y,
const paddle::framework::ExecutionContext &ctx) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto &x_dims = x->dims();
const auto &y_dims = y->dims();
auto &dev_ctx =
Expand Down Expand Up @@ -162,34 +163,36 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out,
int ldout = n;
if (batch_size <= 1) {
int r = 0;
r = xpu::fc_fusion<T, T, T, FCT>(
dev_ctx.x_context(), x->data<T>(), y->data<T>(), data_c, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, ldy,
ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR);
r = xpu::fc_fusion<XPUType, XPUType, XPUType, FCT>(
dev_ctx.x_context(), reinterpret_cast<const XPUType *>(x->data<T>()),
reinterpret_cast<const XPUType *>(y->data<T>()),
reinterpret_cast<XPUType *>(data_c), m, n, k, mat_dim_a.trans_,
mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, ldy, ldout, alpha, 0,
nullptr, xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU fc_fusion kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} else {
// batch matmul
int r = xpu::fc_batched<T, T, T, FCT>(
dev_ctx.x_context(), // Context* ctx,
batch_size, // int batch_size,
mat_dim_a.trans_, // bool x_trans,
mat_dim_b.trans_, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
alpha, // float alpha,
reinterpret_cast<const T *>(x->data<T>()), // const TX* x,
mat_dim_a.stride_, // int stride_a,
reinterpret_cast<const T *>(y->data<T>()), // const TW* w,
mat_dim_b.stride_, // int stride_b,
0.0, // float beta,
reinterpret_cast<T *>(data_c), // TY* y,
m * n, // int stride_c,
nullptr, // const float* x_maxptr,
nullptr); // const float* w_maxptr
int r = xpu::fc_batched<XPUType, XPUType, XPUType, FCT>(
dev_ctx.x_context(), // Context* ctx,
batch_size, // int batch_size,
mat_dim_a.trans_, // bool x_trans,
mat_dim_b.trans_, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
alpha, // float alpha,
reinterpret_cast<const XPUType *>(x->data<T>()), // const TX* x,
mat_dim_a.stride_, // int stride_a,
reinterpret_cast<const XPUType *>(y->data<T>()), // const TW* w,
mat_dim_b.stride_, // int stride_b,
0.0, // float beta,
reinterpret_cast<XPUType *>(data_c), // TY* y,
m * n, // int stride_c,
nullptr, // const float* x_maxptr,
nullptr); // const float* w_maxptr

PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
Expand All @@ -210,10 +213,14 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
out->mutable_data<T>(context.GetPlace());
bool trans_x = context.Attr<bool>("transpose_X");
bool trans_y = context.Attr<bool>("transpose_Y");
if (std::getenv("XPU_PADDLE_MAT_MUL_FCINT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, context);
} else {
if (std::is_same<paddle::platform::float16, T>::value) {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, context);
} else {
if (std::getenv("XPU_PADDLE_MAT_MUL_FCINT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, context);
} else {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, context);
}
}
}
};
Expand All @@ -224,6 +231,7 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T>
static framework::Tensor XPUFoldHeadAndLastDims(
const DeviceContext &context, const framework::Tensor &input) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
Expand All @@ -236,8 +244,9 @@ static framework::Tensor XPUFoldHeadAndLastDims(
static_cast<int>(in_dims[1]),
static_cast<int>(in_dims[2])};
std::vector<int> axis_host = {1, 0, 2};
int r = xpu::transpose(context.x_context(), input.data<T>(), output.data<T>(),
in_shape_host, axis_host);
int r = xpu::transpose(
context.x_context(), reinterpret_cast<const XPUType *>(input.data<T>()),
reinterpret_cast<XPUType *>(output.data<T>()), in_shape_host, axis_host);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU transpose kernel return wrong value[%d %s]", r,
Expand Down Expand Up @@ -280,10 +289,14 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
const framework::Tensor &b, bool trans_b,
framework::Tensor *out) const {
out->mutable_data<T>(context.GetPlace());
if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_FCINT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, context);
} else {
if (std::is_same<paddle::platform::float16, T>::value) {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, context);
} else {
if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_FCINT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, context);
} else {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, context);
}
}
}

Expand Down Expand Up @@ -370,10 +383,14 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_XPU_KERNEL(
matmul, ops::MatMulXPUKernel<paddle::platform::XPUDeviceContext, float>);
matmul, ops::MatMulXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::MatMulXPUKernel<paddle::platform::XPUDeviceContext, plat::float16>);
REGISTER_OP_XPU_KERNEL(
matmul_grad,
ops::MatMulGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::MatMulGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::MatMulGradXPUKernel<paddle::platform::XPUDeviceContext,
plat::float16>);
#endif
79 changes: 47 additions & 32 deletions paddle/fluid/operators/matmul_v2_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ template <typename T, typename FCT>
static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
bool trans_x, bool trans_y,
const paddle::framework::ExecutionContext& ctx) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto& x_dims = x->dims();
const auto& y_dims = y->dims();
auto& dev_ctx =
Expand Down Expand Up @@ -75,9 +76,11 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
int batch_size = mat_dim_a.batch_size_;
if (batch_size <= 1) {
int r = 0;
r = xpu::fc<T, T, T, FCT>(dev_ctx.x_context(), x->data<T>(), y->data<T>(),
data_c, m, n, k, mat_dim_a.trans_,
mat_dim_b.trans_, nullptr, nullptr, nullptr);
r = xpu::fc<XPUType, XPUType, XPUType, FCT>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(x->data<T>()),
reinterpret_cast<const XPUType*>(y->data<T>()),
reinterpret_cast<XPUType*>(data_c), m, n, k, mat_dim_a.trans_,
mat_dim_b.trans_, nullptr, nullptr, nullptr);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
Expand All @@ -87,24 +90,24 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
r, XPUAPIErrorMsg[r], m, n, k, mat_dim_a.trans_, mat_dim_b.trans_));
} else {
// batch matmul
int r = xpu::fc_batched<T, T, T, FCT>(
dev_ctx.x_context(), // Context* ctx,
batch_size, // int batch_size,
mat_dim_a.trans_, // bool x_trans,
mat_dim_b.trans_, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
1.0, // float alpha,
reinterpret_cast<const T*>(x->data<T>()), // const TX* x,
mat_dim_a.stride_, // int stride_a,
reinterpret_cast<const T*>(y->data<T>()), // const TW* w,
mat_dim_b.stride_, // int stride_b,
0.0, // float beta,
reinterpret_cast<T*>(data_c), // TY* y,
m * n, // int stride_c,
nullptr, // const float* x_maxptr,
nullptr); // const float* w_maxptr
int r = xpu::fc_batched<XPUType, XPUType, XPUType, FCT>(
dev_ctx.x_context(), // Context* ctx,
batch_size, // int batch_size,
mat_dim_a.trans_, // bool x_trans,
mat_dim_b.trans_, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
1.0, // float alpha,
reinterpret_cast<const XPUType*>(x->data<T>()), // const TX* x,
mat_dim_a.stride_, // int stride_a,
reinterpret_cast<const XPUType*>(y->data<T>()), // const TW* w,
mat_dim_b.stride_, // int stride_b,
0.0, // float beta,
reinterpret_cast<XPUType*>(data_c), // TY* y,
m * n, // int stride_c,
nullptr, // const float* x_maxptr,
nullptr); // const float* w_maxptr

PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
Expand All @@ -123,17 +126,22 @@ class MatMulV2XPUKernel : public framework::OpKernel<T> {
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
out->mutable_data<T>(ctx.GetPlace());
if (std::getenv("XPU_PADDLE_MAT_MUL_V2_FCINT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, ctx);
} else {
if (std::is_same<paddle::platform::float16, T>::value) {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, ctx);
} else {
if (std::getenv("XPU_PADDLE_MAT_MUL_V2_FCINT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, ctx);
} else {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, ctx);
}
}
}
};

template <typename DeviceContext, typename T>
static framework::Tensor XPUFoldHeadAndLastDims(
const DeviceContext& context, const framework::Tensor& input) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
Expand All @@ -147,8 +155,9 @@ static framework::Tensor XPUFoldHeadAndLastDims(
static_cast<int>(in_dims[2])};
std::vector<int> axis_host = {1, 0, 2};

int r = xpu::transpose(context.x_context(), input.data<T>(), output.data<T>(),
in_shape_host, axis_host);
int r = xpu::transpose(
context.x_context(), reinterpret_cast<const XPUType*>(input.data<T>()),
reinterpret_cast<XPUType*>(output.data<T>()), in_shape_host, axis_host);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU transpose kernel return wrong value[%d %s]", r,
Expand All @@ -166,10 +175,14 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
const framework::Tensor& b, bool trans_b,
framework::Tensor* out) const {
out->mutable_data<T>(ctx.GetPlace());
if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, ctx);
} else {
if (std::is_same<paddle::platform::float16, T>::value) {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, ctx);
} else {
if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, ctx);
} else {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, ctx);
}
}
}

Expand Down Expand Up @@ -261,8 +274,10 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_XPU_KERNEL(matmul_v2, ops::MatMulV2XPUKernel<float>);
REGISTER_OP_XPU_KERNEL(matmul_v2_grad, ops::MatMulV2XPUGradKernel<float>);
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(matmul_v2, ops::MatMulV2XPUKernel<float>,
ops::MatMulV2XPUKernel<plat::float16>);
REGISTER_OP_XPU_KERNEL(matmul_v2_grad, ops::MatMulV2XPUGradKernel<float>,
ops::MatMulV2XPUGradKernel<plat::float16>);

#endif
4 changes: 2 additions & 2 deletions paddle/fluid/operators/softmax_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
int len = x->numel();
T* clip_x_data =
clip_x.mutable_data<T>(context.GetPlace(), len * sizeof(T));
r = xpu::clip(dev_ctx.x_context(), x->data<float>(), clip_x_data, len,
-1e30, 1e30);
r = xpu::clip_v2(dev_ctx.x_context(), x->data<float>(), clip_x_data, len,
static_cast<float>(-1e20), static_cast<float>(1e20));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(clip) return wrong "
"value[%d %s]",
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
int len = logits->numel();
T* clip_logits_data =
clip_logits.mutable_data<T>(context.GetPlace(), len * sizeof(T));
r = xpu::clip(dev_ctx.x_context(), logits->data<float>(), clip_logits_data,
len, -1e30, 1e30);
r = xpu::clip_v2(dev_ctx.x_context(), logits->data<float>(),
clip_logits_data, len, static_cast<float>(-1e20),
static_cast<float>(1e20));
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error. clip "
Expand Down
Loading

0 comments on commit 4d4fb66

Please sign in to comment.