-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP OP&Test] arange op support fp16/bf16 #51106
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
|
self.case = (0, 5, 1) | ||
|
||
def test_check_output(self): | ||
self.check_output(atol=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方1e-3现在可以去掉了,内部已修改了默认值
|
||
namespace phi { | ||
|
||
template <typename T> | ||
__global__ void Range(T start, T step, int64_t size, T* out) { | ||
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } | ||
using MPType = typename phi::dtype::MPTypeTrait<T>::Type; | ||
CUDA_KERNEL_LOOP(index, size) { out[index] = static_cast<T>(static_cast<MPType>(start) + static_cast<MPType>(step) * index); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start和step是在计算中是固定的值,可以在计算前做static_cast成一个临时变量
@@ -39,7 +45,8 @@ void ArangeKernel(const Context& dev_ctx, | |||
T step_value = GetValue<T, Context>(dev_ctx, step); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议直接在这里拿到start、end、step的时候直接转成MPType,后续都使用MPType的值去计算
@@ -1,4 +1,4 @@ | |||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里其实不用修改,但是如果改了就把2020改成2023吧
or not core.is_bfloat16_supported(core.CUDAPlace(0)), | ||
"core is not compiled with CUDA and not support the bfloat16", | ||
) | ||
def test_check_output(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改输入输出这部分不建议写在test_check_output中,建议直接继承OpTest重写setUp,其他参考TestArangeOp写就行
self.outputs = { | ||
'Out': convert_float_to_uint16(np.arange(self.case[0], self.case[1], self.case[2])) | ||
} | ||
self.check_output(atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个后面也会改成默认值,这里可以不写
@xiegegege Performance of model ResNet50_bs32_dygraph has been increased from 248.402 to 480.483,(480.483-248.402)/248.402 equals 0.9343,which is greater than threshold 0.06 |
python/paddle/tensor/creation.py
Outdated
@@ -1234,7 +1234,7 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): | |||
check_dtype( | |||
dtype, | |||
'dtype', | |||
['float32', 'float64', 'int32', 'int64'], | |||
['float32', 'float64', 'int32', 'int64', 'float16', 'bfloat16'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂不支持bfloat16
,需要改成uint16
class TestBFloat16ArangeOp(OpTest): | ||
def setUp(self): | ||
self.op_type = "range" | ||
self.__class__.op_type = self.op_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这句话可以删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for @unittest.skip
PR types
Others
PR changes
OPs
Describe
增加了单测