forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request apache#1 from cmu-catalyst/disaggregation
Add kv transfer kernel
- Loading branch information
Showing
7 changed files
with
365 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
#include <cuda_fp16.h> | ||
#include <dlpack/dlpack.h> | ||
#include <nvshmem.h> | ||
#include <tvm/runtime/logging.h> | ||
#include <tvm/runtime/registry.h> | ||
#include <tvm/runtime/disco/disco_worker.h> | ||
|
||
template <int dim> | ||
__device__ int calc_flattened_index(int shape[dim], int index[dim]) { | ||
int flattened_index = 0; | ||
#pragma unroll | ||
for (int i = 0; i < dim; i++) { | ||
flattened_index *= shape[i]; | ||
flattened_index += index[i]; | ||
} | ||
return flattened_index; | ||
} | ||
|
||
template <typename T, int local_num_kv_head, int remote_num_kv_head, int head_dim, int page_size> | ||
__global__ void KVTransfer(T* pages, T* k_data, T* v_data, int32_t* remote_position_map, | ||
int ntokens, int remote_layer_id, int local_tp_rank, | ||
int remote_tp_group_pe_offset, int remote_num_pages) { | ||
// launch grid: [num_blocks, 1, 1], [32, local_num_kv_head, 1] | ||
// pages(remote): [remote_num_layers, remote_num_pages, 2, remote_num_kv_head, page_size, head_dim] | ||
// k_data: [ntokens, local_num_kv_head, head_dim] | ||
// v_data: [ntokens, local_num_kv_head, head_dim] | ||
int remote_pe; | ||
int remote_kv_head_index; | ||
int h = threadIdx.y; // local kv head index | ||
if (local_num_kv_head <= remote_num_kv_head) { | ||
// gather | ||
assert(remote_num_kv_head % local_num_kv_head == 0); | ||
int gather_factor = remote_num_kv_head / local_num_kv_head; | ||
remote_pe = remote_tp_group_pe_offset + local_tp_rank / gather_factor; | ||
remote_kv_head_index = (local_tp_rank % gather_factor) * local_num_kv_head + h; | ||
} else { | ||
// scatter | ||
assert(local_num_kv_head % remote_num_kv_head == 0); | ||
int scatter_factor = local_num_kv_head / remote_num_kv_head; | ||
remote_pe = remote_tp_group_pe_offset + local_tp_rank * scatter_factor + h / remote_num_kv_head; | ||
remote_kv_head_index = h % remote_num_kv_head; | ||
} | ||
|
||
for (int global_pos = blockIdx.x; global_pos < ntokens; global_pos += gridDim.x) { | ||
int position = remote_position_map[global_pos]; | ||
if (position == -1) { | ||
continue; | ||
}; | ||
int page_id = position / page_size; | ||
int offset_in_page = position % page_size; | ||
int pages_shape[6] = {1, remote_num_pages, 2, remote_num_kv_head, page_size, head_dim}; | ||
int k_page_index[6] = {remote_layer_id, page_id, 0, remote_kv_head_index, offset_in_page, 0}; | ||
int v_page_index[6] = {remote_layer_id, page_id, 1, remote_kv_head_index, offset_in_page, 0}; | ||
int k_v_shape[3] = {ntokens, local_num_kv_head, head_dim}; | ||
int k_v_index[3] = {global_pos, h, 0}; | ||
nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<6>(pages_shape, k_page_index), | ||
k_data + calc_flattened_index<3>(k_v_shape, k_v_index), | ||
head_dim * sizeof(T), remote_pe); | ||
nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<6>(pages_shape, v_page_index), | ||
v_data + calc_flattened_index<3>(k_v_shape, k_v_index), | ||
head_dim * sizeof(T), remote_pe); | ||
} | ||
if (threadIdx.x == 0) { | ||
nvshmem_quiet(); | ||
} | ||
} | ||
|
||
#define DISPATCH_TVM_CUDA_DTYPE(dl_dtype, cuda_dtype, ...) \ | ||
if (dl_dtype.code == kDLFloat && dl_dtype.bits == 16) { \ | ||
using cuda_dtype = half; \ | ||
__VA_ARGS__ \ | ||
} else { \ | ||
LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \ | ||
} | ||
|
||
#define DISPATCH_HEAD_DIM(head_dim, const_head_dim, ...) \ | ||
if (head_dim == 128) { \ | ||
constexpr int const_head_dim = 128; \ | ||
__VA_ARGS__ \ | ||
} else { \ | ||
LOG(FATAL) << "Unsupported head dim " << head_dim; \ | ||
} | ||
|
||
#define DISPATCH_PAGE_SIZE(page_size, const_page_size, ...) \ | ||
if (page_size == 16) { \ | ||
constexpr int const_page_size = 16; \ | ||
__VA_ARGS__ \ | ||
} else if (page_size == 4) { \ | ||
constexpr int const_page_size = 4; \ | ||
__VA_ARGS__ \ | ||
} else { \ | ||
LOG(FATAL) << "Unsupported page size " << page_size; \ | ||
} | ||
|
||
#define DISPATCH_NUM_KV_HEAD(num_kv_head, const_num_kv_head, ...) \ | ||
if (num_kv_head == 1) { \ | ||
constexpr int const_num_kv_head = 1; \ | ||
__VA_ARGS__ \ | ||
} else if (num_kv_head == 2) { \ | ||
constexpr int const_num_kv_head = 2; \ | ||
__VA_ARGS__ \ | ||
} else if (num_kv_head == 4) { \ | ||
constexpr int const_num_kv_head = 4; \ | ||
__VA_ARGS__ \ | ||
} else if (num_kv_head == 8) { \ | ||
constexpr int const_num_kv_head = 8; \ | ||
__VA_ARGS__ \ | ||
} else { \ | ||
LOG(FATAL) << "Unsupported num_kv_head " << num_kv_head; \ | ||
} | ||
|
||
int _KVTransfer(DLTensor* pages, DLTensor* k, DLTensor* v, DLTensor* remote_position_map, | ||
int remote_num_pages, int remote_num_layers, int remote_num_kv_head, | ||
int remote_layer_id, int remote_tp_group_pe_offset) { | ||
CHECK_EQ(pages->device.device_type, kDLCUDA) << "The device of q matrix must be CUDA."; | ||
CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be CUDA."; | ||
CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be CUDA."; | ||
CHECK_EQ(remote_position_map->device.device_type, kDLCUDA) | ||
<< "The device of o matrix must be CUDA."; | ||
|
||
size_t dev_id = pages->device.device_id; | ||
CHECK_EQ(k->device.device_id, dev_id) << "The device id of q and k matrix doesn't match."; | ||
CHECK_EQ(v->device.device_id, dev_id) << "The device id of q and v matrix doesn't match."; | ||
CHECK_EQ(remote_position_map->device.device_id, dev_id) | ||
<< "The device id of q and o matrix doesn't match."; | ||
|
||
CHECK_GE(pages->ndim, 6); | ||
int page_size = pages->shape[pages->ndim - 2]; | ||
int head_dim = pages->shape[pages->ndim - 1]; | ||
|
||
CHECK_GE(k->ndim, 3); | ||
int kv_len = k->shape[k->ndim - 3]; | ||
int local_num_kv_heads = k->shape[k->ndim - 2]; | ||
CHECK_EQ(head_dim, k->shape[k->ndim - 1]); | ||
|
||
CHECK_GE(v->ndim, 3); | ||
CHECK_EQ(kv_len, v->shape[v->ndim - 3]); | ||
CHECK_EQ(local_num_kv_heads, v->shape[v->ndim - 2]); | ||
CHECK_EQ(head_dim, v->shape[v->ndim - 1]); | ||
|
||
CHECK(pages->dtype.lanes == 1 && k->dtype.lanes == 1 && v->dtype.lanes == 1); | ||
CHECK(pages->dtype.bits == k->dtype.bits && pages->dtype.code == k->dtype.code); | ||
CHECK(pages->dtype.bits == v->dtype.bits && pages->dtype.code == v->dtype.code); | ||
int local_tp_rank; | ||
tvm::runtime::DiscoWorker* worker = tvm::runtime::ThreadLocalDiscoWorker::Get()->worker; | ||
if (worker == nullptr){ | ||
local_tp_rank = 0; | ||
} else { | ||
local_tp_rank = worker->worker_id; | ||
} | ||
dim3 blocks(8, 1, 1); | ||
dim3 threads(32, local_num_kv_heads, 1); | ||
DISPATCH_TVM_CUDA_DTYPE( | ||
pages->dtype, dtype_in, | ||
{DISPATCH_HEAD_DIM( | ||
head_dim, HEAD_DIM, | ||
{DISPATCH_PAGE_SIZE( | ||
page_size, PAGE_SIZE, | ||
{DISPATCH_NUM_KV_HEAD( | ||
remote_num_kv_head, REMOTE_NUM_KV_HEAD, | ||
{DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, { | ||
KVTransfer<dtype_in, LOCAL_NUM_KV_HEAD, REMOTE_NUM_KV_HEAD, HEAD_DIM, PAGE_SIZE> | ||
<<<blocks, threads>>>( | ||
(dtype_in*)pages->data, (dtype_in*)k->data, (dtype_in*)v->data, | ||
(int32_t*)remote_position_map->data, kv_len, remote_layer_id, | ||
local_tp_rank, remote_tp_group_pe_offset, remote_num_pages); | ||
})})})})}) | ||
|
||
return 0; | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.