-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API improvement for paddle.median 易用性提升 #62407
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -367,11 +367,18 @@ def median(x, axis=None, keepdim=False, name=None): | |||
the output Tensor is the same as ``x`` except in the reduced | |||
dimensions(it is of size 1 in this case). Otherwise, the shape of | |||
the output Tensor is squeezed in ``axis`` . Default is False. | |||
mode (str, optional): Whether to use mean or min operation to calculate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认值要写下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
python/paddle/tensor/stat.py
Outdated
name (str, optional): Name for the operation (optional, default is None). | ||
For more information, please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
Tensor, results of median along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32. | ||
((Tensor, Tensor), optional), results of median along ``axis`` of ``x``. If ``mode`` is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里语句写的有些绕了,不直白,分情况列举吧:
mode='avg',返回一个Tensor
mode='min',指定axis,返回一个Tensor,不指定axis,返回两个Tensor
数据类型再单独列举
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
python/paddle/tensor/stat.py
Outdated
), | ||
dtype=dtype, | ||
) | ||
if inp_axis is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么需要cast多次?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除对 out_idx 的 cast ,仅对 out_tensor 做 cast
python/paddle/tensor/stat.py
Outdated
else: | ||
out_tensor = paddle.cast( | ||
paddle.slice( | ||
tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1] | ||
), | ||
dtype=dtype, | ||
) | ||
if inp_axis is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么需要cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除对 out_idx 的 cast
python/paddle/tensor/stat.py
Outdated
@@ -423,6 +442,9 @@ def median(x, axis=None, keepdim=False, name=None): | |||
], 'when input 0-D, axis can only be [-1, 0] or default None' | |||
is_flatten = True | |||
|
|||
if mode not in ('avg', 'min'): | |||
raise ValueError(f"Mode {mode} is not supported. Must be avg or min.") | |||
inp_axis = axis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么需要这个变量
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我想在后面处理的时候用 axis
的原始输入是否为 None 做一个判断,但是 axis
为 None 的话下面那一步会把 axis
的值改为 0, 之后就没办法用 axis
来判断了
if axis is None:
is_flatten = True
if is_flatten:
x = paddle.flatten(x)
axis = 0
python/paddle/tensor/stat.py
Outdated
tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1] | ||
) | ||
out_tensor = paddle.cast(out_tensor, dtype=dtype) / 2 | ||
else: # mode == 'min' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个cast其实也不是特别合理,目前就mode == 'avg'维持现状吧,新增的mode='min'就不cast了,torch应该也是这样吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我把这里的 cast 去掉
@@ -22,7 +23,84 @@ | |||
DELTA = 1e-6 | |||
|
|||
|
|||
class TestMedian(unittest.TestCase): | |||
def np_medain_min(data, keepdims=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用 np.median
能直接对齐吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.median
本身是取平均的,所以在 TestMedianAvg
这个类里面就用 np.median
对齐, 在 TestMedianMin
里面用自己写的 np_medain_min
test/legacy_test/test_median.py
Outdated
self.assertRaises(ValueError, paddle.median, paddle.to_tensor([])) | ||
|
||
|
||
class TestMedianMin(unittest.TestCase): | ||
def check_numpy_res(self, np1, np2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个就不用再封函数了,目前的标准测试方式是:np.testing.assert_allclose
、np.testing.assert_equal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Description
mode
:默认值 'avg' ,另外可取 'min' 。当所需要计算的 tensor 在axis
轴上有偶数个元素时, 'avg' 表示计算结果为中间两个数的算术平均值;'min' 则为二者的最小值。mode = 'min'
且axis
不为 None 时,返回值为 (median_values, median_indices) ;其他情况下返回值为 median_values 的 tensor 。