Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] support paged attention #710

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/custom_op/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext {
public:
static const int cuda_resource_ver = 1;

OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api), kernel_context_(ctx) {
api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
if (!cuda_stream_) {
ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
Expand Down Expand Up @@ -521,9 +521,16 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext {
int GetCudaDeviceId() const override {
return device_id_;
}

void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* mem_info, size_t count_or_bytes) override {
void* ret = nullptr;
api_.KernelContext_GetScratchBuffer(&kernel_context_, mem_info, count_or_bytes, &ret);
return ret;
}

private:
const OrtApi& api_;
const OrtKernelContext& kernel_context_;
OrtAllocator* cpu_allocator_;
OrtAllocator* cuda_allocator_;
void* cuda_stream_ = {};
Expand Down
2 changes: 2 additions & 0 deletions include/custom_op/kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <optional>
#include <numeric>
#include <type_traits>
#include "onnxruntime_c_api.h"

namespace Ort {
namespace Custom {
Expand All @@ -26,6 +27,7 @@ class CUDAKernelContext : public KernelContext {
virtual void* GetCudaStream() const = 0;
virtual void* GetCublasHandle() const = 0;
virtual int GetCudaDeviceId() const = 0;
virtual void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* , size_t ) { return nullptr; }
};
#endif

Expand Down
3 changes: 3 additions & 0 deletions include/ort_c_to_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class API {
return instance()->KernelContext_GetAllocator(context, mem_info, out);
}
#endif
static void ReleaseMemoryInfo(OrtMemoryInfo* mem_info) {
return instance()->ReleaseMemoryInfo(mem_info);
}
private:
const OrtApi* operator->() const {
return &api_;
Expand Down
7 changes: 6 additions & 1 deletion operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

#ifdef USE_CUDA
#include "cuda/fast_gelu.h"
#if ORT_API_VERSION >= 18
#include "cuda/paged_attention.h"
#endif
#endif

FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
Expand All @@ -13,8 +16,10 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
#ifdef USE_CUDA
,
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
#if ORT_API_VERSION >= 18
CustomCudaStructV2("PagedAttention", contrib::PagedAttention<ortc::MFloat16>),
#endif
#if ORT_API_VERSION >= 16

CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>)
#endif
Expand Down
111 changes: 111 additions & 0 deletions operators/cuda/paged_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "paged_attention_impl.h"

template<typename T>
struct PagedAttention {
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
int64_t num_heads = 0, head_size = 0;
ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_int64(&info, "num_heads", &num_heads));
assert(num_heads > 0);
num_heads_ = static_cast<int32_t>(num_heads);
num_kv_heads_ = static_cast<int32_t>(OrtW::GetOpAttributeOrDefault<int64_t>(info, "num_kv_heads", num_heads));

ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_int64(&info, "head_size", &head_size));
assert(head_size > 0);
head_size_ = static_cast<int32_t>(head_size);

ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_float(&info, "scale", &scale_));
assert(scale_ > 0);

num_queries_per_kv_ = num_heads_ / num_kv_heads_;
std::vector<int32_t> head_mapping_host(num_heads_);
for (int i = 0; i < num_kv_heads_; i++) {
for (int j = 0; j < num_queries_per_kv_; j++) {
head_mapping_host[i * num_queries_per_kv_ + j] = i;
}
}

OrtAllocator* allocator = nullptr;
ORTX_RETURN_IF_ERROR(api.KernelInfoGetAllocator(&info, OrtMemType::OrtMemTypeDefault, &allocator));
allocator_ = UniquePtrWithDeletor<OrtAllocator>{allocator, [&api](OrtAllocator* p){api.ReleaseAllocator(p);}};
head_mapping_ = GetScratchBuffer<int32_t>(allocator_->Alloc(allocator_.get(), num_heads_), allocator_.get());
InitializeHeadMapping(head_mapping_.get(), head_mapping_host.data(), head_mapping_host.size());
}

OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor<T>& query, const ortc::Tensor<T>& key,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want the kernel can support eager execution, you may use OrtxStatus instead of OrtStatusPtr

const ortc::Tensor<T>& value, const ortc::Tensor<T>& key_cache, const ortc::Tensor<T>& value_cache,
const ortc::Tensor<int32_t>& block_tables, const ortc::Tensor<int32_t>& slot_mappings,
std::optional<const ortc::Tensor<int32_t>*> context_lens,
std::optional<const ortc::Tensor<int64_t>*> positions
std::optional<const ortc::Tensor<T>*> cos_sin_cache, ortc::Tensor<T>& attn_out) const {
InputMetadata input_metadata;
ORTX_RETURN_IF_ERROR(CheckInputs(ctx.GetCudaStream(), allocator_.get(), query, key, value, key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, input_metadata));
const std::vector<int64_t>& query_shape = query.Shape();
T* output_data = attn_out.Allocate(query_shape);

if (cos_sin_cache.has_value()) {
int64_t rot_dim = (*cos_sin_cache)->Shape()[1];
assert(rot_dim == head_size_);
rotary_embedding_neox(reinterpret_cast<cudaStream_t>(ctx.GetCudaStream()), (*positions)->Data<int64_t>(), query.DataRaw(), key.DataRaw(), head_size_,
(*cos_sin_cache)->DataRaw(), input_metadata.num_valid_tokens, rot_dim, num_heads_, num_kv_heads_, 1);
}

const std::vector<int64_t>& key_cache_shape = key_cache.Shape();
if (input_metadata.num_valid_tokens > 0 && key_cache_shape.size() > 3) {
int64_t key_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_};
int64_t value_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_};
int block_size = gsl::narrow<int>(key_cache_shape[3]);
reshape_and_cache(reinterpret_cast<cudaStream_t>(ctx.GetCudaStream()), key.DataRaw(), value.DataRaw(), key_cache.DataRaw(), value_cache.DataRaw(), slot_mappings.Data(),
key_shape_r, value_shape_r, block_size, key_cache_shape[4], 1);
}

using TT = typename CudaT<T>::MappedType;
if (input_metadata.num_prompt_tokens > 0) {
//TODO(leca): flash attention for prompt > 0 case
return nullptr; // Don't handle prompt with decoding case for now
}

if (input_metadata.num_generation_tokens > 0) {
constexpr int PARTITION_SIZE = 512;
int max_num_partitions = (input_metadata.max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE;
int64_t generation_qeury_shape[3] = {input_metadata.num_valid_tokens, num_heads_, head_size_};
if (use_v1) {
paged_attention_v1(reinterpret_cast<cudaStream_t>(ctx.GetCudaStream()), reinterpret_cast<TT*>(output_data), query.DataRaw(),
key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_,
block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr,
value_cache.Shape()[3], input_metadata.max_context_len, nullptr,
input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1);
} else {
OrtMemoryInfo* mem_info = nullptr;
ORTX_RETURN_IF_ERROR(OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, ctx.device_id, OrtMemTypeDefault, &mem_info));
void* tmp_output_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape.size() * max_num_partitions * sizeof(T));
UniquePtrWithDeletor<T> tmp_output = GetScratchBuffer<T>(tmp_output_raw, allocator_.get()); // TODO(leca): should deallocate inside ORT
void* exp_sums_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T));
UniquePtrWithDeletor<T> exp_sums = GetScratchBuffer<T>(exp_sums_raw, allocator_.get());
void* max_logits_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T));
UniquePtrWithDeletor<T> max_logits = GetScratchBuffer<T>(max_logits_raw, allocator_.get());
paged_attention_v2(reinterpret_cast<cudaStream_t>(ctx.GetCudaStream()), exp_sums_raw, max_logits_raw, tmp_output_raw, reinterpret_cast<TT*>(output_data), query.DataRaw(),
key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_,
block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr,
value_cache.Shape()[3], input_metadata.max_context_len, nullptr,
input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1);

OrtW::API::ReleaseMemoryInfo(mem_info);
}
}
return nullptr;
}

private:
int32_t num_heads_; // number of attention heads
int32_t num_kv_heads_; // number of attention kv_heads
int32_t head_size_; // number of attention heads
float scale_; // sqrt(head_size_)
UniquePtrWithDeletor<int32_t> head_mapping_;
int32_t num_queries_per_kv_;
UniquePtrWithDeletor<OrtAllocator> allocator_;
};
145 changes: 145 additions & 0 deletions operators/cuda/paged_attention_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#include "paged_attention_impl.h"
#include <vector>

namespace cuda {

inline OrtStatusPtr CudaCall(cudaError_t cuda_error) {
if (cuda_error == cudaSuccess) return nullptr;
return OrtW::API::CreateStatus(ORT_FAIL, MakeString("cuda error:", (int)cuda_error).c_str());
}

void InitializeHeadMapping(void* dest_data, const void* src_data, size_t count) {
cudaMemcpy(dest_data, src_data, count, cudaMemcpyHostToDevice);
}

template <typename T>
OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor<T>& query, const ortc::Tensor<T>& key,
const ortc::Tensor<T>& value, const ortc::Tensor<T>& key_cache, const ortc::Tensor<T>& value_cache,
const ortc::Tensor<int32_t>& block_tables, const ortc::Tensor<int32_t>& slot_mappings,
std::optional<const ortc::Tensor<int32_t>*> context_lens,
std::optional<const ortc::Tensor<int64_t>*> positions, InputMetadata& input_metadata) {
const std::vector<int64_t>& query_shape = query.Shape();
if (query_shape.size() < 2 || query_shape.size() > 3) {
return OrtW::CreateStatus(MakeString("Invalid query shape, expect 2 or 3 dimensions"), ORT_INVALID_ARGUMENT);
}
if (query_shape.back() != num_heads_ * head_size_) {
return OrtW::CreateStatus(MakesString("query shape should equal to num_heads_ * head_size_"));
}

// TODO(leca): Cpu input or CUDA input?
int seq_len = query_shape.size() == 3 ? query_shape[1] : query_shape[0];
if (positions.has_value()) {
std::vector<int64_t> positions_host((*positions)->Shape().size());
ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost)));
while (positions_host.back() == 0) {
positions_host.pop_back();
seq_len--;
}

input_metadata.max_num_blocks_per_seq = 0;
// in prompt mode
if (positions_host.size() > 1 || positions_host.back() == 0) {
input_metadata.num_prompt_tokens = seq_len;
input_metadata.num_generation_tokens = 0;
} else {
input_metadata.num_prompt_tokens = 0;
input_metadata.num_generation_tokens = seq_len;
input_metadata.max_context_len = positions_host.back() + 1; // TODO(leca): what if position_host is empty?

int32_t block_size = gsl::narrow<int32_t>(key_cache.Shape()[3]);
for (int i = 0; i < positions_host.back() + 1; i += block_size) input_metadata.max_num_blocks_per_seq++;
}
} else {
// TODO(leca): context_lens is nullptr?
std::vector<int32_t> context_len_host((*context_lens)->SizeInBytes());
ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(context_len_host.data(), *(context_lens)->DataRaw(), *(context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost)));
std::vector<int64_t> position_ids;
for (size_t i = 0; i < context_len_host.size(); i++) {
if (context_len_host[i] == 0) continue;
std::vector<int64_t> position_id(context_len_host[i]);
std::iota(position_id.begin(), position_id.end(), 0); // fill position_id with {0, 1, 2, ...context_len_span[i]-1}
position_ids.insert(position_ids.end(), position_id.begin(), position_id.end());
}
input_metadata.position_ids = GetScratchBuffer<int64_t>(allocator->Alloc(allocator, cnt), allocator);
ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream)));
}
input_metadata.num_valid_tokens = seq_len;

return nullptr;
}

void paged_attention_v1(
const cudaStream_t stream,
void* out, // [num_seqs, num_heads, head_size]
const void* query, // [num_seqs, num_heads, head_size]
const void* key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const void* value_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* head_mapping, // [num_heads]
float scale,
const int* block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* context_lens, // [num_seqs]
int block_size,
int max_context_len,
const float* __restrict__ alibi_slopes,
const int max_num_blocks_per_seq,
const int64_t* query_shapes,
int num_queries_per_kv,
int dtype) {

}

template<typename T>
void paged_attention_v2(
const cudaStream_t stream,
void* out, // [num_seqs, num_heads, head_size]
void* exp_sums, // [num_seqs, num_heads, max_num_partitions]
void* max_logits, // [num_seqs, num_heads, max_num_partitions]
void* tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const void* query, // [num_seqs, num_heads, head_size]
const void* key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const void* value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* head_mapping, // [num_heads]
float scale,
const int* block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* context_lens, // [num_seqs]
int block_size,
int max_context_len,
const float* alibi_slopes,
const int max_num_blocks_per_seq,
const int64_t* query_shapes,
int num_queries_per_kv,
int dtype) {

}

void rotary_embedding_neox(
const cudaStream_t stream,
const int64_t* positions, // [num_tokens]
void* query, // [num_tokens, num_heads * head_size]
void* key, // [num_tokens, num_kv_heads * head_size]
int head_size,
const void* cos_sin_cache, // [max_position, rot_dim]
int num_tokens,
int rot_dim,
int num_heads,
int num_kv_heads,
int dtype) {

}

void reshape_and_cache(
const cudaStream_t stream,
const void* key, // [num_tokens, num_heads, head_size]
const void* value, // [num_tokens, num_heads, head_size]
const void* key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const void* value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* slot_mapping, // [num_tokens]
const int64_t* key_shapes,
const int64_t* value_shapes,
const int64_t block_size,
const int vec_x,
int dtype) {

}

} // namespace cuda
Loading
Loading