Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim] reduce_as op support uint8, in8, complex64 and complex128 #63782

Merged
merged 9 commits into from
May 10, 2024
8 changes: 5 additions & 3 deletions paddle/phi/kernels/cpu/reduce_as_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/reduce_as_kernel.h"

#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/impl/reduce_grad.h"

namespace phi {
Expand Down Expand Up @@ -55,6 +55,8 @@ PD_REGISTER_KERNEL(reduce_as_grad,
int,
int64_t,
uint8_t,
int8_t) {
int8_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
8 changes: 5 additions & 3 deletions paddle/phi/kernels/cpu/reduce_as_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
// limitations under the License.

#include "paddle/phi/kernels/reduce_as_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"

#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"

namespace phi {

Expand Down Expand Up @@ -48,4 +48,6 @@ PD_REGISTER_KERNEL(reduce_as,
int,
int64_t,
uint8_t,
int8_t) {}
int8_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
7 changes: 4 additions & 3 deletions paddle/phi/kernels/gpu/reduce_as_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
// limitations under the License.

#include "paddle/phi/kernels/reduce_as_grad_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/reduce_grad.h"

Expand Down Expand Up @@ -65,6 +64,8 @@ PD_REGISTER_KERNEL(reduce_as_grad,
int,
int64_t,
uint8_t,
int8_t) {
int8_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
7 changes: 4 additions & 3 deletions paddle/phi/kernels/gpu/reduce_as_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
// limitations under the License.

#include "paddle/phi/kernels/reduce_as_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {
Expand Down Expand Up @@ -47,4 +46,6 @@ PD_REGISTER_KERNEL(reduce_as,
int,
int64_t,
uint8_t,
int8_t) {}
int8_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
16 changes: 12 additions & 4 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,8 +1583,8 @@ def reduce_as(x, target, name=None):
Computes the sum of tensor elements make the shape of its result equal to the shape of target.

Args:
x (Tensor): An N-D Tensor, the data type is bool, float16, float32, float64, int32 or int64.
target (Tensor): An N-D Tensor, the length of x shape must greater than or equal to the length of target shape. The data type is bool, float16, float32, float64, int32 or int64.
x (Tensor): An N-D Tensor, the data type is bool, float16, float32, float64, int8, uint8, int16, uint16, int32, int64, complex64 or complex128.
target (Tensor): An N-D Tensor, the length of x shape must greater than or equal to the length of target shape. The data type is bool, float16, float32, float64, int8, uint8, int16, uint16, int32, int64, complex64 or complex128.

Returns:
Tensor: The sum of the input tensor x along some axis has the same shape as the shape of the input tensor target, if `x.dtype='bool'`, `x.dtype='int32'`, it's data type is `'int64'`, otherwise it's data type is the same as `x`.
Expand Down Expand Up @@ -1617,13 +1617,17 @@ def reduce_as(x, target, name=None):
'x',
[
'bool',
'uint16',
'float16',
'float32',
'float64',
'int8',
'uint8',
'int16',
'uint16',
'int32',
'int64',
'complex64',
'complex128',
],
'reduce_as',
)
Expand All @@ -1632,13 +1636,17 @@ def reduce_as(x, target, name=None):
'target',
[
'bool',
'uint16',
'float16',
'float32',
'float64',
'int8',
'uint8',
'int16',
'uint16',
'int32',
'int64',
'complex64',
'complex128',
],
'reduce_as',
)
Expand Down
55 changes: 40 additions & 15 deletions test/deprecated/legacy_test/test_reduce_as_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,20 @@ def apply_to_static(net, use_cinn, input_spec=None):
)


class TestSumAsOp(OpTest):
class TestReduceAsOp(OpTest):
def setUp(self):
self.init_dtype()
self.init_shape()
self.init_input()
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.x = np.random.random(self.shape_x) + 1j * np.random.random(
self.shape_y
)
self.y = np.random.random(self.shape_x) + 1j * np.random.random(
self.shape_y
)
else:
self.x = np.random.random(self.shape_x).astype(self.dtype)
self.y = np.random.random(self.shape_y).astype(self.dtype)
self.init_attrs()
self.calc_output()

Expand All @@ -60,10 +69,6 @@ def init_shape(self):
self.shape_x = [10, 10, 6]
self.shape_y = [10, 6]

def init_input(self):
self.x = np.random.random(self.shape_x).astype(self.dtype)
self.y = np.random.random(self.shape_y).astype(self.dtype)

def init_attrs(self):
self.attrs = {'dim': [0]}

Expand All @@ -84,42 +89,62 @@ def test_check_grad(self):
)


class TestSumAsOp2(TestSumAsOp):
class TestReduceAsOp2(TestReduceAsOp):
def init_type(self):
self.dtype = 'float32'


class TestSumAsOp3(TestSumAsOp):
class TestReduceAsOp3(TestReduceAsOp):
def init_type(self):
self.dtype = 'float16'


class TestSumAsOp4(TestSumAsOp):
class TestReduceAsOp4(TestReduceAsOp):
def init_type(self):
self.dtype = 'uint16'


class TestSumAsOp5(TestSumAsOp):
class TestReduceAsOp5(TestReduceAsOp):
def init_type(self):
self.dtype = 'int16'


class TestSumAsOp6(TestSumAsOp):
class TestReduceAsOp6(TestReduceAsOp):
def init_type(self):
self.dtype = 'int64'


class TestSumAsOp7(TestSumAsOp):
class TestReduceAsOp7(TestReduceAsOp):
def init_type(self):
self.dtype = 'bool'


class TestSumAsOp8(TestSumAsOp):
class TestReduceAsOp8(TestReduceAsOp):
def init_type(self):
self.dtype = 'int32'


class TestSumAsOp9(TestSumAsOp):
class TestReduceAsOp9(TestReduceAsOp):
def init_type(self):
self.dtype = 'int8'


class TestReduceAsOp10(TestReduceAsOp):
def init_type(self):
self.dtype = 'uint8'


class TestReduceAs_Complex64(TestReduceAsOp):
def init_type(self):
self.dtype = np.complex64


class TestReduceAs_Complex128(TestReduceAsOp):
def init_type(self):
self.dtype = np.complex128


class TestReduceAsOp13(TestReduceAsOp):
def init_shape(self):
self.shape_x = [10, 10, 6]
self.shape_y = [6]
Expand All @@ -128,7 +153,7 @@ def init_attrs(self):
self.attrs = {'dim': [0, 1]}


class TestSumAsDynamicShape(unittest.TestCase):
class TestReduceAsDynamicShape(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.shape_x = [300, 20, 100]
Expand Down