diff --git a/CMakeLists.txt b/CMakeLists.txt index 16831ba4eeb2..563823451550 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -478,7 +478,9 @@ if (USE_CUDA AND USE_NVSHMEM) if (NOT NVSHMEM_FOUND) message(FATAL_ERROR "Cannot find NVSHMEM, USE_NVSHMEM=" ${USE_NVSHMEM}) endif() - tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc) + set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc src/runtime/contrib/nvshmem/*.cu) list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS}) endif() diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 50fdde4c49d8..2733c595720a 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -38,9 +38,14 @@ ShapeTuple InitNVSHMEMUID() { return ShapeTuple(uid_64); } -void InitNVSHMEM(ShapeTuple uid_64, int num_workers) { - DiscoWorker* worker = DiscoWorker::ThreadLocal(); - ICHECK(worker != nullptr); +void InitNVSHMEM(ShapeTuple uid_64, int num_workers, int worker_id_start) { + DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker; + int worker_id; + if (worker == nullptr) { + worker_id = worker_id_start; + } else { + worker_id = worker_id_start + worker->worker_id; + } CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1) << "ValueError: The length of unique_id must be " << UNIQUEID_PADDING << ", but got " << uid_64.size() << "."; @@ -52,10 +57,24 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) { for (int i = 0; i < UNIQUEID_PADDING; ++i) { uid.internal[i] = static_cast(uid_64[i + 1]); } - nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr); + // FIXME: this is a hack to avoid the issue of NVSHMEM using Multi-process-per-GPU to initialize + cudaSetDevice(worker_id); + nvshmemx_set_attr_uniqueid_args(worker_id, num_workers, &uid, &attr); nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE); CUDA_CALL(cudaSetDevice(mype_node)); + if (worker != nullptr) { + if (worker->default_device.device_type == DLDeviceType::kDLCPU) { + worker->default_device = Device{DLDeviceType::kDLCUDA, mype_node}; + } else { + ICHECK(worker->default_device.device_type == DLDeviceType::kDLCUDA && + worker->default_device.device_id == mype_node) + << "The default device of the worker is inconsistent with the device used for NVSHMEM. " + << "The default device is " << worker->default_device + << ", but the device used for NVSHMEM is " << Device{DLDeviceType::kDLCUDA, mype_node} + << "."; + } + } LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " " << ", npes=" << nvshmem_n_pes(); } diff --git a/src/runtime/contrib/nvshmem/kv_transfer.cu b/src/runtime/contrib/nvshmem/kv_transfer.cu new file mode 100644 index 000000000000..f9902ab7bdb1 --- /dev/null +++ b/src/runtime/contrib/nvshmem/kv_transfer.cu @@ -0,0 +1,172 @@ +#include +#include +#include +#include +#include +#include + +template +__device__ int calc_flattened_index(int shape[dim], int index[dim]) { + int flattened_index = 0; +#pragma unroll + for (int i = 0; i < dim; i++) { + flattened_index *= shape[i]; + flattened_index += index[i]; + } + return flattened_index; +} + +template +__global__ void KVTransfer(T* pages, T* k_data, T* v_data, int32_t* remote_position_map, + int ntokens, int remote_layer_id, int local_tp_rank, + int remote_tp_group_pe_offset, int remote_num_pages) { + // launch grid: [num_blocks, 1, 1], [32, local_num_kv_head, 1] + // pages(remote): [remote_num_layers, remote_num_pages, 2, remote_num_kv_head, page_size, head_dim] + // k_data: [ntokens, local_num_kv_head, head_dim] + // v_data: [ntokens, local_num_kv_head, head_dim] + int remote_pe; + int remote_kv_head_index; + int h = threadIdx.y; // local kv head index + if (local_num_kv_head <= remote_num_kv_head) { + // gather + assert(remote_num_kv_head % local_num_kv_head == 0); + int gather_factor = remote_num_kv_head / local_num_kv_head; + remote_pe = remote_tp_group_pe_offset + local_tp_rank / gather_factor; + remote_kv_head_index = (local_tp_rank % gather_factor) * local_num_kv_head + h; + } else { + // scatter + assert(local_num_kv_head % remote_num_kv_head == 0); + int scatter_factor = local_num_kv_head / remote_num_kv_head; + remote_pe = remote_tp_group_pe_offset + local_tp_rank * scatter_factor + h / remote_num_kv_head; + remote_kv_head_index = h % remote_num_kv_head; + } + + for (int global_pos = blockIdx.x; global_pos < ntokens; global_pos += gridDim.x) { + int position = remote_position_map[global_pos]; + if (position == -1) { + continue; + }; + int page_id = position / page_size; + int offset_in_page = position % page_size; + int pages_shape[6] = {1, remote_num_pages, 2, remote_num_kv_head, page_size, head_dim}; + int k_page_index[6] = {remote_layer_id, page_id, 0, remote_kv_head_index, offset_in_page, 0}; + int v_page_index[6] = {remote_layer_id, page_id, 1, remote_kv_head_index, offset_in_page, 0}; + int k_v_shape[3] = {ntokens, local_num_kv_head, head_dim}; + int k_v_index[3] = {global_pos, h, 0}; + nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<6>(pages_shape, k_page_index), + k_data + calc_flattened_index<3>(k_v_shape, k_v_index), + head_dim * sizeof(T), remote_pe); + nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<6>(pages_shape, v_page_index), + v_data + calc_flattened_index<3>(k_v_shape, k_v_index), + head_dim * sizeof(T), remote_pe); + } + if (threadIdx.x == 0) { + nvshmem_quiet(); + } +} + +#define DISPATCH_TVM_CUDA_DTYPE(dl_dtype, cuda_dtype, ...) \ + if (dl_dtype.code == kDLFloat && dl_dtype.bits == 16) { \ + using cuda_dtype = half; \ + __VA_ARGS__ \ + } else { \ + LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \ + } + +#define DISPATCH_HEAD_DIM(head_dim, const_head_dim, ...) \ + if (head_dim == 128) { \ + constexpr int const_head_dim = 128; \ + __VA_ARGS__ \ + } else { \ + LOG(FATAL) << "Unsupported head dim " << head_dim; \ + } + +#define DISPATCH_PAGE_SIZE(page_size, const_page_size, ...) \ + if (page_size == 16) { \ + constexpr int const_page_size = 16; \ + __VA_ARGS__ \ + } else if (page_size == 4) { \ + constexpr int const_page_size = 4; \ + __VA_ARGS__ \ + } else { \ + LOG(FATAL) << "Unsupported page size " << page_size; \ + } + +#define DISPATCH_NUM_KV_HEAD(num_kv_head, const_num_kv_head, ...) \ + if (num_kv_head == 1) { \ + constexpr int const_num_kv_head = 1; \ + __VA_ARGS__ \ + } else if (num_kv_head == 2) { \ + constexpr int const_num_kv_head = 2; \ + __VA_ARGS__ \ + } else if (num_kv_head == 4) { \ + constexpr int const_num_kv_head = 4; \ + __VA_ARGS__ \ + } else if (num_kv_head == 8) { \ + constexpr int const_num_kv_head = 8; \ + __VA_ARGS__ \ + } else { \ + LOG(FATAL) << "Unsupported num_kv_head " << num_kv_head; \ + } + +int _KVTransfer(DLTensor* pages, DLTensor* k, DLTensor* v, DLTensor* remote_position_map, + int remote_num_pages, int remote_num_layers, int remote_num_kv_head, + int remote_layer_id, int remote_tp_group_pe_offset) { + CHECK_EQ(pages->device.device_type, kDLCUDA) << "The device of q matrix must be CUDA."; + CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be CUDA."; + CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be CUDA."; + CHECK_EQ(remote_position_map->device.device_type, kDLCUDA) + << "The device of o matrix must be CUDA."; + + size_t dev_id = pages->device.device_id; + CHECK_EQ(k->device.device_id, dev_id) << "The device id of q and k matrix doesn't match."; + CHECK_EQ(v->device.device_id, dev_id) << "The device id of q and v matrix doesn't match."; + CHECK_EQ(remote_position_map->device.device_id, dev_id) + << "The device id of q and o matrix doesn't match."; + + CHECK_GE(pages->ndim, 6); + int page_size = pages->shape[pages->ndim - 2]; + int head_dim = pages->shape[pages->ndim - 1]; + + CHECK_GE(k->ndim, 3); + int kv_len = k->shape[k->ndim - 3]; + int local_num_kv_heads = k->shape[k->ndim - 2]; + CHECK_EQ(head_dim, k->shape[k->ndim - 1]); + + CHECK_GE(v->ndim, 3); + CHECK_EQ(kv_len, v->shape[v->ndim - 3]); + CHECK_EQ(local_num_kv_heads, v->shape[v->ndim - 2]); + CHECK_EQ(head_dim, v->shape[v->ndim - 1]); + + CHECK(pages->dtype.lanes == 1 && k->dtype.lanes == 1 && v->dtype.lanes == 1); + CHECK(pages->dtype.bits == k->dtype.bits && pages->dtype.code == k->dtype.code); + CHECK(pages->dtype.bits == v->dtype.bits && pages->dtype.code == v->dtype.code); + int local_tp_rank; + tvm::runtime::DiscoWorker* worker = tvm::runtime::ThreadLocalDiscoWorker::Get()->worker; + if (worker == nullptr){ + local_tp_rank = 0; + } else { + local_tp_rank = worker->worker_id; + } + dim3 blocks(8, 1, 1); + dim3 threads(32, local_num_kv_heads, 1); + DISPATCH_TVM_CUDA_DTYPE( + pages->dtype, dtype_in, + {DISPATCH_HEAD_DIM( + head_dim, HEAD_DIM, + {DISPATCH_PAGE_SIZE( + page_size, PAGE_SIZE, + {DISPATCH_NUM_KV_HEAD( + remote_num_kv_head, REMOTE_NUM_KV_HEAD, + {DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, { + KVTransfer + <<>>( + (dtype_in*)pages->data, (dtype_in*)k->data, (dtype_in*)v->data, + (int32_t*)remote_position_map->data, kv_len, remote_layer_id, + local_tp_rank, remote_tp_group_pe_offset, remote_num_pages); + })})})})}) + + return 0; +} + +TVM_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer); diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc index 89d56ed3dc81..5ab4295b6f29 100644 --- a/src/runtime/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -26,6 +26,7 @@ #include "../../cuda/cuda_common.h" #include "../../memory/pooled_allocator.h" +#include "../../disco/utils.h" namespace tvm { namespace runtime { @@ -88,7 +89,7 @@ class NVSHMEMAllocator final : public PooledAllocator { }; NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) { - return NVSHMEMAllocator::Global()->Empty(shape, dtype, device); + return NVSHMEMAllocator::Global()->Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty); diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 6ee54e14f37b..75e7db483ec7 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -93,7 +93,15 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { StreamCreate(&ctx->default_stream); #endif Device device{TVM_DISCO_DEVICE_TYPE, device_id}; - worker->default_device = device; + if (worker->default_device.device_type == DLDeviceType::kDLCPU) { + worker->default_device = device; + } else { + ICHECK(worker->default_device.device_type == device.device_type && + worker->default_device.device_id == device.device_id) + << "The default device of the worker is inconsistent with the device used for CCL. " + << "The default device is " << worker->default_device << ", but the device used for CCL is " + << device << "."; + } worker->ccl = TVM_DISCO_CCL_NAME; ctx->worker = worker; ctx->device_id = device_id; diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py index b304d145aa38..1c4ffc9c4d08 100644 --- a/tests/python/disco/test_nvshmem.py +++ b/tests/python/disco/test_nvshmem.py @@ -107,7 +107,7 @@ def test_nvshmem_init_finalize(session_kind: di.Session, num_workers: int): f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") uid = f_init_nvshmem_uid() init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") - init_dfunc(uid, num_workers) + init_dfunc(uid, num_workers, 0) sess.sync_worker_0() finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") finalize_dfunc() @@ -123,7 +123,7 @@ def test_nvshmem_empty(session_kind: di.Session, num_workers: int): f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") uid = f_init_nvshmem_uid() init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") - init_dfunc(uid, num_workers) + init_dfunc(uid, num_workers, 0) sess.sync_worker_0() empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty") a = empty_dfunc(ShapeTuple((32, 64)), "float32", device) diff --git a/tests/python/relax/test_runtime_builtin_kv_cache_transfer.py b/tests/python/relax/test_runtime_builtin_kv_cache_transfer.py new file mode 100644 index 000000000000..b6348d631737 --- /dev/null +++ b/tests/python/relax/test_runtime_builtin_kv_cache_transfer.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import enum +from typing import Dict, List, Tuple, Union + +import numpy as np +import pytest +import scipy.special + +import tvm +import tvm.testing +from tvm import dlight as dl +from tvm import tir +from tvm.runtime import ShapeTuple +from tvm.script import tir as T +from mpi4py import MPI +from tvm.runtime import disco as di +from tvm._ffi.runtime_ctypes import Device + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() + +page_size = 4 +num_layers = 4 +num_kv_heads = 4 +head_dim = 128 +num_pages = 100 +ntokens = 16 + + +def test_kv_transfer_without_disco(): + dev = tvm.cuda(rank) + if rank == 0: + f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") + uid = f_init_nvshmem_uid() + else: + uid = None + uid = comm.bcast(uid, root=0) + init_func = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem") + init_func(uid, 2, rank) + empty_func = tvm.get_global_func("runtime.disco.nvshmem.empty") + pages = empty_func(ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size, head_dim)), "float16", dev) + position_map_array = [0, 1, 2, 3, 4, 5, 10, 11, 12, 15, 16, 17, 18, 19, 25, 27] + np.random.seed(0) + k_np = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16) + v_np = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16) + + if rank == 0: + k = tvm.nd.array(k_np, dev) + v = tvm.nd.array(v_np, dev) + remote_position_map_np = np.array(position_map_array, dtype=np.int32) + remote_position_map = tvm.nd.array(remote_position_map_np, dev) + transfer_func = tvm.get_global_func("nvshmem.KVTransfer") + transfer_func(pages, k, v, remote_position_map, num_pages, num_layers, num_kv_heads, 0, 1) + dev.sync() + comm.Barrier() + else: + comm.Barrier() + pages_np = pages.numpy() + for i, position in enumerate(position_map_array): + page_id = position // page_size + offset_in_page = position % page_size + original_k = k_np[i] + transferred_k = pages_np[0, page_id, 0, :, offset_in_page, :] + np.testing.assert_allclose(original_k, transferred_k) + original_v = v_np[i] + transferred_v = pages_np[0, page_id, 1, :, offset_in_page, :] + np.testing.assert_allclose(original_v, transferred_v) + finalize_func = tvm.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") + finalize_func() + comm.Barrier() + +def test_kv_transfer_with_disco(): + if rank == 0: + f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") + uid = f_init_nvshmem_uid() + else: + uid = None + uid = comm.bcast(uid, root=0) + sess = di.ProcessSession(num_workers=2) + init_func = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem") + init_func(uid, 4, rank * 2) + empty_func = sess.get_global_func("runtime.disco.nvshmem.empty") + pages = empty_func(ShapeTuple((num_layers, num_pages, 2, num_kv_heads, page_size, head_dim)), "float16", Device(device_type=0, device_id=0)) + position_map_array = [0, 1, 2, 3, 4, 5, 10, 11, 12, 15, 16, 17, 18, 19, 25, 27] + np.random.seed(0) + k_np_0 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16) + v_np_0 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16) + np.random.seed(1) + k_np_1 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16) + v_np_1 = np.random.rand(ntokens, num_kv_heads, head_dim).astype(np.float16) + if rank == 0: + k = sess.empty((ntokens, num_kv_heads, head_dim), "float16") + v = sess.empty((ntokens, num_kv_heads, head_dim), "float16") + k.debug_copy_from(0, k_np_0) + k.debug_copy_from(1, k_np_1) + v.debug_copy_from(0, v_np_0) + v.debug_copy_from(1, v_np_1) + remote_position_map_np = np.array(position_map_array, dtype=np.int32) + remote_position_map = sess.empty((len(position_map_array),), "int32") + remote_position_map.debug_copy_from(0, remote_position_map_np) + remote_position_map.debug_copy_from(1, remote_position_map_np) + transfer_func = sess.get_global_func("nvshmem.KVTransfer") + transfer_func(pages, k, v, remote_position_map, num_pages, num_layers, num_kv_heads, 0, 2) + for i in range(2): + sess._sync_worker(i) + for i in range(2): + tvm.cuda(i).sync() + comm.Barrier() + else: + comm.Barrier() + pages_np = pages.debug_get_from_remote(0).numpy() + for i, position in enumerate(position_map_array): + page_id = position // page_size + offset_in_page = position % page_size + original_k = k_np_0[i] + transferred_k = pages_np[0, page_id, 0, :, offset_in_page, :] + np.testing.assert_allclose(original_k, transferred_k) + original_v = v_np_0[i] + transferred_v = pages_np[0, page_id, 1, :, offset_in_page, :] + np.testing.assert_allclose(original_v, transferred_v) + pages_np = pages.debug_get_from_remote(1).numpy() + for i, position in enumerate(position_map_array): + page_id = position // page_size + offset_in_page = position % page_size + original_k = k_np_1[i] + transferred_k = pages_np[0, page_id, 0, :, offset_in_page, :] + np.testing.assert_allclose(original_k, transferred_k) + original_v = v_np_1[i] + transferred_v = pages_np[0, page_id, 1, :, offset_in_page, :] + np.testing.assert_allclose(original_v, transferred_v) + finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem") + finalize_dfunc() + for i in range(2): + sess._sync_worker(i) + +if __name__ == "__main__": + # FIXME: only one test can be run at a time + # test_kv_transfer_without_disco() + test_kv_transfer_with_disco()