-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP] add fp16&bf16 support for flatten op #52035
[AMP] add fp16&bf16 support for flatten op #52035
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
self.dtype = "float64" | ||
|
||
def init_input_data(self): | ||
x = np.random.random(self.in_shape).astype("float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里所有数据类型都会被先初始化为float32的,应该改成self.dtype,对uint16单独处理即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -31,7 +32,8 @@ def setUp(self): | |||
self.stop_axis = -1 | |||
self.skip_cinn() | |||
self.init_test_case() | |||
self.inputs = {"X": np.random.random(self.in_shape).astype("float64")} | |||
self.init_test_dtype() | |||
self.init_input_data() | |||
self.init_attrs() | |||
self.outputs = { | |||
"Out": self.inputs["X"].reshape(self.new_shape), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output也需要对uint16做特殊处理,convert_float_to_uint16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.inputs["X"]已经在生成的时候转为了uint16,此处无需再转换
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe
为flatten算子单测添加
float16 & bfloat16
类型的测试.