From e243dd0c9b305fa06e7ae6aa3f95586aa0bed962 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Sat, 26 Oct 2024 02:05:20 +0000 Subject: [PATCH] move kv transfer to nvshmem folder --- CMakeLists.txt | 2 +- src/runtime/{relax_vm => contrib/nvshmem}/kv_transfer.cu | 2 +- .../relax/test_runtime_builtin_kv_cache_transfer.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) rename src/runtime/{relax_vm => contrib/nvshmem}/kv_transfer.cu (99%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 16ec66b93a1f..563823451550 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -480,7 +480,7 @@ if (USE_CUDA AND USE_NVSHMEM) endif() set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) - tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc src/runtime/relax_vm/*.cu) + tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc src/runtime/contrib/nvshmem/*.cu) list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS}) endif() diff --git a/src/runtime/relax_vm/kv_transfer.cu b/src/runtime/contrib/nvshmem/kv_transfer.cu similarity index 99% rename from src/runtime/relax_vm/kv_transfer.cu rename to src/runtime/contrib/nvshmem/kv_transfer.cu index 218191bbe322..f9902ab7bdb1 100644 --- a/src/runtime/relax_vm/kv_transfer.cu +++ b/src/runtime/contrib/nvshmem/kv_transfer.cu @@ -169,4 +169,4 @@ int _KVTransfer(DLTensor* pages, DLTensor* k, DLTensor* v, DLTensor* remote_posi return 0; } -TVM_REGISTER_GLOBAL("vm.builtin.KVTransfer").set_body_typed(_KVTransfer); +TVM_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer); diff --git a/tests/python/relax/test_runtime_builtin_kv_cache_transfer.py b/tests/python/relax/test_runtime_builtin_kv_cache_transfer.py index 7496ac45ef9c..b6348d631737 100644 --- a/tests/python/relax/test_runtime_builtin_kv_cache_transfer.py +++ b/tests/python/relax/test_runtime_builtin_kv_cache_transfer.py @@ -64,7 +64,7 @@ def test_kv_transfer_without_disco(): v = tvm.nd.array(v_np, dev) remote_position_map_np = np.array(position_map_array, dtype=np.int32) remote_position_map = tvm.nd.array(remote_position_map_np, dev) - transfer_func = tvm.get_global_func("vm.builtin.KVTransfer") + transfer_func = tvm.get_global_func("nvshmem.KVTransfer") transfer_func(pages, k, v, remote_position_map, num_pages, num_layers, num_kv_heads, 0, 1) dev.sync() comm.Barrier() @@ -114,7 +114,7 @@ def test_kv_transfer_with_disco(): remote_position_map = sess.empty((len(position_map_array),), "int32") remote_position_map.debug_copy_from(0, remote_position_map_np) remote_position_map.debug_copy_from(1, remote_position_map_np) - transfer_func = sess.get_global_func("vm.builtin.KVTransfer") + transfer_func = sess.get_global_func("nvshmem.KVTransfer") transfer_func(pages, k, v, remote_position_map, num_pages, num_layers, num_kv_heads, 0, 2) for i in range(2): sess._sync_worker(i) @@ -150,5 +150,5 @@ def test_kv_transfer_with_disco(): if __name__ == "__main__": # FIXME: only one test can be run at a time - test_kv_transfer_without_disco() - # test_kv_transfer_with_disco() + # test_kv_transfer_without_disco() + test_kv_transfer_with_disco()