Skip to content

Commit

Permalink
[XPU] fix performance problem of stage1_overlap (PaddlePaddle#63920)
Browse files Browse the repository at this point in the history
  • Loading branch information
lj970926 authored and co63oc committed May 6, 2024
1 parent 488b0eb commit 39daa11
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 87 deletions.
4 changes: 4 additions & 0 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,14 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT64,
phi::DataType::FLOAT32})},
{"fill_any",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT64,
phi::DataType::FLOAT32})},
{"fill_any_like",
Expand Down Expand Up @@ -430,12 +432,14 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"flatten_contiguous_range",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"flatten_grad",
XPUKernelSet({phi::DataType::INT64,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,14 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT64,
phi::DataType::FLOAT32})},
{"fill_any",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT64,
phi::DataType::FLOAT32})},
{"fill_any_like",
Expand Down Expand Up @@ -448,12 +450,14 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"flatten_contiguous_range",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"flatten_grad",
XPUKernelSet({phi::DataType::INT64,
Expand Down
46 changes: 12 additions & 34 deletions paddle/phi/kernels/funcs/math_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,50 +121,28 @@ struct TensorSetConstantXPU {
auto* ctx = phi::DeviceContextPool::Instance().Get(place_);
auto begin = ctx->Alloc<T>(tensor_);
int numel = tensor_->numel();
std::unique_ptr<T[]> data_cpu(new T[numel]);
std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast<T>(value_));
memory_utils::Copy(place_,
begin,
phi::CPUPlace(),
static_cast<void*>(data_cpu.get()),
numel * sizeof(T));
}
phi::DenseTensor* tensor_;
U value_;
phi::Place place_;
};

template <>
struct TensorSetConstantXPU<float> {
TensorSetConstantXPU(phi::DenseTensor* tensor, float value, phi::Place place)
: tensor_(tensor), value_(value), place_(place) {}
template <typename T>
void apply() const {
auto* ctx = phi::DeviceContextPool::Instance().Get(place_);
auto begin = ctx->Alloc<T>(tensor_);
int numel = tensor_->numel();
if ((std::is_same<T, float>::value) ||
(std::is_same<T, phi::dtype::bfloat16>::value) ||
(std::is_same<T, phi::dtype::float16>::value)) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto* dev_ctx = static_cast<phi::XPUContext*>(ctx);
int r = xpu::constant<XPUType>(dev_ctx->x_context(),
reinterpret_cast<XPUType*>(begin),
numel,
static_cast<XPUType>(value_));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
} else {
if (std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value) {
std::unique_ptr<T[]> data_cpu(new T[numel]);
std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast<T>(value_));
memory_utils::Copy(place_,
begin,
phi::CPUPlace(),
static_cast<void*>(data_cpu.get()),
numel * sizeof(T));
} else {
auto* dev_ctx = static_cast<phi::XPUContext*>(ctx);
using XPUType = typename XPUTypeTrait<T>::Type;
T val = static_cast<T>(value_);
int r = xpu::constant<XPUType>(dev_ctx->x_context(),
reinterpret_cast<XPUType*>(begin),
numel,
static_cast<XPUType>(val));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
}
}
phi::DenseTensor* tensor_;
float value_;
U value_;
phi::Place place_;
};
#endif
Expand Down
97 changes: 44 additions & 53 deletions test/xpu/test_fill_any_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
create_test_class,
get_xpu_op_support_types,
)
from op_test import convert_float_to_uint16
from op_test_xpu import XPUOpTest

import paddle
Expand All @@ -35,15 +36,17 @@ def __init__(self):
class TestFillAnyOp(XPUOpTest):
def setUp(self):
self.op_type = "fill_any"
self.dtype = 'float64'
self.dtype = self.in_type
self.value = 0.0
self.init()
self.inputs = {'X': np.random.random((20, 30)).astype(self.dtype)}
self.attrs = {'value': float(self.value)}
self.outputs = {
'Out': self.value
* np.ones_like(self.inputs["X"]).astype(self.dtype)
}
out_np = self.value * np.ones_like(self.inputs["X"])
if self.dtype == np.uint16:
out_np = convert_float_to_uint16(out_np)
else:
out_np = out_np.astype(self.dtype)
self.outputs = {'Out': out_np}

def init(self):
pass
Expand All @@ -54,61 +57,49 @@ def test_check_output(self):
def test_check_grad(self):
self.check_grad_with_place(paddle.XPUPlace(0), ['X'], 'Out')

class TestFillAnyOpFloat32(TestFillAnyOp):
def init(self):
self.dtype = np.float32
self.value = 0.0

class TestFillAnyOpFloat16(TestFillAnyOp):
def init(self):
self.dtype = np.float16

class TestFillAnyOpvalue1(TestFillAnyOp):
def init(self):
self.dtype = np.float32
self.value = 111111555
self.value = 11555

class TestFillAnyOpvalue2(TestFillAnyOp):
def init(self):
self.dtype = np.float32
self.value = 11111.1111

class TestFillAnyInplace(unittest.TestCase):
def test_fill_any_version(self):
with paddle.base.dygraph.guard():
var = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
self.assertEqual(var.inplace_version, 0)

var.fill_(0)
self.assertEqual(var.inplace_version, 1)

var.fill_(0)
self.assertEqual(var.inplace_version, 2)

var.fill_(0)
self.assertEqual(var.inplace_version, 3)

def test_fill_any_equal(self):
with paddle.base.dygraph.guard():
tensor = paddle.to_tensor(
np.random.random((20, 30)).astype(np.float32)
)
target = tensor.numpy()
target[...] = 1

tensor.fill_(1)
self.assertEqual((tensor.numpy() == target).all().item(), True)

def test_backward(self):
with paddle.base.dygraph.guard():
x = paddle.full([10, 10], -1.0, dtype='float32')
x.stop_gradient = False
y = 2 * x
y.fill_(1)
y.backward()
np.testing.assert_array_equal(
x.grad.numpy(), np.zeros([10, 10])
)

class TestFillAnyInplace(unittest.TestCase):
def test_fill_any_version(self):
with paddle.base.dygraph.guard():
var = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
self.assertEqual(var.inplace_version, 0)

var.fill_(0)
self.assertEqual(var.inplace_version, 1)

var.fill_(0)
self.assertEqual(var.inplace_version, 2)

var.fill_(0)
self.assertEqual(var.inplace_version, 3)

def test_fill_any_equal(self):
with paddle.base.dygraph.guard():
tensor = paddle.to_tensor(
np.random.random((20, 30)).astype(np.float32)
)
target = tensor.numpy()
target[...] = 1

tensor.fill_(1)
self.assertEqual((tensor.numpy() == target).all().item(), True)

def test_backward(self):
with paddle.base.dygraph.guard():
x = paddle.full([10, 10], -1.0, dtype='float32')
x.stop_gradient = False
y = 2 * x
y.fill_(1)
y.backward()
np.testing.assert_array_equal(x.grad.numpy(), np.zeros([10, 10]))


class TestFillAnyLikeOpSpecialValue(unittest.TestCase):
Expand Down

0 comments on commit 39daa11

Please sign in to comment.