Skip to content

Commit

Permalink
Allow overscheduling prompt (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
madamczykhabana authored Jul 23, 2024
1 parent 969bd83 commit b419b07
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500

VLLM_OVERSCHEDULE = os.environ.get('VLLM_OVERSCHEDULE', 'true') == 'true'

class PreemptionMode(enum.Enum):
"""Preemption modes.
Expand Down Expand Up @@ -55,15 +56,18 @@ class SchedulingBudget:
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0

def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int, overschedule: int = 0):
assert num_new_tokens != 0
assert num_new_seqs != 0
return (self.num_batched_tokens + num_new_tokens <= self.token_budget
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs + overschedule)

def remaining_token_budget(self):
return self.token_budget - self.num_batched_tokens

def remaining_seq_budget(self):
return self.max_num_seqs - self.num_curr_seqs

def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
if req_id in self._requeset_ids_num_batched_tokens:
return
Expand Down Expand Up @@ -406,7 +410,11 @@ def _schedule_running(
# groups to preempt.
now = time.time()
running_queue = policy.sort_by_priority(now, running_queue)
i = 1
while running_queue:
if i > self.scheduler_config.max_num_seqs:
break
i += 1
seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
Expand Down Expand Up @@ -631,6 +639,10 @@ def _schedule_prefills(
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
i = 0
max_prefill_batch_size = int(os.getenv("VLLM_PROMPT_BS_BUCKET_MAX", budget.max_num_seqs))
if budget.remaining_seq_budget() > 0 and VLLM_OVERSCHEDULE:
overschedule = max_prefill_batch_size
else:
overschedule = 0
while self._passed_delay(time.time()) and waiting_queue and i < max_prefill_batch_size:
i += 1
seq_group = waiting_queue[0]
Expand Down Expand Up @@ -687,9 +699,9 @@ def _schedule_prefills(
continue

num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_new_tokens == 0
or not budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs)):

can_fit = budget.can_schedule(num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs, overschedule=overschedule)
if num_new_tokens == 0 or not can_fit:
break

# Can schedule this request.
Expand Down Expand Up @@ -768,7 +780,8 @@ def _schedule_default(self) -> SchedulerOutputs:

assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
if not VLLM_OVERSCHEDULE:
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs

# Update waiting requests.
self.waiting = remaining_waiting
Expand Down

0 comments on commit b419b07

Please sign in to comment.