Skip to content

Commit

Permalink
move kv transfer to nvshmem folder
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii committed Oct 26, 2024
1 parent ed1dd01 commit e243dd0
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ if (USE_CUDA AND USE_NVSHMEM)
endif()
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc src/runtime/relax_vm/*.cu)
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc src/runtime/contrib/nvshmem/*.cu)
list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,4 @@ int _KVTransfer(DLTensor* pages, DLTensor* k, DLTensor* v, DLTensor* remote_posi
return 0;
}

TVM_REGISTER_GLOBAL("vm.builtin.KVTransfer").set_body_typed(_KVTransfer);
TVM_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer);
8 changes: 4 additions & 4 deletions tests/python/relax/test_runtime_builtin_kv_cache_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_kv_transfer_without_disco():
v = tvm.nd.array(v_np, dev)
remote_position_map_np = np.array(position_map_array, dtype=np.int32)
remote_position_map = tvm.nd.array(remote_position_map_np, dev)
transfer_func = tvm.get_global_func("vm.builtin.KVTransfer")
transfer_func = tvm.get_global_func("nvshmem.KVTransfer")
transfer_func(pages, k, v, remote_position_map, num_pages, num_layers, num_kv_heads, 0, 1)
dev.sync()
comm.Barrier()
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_kv_transfer_with_disco():
remote_position_map = sess.empty((len(position_map_array),), "int32")
remote_position_map.debug_copy_from(0, remote_position_map_np)
remote_position_map.debug_copy_from(1, remote_position_map_np)
transfer_func = sess.get_global_func("vm.builtin.KVTransfer")
transfer_func = sess.get_global_func("nvshmem.KVTransfer")
transfer_func(pages, k, v, remote_position_map, num_pages, num_layers, num_kv_heads, 0, 2)
for i in range(2):
sess._sync_worker(i)
Expand Down Expand Up @@ -150,5 +150,5 @@ def test_kv_transfer_with_disco():

if __name__ == "__main__":
# FIXME: only one test can be run at a time
test_kv_transfer_without_disco()
# test_kv_transfer_with_disco()
# test_kv_transfer_without_disco()
test_kv_transfer_with_disco()

0 comments on commit e243dd0

Please sign in to comment.