diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 1e1ef410b4e1..3543f798c36e 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -14,11 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import numpy as np + import tvm import tvm.testing from tvm.script import tir as T +import pytest +import numpy as np + def count_cp_async(stmt): num_alloc = [0] @@ -351,36 +354,54 @@ def test_inject_async_copy_shared_dyn(): """ -generated_code = "" -support_async = True +@pytest.fixture +def postproc_if_missing_async_support(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + support_async = major >= 8 + + func_name = "tvm_callback_cuda_postproc" + prev_postproc = tvm.get_global_func(func_name, allow_missing=True) + + # Store the generated code prior to the post-processing. This + # way, even though the generated code doesn't compile on platforms + # that do not support async, the comparison against an expected + # output can still be performed. We cannot use + # `mod.get_source()`, as that contains the source after all + # post-processing. + original_code = None + + def get_original_code(): + nonlocal original_code + return original_code + + @tvm.register_func(func_name, override=True) + def tvm_callback_cuda_postproc(code, _): + nonlocal original_code + original_code = code + if support_async: + return code + else: + ret = [] + for line in code.split("\n"): + ret.append(line) + ret.append("\n") + if line.startswith('extern "C" __global__') and line.endswith("{"): + break + ret.append("}") + return "".join(ret) + yield get_original_code -@tvm.register_func -def tvm_callback_cuda_postproc(code, _): - global generated_code - global support_async - generated_code = code - # return a dummy code so that device < sm80 could build correctly - if not support_async: - ret = "" - for line in code.split("\n"): - ret += line + "\n" - if line.startswith('extern "C" __global__'): - break - ret += "}" - return ret - return code + # Restore previous postproc func to avoid impacting other tests + if prev_postproc is None: + tvm._ffi.registry.remove_global_func(func_name) + else: + tvm.register_func(func_name, prev_postproc, override=True) @tvm.testing.requires_cuda -def test_cp_async_in_if_then_else(): - global support_async - arch = tvm.contrib.nvcc.get_target_compute_version() - major, _ = tvm.contrib.nvcc.parse_compute_version(arch) - if major < 8: - # At least sm80 is required - support_async = False - +def test_cp_async_in_if_then_else(postproc_if_missing_async_support): @T.prim_func def simple_compute( A: T.Buffer((16, 14), "float32"), @@ -422,22 +443,12 @@ def simple_compute( mod = tvm.IRModule.from_expr(simple_compute) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): tvm.build(mod, target="cuda") + generated_code = postproc_if_missing_async_support() assert generated_code == expected_cuda_script - if not support_async: - # avoid return dummy code to other tests - support_async = True - @tvm.testing.requires_cuda -def test_vectorize_cp_async_in_if_then_else(): - global support_async - arch = tvm.contrib.nvcc.get_target_compute_version() - major, _ = tvm.contrib.nvcc.parse_compute_version(arch) - if major < 8: - # At least sm80 is required - support_async = False - +def test_vectorize_cp_async_in_if_then_else(postproc_if_missing_async_support): @T.prim_func def complex_compute( A: T.Buffer((2, 16, 16, 1280), "float16"), @@ -887,16 +898,10 @@ def complex_compute( mod = tvm.IRModule.from_expr(complex_compute) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): tvm.build(mod, target="cuda") + generated_code = postproc_if_missing_async_support() # generated_code must contain " setp.ne.b32 p, %0, 0;" assert "setp.ne.b32" in generated_code - if not support_async: - # avoid return dummy code to other tests - support_async = True - if __name__ == "__main__": - test_inject_async_copy() - test_inject_async_copy_shared_dyn() - test_cp_async_in_if_then_else() - test_vectorize_cp_async_in_if_then_else() + tvm.testing.main()