Skip to content

Commit

Permalink
[Kernel][Model] Improve continuous batching for Jamba and Mamba (vllm…
Browse files Browse the repository at this point in the history
…-project#9189)

Signed-off-by: charlifu <[email protected]>
  • Loading branch information
mzusman authored and charlifu committed Oct 23, 2024
1 parent f22047d commit 39c9698
Show file tree
Hide file tree
Showing 15 changed files with 511 additions and 439 deletions.
37 changes: 25 additions & 12 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
const at::Tensor out,
const c10::optional<at::Tensor>& bias,
bool silu_activation,
int64_t pad_slot_id,
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
Expand All @@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
params.dim = dim;
params.seqlen = seqlen;
params.width = width;
params.pad_slot_id = pad_slot_id;

params.silu_activation = silu_activation;

Expand All @@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase &params,
}


at::Tensor
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
const c10::optional<at::Tensor> &conv_states,
const c10::optional<at::Tensor> &query_start_loc,
const c10::optional<at::Tensor> &cache_indices,
const c10::optional<at::Tensor> &has_initial_state,
bool silu_activation) {
bool silu_activation,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
CHECK_SHAPE(cache_indices_, batch_size);
}

at::Tensor out = torch::empty_like(x);
at::Tensor out = x;

ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation,
pad_slot_id,
query_start_loc,
cache_indices,
has_initial_state
Expand All @@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
});
return out;
}


at::Tensor
causal_conv1d_update(const at::Tensor &x,
void causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state,
const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
bool silu_activation,
const c10::optional<at::Tensor> &cache_seqlens_,
const c10::optional<at::Tensor> &conv_state_indices_) {
const c10::optional<at::Tensor> &conv_state_indices_,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x,
CHECK_SHAPE(bias, dim);
}

at::Tensor out = torch::empty_like(x);
at::Tensor out = x;

ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation);
silu_activation,
pad_slot_id);
params.conv_state_ptr = conv_state.data_ptr();
params.conv_state_len = conv_state_len;
// All stride are in elements, not bytes.
Expand Down Expand Up @@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
});
return out;
}

template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
Expand Down Expand Up @@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];

// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;

Expand Down Expand Up @@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id
: params.conv_state_indices_ptr[batch_id];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if (conv_state_batch_coord == params.pad_slot_id){
return;
}
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;
Expand Down
1 change: 1 addition & 0 deletions csrc/mamba/causal_conv1d/causal_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct ConvParamsBase {
using index_t = uint32_t;

int batch, dim, seqlen, width;
int64_t pad_slot_id;
bool silu_activation;

index_t x_batch_stride;
Expand Down
1 change: 1 addition & 0 deletions csrc/mamba/mamba_ssm/selective_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct SSMParamsBase {
int dim_ngroups_ratio;
bool is_variable_B;
bool is_variable_C;
int64_t pad_slot_id;

bool delta_softplus;

Expand Down
24 changes: 14 additions & 10 deletions csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
+ dim_id * kNRows * params.u_d_stride;
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
Expand Down Expand Up @@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const size_t seqlen,
const size_t dstate,
const size_t n_groups,
const size_t n_chunks,
const bool is_variable_B,
const bool is_variable_C,
// device pointers
Expand All @@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool varlen) {
bool varlen,
int64_t pad_slot_id) {

// Reset the parameters
memset(&params, 0, sizeof(params));
Expand All @@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.seqlen = seqlen;
params.dstate = dstate;
params.n_groups = n_groups;
params.n_chunks = n_chunks;
params.dim_ngroups_ratio = dim / n_groups;
params.pad_slot_id = pad_slot_id;

params.delta_softplus = delta_softplus;

Expand Down Expand Up @@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const c10::optional<torch::Tensor> &query_start_loc,
const c10::optional<torch::Tensor> &cache_indices,
const c10::optional<torch::Tensor> &has_initial_state,
const torch::Tensor &ssm_states) {
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,

out_z = z;

const int n_chunks = (seqlen + 2048 - 1) / 2048;
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
// at::Tensor out = torch::empty_like(u);
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type);
TORCH_CHECK(ssm_states.is_cuda());
TORCH_CHECK(ssm_states.stride(-1) == 1);
CHECK_SHAPE(ssm_states, batch_size, dim, dstate);

SSMParamsBase params;
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
u, delta, A, B, C, out, z, out_z,
D_,
delta_bias_,
Expand All @@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
query_start_loc,
cache_indices,
has_initial_state,
varlen
varlen,
pad_slot_id
);


Expand Down
32 changes: 17 additions & 15 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const c10::optional<torch::Tensor>& query_start_loc,
const c10::optional<torch::Tensor>& cache_indices,
const c10::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states);

at::Tensor causal_conv1d_update(
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_, bool silu_activation,
const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& conv_state_indices_);

at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation);
const torch::Tensor& ssm_states, int64_t pad_slot_id);

void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
bool silu_activation,
const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& conv_state_indices_,
int64_t pad_slot_id);

void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation, int64_t pad_slot_id);

#ifndef USE_ROCM
using fptr_t = int64_t;
Expand Down
9 changes: 6 additions & 3 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states) -> ()");
"Tensor! ssm_states,"
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

ops.def(
Expand All @@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices) -> Tensor");
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);

ops.def(
Expand All @@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation) -> Tensor");
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif

Expand Down
Loading

0 comments on commit 39c9698

Please sign in to comment.