Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii committed Oct 26, 2024
1 parent 49eeb0e commit ed1dd01
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
3 changes: 1 addition & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/disco/*.cc
src/runtime/minrpc/*.cc
src/runtime/relax_vm/*.cc
src/runtime/relax_vm/*.cu
)
set(TVM_RUNTIME_EXT_OBJS "")

Expand Down Expand Up @@ -481,7 +480,7 @@ if (USE_CUDA AND USE_NVSHMEM)
endif()
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc)
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc src/runtime/relax_vm/*.cu)
list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
endif()

Expand Down
11 changes: 6 additions & 5 deletions src/runtime/contrib/nvshmem/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,22 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers, int worker_id_start) {
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
uid.internal[i] = static_cast<char>(uid_64[i + 1]);
}
//FIXME: this is a hack to avoid the issue of NVSHMEM using Multi-process-per-GPU to initialize
// FIXME: this is a hack to avoid the issue of NVSHMEM using Multi-process-per-GPU to initialize
cudaSetDevice(worker_id);
nvshmemx_set_attr_uniqueid_args(worker_id, num_workers, &uid, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
CUDA_CALL(cudaSetDevice(mype_node));
if(worker!=nullptr){
if(worker->default_device.device_type == DLDeviceType::kDLCPU){
if (worker != nullptr) {
if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
worker->default_device = Device{DLDeviceType::kDLCUDA, mype_node};
} else {
ICHECK(worker->default_device.device_type == DLDeviceType::kDLCUDA &&
worker->default_device.device_id == mype_node)
<< "The default device of the worker is inconsistent with the device used for NVSHMEM. "
<< "The default device is " << worker->default_device << ", but the device used for NVSHMEM is "
<< Device{DLDeviceType::kDLCUDA, mype_node} << ".";
<< "The default device is " << worker->default_device
<< ", but the device used for NVSHMEM is " << Device{DLDeviceType::kDLCUDA, mype_node}
<< ".";
}
}
LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,6 @@ def test_kv_transfer_with_disco():
sess._sync_worker(i)

if __name__ == "__main__":
# FIXME: only one test can be run at a time
test_kv_transfer_without_disco()
# test_kv_transfer_with_disco()

0 comments on commit ed1dd01

Please sign in to comment.