From 88e1016f83121dea6e0912f8a2e0daad6a4f5487 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Thu, 11 Jul 2024 17:59:20 +0800 Subject: [PATCH 1/4] add --- python/llm/src/ipex_llm/transformers/npu_model.py | 3 ++- .../src/ipex_llm/transformers/npu_models/convert.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 2a3ecffcda6..2cf7d13c000 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -150,8 +150,9 @@ def from_pretrained(cls, @classmethod def load_convert(cls, q_k, optimize_model, device, *arg, **kwarg): - from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear + from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear, replace_with_QuantizedMLP replace_with_QuantizedLinear(optimize_model, q_k, device=device) + replace_with_QuantizedMLP(optimize_model, q_k, device=device) @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 6d3c95ee0bf..4dde08d83dc 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -61,6 +61,15 @@ def replace_with_QuantizedLinear(layer, qtype, device): return QuantizedLinear(qweights, scale, layer.bias) +@module_optimization +def replace_with_QuantizedMLP(layer, qtype, device): + from transformers.models.llama.modeling_llama import LlamaMLP + from ipex_llm.transformers.npu_models.fusedmlp import FusedLlamaQuantizedMLP + if isinstance(layer, LlamaMLP): + weights = [(layer.gate_proj.weight, layer.gate_proj.scale), (layer.up_proj.weight, layer.up_proj.scale), (layer.down_proj.weight, layer.down_proj.scale)] + return FusedLlamaQuantizedMLP(weights) # TODO: handle bias + + def convert_forward(m, target_m, new_forward): if m.__class__ == target_m: bound_method = new_forward.__get__(m, m.__class__) @@ -74,7 +83,7 @@ def optimize_llm(model: torch.nn.Module): from ipex_llm.transformers.npu_models.llama import merge_qkv from ipex_llm.transformers.npu_models.llama import merge_mlp model.apply(merge_qkv) - model.apply(merge_mlp) + # model.apply(merge_mlp) from ipex_llm.transformers.npu_models.llama import llama_model_forward from ipex_llm.transformers.npu_models.llama import llama_attention_forward @@ -84,7 +93,7 @@ def optimize_llm(model: torch.nn.Module): from transformers.models.llama.modeling_llama import LlamaMLP convert_forward(model, LlamaModel, llama_model_forward) convert_forward(model, LlamaAttention, llama_attention_forward) - convert_forward(model, LlamaMLP, llama_mlp_forward) + # convert_forward(model, LlamaMLP, llama_mlp_forward) elif model.config.model_type == "mistral": from ipex_llm.transformers.npu_models.mistral import merge_qkv From 722df738ebc1b2eec016a566b7da711d960909f3 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Thu, 11 Jul 2024 17:59:32 +0800 Subject: [PATCH 2/4] initial --- .../transformers/npu_models/fusedmlp.py | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/fusedmlp.py diff --git a/python/llm/src/ipex_llm/transformers/npu_models/fusedmlp.py b/python/llm/src/ipex_llm/transformers/npu_models/fusedmlp.py new file mode 100644 index 00000000000..d0fbb709e87 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/fusedmlp.py @@ -0,0 +1,163 @@ +from intel_npu_acceleration_library.backend.factory import NNFactory +from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function, adapt_output_tensor, _model_cache +from typing import Optional, Sequence, List, Union, Any +from functools import partial +from collections import deque +import numpy as np +import torch +import uuid + + +class QuantizedMLP(NNFactory): + """Quantized Linear class, computing a matrix matrix multiplication with weights prefetching.""" + + def __init__( + self, + input_shape: Sequence[int], + intermediate_size: int, + activation: str = "swiglu", + bias: Optional[bool] = False, + dtype: np.dtype = np.int8, + profile: bool = False, + device: str = "NPU", + **additional_args + ): + """Initialize the Linear class. + + Args: + input_shape (Sequence[int]): input shape channels + intermediate_size (int): intermediate_size + activation (str): activation function to use + bias (Optional[bool], optional): Enable/Disable bias. Defaults to False. + profile (bool): Enable/Disable profiling. Defaults to False. + device (str): Target device, default to "NPU". + additional_args: additional arguments + """ + super().__init__(profile, device) + self.intermediate_size = intermediate_size + self.batch, self.hidden_size = input_shape + input = self.parameter((self.batch, self.hidden_size)) + + mm1 = self.linear(input, self.intermediate_size, self.hidden_size, bias=bias, wt_dtype=dtype) + + if activation == "swiglu": + mm2 = self.linear(input, self.intermediate_size, self.hidden_size, bias=bias, wt_dtype=dtype) # type: ignore[attr-defined] + mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] + elif activation == "clamp": + atc_fn = getattr(self, activation) + mm1 = atc_fn(mm1, additional_args.get("min"), additional_args.get("max")) + elif activation == "elu": + atc_fn = getattr(self, activation) + mm1 = atc_fn(mm1, additional_args.get("alpha", 1.0)) + elif activation == "grn": + atc_fn = getattr(self, activation) + mm1 = atc_fn(mm1, additional_args.get("grn_bias")) + else: + atc_fn = getattr(self, activation) + mm1 = atc_fn(mm1) + + _ = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=bias, wt_dtype=dtype) + self.compile() + + +class FusedLlamaQuantizedMLP(torch.nn.Module): + """LLAMA MLP operation NPU backend.""" + + def __init__( + self, + parameters: List[torch.Tensor], + ): + """Initialize LLAMA MLP operation. + + Args: + parameters (List[torch.Tensor]): model weights + """ + super().__init__() + self.op_parameters = parameters + self.op_id = str(uuid.uuid4()) + np_dtype = np.float16 + if isinstance(parameters[0], tuple): # from QuantizedLinear + np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + intermediate_size, _ = parameters[0][0].shape + else: + intermediate_size, _ = parameters[0].shape + self.backend_cls = partial(QuantizedMLP, intermediate_size=intermediate_size, dtype=np_dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Torch module forward method. + + Args: + x (torch.Tensor): Input tensor + + Returns: + torch.Tensor: result + """ + original_shape = x.shape + if len(x.shape) > 2: + x = x.view([-1, x.shape[-1]]) + output = run_factory(x, self.op_parameters, self.backend_cls, self.op_id) + return output.view(original_shape) + + +# TODO: separate it into a single file +@torch.no_grad() +def run_factory( + x: Union[torch.Tensor, List[torch.Tensor]], + weights: List[torch.Tensor], + backend_cls: Any, + op_id: Optional[str] = None, +) -> torch.Tensor: + """Run a factory operation. Depending on the datatype of the weights it runs a float or quantized operation. + + Args: + x (Union[torch.Tensor, List[torch.Tensor]]): Activation tensor(s). Its dtype must be torch.float16 + weights (torch.Tensor): Weights tensor. Its dtype can be torch.float16 or torch.int8 + backend_cls (Any): Backend class to run + op_id (Optional[str], optional): Operation ID. Defaults to None. + + Returns: + torch.Tensor: result + """ + global _model_cache + + # Use or not op_id depending on the class used + op_kwargs = {"op_id": op_id} if op_id else {} + + if not isinstance(x, (list, tuple)): + x = [x] + + # Reshape input + input_dtype = x[0].dtype + x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] + op_args = [] + op_args_flatten = [] + for w in weights: + if isinstance(w, tuple): # from QuantizedLinear + op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) + op_args_flatten.append(op_args[-1][0]) + op_args_flatten.append(op_args[-1][1]) + else: + op_args.append(set_contiguous(w).numpy()) + op_args_flatten.append(op_args[-1]) + + shape_dtype_signature = "_".join( + ["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten] + ) + key = f"{backend_cls.func.__name__}_{shape_dtype_signature}" + models = _model_cache.get(key, None) + + input_shapes = [elem.shape for elem in x_np] + if models is None: + _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(4)]) + elif len(models) < 1: + _model_cache[key].append(backend_cls(*input_shapes)) + else: + _model_cache[key].rotate(1) + + # Get the model + model = _model_cache[key][0] + + with record_function(f"npu_factory_mul_{key}"): + ret = model.run(*x_np, *op_args, **op_kwargs) + + return adapt_output_tensor(ret, ret.shape, input_dtype) From 51f7326823ebd5e49a7d3d2b607743b2851a3aa4 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Fri, 12 Jul 2024 17:52:04 +0800 Subject: [PATCH 3/4] refactor and meet review --- .../src/ipex_llm/transformers/npu_model.py | 4 +- .../transformers/npu_models/convert.py | 8 +- .../transformers/npu_models/fusedmlp.py | 163 ------------------ .../transformers/npu_models/lowbitmlp.py | 123 +++++++++++++ .../transformers/npu_models/runtime.py | 91 ++++++++++ 5 files changed, 219 insertions(+), 170 deletions(-) delete mode 100644 python/llm/src/ipex_llm/transformers/npu_models/fusedmlp.py create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/lowbitmlp.py create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/runtime.py diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 2cf7d13c000..3c42c1436e3 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -150,9 +150,9 @@ def from_pretrained(cls, @classmethod def load_convert(cls, q_k, optimize_model, device, *arg, **kwarg): - from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear, replace_with_QuantizedMLP + from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear, replace_with_LowBitMLP replace_with_QuantizedLinear(optimize_model, q_k, device=device) - replace_with_QuantizedMLP(optimize_model, q_k, device=device) + replace_with_LowBitMLP(optimize_model, q_k, device=device) @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 4dde08d83dc..29382af176a 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -62,12 +62,12 @@ def replace_with_QuantizedLinear(layer, qtype, device): @module_optimization -def replace_with_QuantizedMLP(layer, qtype, device): +def replace_with_LowBitMLP(layer, qtype, device): from transformers.models.llama.modeling_llama import LlamaMLP - from ipex_llm.transformers.npu_models.fusedmlp import FusedLlamaQuantizedMLP + from ipex_llm.transformers.npu_models.lowbitmlp import FusedLlamaLowBitMLP if isinstance(layer, LlamaMLP): weights = [(layer.gate_proj.weight, layer.gate_proj.scale), (layer.up_proj.weight, layer.up_proj.scale), (layer.down_proj.weight, layer.down_proj.scale)] - return FusedLlamaQuantizedMLP(weights) # TODO: handle bias + return FusedLlamaLowBitMLP(weights) # TODO: handle bias def convert_forward(m, target_m, new_forward): @@ -83,7 +83,6 @@ def optimize_llm(model: torch.nn.Module): from ipex_llm.transformers.npu_models.llama import merge_qkv from ipex_llm.transformers.npu_models.llama import merge_mlp model.apply(merge_qkv) - # model.apply(merge_mlp) from ipex_llm.transformers.npu_models.llama import llama_model_forward from ipex_llm.transformers.npu_models.llama import llama_attention_forward @@ -93,7 +92,6 @@ def optimize_llm(model: torch.nn.Module): from transformers.models.llama.modeling_llama import LlamaMLP convert_forward(model, LlamaModel, llama_model_forward) convert_forward(model, LlamaAttention, llama_attention_forward) - # convert_forward(model, LlamaMLP, llama_mlp_forward) elif model.config.model_type == "mistral": from ipex_llm.transformers.npu_models.mistral import merge_qkv diff --git a/python/llm/src/ipex_llm/transformers/npu_models/fusedmlp.py b/python/llm/src/ipex_llm/transformers/npu_models/fusedmlp.py deleted file mode 100644 index d0fbb709e87..00000000000 --- a/python/llm/src/ipex_llm/transformers/npu_models/fusedmlp.py +++ /dev/null @@ -1,163 +0,0 @@ -from intel_npu_acceleration_library.backend.factory import NNFactory -from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function, adapt_output_tensor, _model_cache -from typing import Optional, Sequence, List, Union, Any -from functools import partial -from collections import deque -import numpy as np -import torch -import uuid - - -class QuantizedMLP(NNFactory): - """Quantized Linear class, computing a matrix matrix multiplication with weights prefetching.""" - - def __init__( - self, - input_shape: Sequence[int], - intermediate_size: int, - activation: str = "swiglu", - bias: Optional[bool] = False, - dtype: np.dtype = np.int8, - profile: bool = False, - device: str = "NPU", - **additional_args - ): - """Initialize the Linear class. - - Args: - input_shape (Sequence[int]): input shape channels - intermediate_size (int): intermediate_size - activation (str): activation function to use - bias (Optional[bool], optional): Enable/Disable bias. Defaults to False. - profile (bool): Enable/Disable profiling. Defaults to False. - device (str): Target device, default to "NPU". - additional_args: additional arguments - """ - super().__init__(profile, device) - self.intermediate_size = intermediate_size - self.batch, self.hidden_size = input_shape - input = self.parameter((self.batch, self.hidden_size)) - - mm1 = self.linear(input, self.intermediate_size, self.hidden_size, bias=bias, wt_dtype=dtype) - - if activation == "swiglu": - mm2 = self.linear(input, self.intermediate_size, self.hidden_size, bias=bias, wt_dtype=dtype) # type: ignore[attr-defined] - mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] - elif activation == "clamp": - atc_fn = getattr(self, activation) - mm1 = atc_fn(mm1, additional_args.get("min"), additional_args.get("max")) - elif activation == "elu": - atc_fn = getattr(self, activation) - mm1 = atc_fn(mm1, additional_args.get("alpha", 1.0)) - elif activation == "grn": - atc_fn = getattr(self, activation) - mm1 = atc_fn(mm1, additional_args.get("grn_bias")) - else: - atc_fn = getattr(self, activation) - mm1 = atc_fn(mm1) - - _ = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=bias, wt_dtype=dtype) - self.compile() - - -class FusedLlamaQuantizedMLP(torch.nn.Module): - """LLAMA MLP operation NPU backend.""" - - def __init__( - self, - parameters: List[torch.Tensor], - ): - """Initialize LLAMA MLP operation. - - Args: - parameters (List[torch.Tensor]): model weights - """ - super().__init__() - self.op_parameters = parameters - self.op_id = str(uuid.uuid4()) - np_dtype = np.float16 - if isinstance(parameters[0], tuple): # from QuantizedLinear - np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 - intermediate_size, _ = parameters[0][0].shape - else: - intermediate_size, _ = parameters[0].shape - self.backend_cls = partial(QuantizedMLP, intermediate_size=intermediate_size, dtype=np_dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Torch module forward method. - - Args: - x (torch.Tensor): Input tensor - - Returns: - torch.Tensor: result - """ - original_shape = x.shape - if len(x.shape) > 2: - x = x.view([-1, x.shape[-1]]) - output = run_factory(x, self.op_parameters, self.backend_cls, self.op_id) - return output.view(original_shape) - - -# TODO: separate it into a single file -@torch.no_grad() -def run_factory( - x: Union[torch.Tensor, List[torch.Tensor]], - weights: List[torch.Tensor], - backend_cls: Any, - op_id: Optional[str] = None, -) -> torch.Tensor: - """Run a factory operation. Depending on the datatype of the weights it runs a float or quantized operation. - - Args: - x (Union[torch.Tensor, List[torch.Tensor]]): Activation tensor(s). Its dtype must be torch.float16 - weights (torch.Tensor): Weights tensor. Its dtype can be torch.float16 or torch.int8 - backend_cls (Any): Backend class to run - op_id (Optional[str], optional): Operation ID. Defaults to None. - - Returns: - torch.Tensor: result - """ - global _model_cache - - # Use or not op_id depending on the class used - op_kwargs = {"op_id": op_id} if op_id else {} - - if not isinstance(x, (list, tuple)): - x = [x] - - # Reshape input - input_dtype = x[0].dtype - x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] - op_args = [] - op_args_flatten = [] - for w in weights: - if isinstance(w, tuple): # from QuantizedLinear - op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) - op_args_flatten.append(op_args[-1][0]) - op_args_flatten.append(op_args[-1][1]) - else: - op_args.append(set_contiguous(w).numpy()) - op_args_flatten.append(op_args[-1]) - - shape_dtype_signature = "_".join( - ["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten] - ) - key = f"{backend_cls.func.__name__}_{shape_dtype_signature}" - models = _model_cache.get(key, None) - - input_shapes = [elem.shape for elem in x_np] - if models is None: - _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(4)]) - elif len(models) < 1: - _model_cache[key].append(backend_cls(*input_shapes)) - else: - _model_cache[key].rotate(1) - - # Get the model - model = _model_cache[key][0] - - with record_function(f"npu_factory_mul_{key}"): - ret = model.run(*x_np, *op_args, **op_kwargs) - - return adapt_output_tensor(ret, ret.shape, input_dtype) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/lowbitmlp.py b/python/llm/src/ipex_llm/transformers/npu_models/lowbitmlp.py new file mode 100644 index 00000000000..9ae26ac8e12 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/lowbitmlp.py @@ -0,0 +1,123 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is adapted from +# https://github.com/intel/intel-npu-acceleration-library/blob/main/intel_npu_acceleration_library/nn/linear.py + +# +# Copyright © 2024 Intel Corporation +# SPDX-License-Identifier: Apache 2.0 +# + +from ipex_llm.transformers.npu_models.runtime import run_model +from intel_npu_acceleration_library.backend.factory import NNFactory +from typing import Optional, Sequence, List +from functools import partial +import numpy as np +import torch +import uuid + + +class LowBitMLP(NNFactory): + """Computing a LowBit MLP with weights prefetching.""" + + def __init__( + self, + input_shape: Sequence[int], + intermediate_size: int, + activation: str = "swiglu", + bias: Optional[bool] = False, + dtype: np.dtype = np.int8, + profile: bool = False, + device: str = "NPU", + **additional_args + ): + """Initialize the LowBitMLP class. + + Args: + input_shape (Sequence[int]): input shape channels + intermediate_size (int): intermediate_size + activation (str): activation function to use + bias (Optional[bool]): Enable/Disable bias. Defaults to False. + dtype (np.dtype): parameter type np.int8, np.uint8 and np.float16 supported. Defaults to np.int8. Unit8 represents packed i4 dtypes. + profile (bool): Enable/Disable profiling. Defaults to False. + device (str): Target device, default to "NPU". + additional_args: additional arguments + """ + super().__init__(profile, device) + self.intermediate_size = intermediate_size + self.batch, self.hidden_size = input_shape + input = self.parameter((self.batch, self.hidden_size)) + + mm1 = self.linear(input, self.intermediate_size, self.hidden_size, bias=bias, wt_dtype=dtype) + + if activation == "swiglu": + mm2 = self.linear(input, self.intermediate_size, self.hidden_size, bias=bias, wt_dtype=dtype) # type: ignore[attr-defined] + mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] + elif activation == "clamp": + atc_fn = getattr(self, activation) + mm1 = atc_fn(mm1, additional_args.get("min"), additional_args.get("max")) + elif activation == "elu": + atc_fn = getattr(self, activation) + mm1 = atc_fn(mm1, additional_args.get("alpha", 1.0)) + elif activation == "grn": + atc_fn = getattr(self, activation) + mm1 = atc_fn(mm1, additional_args.get("grn_bias")) + else: + atc_fn = getattr(self, activation) + mm1 = atc_fn(mm1) + + _ = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=bias, wt_dtype=dtype) + self.compile() + + +class FusedLlamaLowBitMLP(torch.nn.Module): + """LLAMA LowBit MLP operation NPU backend.""" + + def __init__( + self, + parameters: List[torch.Tensor], + ): + """Initialize LLAMA LowBit MLP operation. + + Args: + parameters (List[torch.Tensor]): model weights + """ + super().__init__() + self.op_parameters = parameters + self.op_id = str(uuid.uuid4()) + if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear + np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + intermediate_size, _ = parameters[0][0].shape + else: # FP16 Linear + np_dtype = np.float16 + intermediate_size, _ = parameters[0].shape + self.backend_cls = partial(LowBitMLP, intermediate_size=intermediate_size, dtype=np_dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Torch module forward method. + + Args: + x (torch.Tensor): Input tensor + + Returns: + torch.Tensor: result + """ + # Handle 3D input shape (similarly done in run_matmul) + original_shape = x.shape + if len(x.shape) > 2: + x = x.view([-1, x.shape[-1]]) + output = run_model(x, self.op_parameters, self.backend_cls, self.op_id) + return output.view(original_shape) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/runtime.py b/python/llm/src/ipex_llm/transformers/npu_models/runtime.py new file mode 100644 index 00000000000..a1175df76ac --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/runtime.py @@ -0,0 +1,91 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is adapted from +# https://github.com/intel/intel-npu-acceleration-library/blob/main/intel_npu_acceleration_library/nn/linear.py + +# +# Copyright © 2024 Intel Corporation +# SPDX-License-Identifier: Apache 2.0 +# + +from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function, adapt_output_tensor, _model_cache +from typing import Optional, List, Union, Any +from collections import deque +import torch + +NUM_REPLICAS = 4 # TODO: make it an environment variable? + +@torch.no_grad() +def run_model( + x: Union[torch.Tensor, List[torch.Tensor]], + weights: List[torch.Tensor], + backend_cls: Any, + op_id: Optional[str] = None, +) -> torch.Tensor: + """Run a factory operation. Depending on the datatype of the weights it runs a float or quantized operation. + + Args: + x (Union[torch.Tensor, List[torch.Tensor]]): Activation tensor(s). Its dtype must be torch.float16 + weights (torch.Tensor): Weights tensor. Its dtype can be torch.float16 or torch.int8 + backend_cls (Any): Backend class to run + op_id (Optional[str], optional): Operation ID. Defaults to None. + + Returns: + torch.Tensor: result + """ + global _model_cache + + # Use or not op_id depending on the class used + op_kwargs = {"op_id": op_id} if op_id else {} + + if not isinstance(x, (list, tuple)): + x = [x] + + # Reshape input + input_dtype = x[0].dtype + x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] + op_args = [] + op_args_flatten = [] + for w in weights: + if isinstance(w, tuple): # from QuantizedLinear + op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) + op_args_flatten.append(op_args[-1][0]) + op_args_flatten.append(op_args[-1][1]) + else: + op_args.append(set_contiguous(w).numpy()) + op_args_flatten.append(op_args[-1]) + + shape_dtype_signature = "_".join( + ["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten] + ) + key = f"{backend_cls.func.__name__}_{shape_dtype_signature}" + models = _model_cache.get(key, None) + + input_shapes = [elem.shape for elem in x_np] + if models is None: + _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(NUM_REPLICAS)]) + elif len(models) < 1: + _model_cache[key].append(backend_cls(*input_shapes)) + else: + _model_cache[key].rotate(1) + + # Get the model + model = _model_cache[key][0] + + with record_function(f"npu_factory_mul_{key}"): + ret = model.run(*x_np, *op_args, **op_kwargs) + + return adapt_output_tensor(ret, ret.shape, input_dtype) From 8152e0f81189c35c65a3ac30546a8b9f2546ce66 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Fri, 12 Jul 2024 18:40:15 +0800 Subject: [PATCH 4/4] fix style --- .../llm/src/ipex_llm/transformers/npu_model.py | 3 ++- .../ipex_llm/transformers/npu_models/convert.py | 4 +++- .../transformers/npu_models/lowbitmlp.py | 15 +++++++++------ .../ipex_llm/transformers/npu_models/runtime.py | 16 ++++++++++------ 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 3c42c1436e3..0eec9919909 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -150,7 +150,8 @@ def from_pretrained(cls, @classmethod def load_convert(cls, q_k, optimize_model, device, *arg, **kwarg): - from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear, replace_with_LowBitMLP + from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear + from ipex_llm.transformers.npu_models.convert import replace_with_LowBitMLP replace_with_QuantizedLinear(optimize_model, q_k, device=device) replace_with_LowBitMLP(optimize_model, q_k, device=device) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 29382af176a..6871202c73b 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -66,7 +66,9 @@ def replace_with_LowBitMLP(layer, qtype, device): from transformers.models.llama.modeling_llama import LlamaMLP from ipex_llm.transformers.npu_models.lowbitmlp import FusedLlamaLowBitMLP if isinstance(layer, LlamaMLP): - weights = [(layer.gate_proj.weight, layer.gate_proj.scale), (layer.up_proj.weight, layer.up_proj.scale), (layer.down_proj.weight, layer.down_proj.scale)] + weights = [(layer.gate_proj.weight, layer.gate_proj.scale), + (layer.up_proj.weight, layer.up_proj.scale), + (layer.down_proj.weight, layer.down_proj.scale)] return FusedLlamaLowBitMLP(weights) # TODO: handle bias diff --git a/python/llm/src/ipex_llm/transformers/npu_models/lowbitmlp.py b/python/llm/src/ipex_llm/transformers/npu_models/lowbitmlp.py index 9ae26ac8e12..9bf2c0fc249 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/lowbitmlp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/lowbitmlp.py @@ -47,11 +47,12 @@ def __init__( """Initialize the LowBitMLP class. Args: - input_shape (Sequence[int]): input shape channels - intermediate_size (int): intermediate_size - activation (str): activation function to use + input_shape (Sequence[int]): input shape channels. + intermediate_size (int): intermediate_size of the MLP. + activation (str): activation function to use. bias (Optional[bool]): Enable/Disable bias. Defaults to False. - dtype (np.dtype): parameter type np.int8, np.uint8 and np.float16 supported. Defaults to np.int8. Unit8 represents packed i4 dtypes. + dtype (np.dtype): parameter type np.int8, np.uint8 and np.float16 supported. + Defaults to np.int8. Unit8 represents packed i4 dtypes. profile (bool): Enable/Disable profiling. Defaults to False. device (str): Target device, default to "NPU". additional_args: additional arguments @@ -61,10 +62,12 @@ def __init__( self.batch, self.hidden_size = input_shape input = self.parameter((self.batch, self.hidden_size)) - mm1 = self.linear(input, self.intermediate_size, self.hidden_size, bias=bias, wt_dtype=dtype) + mm1 = self.linear(input, self.intermediate_size, self.hidden_size, + bias=bias, wt_dtype=dtype) if activation == "swiglu": - mm2 = self.linear(input, self.intermediate_size, self.hidden_size, bias=bias, wt_dtype=dtype) # type: ignore[attr-defined] + mm2 = self.linear(input, self.intermediate_size, self.hidden_size, + bias=bias, wt_dtype=dtype) # type: ignore[attr-defined] mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] elif activation == "clamp": atc_fn = getattr(self, activation) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/runtime.py b/python/llm/src/ipex_llm/transformers/npu_models/runtime.py index a1175df76ac..3f0177eb13f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/runtime.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/runtime.py @@ -21,13 +21,15 @@ # SPDX-License-Identifier: Apache 2.0 # -from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function, adapt_output_tensor, _model_cache +from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function +from intel_npu_acceleration_library.backend.runtime import adapt_output_tensor, _model_cache from typing import Optional, List, Union, Any from collections import deque import torch NUM_REPLICAS = 4 # TODO: make it an environment variable? + @torch.no_grad() def run_model( x: Union[torch.Tensor, List[torch.Tensor]], @@ -35,12 +37,14 @@ def run_model( backend_cls: Any, op_id: Optional[str] = None, ) -> torch.Tensor: - """Run a factory operation. Depending on the datatype of the weights it runs a float or quantized operation. + """Run a factory operation. + Depending on the datatype of the weights it runs a float or quantized operation. Args: - x (Union[torch.Tensor, List[torch.Tensor]]): Activation tensor(s). Its dtype must be torch.float16 - weights (torch.Tensor): Weights tensor. Its dtype can be torch.float16 or torch.int8 - backend_cls (Any): Backend class to run + x (Union[torch.Tensor, List[torch.Tensor]]): Activation tensor(s). + Its dtype must be torch.float16. + weights (torch.Tensor): Weights tensor. Its dtype can be torch.float16 or torch.int8. + backend_cls (Any): Backend class to run. op_id (Optional[str], optional): Operation ID. Defaults to None. Returns: @@ -60,7 +64,7 @@ def run_model( op_args = [] op_args_flatten = [] for w in weights: - if isinstance(w, tuple): # from QuantizedLinear + if isinstance(w, tuple): # from QuantizedLinear op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) op_args_flatten.append(op_args[-1][0]) op_args_flatten.append(op_args[-1][1])