Skip to content

Commit

Permalink
Enable mllm support on Gaudi (#145)
Browse files Browse the repository at this point in the history
* Enable mllm support on Gaudi

Signed-off-by: Xue, Chendi <[email protected]>

* reformat

Signed-off-by: Xue, Chendi <[email protected]>

---------

Signed-off-by: Xue, Chendi <[email protected]>
  • Loading branch information
xuechendi authored Mar 22, 2024
1 parent 4abcd74 commit e707c78
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions llm_on_ray/inference/mllm_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,35 @@ def __init__(self, infer_conf: InferenceConfig):
)

model = model.eval().to(self.device)
# # to channels last
model = model.to(memory_format=torch.channels_last)
# to ipex
if infer_conf.ipex.enabled:
import intel_extension_for_pytorch as ipex
if self.device.type == "hpu":
self.use_hpu_graphs = model_desc.use_hpu_graphs
if self.use_hpu_graphs:
from habana_frameworks.torch.hpu import (
wrap_in_hpu_graph,
) # pylint: disable=E0401

torch._C._jit_set_texpr_fuser_enabled(False)
try:
ipex._C.disable_jit_linear_repack()
except Exception:
pass
model = ipex.llm.optimize(
model.eval(),
dtype=torch.bfloat16
if infer_conf.ipex.precision == PRECISION_BF16
else torch.float32,
inplace=True,
)
model = wrap_in_hpu_graph(model)
else:
print("Warning: use_hpu_graphs is set to False. This will hurt the performance.")
else:
# # to channels last
model = model.to(memory_format=torch.channels_last)
# to ipex
if infer_conf.ipex.enabled:
import intel_extension_for_pytorch as ipex

torch._C._jit_set_texpr_fuser_enabled(False)
try:
ipex._C.disable_jit_linear_repack()
except Exception:
pass
model = ipex.optimize_transformers(
model.eval(),
dtype=torch.bfloat16
if infer_conf.ipex.precision == PRECISION_BF16
else torch.float32,
inplace=True,
)
self.model = model
self.processor = processor

Expand All @@ -65,6 +76,8 @@ def _process_config(self, config):

def _tokenize_inputs(self, image, text_prompt):
input_tokens = self.processor(text=text_prompt, images=image, return_tensors="pt")
if self.device.type != "cpu":
input_tokens = input_tokens.to(device=self.device)
return input_tokens

def streaming_generate(self, image, prompt, streamer, **config):
Expand Down

0 comments on commit e707c78

Please sign in to comment.