diff --git a/llm_on_ray/inference/mllm_predictor.py b/llm_on_ray/inference/mllm_predictor.py index 6168abe5c..006ecb855 100644 --- a/llm_on_ray/inference/mllm_predictor.py +++ b/llm_on_ray/inference/mllm_predictor.py @@ -32,24 +32,35 @@ def __init__(self, infer_conf: InferenceConfig): ) model = model.eval().to(self.device) - # # to channels last - model = model.to(memory_format=torch.channels_last) - # to ipex - if infer_conf.ipex.enabled: - import intel_extension_for_pytorch as ipex + if self.device.type == "hpu": + self.use_hpu_graphs = model_desc.use_hpu_graphs + if self.use_hpu_graphs: + from habana_frameworks.torch.hpu import ( + wrap_in_hpu_graph, + ) # pylint: disable=E0401 - torch._C._jit_set_texpr_fuser_enabled(False) - try: - ipex._C.disable_jit_linear_repack() - except Exception: - pass - model = ipex.llm.optimize( - model.eval(), - dtype=torch.bfloat16 - if infer_conf.ipex.precision == PRECISION_BF16 - else torch.float32, - inplace=True, - ) + model = wrap_in_hpu_graph(model) + else: + print("Warning: use_hpu_graphs is set to False. This will hurt the performance.") + else: + # # to channels last + model = model.to(memory_format=torch.channels_last) + # to ipex + if infer_conf.ipex.enabled: + import intel_extension_for_pytorch as ipex + + torch._C._jit_set_texpr_fuser_enabled(False) + try: + ipex._C.disable_jit_linear_repack() + except Exception: + pass + model = ipex.optimize_transformers( + model.eval(), + dtype=torch.bfloat16 + if infer_conf.ipex.precision == PRECISION_BF16 + else torch.float32, + inplace=True, + ) self.model = model self.processor = processor @@ -65,6 +76,8 @@ def _process_config(self, config): def _tokenize_inputs(self, image, text_prompt): input_tokens = self.processor(text=text_prompt, images=image, return_tensors="pt") + if self.device.type != "cpu": + input_tokens = input_tokens.to(device=self.device) return input_tokens def streaming_generate(self, image, prompt, streamer, **config):