-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OpTest] support prim test in OpTest #50509
[OpTest] support prim test in OpTest #50509
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… prim_test_frame
return kernel_sig | ||
|
||
def is_only_check_prim(self): | ||
return self.only_prim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why only test prim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用户在新增一些组合单测会触发Op测试,但是部分Op规定不可以测试fp32类型,导致单测挂掉。对于组合测试并没有这个要求,这个开关用来跳过非组合的测试。
) | ||
if not kernel_sig: | ||
return None | ||
assert hasattr(self, "python_api"), ( | ||
"Detect there is KernelSignature for `%s` op, please set the `self.python_api` if you set check_dygraph = True" | ||
% self.op_type | ||
) | ||
args = prepare_python_api_arguments( | ||
args = OpTestUtils.prepare_python_api_arguments( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add todo: those code change will recover after delete legacy dygraph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
inplace_atol=None, | ||
): | ||
core._set_prim_all_enabled(False) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_set_prim_forward_enabled enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在跑这个前向case中会有反向inplace的测试,这儿是修复当某个单测挂掉导致前反向开关没有关,走到组合inplace逻辑中导致段错误。
"rev_comp": {"rtol": 1e-2, "atol": 1e-2}, | ||
"cinn": {"rtol": 1e-1, "atol": 1e-1}, | ||
}, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否需要支持bfloat16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy没有bfloat16数据类型,这儿用unit16来表示bfloat16,目前python api都是这样做的
if check_prim: | ||
prim_checker = PrimForwardChecker(self, place) | ||
prim_checker.check() | ||
# Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用户会感知这里的dtype逻辑么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不会
@@ -66,6 +55,8 @@ def setUp(self): | |||
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) | |||
} | |||
self.gradient = self.calc_gradient() | |||
# error occurred in cinn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add todo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
未来我们会用代码扫一遍看看哪些case没开cinn,最后会集中处理。
@@ -1265,6 +1265,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): | |||
'x', | |||
[ | |||
'bool', | |||
'uint16', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么需要新增这个数据类型呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为这个在python端拿来表示bfloat16数据类型(numpy没有bfloat16,所以用这个来表示),静态图目前bfloat16数据类型流程没有完全测试过,这儿新增uint16来测试sum算子静态图下bfloat16数据类型。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
LGTM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments for now
|
||
import numpy as np | ||
|
||
TOLERANCE = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dose this config used only for op test?
@@ -401,6 +401,7 @@ def is_custom_device_op_test(): | |||
and not is_npu_op_test() | |||
and not is_mlu_op_test() | |||
and not is_custom_device_op_test() | |||
and not cls.check_prim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this?
|
||
sys.path.append(os.path.abspath(os.path.dirname(__file__))) | ||
from prim_op_test import OpTestUtils, PrimForwardChecker, PrimGradChecker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make OpTestUtils in an independent file instead of in prim_op_test, since it's not only for prim op test
program = Program() | ||
block = program.global_block() | ||
op = self._append_ops(block) | ||
with paddle.fluid.framework._dygraph_guard(None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use fluid api?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why dygraph_guard
, there are all static operations below
grad_program | ||
).with_data_parallel( | ||
loss_name="", build_strategy=build_strategy, places=place | ||
with paddle.fluid.framework._dygraph_guard(None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same...using program_guard
self.checker_name = "PrimForwardChecker" | ||
self.place = place | ||
self.op_test = op_test | ||
self.save_eager_or_static_status() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this?
def init_checker(self): | ||
assert hasattr( | ||
self.op_test, 'prim_op_type' | ||
), "if you want to test comp op, please set prim_op_type in setUp function." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more comments for prim_op_type
, what is it?
assert hasattr( | ||
self.op_test, 'dtype' | ||
), "Please set dtype in setUp function." | ||
self.op_type = self.op_test.op_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not use self.op_test.xxx directly
PR types
New features
PR changes
Others
Describe
任务背景:
组合算子+编译器协同任务需要对算子精度进行保障,为了复用框架原有的算子单测case,本PR在
OpTest
中进行改造,新增了组合测试功能对算子精度进行保障。PR改动:
1.对
OpTest
新增组合测试功能。2.修复
OpTest
框架如果某个单测挂了导致动态图静态图运行模式不正确的问题,对静态图代码利用guard进行保护。3.对
softmax、expand、reduce_sum
三个算子添加组合单测。