diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 83ebe2448d0..78b83f88fa1 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -67,7 +67,7 @@ def from_pretrained(cls, warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used") kwargs['torch_dtype'] = torch.float - low_bit = kwargs.pop('load_in_low_bit', torch.float) + low_bit = kwargs.pop('load_in_low_bit', 'fp32') try: # for intel_npu_acceleration_library >= 1.1.0 from intel_npu_acceleration_library.dtypes import int8, int4