diff --git a/python/tvm/tl/engine.py b/python/tvm/tl/engine.py index 8297f44c92be..d5faea475fb8 100644 --- a/python/tvm/tl/engine.py +++ b/python/tvm/tl/engine.py @@ -108,7 +108,6 @@ def lower(func, target="cuda", runtime_only=False): mod = tir.transform.VerifyMemory()(mod) mod = tir.transform.AnnotateEntryFunc()(mod) - mod = tir.transform.ThreadSync("shared")(mod) # TODO(lei): This is a hack to make sure the # thread level allreduce pass can be applied # in TL. As Tl ony use one thread dimension @@ -119,13 +118,15 @@ def lower(func, target="cuda", runtime_only=False): # of putting the LowerThreadAllreduce before # the Legalization. mod = tir.transform.LowerThreadAllreduce()(mod) - mod = tir.transform.ThreadSync("shared.dyn")(mod) mod = tl.transform.LowerHopperIntrin()(mod) mod = tir.transform.InjectPTXAsyncCopy()(mod) mod = tir.transform.AnnotateDeviceRegions()(mod) mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.MergeSharedMemoryAllocations()(mod) + mod = tir.transform.ThreadSync("shared")(mod) + mod = tir.transform.ThreadSync("shared.dyn")(mod) + mod = tir.transform.MakePackedAPI()(mod) mod = tir.transform.LowerDeviceKernelLaunch()(mod) host_mod = tir.transform.Filter(is_host_call)(mod)