Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.48】为 Paddle assign_value、meshgrid、kthvalue、determinant 算子实现 float16 数据类型支持 #52046

Closed
wants to merge 10 commits into from
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
backward : assign_grad

- op : assign_value
args : (int[] shape, DataType dtype, int[] bool_values = {}, float[] fp32_values = {}, int[] int32_values = {}, int64_t[] int64_values = {})
args : (int[] shape, DataType dtype, int[] bool_values = {}, float[] fp32_values = {}, float[] fp16_values = {}, int[] int32_values = {}, int64_t[] int64_values = {})
output : Tensor(out)
infer_meta :
func : AssignValueInferMeta
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/assign_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ PD_REGISTER_KERNEL(assign_value,
phi::AssignValueKernel,
bool,
int,
phi::dtype::float16,
float,
int64_t) {}
#endif
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/determinant_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ PD_REGISTER_KERNEL(determinant_grad,
GPU,
ALL_LAYOUT,
phi::DeterminantGradKernel,
phi::dtype::float16,
float,
double) {}
9 changes: 7 additions & 2 deletions paddle/phi/kernels/gpu/determinant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"

PD_REGISTER_KERNEL(
determinant, GPU, ALL_LAYOUT, phi::DeterminantKernel, float, double) {}
PD_REGISTER_KERNEL(determinant,
GPU,
ALL_LAYOUT,
phi::DeterminantKernel,
phi::dtype::float16,
float,
double) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(meshgrid_grad,
GPU,
ALL_LAYOUT,
phi::MeshgridGradKernel,
phi::dtype::float16,
float,
double,
int,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(meshgrid,
GPU,
ALL_LAYOUT,
phi::MeshgridKernel,
phi::dtype::float16,
float,
double,
int,
Expand Down
32 changes: 26 additions & 6 deletions paddle/phi/kernels/impl/determinant_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
#pragma once

#include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"

#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/determinant_grad_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
Expand All @@ -26,7 +28,6 @@
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {
namespace detail {

Expand Down Expand Up @@ -113,6 +114,11 @@ void DeterminantGradKernel(const Context& dev_ctx,
return;
}

using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
auto origin_dt = std::is_same<phi::dtype::float16, T>::value
? DataType::FLOAT16
: DataType::BFLOAT16;

// The matrix is invertible
// let |A| = Determinant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
Expand All @@ -123,16 +129,22 @@ void DeterminantGradKernel(const Context& dev_ctx,
DenseTensor inverse_A;
// A must be square matrices!
inverse_A.Resize(x.dims());
dev_ctx.template Alloc<T>(&inverse_A);
dev_ctx.template Alloc<MPType>(&inverse_A);

phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
mat_inv(dev_ctx, x, &inverse_A);
phi::funcs::MatrixInverseFunctor<Context, MPType> mat_inv;
if (!std::is_same<MPType, T>::value) {
mat_inv(dev_ctx,
phi::Cast<T, Context>(dev_ctx, x, DataType::FLOAT32),
&inverse_A);
} else {
mat_inv(dev_ctx, x, &inverse_A);
}

VLOG(3) << "inverse(A) dims: " << inverse_A.dims();

// Second: inverse(A).transpose(-2, -1)
DenseTensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, inverse_A);
phi::TransposeLast2Dim<MPType>(dev_ctx, inverse_A);

VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
<< transpose_inverse_A.dims();
Expand All @@ -147,7 +159,15 @@ void DeterminantGradKernel(const Context& dev_ctx,
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();

// Finally: unsqueeze(dA * |A|) * inverse(A)
auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
DenseTensor res;
if (!std::is_same<MPType, T>::value) {
res = phi::Multiply<T>(
dev_ctx,
unsqueeze2,
phi::Cast<MPType, Context>(dev_ctx, transpose_inverse_A, origin_dt));
} else {
res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
}

VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();

Expand Down
12 changes: 11 additions & 1 deletion paddle/phi/kernels/impl/determinant_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <vector>

#include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"

#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_utils.h"
Expand All @@ -31,6 +32,13 @@ namespace detail {
template <typename T>
class EigenMatrix {};

template <>
class EigenMatrix<phi::dtype::float16> {
public:
using MatrixType =
Eigen::Matrix<phi::dtype::float16, Eigen::Dynamic, Eigen::Dynamic>;
};

template <>
class EigenMatrix<float> {
public:
Expand Down Expand Up @@ -74,6 +82,7 @@ struct DeterminantFunctor {
std::vector<T> input_vec;
std::vector<T> output_vec;
phi::TensorToVector(input, dev_ctx, &input_vec);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
auto begin_iter = input_vec.begin() + i * rank * rank;
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
Expand All @@ -85,7 +94,8 @@ struct DeterminantFunctor {
matrix(i, j) = sub_vec[rank * i + j];
}
}
output_vec.push_back(matrix.determinant());
output_vec.push_back(
static_cast<T>(matrix.template cast<MPType>().determinant()));
}
phi::TensorFromVector(output_vec, dev_ctx, output);
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/ops/compat/assign_value_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ KernelSignature AssignValueOpArgumentMapping(
} else if (dtype == /*INT64*/ 3) {
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "int64_values"}, {"Out"});
} else if (dtype == /*FP16*/ 4) {
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "fp16_values"}, {"Out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
Expand Down
20 changes: 19 additions & 1 deletion python/paddle/fluid/tests/unittests/test_assign_value_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import paddle
from paddle import fluid
from paddle.fluid import framework
from paddle.fluid import core, framework


def assign_value_wrapper(
Expand Down Expand Up @@ -72,6 +72,16 @@ def init_data(self):
self.attrs["bool_values"] = [int(v) for v in self.value.flat]


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个skip应该可以去掉?现在单测框架能自动为fp16跳过不支持的设备,看后面的其他单测也没有添加这个装饰器。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的老师,之前确实其他fp16单测没有加装饰器,但是好像这个单测没有正常跳过,挂掉了。我去掉后重新跑下ci看看。

class TestAssignValueOpFp16(TestAssignValueOp):
def init_data(self):
self.dtype = np.float16
self.value = np.random.random(size=(2, 5)).astype(self.dtype)
self.attrs["fp16_values"] = [float(v) for v in self.value.flat]


class TestAssignApi(unittest.TestCase):
def setUp(self):
with eager_op_test.paddle_static_guard():
Expand Down Expand Up @@ -128,5 +138,13 @@ def init_dtype(self):
self.dtype = "bool"


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的装饰器应该不能去掉,因为它的基类继承自unittest.TestCase,没有自动跳过的机制。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

class TestAssignApiFp16(TestAssignApi):
def init_dtype(self):
self.dtype = np.float16


if __name__ == '__main__':
unittest.main()
19 changes: 19 additions & 0 deletions python/paddle/fluid/tests/unittests/test_determinant_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def init_data(self):
self.target = np.linalg.det(self.case)


class TestDeterminantOpCase1FP16(TestDeterminantOp):
def init_data(self):
np.random.seed(0)
self.case = np.random.rand(10, 10).astype(np.float16)
self.inputs = {'Input': self.case}
self.target = np.linalg.det(self.case.astype(np.float32))


class TestDeterminantOpCase2(TestDeterminantOp):
def init_data(self):
np.random.seed(0)
Expand All @@ -59,6 +67,17 @@ def init_data(self):
self.target = np.linalg.det(self.case)


class TestDeterminantOpCase2FP16(TestDeterminantOp):
def init_data(self):
np.random.seed(0)
# not invertible matrix
self.case = np.ones([4, 2, 4, 4]).astype(np.float16)
self.inputs = {'Input': self.case}
self.target = np.linalg.det(self.case.astype(np.float32)).astype(
np.float16
)


class TestDeterminantAPI(unittest.TestCase):
def setUp(self):
np.random.seed(0)
Expand Down
24 changes: 20 additions & 4 deletions python/paddle/fluid/tests/unittests/test_kthvalue_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ def init_args(self):
self.k = 5
self.axis = -1

def init_dtype(self):
self.dtype = np.float64

def setUp(self):
self.op_type = "kthvalue"
self.python_api = paddle.kthvalue
self.dtype = np.float64
self.input_data = np.random.random((2, 1, 2, 4, 10))
self.init_dtype()
self.input_data = np.random.random((2, 1, 2, 4, 10)).astype(self.dtype)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis}
Expand All @@ -62,17 +65,25 @@ def test_check_grad(self):
self.check_grad({'X'}, 'Out')


class TestKthvalueOpFp16(TestKthvalueOp):
def init_dtype(self):
self.dtype = np.float16


class TestKthvalueOpWithKeepdim(OpTest):
def init_args(self):
self.k = 2
self.axis = 1

def init_dtype(self):
self.dtype = np.float64

def setUp(self):
self.init_args()
self.init_dtype()
self.op_type = "kthvalue"
self.python_api = paddle.kthvalue
self.dtype = np.float64
self.input_data = np.random.random((1, 3, 2, 4, 10))
self.input_data = np.random.random((1, 3, 2, 4, 10)).astype(self.dtype)
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'keepdim': True}
output, indices = cal_kthvalue(
Expand All @@ -89,6 +100,11 @@ def test_check_grad(self):
self.check_grad({'X'}, 'Out')


class TestKthvalueOpWithKeepdimFp16(TestKthvalueOpWithKeepdim):
def init_dtype(self):
self.dtype = np.float16


class TestKthvalueOpKernels(unittest.TestCase):
def setUp(self):
self.axises = [2, -1]
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/fluid/tests/unittests/test_meshgrid_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ def get_x_shape(self):
return [100, 300]


class TestMeshgridOp2Fp16(TestMeshgridOp):
def get_x_shape(self):
return [100, 300]

def get_dtype(self):
return np.float16


class TestMeshgridOp3(unittest.TestCase):
def test_api(self):
x = paddle.static.data(shape=[100], dtype='int32', name='x')
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ def meshgrid(*args, **kwargs):

Args:
*args(Tensor|list of Tensor) : tensors (tuple(list) of tensor): the shapes of input k tensors are (N1,),
(N2,),..., (Nk,). Support data types: ``float64``, ``float32``, ``int32``, ``int64``.
(N2,),..., (Nk,). Support data types: ``float64``, ``float32``, ``float16``, ``int32``, ``int64``.
**kwargs (optional): Currently, only accept name in **kwargs
The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -2124,6 +2124,9 @@ def convert_scalar(x):
if dtype == core.VarDesc.VarType.BOOL:
value_name = "bool_values"
values = [int(v) for v in input.flat]
elif dtype == core.VarDesc.VarType.FP16:
value_name = "fp16_values"
values = [float(v) for v in input.flat]
elif dtype == core.VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in input.flat]
Expand All @@ -2136,7 +2139,7 @@ def convert_scalar(x):
else:
raise TypeError(
"When the type of 'input' in assign is numpy.ndarray, "
"the data type of 'input' must be bool, float32, int32 or int64, but "
"the data type of 'input' must be bool, float16, float32, int32 or int64, but "
"received %s." % convert_dtype(dtype)
)
if input.size > 1024 * 1024:
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,7 @@ def det(x, name=None):
if in_dygraph_mode():
return _C_ops.det(x)
else:
check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'det')
check_dtype(x.dtype, 'Input', ['float16', 'float32', 'float64'], 'det')

input_shape = list(x.shape)
assert len(input_shape) >= 2, (
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None):
Find values and indices of the k-th smallest at the axis.

Args:
x(Tensor): A N-D Tensor with type float32, float64, int32, int64.
x(Tensor): A N-D Tensor with type float16, float32, float64, int32, int64.
k(int): The k for the k-th smallest number to look for along the axis.
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
Expand Down