Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make max_num_batched_tokens behavior more verbose, add legacy mode #208

Merged
merged 2 commits into from
Aug 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 58 additions & 12 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,44 @@ def warmup_range(config: Tuple[int, int, int]):

def warmup_buckets(bs_bucket_config, seq_bucket_config,
max_num_batched_tokens):
buckets = itertools.product(warmup_range(bs_bucket_config),
warmup_range(seq_bucket_config))
buckets = list(
itertools.product(warmup_range(bs_bucket_config),
warmup_range(seq_bucket_config)))
if len(buckets) == 0:
msg = ("No buckets could be captured with following config "
f"(min, step, max_warmup): "
f"bs:{bs_bucket_config}, "
f"seq:{seq_bucket_config}")
raise ValueError(msg)

# Remove buckets exceeding batch token budget
filtered_buckets = filter(
lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens,
buckets)
return list(
filtered_buckets = list(
filter(lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens,
buckets))

if len(filtered_buckets) == 0:
# legacy case - we can handle this if we ignore max_num_batched_tokens
min_bucket_bs, min_bucket_seq = min(buckets,
key=lambda b: (b[0] * b[1]))
min_reqd_budget = min_bucket_bs * min_bucket_seq
msg = (
"The current bucketing configuration "
f"(min, step, max_warmup): "
f"bs:{bs_bucket_config}, "
f"seq:{seq_bucket_config} cannot be used with specified "
f"max_num_batched_tokens ({max_num_batched_tokens}), as the "
f"smallest bucket ({min_reqd_budget}) would exceed token budget. "
"Please increase max_num_batched_tokens or decrease bucket minimum "
"Ignoring max_num_batched_tokens at risk of out-of-memory errors.")
logger.error(msg)
return list(sorted(buckets, key=lambda b:
(b[0] * b[1], b[1], b[0]))), []

captured_buckets = list(
sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
omitted_buckets = list(
sorted([x for x in buckets if x not in filtered_buckets]))
return captured_buckets, omitted_buckets


def next_pow2(value: int):
Expand Down Expand Up @@ -531,9 +561,9 @@ def _setup_buckets(self) -> None:
f"bs:{self.prompt_bs_bucket_cfg}, "
f"seq:{self.prompt_seq_bucket_cfg}")
logger.info(msg)
self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg,
self.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)
self.prompt_buckets, prompt_omitted_buckets = warmup_buckets(
self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)

if self.lora_config:
self.prompt_buckets[:] = [
Expand All @@ -545,13 +575,21 @@ def _setup_buckets(self) -> None:
f"prompt buckets: {list(sorted(self.prompt_buckets))}")
logger.info(msg)

msg = (f"Omitted {len(prompt_omitted_buckets)} "
"prompt buckets due to exceeded token budget "
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
logger.info(msg)

msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}"
logger.debug(msg)

msg = ("Decode bucket config (min, step, max_warmup) "
f"bs:{self.decode_bs_bucket_cfg}, "
f"seq:{self.decode_seq_bucket_cfg}")
logger.info(msg)
self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg,
self.decode_seq_bucket_cfg,
self.max_num_batched_tokens)
self.decode_buckets, decode_omitted_buckets = warmup_buckets(
self.decode_bs_bucket_cfg, self.decode_seq_bucket_cfg,
self.max_num_batched_tokens)
if self.lora_config:
self.decode_buckets[:] = [
bucket for bucket in self.decode_buckets
Expand All @@ -561,6 +599,14 @@ def _setup_buckets(self) -> None:
f"{list(sorted(self.decode_buckets))}")
logger.info(msg)

msg = (f"Omitted {len(decode_omitted_buckets)} "
"decode buckets due to exceeded token budget "
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
logger.info(msg)

msg = f"Omitted decode buckets: {list(sorted(decode_omitted_buckets))}"
logger.debug(msg)

def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down
Loading