-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Add Prefix Caching Warmup Step #3901
base: main
Are you sure you want to change the base?
Changes from 7 commits
8008f98
300ab76
972ff6c
4f8eb72
6a7d8be
eee3ec6
8233c93
e75a971
79d33c9
ed569a4
500ea31
430537c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -788,6 +788,53 @@ def list_loras(self) -> Set[int]: | |
raise RuntimeError("LoRA is not enabled.") | ||
return self.lora_manager.list_loras() | ||
|
||
@torch.inference_mode() | ||
def warmup_prefix_attn(self, kv_caches: List[torch.Tensor]) -> None: | ||
"""Prefix attention uses a triton jit. | ||
|
||
In our profile_run() step, we profile with random data, so the case | ||
with a cache hit is not executed. The triton JIT is generated on | ||
the fly, so the first call to context_attention_fwd will take | ||
~3s to process. Without this warmup, this JIT will occur on the hot | ||
path. | ||
|
||
In this case, we make 2 sequences. Sequence 0 runs prompt_fwd with | ||
self.block_size + 1 tokens, filling up physical block 1. Sequence | ||
1 then runs the same prompt, but with metadata that block 1 is | ||
computed. This thus triggers context_attention_fwd and generates | ||
the code. | ||
""" | ||
NUM_ITERATIONS = 10 | ||
NUM_BLOCKS = 10 | ||
NUM_COMPUTED_BLOCKS = NUM_BLOCKS - 1 | ||
prompt_tokens = list(range(self.block_size * NUM_BLOCKS + 1)) | ||
block_table = list(range(1, NUM_BLOCKS + 2)) | ||
|
||
# Prompt forward to fill up the KV cache for block 1. | ||
request_0 = SequenceGroupMetadata( | ||
request_id="first_request", | ||
is_prompt=True, | ||
seq_data={0: SequenceData(prompt_tokens)}, | ||
sampling_params=SamplingParams(temperature=0), | ||
block_tables={0: block_table}, | ||
) | ||
self.execute_model([request_0], kv_caches) | ||
|
||
# Prompt forward with block 1 computed. (Triggers | ||
# context_attention_fwd). | ||
request_1 = SequenceGroupMetadata( | ||
request_id="second_request", | ||
is_prompt=True, | ||
seq_data={0: SequenceData(prompt_tokens)}, | ||
sampling_params=SamplingParams(temperature=0), | ||
block_tables={0: block_table}, | ||
computed_block_nums=block_table[:NUM_COMPUTED_BLOCKS], | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just run There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think that should work |
||
for _ in range(NUM_ITERATIONS): | ||
self.execute_model([request_1], kv_caches) | ||
|
||
return | ||
|
||
@torch.inference_mode() | ||
def capture_model(self, kv_caches: List[torch.Tensor]) -> None: | ||
"""Cuda graph capture a model. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,6 +165,8 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: | |
def warm_up_model(self) -> None: | ||
if not self.model_config.enforce_eager: | ||
self.model_runner.capture_model(self.gpu_cache) | ||
if self.cache_config.enable_prefix_caching: | ||
self.model_runner.warmup_prefix_attn(self.gpu_cache) | ||
Comment on lines
+168
to
+169
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this called only in profiling? Or each inference? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# Reset the seed to ensure that the random state is not affected by | ||
# the model initialization and profiling. | ||
set_random_seed(self.model_config.seed) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel one iteration should be good enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought so too, but empirically seemed that I needed more than 1 run to make timing stable, so I just picked 10.
This takes <1s so its not impactful to UX, but agree the code is a bit silly
Let me do some more experiments.