Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Add Prefix Caching Warmup Step #3901

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
47 changes: 47 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,53 @@ def list_loras(self) -> Set[int]:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_loras()

@torch.inference_mode()
def warmup_prefix_attn(self, kv_caches: List[torch.Tensor]) -> None:
"""Prefix attention uses a triton jit.

In our profile_run() step, we profile with random data, so the case
with a cache hit is not executed. The triton JIT is generated on
the fly, so the first call to context_attention_fwd will take
~3s to process. Without this warmup, this JIT will occur on the hot
path.

In this case, we make 2 sequences. Sequence 0 runs prompt_fwd with
self.block_size + 1 tokens, filling up physical block 1. Sequence
1 then runs the same prompt, but with metadata that block 1 is
computed. This thus triggers context_attention_fwd and generates
the code.
"""
NUM_ITERATIONS = 10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel one iteration should be good enough?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought so too, but empirically seemed that I needed more than 1 run to make timing stable, so I just picked 10.
This takes <1s so its not impactful to UX, but agree the code is a bit silly

Let me do some more experiments.

NUM_BLOCKS = 10
NUM_COMPUTED_BLOCKS = NUM_BLOCKS - 1
prompt_tokens = list(range(self.block_size * NUM_BLOCKS + 1))
block_table = list(range(1, NUM_BLOCKS + 2))

# Prompt forward to fill up the KV cache for block 1.
request_0 = SequenceGroupMetadata(
request_id="first_request",
is_prompt=True,
seq_data={0: SequenceData(prompt_tokens)},
sampling_params=SamplingParams(temperature=0),
block_tables={0: block_table},
)
self.execute_model([request_0], kv_caches)

# Prompt forward with block 1 computed. (Triggers
# context_attention_fwd).
request_1 = SequenceGroupMetadata(
request_id="second_request",
is_prompt=True,
seq_data={0: SequenceData(prompt_tokens)},
sampling_params=SamplingParams(temperature=0),
block_tables={0: block_table},
computed_block_nums=block_table[:NUM_COMPUTED_BLOCKS],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just run request_1? I believe the only goal here is to activate the triton kernel.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that should work

for _ in range(NUM_ITERATIONS):
self.execute_model([request_1], kv_caches)

return

@torch.inference_mode()
def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
"""Cuda graph capture a model.
Expand Down
2 changes: 2 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None:
def warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
if self.cache_config.enable_prefix_caching:
self.model_runner.warmup_prefix_attn(self.gpu_cache)
Comment on lines +168 to +169
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this called only in profiling? Or each inference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warm_up_model is not called on the hotpath

# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
Expand Down
Loading