Skip to content

Commit

Permalink
add fp16 NPU Linear support and fix version 1.0 support
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Jun 18, 2024
1 parent 83082e5 commit 646730c
Showing 1 changed file with 28 additions and 17 deletions.
45 changes: 28 additions & 17 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from transformers.dynamic_module_utils import get_imports

import intel_npu_acceleration_library as npu_lib
from intel_npu_acceleration_library.dtypes import int8, int4

from ipex_llm.utils.common.log4Error import invalidInputError

Expand Down Expand Up @@ -55,28 +54,40 @@ def from_pretrained(cls,
The loaded model will run supported OPs on NPU, then run other OPs on CPU.
Three new arguments are added to extend Hugging Face's from_pretrained method as follows:
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``, ``'fp32'``.
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'sym_int8'``,
``'fp16'``, ``'fp32'``.
Relevant low bit optimizations will be applied to the model.
:return: a model instance
"""
if kwargs.get('device_map', None) not in [None, 'cpu', 'auto']:
warnings.warn("`device_map` will be ignored")
kwargs['device_map'] = 'cpu'

low_bit = kwargs.pop('load_in_low_bit', None)
low_bit_to_dtype_map = {
'sym_int4': int4,
'sym_int8': int8,
'fp32': torch.float,
}
if low_bit is not None:
dtype = low_bit_to_dtype_map[low_bit]
else:
dtype = kwargs.get('torch_dtype', torch.float)
dtype = torch.float if dtype == 'auto' else dtype
invalidInputError(dtype in low_bit_to_dtype_map.values(),
f"unsupported dtype: {dtype}, "
"only `sym_int4`, `sym_int8`, `fp32` are supported")
if kwargs.get('torch_dtype', None) not in [None, 'auto', torch.float]:
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
kwargs['torch_dtype'] = torch.float

low_bit = kwargs.pop('load_in_low_bit', torch.float)
try:
# for intel_npu_acceleration_library >= 1.1.0
from intel_npu_acceleration_library.dtypes import int8, int4
qtype_map = {
'sym_int4': int4,
'sym_int8': int8,
'fp16': torch.half,
'fp32': torch.float,
}
except ImportError as _e:
# for intel_npu_acceleration_library < 1.1.0
qtype_map = {
'sym_int8': torch.int8,
'fp16': torch.half,
'fp32': torch.float,
}
invalidInputError(low_bit in qtype_map.keys(),
f"unsupported low_bit: {low_bit}, "
f"only {list(qtype_map.keys())} are supported")
qtype = qtype_map[low_bit]

kwargs["low_cpu_mem_usage"] = True

Expand All @@ -96,7 +107,7 @@ def from_pretrained(cls,
ignore_argument(kwargs, "pipeline_parallel_stages")

model = cls.HF_Model.from_pretrained(*args, **kwargs)
model = npu_lib.compile(model, dtype, False)
model = npu_lib.compile(model, qtype, False)

return model

Expand Down

0 comments on commit 646730c

Please sign in to comment.