From 9f042ef77be382f4a4e5d237f4089be779b0e457 Mon Sep 17 00:00:00 2001 From: houj04 Date: Fri, 17 May 2024 17:18:06 +0800 Subject: [PATCH] [XPU] fix device version in unittests --- test/xpu/get_test_cover_info.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/test/xpu/get_test_cover_info.py b/test/xpu/get_test_cover_info.py index c6f3756a69456..2f32c37597580 100644 --- a/test/xpu/get_test_cover_info.py +++ b/test/xpu/get_test_cover_info.py @@ -106,6 +106,16 @@ def create_classes(self): return base_class, classes +def get_version_str(xpu_version): + if xpu_version == core.XPUVersion.XPU1: + return "xpu1" + if xpu_version == core.XPUVersion.XPU2: + return "xpu2" + if xpu_version == core.XPUVersion.XPU3: + return "xpu3" + raise ValueError("unknown xpu version, not 1, 2, or 3") + + def get_op_white_list(): op_white_list = xpu_test_op_white_list if os.getenv('XPU_TEST_OP_WHITE_LIST') is not None: @@ -117,19 +127,25 @@ def get_op_white_list(): def get_type_white_list(): xpu_version = core.get_xpu_device_version(0) - version_str = "xpu2" if xpu_version == core.XPUVersion.XPU2 else "xpu1" + version_str = get_version_str(xpu_version) xpu1_type_white_list = [] xpu2_type_white_list = [] + xpu3_type_white_list = [] for device_type in xpu_test_device_type_white_list: device, t_type = device_type.split("_") if "xpu1" == device: xpu1_type_white_list.append(t_type) - else: + elif "xpu2" == device: xpu2_type_white_list.append(t_type) + elif "xpu3" == device: + xpu3_type_white_list.append(t_type) + if version_str == "xpu1": + type_white_list = xpu1_type_white_list + elif version_str == "xpu2": + type_white_list = xpu2_type_white_list + elif version_str == "xpu3": + type_white_list = xpu3_type_white_list - type_white_list = ( - xpu1_type_white_list if version_str == "xpu1" else xpu2_type_white_list - ) if os.getenv('XPU_TEST_TYPE_WHITE_LIST') is not None: type_white_list.extend( os.getenv('XPU_TEST_TYPE_WHITE_LIST').strip().split(',') @@ -167,7 +183,7 @@ def get_device_op_type_white_list(): def make_xpu_op_list(xpu_version): ops = [] raw_op_list = core.get_xpu_device_op_list(xpu_version) - version_str = "xpu2" if xpu_version == core.XPUVersion.XPU2 else "xpu1" + version_str = get_version_str(xpu_version) op_white_list = get_op_white_list() type_white_list = get_type_white_list() op_type_white_list = get_op_type_white_list() @@ -310,7 +326,7 @@ def create_test_class( def get_test_cover_info(): xpu_version = core.get_xpu_device_version(0) - version_str = "xpu2" if xpu_version == core.XPUVersion.XPU2 else "xpu1" + version_str = get_version_str(xpu_version) xpu_op_list = make_xpu_op_list(xpu_version) xpu_op_covered = []