-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP OP&Test]add fp16 and bf16 OpTest for index_select #51159
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
self.index_size = 100 | ||
|
||
|
||
class TestIndexSelectBF16(OpTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把类名按照规范改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把类名按照规范改一下
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,还有些小问题,下个PR修改下
@@ -58,5 +58,6 @@ PD_REGISTER_KERNEL(index_select, | |||
phi::IndexSelectKernel, | |||
float, | |||
double, | |||
phi::dtype::bfloat16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpu不要求修改,不过你这里加了前向的,下个PR需要再补下反向的吧
@@ -342,7 +342,7 @@ def index_select(x, index, axis=0, name=None): | |||
check_variable_and_dtype( | |||
x, | |||
'x', | |||
['float16', 'float32', 'float64', 'int32', 'int64'], | |||
['bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要改成uint16,下个PR改下吧
PR types
Others
PR changes
Others
Describe
add bf16 support of index_select
add fp16 and bf16 test of index_select