Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/habana_main' into private/jmaksy…
Browse files Browse the repository at this point in the history
…mczuk/fake_hpu_cpu
  • Loading branch information
jmaksymczuk committed Sep 11, 2024
2 parents 4ab0063 + 2091161 commit 73f213a
Showing 1 changed file with 52 additions and 21 deletions.
73 changes: 52 additions & 21 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,26 @@ def align_workers(value, op):
return value_t.item()


def setup_profiler():
schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1)
DEVICE = 'hpu'
activities = [torch.profiler.ProfilerActivity.CPU]
activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE ==
'hpu' else [])
#from habana_frameworks.torch.activity_profiler import DebugActivity
#debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS]

profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
#debug_activities=debug_activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler('.',
use_gzip=True),
record_shapes=False,
with_stack=True)
return profiler


def pad_list(list, k, v):
target_len = round_up(len(list), k)
padding = target_len - len(list)
Expand Down Expand Up @@ -579,8 +599,6 @@ def load_model(self) -> None:
htcore.mark_step()
torch.hpu.synchronize()

# FIXME: Running with disable_tensor_cache=True causes
# RuntimeErrors. This needs to be debugged
with HabanaMemoryProfiler() as m_wrap:
self.model = _maybe_wrap_in_hpu_graph(
self.model,
Expand Down Expand Up @@ -892,6 +910,9 @@ def _prepare_decode(
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)

dummy_slots = itertools.cycle(
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))

for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
Expand Down Expand Up @@ -921,8 +942,11 @@ def _prepare_decode(

block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
if block_number == _PAD_BLOCK_ID:
slot = next(dummy_slots)
else:
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
lora_index_mapping.append(lora_id)
lora_prompt_mapping.append(lora_id)
Expand All @@ -943,12 +967,6 @@ def _prepare_decode(
dtype=torch.long,
device=self.device)

dummy_slots = itertools.cycle(
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))
slot_mapping = [[
s if s != _PAD_SLOT_ID else next(dummy_slots) for s in sl
] for sl in slot_mapping]

num_decode_tokens = sum(seq_lens)

blocks_used = [len(bt) for bt in block_tables]
Expand Down Expand Up @@ -1242,11 +1260,7 @@ def profile_run(self) -> None:
max_seq_len = min(self.prompt_seq_bucket_cfg[-1],
self.max_num_batched_tokens // max_batch_size)

self.warmup_scenario(max_batch_size,
max_seq_len,
True,
kv_caches,
is_profile_run=True)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches)
return

def warmup_scenario(self,
Expand Down Expand Up @@ -1286,7 +1300,7 @@ def warmup_scenario(self,
for idx in range(max_num_seqs)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs else 1
times = 3 if use_graphs or is_profile_run else 1
if self.lora_config and not is_profile_run:
lora_mapping = LoRAMapping(
[0] * batch_size * seq_len,
Expand Down Expand Up @@ -1317,10 +1331,19 @@ def warmup_scenario(self,
for i, b in enumerate(blocks)
]
torch.hpu.synchronize()
profiler = None
if is_profile_run and self.is_driver_worker:
profiler = setup_profiler()
profiler.start()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=False)
self.execute_model(inputs, kv_caches, warmup_mode=True)
torch.hpu.synchronize()
if profiler:
profiler.step()
if profiler:
profiler.stop()
self.profiler.end()
gc.collect()

def remove_all_loras(self):
Expand Down Expand Up @@ -1432,6 +1455,15 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):

@torch.inference_mode()
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
if profile := os.environ.get('VLLM_PT_PROFILE', None):
phase, bs, seq_len, graph = profile.split('_')
is_prompt = phase == 'prompt'
graphs = graph == 't'
if graphs:
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
True)
raise AssertionError("Finished profiling")
if os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true':
logger.info("Skipping warmup...")
return
Expand Down Expand Up @@ -1579,10 +1611,9 @@ def mem_margin(self, value):


def _maybe_wrap_in_hpu_graph(*args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter(
*args, **
kwargs)) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
*args, **kwargs)
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs)


class HabanaProfilerCounterHelper():
Expand Down

0 comments on commit 73f213a

Please sign in to comment.