Skip to content

Commit

Permalink
Port PT Profiler to habana_main (#256)
Browse files Browse the repository at this point in the history
Porting PT Profiler from:

81a23a7
and

e805b88
  • Loading branch information
adobrzyniewicz-habana authored Sep 11, 2024
1 parent 53f96b7 commit 2091161
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,26 @@ def align_workers(value, op):
return value_t.item()


def setup_profiler():
schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1)
DEVICE = 'hpu'
activities = [torch.profiler.ProfilerActivity.CPU]
activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE ==
'hpu' else [])
#from habana_frameworks.torch.activity_profiler import DebugActivity
#debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS]

profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
#debug_activities=debug_activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler('.',
use_gzip=True),
record_shapes=False,
with_stack=True)
return profiler


def pad_list(list, k, v):
target_len = round_up(len(list), k)
padding = target_len - len(list)
Expand Down Expand Up @@ -1237,11 +1257,7 @@ def profile_run(self) -> None:
max_seq_len = min(self.prompt_seq_bucket_cfg[-1],
self.max_num_batched_tokens // max_batch_size)

self.warmup_scenario(max_batch_size,
max_seq_len,
True,
kv_caches,
is_profile_run=True)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches)
return

def warmup_scenario(self,
Expand Down Expand Up @@ -1281,7 +1297,7 @@ def warmup_scenario(self,
for idx in range(max_num_seqs)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs else 1
times = 3 if use_graphs or is_profile_run else 1
if self.lora_config and not is_profile_run:
lora_mapping = LoRAMapping(
[0] * batch_size * seq_len,
Expand Down Expand Up @@ -1312,10 +1328,19 @@ def warmup_scenario(self,
for i, b in enumerate(blocks)
]
torch.hpu.synchronize()
profiler = None
if is_profile_run and self.is_driver_worker:
profiler = setup_profiler()
profiler.start()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=True)
torch.hpu.synchronize()
if profiler:
profiler.step()
if profiler:
profiler.stop()
self.profiler.end()
gc.collect()

def remove_all_loras(self):
Expand Down Expand Up @@ -1427,6 +1452,15 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):

@torch.inference_mode()
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
if profile := os.environ.get('VLLM_PT_PROFILE', None):
phase, bs, seq_len, graph = profile.split('_')
is_prompt = phase == 'prompt'
graphs = graph == 't'
if graphs:
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
True)
raise AssertionError("Finished profiling")
if os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true':
logger.info("Skipping warmup...")
return
Expand Down

0 comments on commit 2091161

Please sign in to comment.