Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API improvement for paddle.median 易用性提升 #62407

Merged
merged 7 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 66 additions & 13 deletions python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None):
return out


def median(x, axis=None, keepdim=False, name=None):
def median(x, axis=None, keepdim=False, mode='avg', name=None):
"""
Compute the median along the specified axis.

Expand All @@ -367,11 +367,23 @@ def median(x, axis=None, keepdim=False, name=None):
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
mode (str, optional): Whether to use mean or min operation to calculate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认值要写下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

the median values when the input tensor has an even number of elements
in the dimension ``axis``. Support 'avg' and 'min'. Default is 'avg'.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Returns:
Tensor, results of median along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32.
((Tensor, Tensor), optional)
sunzhongkai588 marked this conversation as resolved.
Show resolved Hide resolved
If ``mode`` == 'avg', the result will be the tensor of median values;
If ``mode`` == 'min' and ``axis`` is None, the result will be the tensor of median values;
If ``mode`` == 'min' and ``axis`` is not None, the result will be a tuple of two tensors
containing median values and their indices.

When ``mode`` == 'avg', if data type of ``x`` is float64, data type of median values will be float64,
otherwise data type of median values will be float32.
When ``mode`` == 'min', the data type of median values will be the same as ``x``. The data type of
indices will be int64.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -405,6 +417,18 @@ def median(x, axis=None, keepdim=False, name=None):
Tensor(shape=[1, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[4., 5., 6., 7.]])

>>> y5 = paddle.median(x, mode='min')
>>> print(y5)
Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True,
5)

>>> median_value, median_indices = paddle.median(x, axis=1, mode='min')
>>> print(median_value)
Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True,
[1, 5, 9])
>>> print(median_indices)
Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True,
[1, 1, 1])
"""
if not isinstance(x, (Variable, paddle.pir.Value)):
raise TypeError("In median, the input x should be a Tensor.")
Expand All @@ -423,6 +447,9 @@ def median(x, axis=None, keepdim=False, name=None):
], 'when input 0-D, axis can only be [-1, 0] or default None'
is_flatten = True

if mode not in ('avg', 'min'):
raise ValueError(f"Mode {mode} is not supported. Must be avg or min.")
need_idx = axis is not None
if axis is None:
is_flatten = True

Expand All @@ -445,18 +472,39 @@ def median(x, axis=None, keepdim=False, name=None):
in [core.VarDesc.VarType.FP64, paddle.base.core.DataType.FLOAT64]
else 'float32'
)
if sz & 1 == 0:
out_tensor = paddle.slice(
tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth]
) + paddle.slice(tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1])
out_tensor = paddle.cast(out_tensor, dtype=dtype) / 2
else:
out_tensor = paddle.cast(
paddle.slice(
if mode == 'avg':
if sz & 1 == 0:
out_tensor = paddle.slice(
tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth]
) + paddle.slice(
tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1]
),
dtype=dtype,
)
)
out_tensor = paddle.cast(out_tensor, dtype=dtype) / 2
else:
out_tensor = paddle.cast(
paddle.slice(
tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1]
),
dtype=dtype,
)
else: # mode == 'min'
if sz & 1 == 0:
out_tensor = paddle.slice(
tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth]
)
if need_idx:
out_idx = paddle.slice(
idx, axes=[axis], starts=[kth - 1], ends=[kth]
)
else:
out_tensor = paddle.slice(
tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1]
)
if need_idx:
out_idx = paddle.slice(
idx, axes=[axis], starts=[kth], ends=[kth + 1]
)

out_tensor = out_tensor + paddle.sum(
paddle.cast(paddle.isnan(x), dtype=dtype) * x, axis=axis, keepdim=True
)
Expand All @@ -468,6 +516,11 @@ def median(x, axis=None, keepdim=False, name=None):
else:
if not keepdim:
out_tensor = out_tensor.squeeze(axis)

if mode == 'min' and need_idx:
if not keepdim:
out_idx = out_idx.squeeze(axis)
return out_tensor, out_idx
return out_tensor


Expand Down
152 changes: 151 additions & 1 deletion test/legacy_test/test_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import unittest

import numpy as np
Expand All @@ -22,7 +23,84 @@
DELTA = 1e-6


class TestMedian(unittest.TestCase):
def np_medain_min(data, keepdims=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用 np.median 能直接对齐吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.median 本身是取平均的,所以在 TestMedianAvg 这个类里面就用 np.median 对齐, 在 TestMedianMin 里面用自己写的 np_medain_min

shape = data.shape
data_flat = data.flatten()
data_cnt = len(data_flat)

data_flat[np.isnan(data_flat)] = np.inf
data_sort = np.sort(data_flat)
data_sort[np.isinf(data_sort)] = np.nan

if data_cnt % 2:
is_odd = False
else:
is_odd = True

i = int(data_cnt / 2)
if is_odd:
np_res = min(data_sort[i - 1], data_sort[i])
else:
np_res = data_sort[i]
if keepdims:
new_shape = [1] * len(shape)
return np_res.reshape(new_shape)
else:
return np_res


def np_medain_min_axis(data, axis=None, keepdims=False):
data = copy.deepcopy(data)
if axis is None:
return np_medain_min(data, keepdims)

axis = axis + len(data.shape) if axis < 0 else axis
trans_shape = []
reshape = []
for i in range(len(data.shape)):
if i != axis:
trans_shape.append(i)
reshape.append(data.shape[i])
trans_shape.append(axis)
last_shape = data.shape[axis]
reshape.append(last_shape)

data_flat = np.transpose(data, trans_shape)

data_flat = np.reshape(data_flat, (-1, reshape[-1]))

data_cnt = np.full(
shape=data_flat.shape[:-1], fill_value=data_flat.shape[-1]
)

data_flat[np.isnan(data_flat)] = np.inf
data_sort = np.sort(data_flat, axis=-1)
data_sort[np.isinf(data_sort)] = np.nan

is_odd = data_cnt % 2

np_res = np.zeros(len(is_odd), dtype=data.dtype)

for j in range(len(is_odd)):
if data_cnt[j] == 0:
np_res[j] = np.nan
continue

i = int(data_cnt[j] / 2)
if is_odd[j]:
np_res[j] = data_sort[j, i]
else:
np_res[j] = min(data_sort[j, i - 1], data_sort[j, i])

if keepdims:
shape = list(data.shape)
shape[axis] = 1
return np.reshape(np_res, shape)
else:
return np.reshape(np_res, reshape[:-1])


class TestMedianAvg(unittest.TestCase):
def check_numpy_res(self, np1, np2):
self.assertEqual(np1.shape, np2.shape)
mismatch = np.sum((np1 - np2) * (np1 - np2))
Expand Down Expand Up @@ -83,8 +161,80 @@ def test_median_exception(self):
x = paddle.arange(12).reshape([3, 4])
self.assertRaises(ValueError, paddle.median, x, 1.0)
self.assertRaises(ValueError, paddle.median, x, 2)
self.assertRaises(ValueError, paddle.median, x, 2, False, 'max')
self.assertRaises(ValueError, paddle.median, paddle.to_tensor([]))


class TestMedianMin(unittest.TestCase):
def check_numpy_res(self, np1, np2):
Copy link
Contributor

@zhwesky2010 zhwesky2010 Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个就不用再封函数了,目前的标准测试方式是:np.testing.assert_allclosenp.testing.assert_equal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

self.assertEqual(np1.shape, np2.shape)
mismatch = np.sum((np1 - np2) * (np1 - np2))
self.assertAlmostEqual(mismatch, 0, DELTA)

def static_single_test_median(self, lis_test):
paddle.enable_static()
x, axis, keepdims = lis_test
res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
exe = paddle.static.Executor()
with paddle.static.program_guard(main_program, startup_program):
x_in = paddle.static.data(shape=x.shape, dtype=x.dtype, name='x')
y = paddle.median(x_in, axis, keepdims, mode='min')
[res_pd, _] = exe.run(feed={'x': x}, fetch_list=[y])
self.check_numpy_res(res_pd, res_np)
paddle.disable_static()

def dygraph_single_test_median(self, lis_test):
x, axis, keepdims = lis_test
res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims)
res_pd, _ = paddle.median(
paddle.to_tensor(x), axis, keepdims, mode='min'
)
self.check_numpy_res(res_pd.numpy(False), res_np)

@test_with_pir_api
def test_median_static(self):
h = 3
w = 4
l = 2
x = np.arange(h * w * l).reshape([h, w, l]).astype("float32")
lis_tests = [
[x, axis, keepdims]
for axis in [-1, 0, 1, 2]
for keepdims in [False, True]
]
for lis_test in lis_tests:
self.static_single_test_median(lis_test)

def test_median_dygraph(self):
paddle.disable_static()
h = 3
w = 4
l = 2
x = np.arange(h * w * l).reshape([h, w, l]).astype("float32")
lis_tests = [
[x, axis, keepdims]
for axis in [-1, 0, 1, 2]
for keepdims in [False, True]
]
for lis_test in lis_tests:
self.dygraph_single_test_median(lis_test)

def test_index_even_case(self):
paddle.disable_static()
x = paddle.arange(2 * 100).reshape((2, 100)).astype(paddle.float32)
out, index = paddle.median(x, axis=1, mode='min')
np.testing.assert_allclose(out.numpy(), [49.0, 149.0])
np.testing.assert_equal(index.numpy(), [49, 49])

def test_index_odd_case(self):
paddle.disable_static()
x = paddle.arange(30).reshape((3, 10)).astype(paddle.float32)
out, index = paddle.median(x, axis=1, mode='min')
np.testing.assert_allclose(out.numpy(), [4.0, 14.0, 24.0])
np.testing.assert_equal(index.numpy(), [4, 4, 4])


if __name__ == '__main__':
unittest.main()