Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support blocked KV cache for flash decoding #678

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 231 additions & 17 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,23 +605,23 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q

void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d <= 32) {
run_mha_bwd_<elem_type, 32>(params, stream, configure);
} else if (params.d <= 64) {
run_mha_bwd_<elem_type, 64>(params, stream, configure);
} else if (params.d <= 96) {
run_mha_bwd_<elem_type, 96>(params, stream, configure);
} else if (params.d <= 128) {
run_mha_bwd_<elem_type, 128>(params, stream, configure);
} else if (params.d <= 160) {
run_mha_bwd_<elem_type, 160>(params, stream, configure);
} else if (params.d <= 192) {
run_mha_bwd_<elem_type, 192>(params, stream, configure);
} else if (params.d <= 224) {
run_mha_bwd_<elem_type, 224>(params, stream, configure);
} else if (params.d <= 256) {
run_mha_bwd_<elem_type, 256>(params, stream, configure);
}
if (params.d <= 32) {
run_mha_bwd_<elem_type, 32>(params, stream, configure);
} else if (params.d <= 64) {
run_mha_bwd_<elem_type, 64>(params, stream, configure);
} else if (params.d <= 96) {
run_mha_bwd_<elem_type, 96>(params, stream, configure);
} else if (params.d <= 128) {
run_mha_bwd_<elem_type, 128>(params, stream, configure);
} else if (params.d <= 160) {
run_mha_bwd_<elem_type, 160>(params, stream, configure);
} else if (params.d <= 192) {
run_mha_bwd_<elem_type, 192>(params, stream, configure);
} else if (params.d <= 224) {
run_mha_bwd_<elem_type, 224>(params, stream, configure);
} else if (params.d <= 256) {
run_mha_bwd_<elem_type, 256>(params, stream, configure);
}
});
}

Expand Down Expand Up @@ -1305,11 +1305,225 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
return {out, softmax_lse};
}

std::vector<at::Tensor>
mha_fwd_blocked_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x block_size x num_heads_k x head_size
const at::Tensor &vcache, // num_blocks x block_size x num_heads_k x head_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const at::Tensor &seqlens_k, // batch_size
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
const int window_size_left,
int window_size_right,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits
) {

auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
}
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");

CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);

TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");

const auto sizes = q.sizes();

const int batch_size = sizes[0];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
const int head_size = sizes[3];
const int num_blocks = kcache.size(0);
const int block_size = kcache.size(1);
const int max_num_blocks_per_seq = block_table.size(1);
const int seqlen_k = max_num_blocks_per_seq * block_size;
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size <= 512, "FlashAttention forward only supports head dimension at most 512");
TORCH_CHECK(head_size % 8 == 0);
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
} else {
out = torch::empty_like(q);
}

if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size % 8 == 0;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
out = out.view({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
seqlen_q = ngroups;
num_heads = num_heads_k;
}

CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(kcache, num_blocks, block_size, num_heads_k, head_size);
CHECK_SHAPE(vcache, num_blocks, block_size, num_heads_k, head_size);

auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};

auto opts = q.options();

auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));

Flash_fwd_params params;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, kcache, vcache, out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
/*p_ptr=*/nullptr,
softmax_lse.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
window_size_left,
window_size_right);

at::Tensor k, v;
if (k_.has_value()) {
TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
k = k_.value();
v = v_.value();
TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
CHECK_DEVICE(k); CHECK_DEVICE(v);
TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
int seqlen_knew = k.size(1);
CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size);
params.seqlen_knew = seqlen_knew;
params.knew_ptr = k.data_ptr();
params.vnew_ptr = v.data_ptr();
// All stride are in elements, not bytes.
params.knew_batch_stride = k.stride(0);
params.vnew_batch_stride = v.stride(0);
params.knew_row_stride = k.stride(-3);
params.vnew_row_stride = v.stride(-3);
params.knew_head_stride = k.stride(-2);
params.vnew_head_stride = v.stride(-2);
}

TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
CHECK_DEVICE(block_table);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
params.block_table = static_cast<int *>(block_table.data_ptr());
params.block_size = block_size;
params.block_table_batch_stride = block_table.stride(0);

TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
params.is_seqlens_k_cumulative = false;

if (rotary_cos_.has_value()) {
TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
auto rotary_cos = rotary_cos_.value();
CHECK_DEVICE(rotary_cos);
params.rotary_dim = rotary_cos.size(1) * 2;
TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
const int seqlen_ro = rotary_cos.size(0);
TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
CHECK_CONTIGUOUS(rotary_cos);
TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");

TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
auto rotary_sin = rotary_sin_.value();
CHECK_DEVICE(rotary_sin);
CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
CHECK_CONTIGUOUS(rotary_sin);
TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
params.rotary_cos_ptr = rotary_cos.data_ptr();
params.rotary_sin_ptr = rotary_sin.data_ptr();
params.is_rotary_interleaved = is_rotary_interleaved;
} else {
params.rotary_dim = 0;
}

// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
params.num_splits = num_splits;
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}

auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream, /*force_split_kernel=*/true);

if (seqlenq_ngroups_swapped) {
out = out.transpose(1, 2).view({batch_size, 1, num_heads_k * seqlen_q, head_size});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, softmax_lse};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
m.def("fwd_blocked_kvcache", &mha_fwd_blocked_kvcache, "Forward pass, with blocked KV-cache");
}
5 changes: 5 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;

// The block table.
int *__restrict__ block_table;
int block_size;
index_t block_table_batch_stride;

// The indices to index into the KV cache.
int *__restrict__ cache_batch_idx;

Expand Down
Loading