Skip to content

Commit

Permalink
refactor device check and remove cohere/mixtral support (#12659)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jan 7, 2025
1 parent ea65e4f commit ddc0ef3
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 1,359 deletions.
82 changes: 11 additions & 71 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,31 +1710,6 @@ def _optimize_post(model):
convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
elif model.config.model_type == "cohere":
# for CohereForAI/c4ai-command-r-v01
invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"),
"Please upgrade transformers to 4.40.0 or higher version "
"to run Mixtral models.")
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
if version.parse(trans_version) >= version.parse("4.41.0"):
from ipex_llm.transformers.models.cohere import cohere_model_forward_4_41
convert_forward(model,
module.CohereModel,
cohere_model_forward_4_41)
else:
from ipex_llm.transformers.models.cohere import cohere_model_forward
convert_forward(model,
module.CohereModel,
cohere_model_forward)

from ipex_llm.transformers.models.cohere import cohere_attention_forward
convert_forward(model,
module.CohereAttention,
cohere_attention_forward)
convert_forward(model,
module.CohereMLP,
mlp_silu_forward)
elif model.config.model_type == "aquila":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
Expand All @@ -1746,31 +1721,6 @@ def _optimize_post(model):
convert_forward(model,
module.AquilaRMSNorm,
rms_norm_forward)
elif model.config.model_type == "mixtral":
# For mistralai/Mixtral-8x7B-v0.1
invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"),
"Please upgrade transformers to 4.36.0 or higher version "
"to run Mixtral models.")
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.mixtral import mixtral_moeblock_forward, \
mixtral_attention_forward, mixtral_mlp_forward, mixtral_model_forward
convert_forward(model,
module.MixtralAttention,
mixtral_attention_forward)
convert_forward(model,
module.MixtralRMSNorm,
rms_norm_forward)
convert_forward(model,
module.MixtralSparseMoeBlock,
mixtral_moeblock_forward)
convert_forward(model,
module.MixtralBLockSparseTop2MLP,
mixtral_mlp_forward)
convert_forward(model,
module.MixtralModel,
mixtral_model_forward)

elif model.config.model_type == "phi-msft" and \
hasattr(model.config, "num_local_experts"):
# For phixtral, limit the condition to avoid applying on phi-2 hosted by ModelScope
Expand All @@ -1785,29 +1735,19 @@ def _optimize_post(model):
module.MLP,
phixtral_mlp_forward)
elif model.config.model_type == "mistral":
if model.config.architectures is not None and \
model.config.architectures[0] == "MixtralForCausalLM":
# For DiscoResearch/mixtral-7b-8expert
invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"),
"Please upgrade transformers to 4.36.0 or higher version "
"to run Mixtral models.")
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(model, module.MistralRMSNorm, rms_norm_forward)
else:
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

from ipex_llm.transformers.models.mistral import mistral_model_forward
from ipex_llm.transformers.models.mistral import mistral_attention_forward
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.common import mlp_silu_forward
from ipex_llm.transformers.models.mistral import mistral_model_forward
from ipex_llm.transformers.models.mistral import mistral_attention_forward
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.common import mlp_silu_forward

convert_forward(model, module.MistralModel, mistral_model_forward)
convert_forward(model, module.MistralAttention, mistral_attention_forward)
convert_forward(model, module.MistralSdpaAttention, mistral_attention_forward)
convert_forward(model, module.MistralRMSNorm, rms_norm_forward)
convert_forward(model, module.MistralMLP, mlp_silu_forward)
convert_forward(model, module.MistralModel, mistral_model_forward)
convert_forward(model, module.MistralAttention, mistral_attention_forward)
convert_forward(model, module.MistralSdpaAttention, mistral_attention_forward)
convert_forward(model, module.MistralRMSNorm, rms_norm_forward)
convert_forward(model, module.MistralMLP, mlp_silu_forward)
elif model.config.model_type == "gemma":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks,\
_prepare_generate_args_4_45
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.utils import get_xpu_device_type
from ipex_llm.transformers.utils import get_xpu_device_name

logger = logging.getLogger("ipex_llm.lookup")

Expand Down Expand Up @@ -295,7 +295,7 @@ def lookup_generate(self,
invalidInputError(input_ids.shape[0] == 1,
"Prompt lookup is currently not supported with batch inference.")

device_name = get_xpu_device_type(input_ids)
device_name = get_xpu_device_name(input_ids.device)

candidates_generator = PromptLookupCandidateGenerator(
num_output_tokens=num_output_tokens,
Expand Down
16 changes: 8 additions & 8 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from operator import mul
from functools import reduce
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name, \
get_ipex_version
from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm

Expand Down Expand Up @@ -266,7 +266,7 @@ def reshape_lm_head_input(x):


def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
device = get_xpu_device_type(x)
device_name = get_xpu_device_name(x.device)
batch_size = x.shape[0]
hard_condition = (
x.dtype in [torch.float, torch.half]
Expand All @@ -286,7 +286,7 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
or (
qtype in [SYM_INT8, FP4, FP6, Q4_K, Q6_K]
and batch_size <= 48
and device in ["arc", "flex", "pvc", "mtl"]
and device_name in ["arc", "pvc", "mtl", "lnl", "arl"]
and x.shape[1] % 256 == 0
and output_len % 32 == 0
)
Expand All @@ -295,8 +295,8 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
if hard_condition:
return (
batch_size > 1
or (device in ["arc", "flex"] and qtype in [SYM_INT8, FP4])
or (device in ["arc", "flex", "mtl"] and qtype in [FP8E4])
or (device in ["arc"] and qtype in [SYM_INT8, FP4])
or (device in ["arc", "mtl"] and qtype in [FP8E4])
or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
or (device in ["bmg"] and qtype in [SYM_INT4, FP8E5])
)
Expand Down Expand Up @@ -603,7 +603,7 @@ def forward(self, x: torch.Tensor):
# empty cache before and after lm_head at first token when input > 1024
# on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time.
if self.device is None:
self.device = get_xpu_device_type(self.weight.data)
self.device = get_xpu_device_name(self.weight.data.device)
self.low_memory_mode = \
self.low_memory_mode and \
(self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
Expand Down Expand Up @@ -782,7 +782,7 @@ def forward(self, x: torch.Tensor):
if not self.use_esimd_kernel(x):
if (
get_ipex_version() < "2.1.10+xpu"
or get_xpu_device_type(x) not in ["arc", "flex", "pvc"]
or get_xpu_device_name(x.device) not in ["arc", "pvc"]
or self.disable_fp16_opt
):
if self.weight_type == 2:
Expand Down Expand Up @@ -848,7 +848,7 @@ def forward(self, x: torch.Tensor):
return result.to(x.dtype)

def use_esimd_kernel(self, x):
gpu_type = get_xpu_device_type(x)
gpu_type = get_xpu_device_name(x.device)
if self.disable_fp16_opt:
return False
# esimd kernel can only be used for Arc and Flex
Expand Down
Loading

0 comments on commit ddc0ef3

Please sign in to comment.