Skip to content

Commit

Permalink
Change scheduler & input tensor shape (#1381)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Oct 17, 2023
1 parent 651c614 commit c1376e0
Show file tree
Hide file tree
Showing 13 changed files with 181 additions and 179 deletions.
28 changes: 14 additions & 14 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ __device__ __forceinline__ T silu(const T& x) {

template<typename scalar_t>
__global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [num_tokens, d]
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
Expand All @@ -27,11 +27,11 @@ __global__ void silu_and_mul_kernel(
} // namespace vllm

void silu_and_mul(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, 2 * d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
int num_tokens = input.size(0);
int d = input.size(1) / 2;
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;

dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
Expand All @@ -52,8 +52,8 @@ namespace vllm {
// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [num_tokens, d]
const scalar_t* __restrict__ input, // [num_tokens, d]
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
Expand All @@ -66,8 +66,8 @@ __global__ void activation_kernel(

// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int num_tokens = input.size(0); \
int d = input.size(1); \
int d = input.size(-1); \
int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
Expand Down Expand Up @@ -100,15 +100,15 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
} // namespace vllm

void gelu_new(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}

void gelu_fast(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, d]
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
9 changes: 7 additions & 2 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ __global__ void reshape_and_cache_kernel(
const int x) {
const int token_idx = blockIdx.x;
const int slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}

const int block_idx = slot_idx / block_size;
const int block_offset = slot_idx % block_size;

Expand All @@ -176,8 +181,8 @@ __global__ void reshape_and_cache_kernel(
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
key_cache[tgt_key_idx] = key[src_key_idx];
value_cache[tgt_value_idx] = value[src_value_idx];
}
}

Expand Down
12 changes: 6 additions & 6 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template<typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
Expand All @@ -37,12 +37,12 @@ __global__ void rms_norm_kernel(
} // namespace vllm

void rms_norm(
torch::Tensor& out, // [num_tokens, hidden_size]
torch::Tensor& input, // [num_tokens, hidden_size]
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int num_tokens = input.size(0);
int hidden_size = input.size(1);
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
Expand Down
22 changes: 11 additions & 11 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ inline __device__ void apply_rotary_embedding(

template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int query_stride,
Expand Down Expand Up @@ -78,18 +78,18 @@ __global__ void rotary_embedding_kernel(
} // namespace vllm

void rotary_embedding(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int num_tokens = query.size(0);
int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
int num_kv_heads = key.size(1) / head_size;
int query_stride = query.stride(0);
int key_stride = key.stride(0);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int query_stride = query.stride(-2);
int key_stride = key.stride(-2);

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
Expand Down
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,15 @@ class SchedulerConfig:
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
"""

def __init__(
self,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
max_paddings: int,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand All @@ -284,6 +286,7 @@ def __init__(
self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.max_paddings = max_paddings
self._verify_args()

def _verify_args(self) -> None:
Expand Down
15 changes: 11 additions & 4 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def _schedule(self) -> SchedulerOutputs:
# requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
num_batched_tokens = 0
seq_lens: List[int] = []

# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
Expand All @@ -157,7 +158,9 @@ def _schedule(self) -> SchedulerOutputs:
break

# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
new_seq_lens = seq_lens + [num_prompt_tokens]
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens):
break

Expand All @@ -168,18 +171,22 @@ def _schedule(self) -> SchedulerOutputs:
self.scheduler_config.max_num_seqs):
break

num_paddings = num_batched_tokens - sum(new_seq_lens)
if num_paddings > self.scheduler_config.max_paddings:
break
seq_lens = new_seq_lens

seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
num_curr_seqs += num_new_seqs
scheduled.append(seq_group)

if scheduled or ignored_seq_groups:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=num_batched_tokens,
num_batched_tokens=len(seq_lens) * max(seq_lens),
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
Expand Down
8 changes: 7 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class EngineArgs:
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
Expand Down Expand Up @@ -156,6 +157,10 @@ def add_cli_args(
type=int,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--max-paddings',
type=int,
default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch')
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
Expand Down Expand Up @@ -193,7 +198,8 @@ def create_engine_configs(
self.worker_use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len)
model_config.max_model_len,
self.max_paddings)
return model_config, cache_config, parallel_config, scheduler_config


Expand Down
11 changes: 5 additions & 6 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,28 @@ def __init__(
self.max_context_len = max_context_len
self.block_tables = block_tables

self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.to_cache = None
if sliding_window is not None:
# We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which
# elements we need to cache and where
# elements we need to cache.
to_cache, start_idx = [], 0
for prompt_len in self.prompt_lens:
to_cache.extend(
range(
start_idx + max(0, prompt_len - sliding_window),
start_idx + prompt_len,
))
start_idx += prompt_len
start_idx += self.max_prompt_len
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
self.to_cache = torch.tensor(to_cache,
dtype=torch.int32,
device=self.slot_mapping.device)

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1]
else:
Expand All @@ -69,12 +69,11 @@ def __init__(
assert context_lens.shape[0] == self.num_generation_tokens

# Set during the execution of the first attention op.
self.attn_bias: List[AttentionBias] = []
self.attn_bias: Optional[AttentionBias] = None

def __repr__(self) -> str:
# Print only useful metadata.
return (f'InputMetadata('
f'num_valid_tokens={self.num_valid_tokens}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
Expand Down
20 changes: 8 additions & 12 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,33 @@
class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d)
return: (num_tokens, d)
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""

def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
return out


class NewGELU(nn.Module):

def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
out = torch.empty_like(x)
activation_ops.gelu_new(out, x)
return out


class FastGELU(nn.Module):

def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
out = torch.empty_like(x)
activation_ops.gelu_fast(out, x)
return out

Expand Down
Loading

0 comments on commit c1376e0

Please sign in to comment.