diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 2360e39fcba28..55f205915ea8c 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -238,11 +238,12 @@ def pad_list(list, k, v): class HpuModelAdapter(): - def __init__(self, model, block_size, enforce_eager): + def __init__(self, model, block_size, dtype, enforce_eager): self.model = model self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '0').lower() in ['1', 'true'] self.block_size = block_size + self.dtype = dtype if not htorch.utils.internal.is_lazy() and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', @@ -304,7 +305,7 @@ def forward(self, *args, **kwargs): input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._update_metadata( kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), - input_ids.device, torch.bfloat16) + input_ids.device, self.dtype) LoraMask.setLoraMask(kwargs.pop('lora_mask')) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -600,6 +601,7 @@ def load_model(self) -> None: self.model = _maybe_wrap_in_hpu_graph( self.model, self.block_size, + dtype=self.model_config.dtype, enforce_eager=self.enforce_eager) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg)