Skip to content

Commit

Permalink
Merge pull request apache#1 from cmu-catalyst/disaggregation
Browse files Browse the repository at this point in the history
Add kv transfer kernel
  • Loading branch information
tqchen authored Oct 26, 2024
2 parents 35a317f + e243dd0 commit 6e3d8b9
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 9 deletions.
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ if (USE_CUDA AND USE_NVSHMEM)
if (NOT NVSHMEM_FOUND)
message(FATAL_ERROR "Cannot find NVSHMEM, USE_NVSHMEM=" ${USE_NVSHMEM})
endif()
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc src/runtime/contrib/nvshmem/*.cu)
list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
endif()

Expand Down
27 changes: 23 additions & 4 deletions src/runtime/contrib/nvshmem/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@ ShapeTuple InitNVSHMEMUID() {
return ShapeTuple(uid_64);
}

void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
ICHECK(worker != nullptr);
void InitNVSHMEM(ShapeTuple uid_64, int num_workers, int worker_id_start) {
DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
int worker_id;
if (worker == nullptr) {
worker_id = worker_id_start;
} else {
worker_id = worker_id_start + worker->worker_id;
}
CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1)
<< "ValueError: The length of unique_id must be " << UNIQUEID_PADDING << ", but got "
<< uid_64.size() << ".";
Expand All @@ -52,10 +57,24 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
uid.internal[i] = static_cast<char>(uid_64[i + 1]);
}
nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr);
// FIXME: this is a hack to avoid the issue of NVSHMEM using Multi-process-per-GPU to initialize
cudaSetDevice(worker_id);
nvshmemx_set_attr_uniqueid_args(worker_id, num_workers, &uid, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
CUDA_CALL(cudaSetDevice(mype_node));
if (worker != nullptr) {
if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
worker->default_device = Device{DLDeviceType::kDLCUDA, mype_node};
} else {
ICHECK(worker->default_device.device_type == DLDeviceType::kDLCUDA &&
worker->default_device.device_id == mype_node)
<< "The default device of the worker is inconsistent with the device used for NVSHMEM. "
<< "The default device is " << worker->default_device
<< ", but the device used for NVSHMEM is " << Device{DLDeviceType::kDLCUDA, mype_node}
<< ".";
}
}
LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
<< ", npes=" << nvshmem_n_pes();
}
Expand Down
172 changes: 172 additions & 0 deletions src/runtime/contrib/nvshmem/kv_transfer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
#include <cuda_fp16.h>
#include <dlpack/dlpack.h>
#include <nvshmem.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/disco/disco_worker.h>

template <int dim>
__device__ int calc_flattened_index(int shape[dim], int index[dim]) {
int flattened_index = 0;
#pragma unroll
for (int i = 0; i < dim; i++) {
flattened_index *= shape[i];
flattened_index += index[i];
}
return flattened_index;
}

template <typename T, int local_num_kv_head, int remote_num_kv_head, int head_dim, int page_size>
__global__ void KVTransfer(T* pages, T* k_data, T* v_data, int32_t* remote_position_map,
int ntokens, int remote_layer_id, int local_tp_rank,
int remote_tp_group_pe_offset, int remote_num_pages) {
// launch grid: [num_blocks, 1, 1], [32, local_num_kv_head, 1]
// pages(remote): [remote_num_layers, remote_num_pages, 2, remote_num_kv_head, page_size, head_dim]
// k_data: [ntokens, local_num_kv_head, head_dim]
// v_data: [ntokens, local_num_kv_head, head_dim]
int remote_pe;
int remote_kv_head_index;
int h = threadIdx.y; // local kv head index
if (local_num_kv_head <= remote_num_kv_head) {
// gather
assert(remote_num_kv_head % local_num_kv_head == 0);
int gather_factor = remote_num_kv_head / local_num_kv_head;
remote_pe = remote_tp_group_pe_offset + local_tp_rank / gather_factor;
remote_kv_head_index = (local_tp_rank % gather_factor) * local_num_kv_head + h;
} else {
// scatter
assert(local_num_kv_head % remote_num_kv_head == 0);
int scatter_factor = local_num_kv_head / remote_num_kv_head;
remote_pe = remote_tp_group_pe_offset + local_tp_rank * scatter_factor + h / remote_num_kv_head;
remote_kv_head_index = h % remote_num_kv_head;
}

for (int global_pos = blockIdx.x; global_pos < ntokens; global_pos += gridDim.x) {
int position = remote_position_map[global_pos];
if (position == -1) {
continue;
};
int page_id = position / page_size;
int offset_in_page = position % page_size;
int pages_shape[6] = {1, remote_num_pages, 2, remote_num_kv_head, page_size, head_dim};
int k_page_index[6] = {remote_layer_id, page_id, 0, remote_kv_head_index, offset_in_page, 0};
int v_page_index[6] = {remote_layer_id, page_id, 1, remote_kv_head_index, offset_in_page, 0};
int k_v_shape[3] = {ntokens, local_num_kv_head, head_dim};
int k_v_index[3] = {global_pos, h, 0};
nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<6>(pages_shape, k_page_index),
k_data + calc_flattened_index<3>(k_v_shape, k_v_index),
head_dim * sizeof(T), remote_pe);
nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<6>(pages_shape, v_page_index),
v_data + calc_flattened_index<3>(k_v_shape, k_v_index),
head_dim * sizeof(T), remote_pe);
}
if (threadIdx.x == 0) {
nvshmem_quiet();
}
}

#define DISPATCH_TVM_CUDA_DTYPE(dl_dtype, cuda_dtype, ...) \
if (dl_dtype.code == kDLFloat && dl_dtype.bits == 16) { \
using cuda_dtype = half; \
__VA_ARGS__ \
} else { \
LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \
}

#define DISPATCH_HEAD_DIM(head_dim, const_head_dim, ...) \
if (head_dim == 128) { \
constexpr int const_head_dim = 128; \
__VA_ARGS__ \
} else { \
LOG(FATAL) << "Unsupported head dim " << head_dim; \
}

#define DISPATCH_PAGE_SIZE(page_size, const_page_size, ...) \
if (page_size == 16) { \
constexpr int const_page_size = 16; \
__VA_ARGS__ \
} else if (page_size == 4) { \
constexpr int const_page_size = 4; \
__VA_ARGS__ \
} else { \
LOG(FATAL) << "Unsupported page size " << page_size; \
}

#define DISPATCH_NUM_KV_HEAD(num_kv_head, const_num_kv_head, ...) \
if (num_kv_head == 1) { \
constexpr int const_num_kv_head = 1; \
__VA_ARGS__ \
} else if (num_kv_head == 2) { \
constexpr int const_num_kv_head = 2; \
__VA_ARGS__ \
} else if (num_kv_head == 4) { \
constexpr int const_num_kv_head = 4; \
__VA_ARGS__ \
} else if (num_kv_head == 8) { \
constexpr int const_num_kv_head = 8; \
__VA_ARGS__ \
} else { \
LOG(FATAL) << "Unsupported num_kv_head " << num_kv_head; \
}

int _KVTransfer(DLTensor* pages, DLTensor* k, DLTensor* v, DLTensor* remote_position_map,
int remote_num_pages, int remote_num_layers, int remote_num_kv_head,
int remote_layer_id, int remote_tp_group_pe_offset) {
CHECK_EQ(pages->device.device_type, kDLCUDA) << "The device of q matrix must be CUDA.";
CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be CUDA.";
CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be CUDA.";
CHECK_EQ(remote_position_map->device.device_type, kDLCUDA)
<< "The device of o matrix must be CUDA.";

size_t dev_id = pages->device.device_id;
CHECK_EQ(k->device.device_id, dev_id) << "The device id of q and k matrix doesn't match.";
CHECK_EQ(v->device.device_id, dev_id) << "The device id of q and v matrix doesn't match.";
CHECK_EQ(remote_position_map->device.device_id, dev_id)
<< "The device id of q and o matrix doesn't match.";

CHECK_GE(pages->ndim, 6);
int page_size = pages->shape[pages->ndim - 2];
int head_dim = pages->shape[pages->ndim - 1];

CHECK_GE(k->ndim, 3);
int kv_len = k->shape[k->ndim - 3];
int local_num_kv_heads = k->shape[k->ndim - 2];
CHECK_EQ(head_dim, k->shape[k->ndim - 1]);

CHECK_GE(v->ndim, 3);
CHECK_EQ(kv_len, v->shape[v->ndim - 3]);
CHECK_EQ(local_num_kv_heads, v->shape[v->ndim - 2]);
CHECK_EQ(head_dim, v->shape[v->ndim - 1]);

CHECK(pages->dtype.lanes == 1 && k->dtype.lanes == 1 && v->dtype.lanes == 1);
CHECK(pages->dtype.bits == k->dtype.bits && pages->dtype.code == k->dtype.code);
CHECK(pages->dtype.bits == v->dtype.bits && pages->dtype.code == v->dtype.code);
int local_tp_rank;
tvm::runtime::DiscoWorker* worker = tvm::runtime::ThreadLocalDiscoWorker::Get()->worker;
if (worker == nullptr){
local_tp_rank = 0;
} else {
local_tp_rank = worker->worker_id;
}
dim3 blocks(8, 1, 1);
dim3 threads(32, local_num_kv_heads, 1);
DISPATCH_TVM_CUDA_DTYPE(
pages->dtype, dtype_in,
{DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM,
{DISPATCH_PAGE_SIZE(
page_size, PAGE_SIZE,
{DISPATCH_NUM_KV_HEAD(
remote_num_kv_head, REMOTE_NUM_KV_HEAD,
{DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, {
KVTransfer<dtype_in, LOCAL_NUM_KV_HEAD, REMOTE_NUM_KV_HEAD, HEAD_DIM, PAGE_SIZE>
<<<blocks, threads>>>(
(dtype_in*)pages->data, (dtype_in*)k->data, (dtype_in*)v->data,
(int32_t*)remote_position_map->data, kv_len, remote_layer_id,
local_tp_rank, remote_tp_group_pe_offset, remote_num_pages);
})})})})})

return 0;
}

TVM_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer);
3 changes: 2 additions & 1 deletion src/runtime/contrib/nvshmem/memory_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "../../cuda/cuda_common.h"
#include "../../memory/pooled_allocator.h"
#include "../../disco/utils.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -88,7 +89,7 @@ class NVSHMEMAllocator final : public PooledAllocator {
};

NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) {
return NVSHMEMAllocator::Global()->Empty(shape, dtype, device);
return NVSHMEMAllocator::Global()->Empty(shape, dtype, UseDefaultDeviceIfNone(device));
}

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty);
Expand Down
10 changes: 9 additions & 1 deletion src/runtime/disco/nccl/nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,15 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) {
StreamCreate(&ctx->default_stream);
#endif
Device device{TVM_DISCO_DEVICE_TYPE, device_id};
worker->default_device = device;
if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
worker->default_device = device;
} else {
ICHECK(worker->default_device.device_type == device.device_type &&
worker->default_device.device_id == device.device_id)
<< "The default device of the worker is inconsistent with the device used for CCL. "
<< "The default device is " << worker->default_device << ", but the device used for CCL is "
<< device << ".";
}
worker->ccl = TVM_DISCO_CCL_NAME;
ctx->worker = worker;
ctx->device_id = device_id;
Expand Down
4 changes: 2 additions & 2 deletions tests/python/disco/test_nvshmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_nvshmem_init_finalize(session_kind: di.Session, num_workers: int):
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers)
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
Expand All @@ -123,7 +123,7 @@ def test_nvshmem_empty(session_kind: di.Session, num_workers: int):
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers)
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
a = empty_dfunc(ShapeTuple((32, 64)), "float32", device)
Expand Down
Loading

0 comments on commit 6e3d8b9

Please sign in to comment.