Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API improvement for paddle.median 易用性提升 #62407

Merged
merged 7 commits into from
Mar 13, 2024

Conversation

NKNaN
Copy link
Contributor

@NKNaN NKNaN commented Mar 5, 2024

PR types

New features

PR changes

APIs

Description

  • 增加参数 mode :默认值 'avg' ,另外可取 'min' 。当所需要计算的 tensor 在 axis 轴上有偶数个元素时, 'avg' 表示计算结果为中间两个数的算术平均值;'min' 则为二者的最小值。
  • 返回值:当 mode = 'min'axis 不为 None 时,返回值为 (median_values, median_indices) ;其他情况下返回值为 median_values 的 tensor 。

Copy link

paddle-bot bot commented Mar 5, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Mar 5, 2024
@zhwesky2010 zhwesky2010 changed the title API improvement for paddle.median API improvement for paddle.median 易用性提升 Mar 5, 2024
@@ -367,11 +367,18 @@ def median(x, axis=None, keepdim=False, name=None):
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
mode (str, optional): Whether to use mean or min operation to calculate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认值要写下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Returns:
Tensor, results of median along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32.
((Tensor, Tensor), optional), results of median along ``axis`` of ``x``. If ``mode`` is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里语句写的有些绕了,不直白,分情况列举吧:

mode='avg',返回一个Tensor
mode='min',指定axis,返回一个Tensor,不指定axis,返回两个Tensor

数据类型再单独列举

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

),
dtype=dtype,
)
if inp_axis is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要cast多次?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除对 out_idx 的 cast ,仅对 out_tensor 做 cast

else:
out_tensor = paddle.cast(
paddle.slice(
tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1]
),
dtype=dtype,
)
if inp_axis is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么需要cast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除对 out_idx 的 cast

@@ -423,6 +442,9 @@ def median(x, axis=None, keepdim=False, name=None):
], 'when input 0-D, axis can only be [-1, 0] or default None'
is_flatten = True

if mode not in ('avg', 'min'):
raise ValueError(f"Mode {mode} is not supported. Must be avg or min.")
inp_axis = axis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要这个变量

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我想在后面处理的时候用 axis 的原始输入是否为 None 做一个判断,但是 axis 为 None 的话下面那一步会把 axis 的值改为 0, 之后就没办法用 axis 来判断了

if axis is None:
    is_flatten = True

if is_flatten:
    x = paddle.flatten(x)
    axis = 0

tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1]
)
out_tensor = paddle.cast(out_tensor, dtype=dtype) / 2
else: # mode == 'min'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个cast其实也不是特别合理,目前就mode == 'avg'维持现状吧,新增的mode='min'就不cast了,torch应该也是这样吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我把这里的 cast 去掉

@@ -22,7 +23,84 @@
DELTA = 1e-6


class TestMedian(unittest.TestCase):
def np_medain_min(data, keepdims=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用 np.median 能直接对齐吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.median 本身是取平均的,所以在 TestMedianAvg 这个类里面就用 np.median 对齐, 在 TestMedianMin 里面用自己写的 np_medain_min

self.assertRaises(ValueError, paddle.median, paddle.to_tensor([]))


class TestMedianMin(unittest.TestCase):
def check_numpy_res(self, np1, np2):
Copy link
Contributor

@zhwesky2010 zhwesky2010 Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个就不用再封函数了,目前的标准测试方式是:np.testing.assert_allclosenp.testing.assert_equal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

zhwesky2010
zhwesky2010 previously approved these changes Mar 8, 2024
Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit 5465f40 into PaddlePaddle:develop Mar 13, 2024
30 checks passed
@NKNaN NKNaN deleted the api/median branch March 20, 2024 02:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants