Skip to content

Commit

Permalink
refactor and meet review
Browse files Browse the repository at this point in the history
  • Loading branch information
hkvision committed Jul 12, 2024
1 parent 722df73 commit 51f7326
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 170 deletions.
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def from_pretrained(cls,

@classmethod
def load_convert(cls, q_k, optimize_model, device, *arg, **kwarg):
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear, replace_with_QuantizedMLP
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear, replace_with_LowBitMLP
replace_with_QuantizedLinear(optimize_model, q_k, device=device)
replace_with_QuantizedMLP(optimize_model, q_k, device=device)
replace_with_LowBitMLP(optimize_model, q_k, device=device)

@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
Expand Down
8 changes: 3 additions & 5 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def replace_with_QuantizedLinear(layer, qtype, device):


@module_optimization
def replace_with_QuantizedMLP(layer, qtype, device):
def replace_with_LowBitMLP(layer, qtype, device):
from transformers.models.llama.modeling_llama import LlamaMLP
from ipex_llm.transformers.npu_models.fusedmlp import FusedLlamaQuantizedMLP
from ipex_llm.transformers.npu_models.lowbitmlp import FusedLlamaLowBitMLP
if isinstance(layer, LlamaMLP):
weights = [(layer.gate_proj.weight, layer.gate_proj.scale), (layer.up_proj.weight, layer.up_proj.scale), (layer.down_proj.weight, layer.down_proj.scale)]
return FusedLlamaQuantizedMLP(weights) # TODO: handle bias
return FusedLlamaLowBitMLP(weights) # TODO: handle bias


def convert_forward(m, target_m, new_forward):
Expand All @@ -83,7 +83,6 @@ def optimize_llm(model: torch.nn.Module):
from ipex_llm.transformers.npu_models.llama import merge_qkv
from ipex_llm.transformers.npu_models.llama import merge_mlp
model.apply(merge_qkv)
# model.apply(merge_mlp)

from ipex_llm.transformers.npu_models.llama import llama_model_forward
from ipex_llm.transformers.npu_models.llama import llama_attention_forward
Expand All @@ -93,7 +92,6 @@ def optimize_llm(model: torch.nn.Module):
from transformers.models.llama.modeling_llama import LlamaMLP
convert_forward(model, LlamaModel, llama_model_forward)
convert_forward(model, LlamaAttention, llama_attention_forward)
# convert_forward(model, LlamaMLP, llama_mlp_forward)

elif model.config.model_type == "mistral":
from ipex_llm.transformers.npu_models.mistral import merge_qkv
Expand Down
163 changes: 0 additions & 163 deletions python/llm/src/ipex_llm/transformers/npu_models/fusedmlp.py

This file was deleted.

123 changes: 123 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/lowbitmlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is adapted from
# https://github.com/intel/intel-npu-acceleration-library/blob/main/intel_npu_acceleration_library/nn/linear.py

#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

from ipex_llm.transformers.npu_models.runtime import run_model
from intel_npu_acceleration_library.backend.factory import NNFactory
from typing import Optional, Sequence, List
from functools import partial
import numpy as np
import torch
import uuid


class LowBitMLP(NNFactory):
"""Computing a LowBit MLP with weights prefetching."""

def __init__(
self,
input_shape: Sequence[int],
intermediate_size: int,
activation: str = "swiglu",
bias: Optional[bool] = False,
dtype: np.dtype = np.int8,
profile: bool = False,
device: str = "NPU",
**additional_args
):
"""Initialize the LowBitMLP class.
Args:
input_shape (Sequence[int]): input shape channels
intermediate_size (int): intermediate_size
activation (str): activation function to use
bias (Optional[bool]): Enable/Disable bias. Defaults to False.
dtype (np.dtype): parameter type np.int8, np.uint8 and np.float16 supported. Defaults to np.int8. Unit8 represents packed i4 dtypes.
profile (bool): Enable/Disable profiling. Defaults to False.
device (str): Target device, default to "NPU".
additional_args: additional arguments
"""
super().__init__(profile, device)
self.intermediate_size = intermediate_size
self.batch, self.hidden_size = input_shape
input = self.parameter((self.batch, self.hidden_size))

mm1 = self.linear(input, self.intermediate_size, self.hidden_size, bias=bias, wt_dtype=dtype)

if activation == "swiglu":
mm2 = self.linear(input, self.intermediate_size, self.hidden_size, bias=bias, wt_dtype=dtype) # type: ignore[attr-defined]
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
elif activation == "clamp":
atc_fn = getattr(self, activation)
mm1 = atc_fn(mm1, additional_args.get("min"), additional_args.get("max"))
elif activation == "elu":
atc_fn = getattr(self, activation)
mm1 = atc_fn(mm1, additional_args.get("alpha", 1.0))
elif activation == "grn":
atc_fn = getattr(self, activation)
mm1 = atc_fn(mm1, additional_args.get("grn_bias"))
else:
atc_fn = getattr(self, activation)
mm1 = atc_fn(mm1)

_ = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=bias, wt_dtype=dtype)
self.compile()


class FusedLlamaLowBitMLP(torch.nn.Module):
"""LLAMA LowBit MLP operation NPU backend."""

def __init__(
self,
parameters: List[torch.Tensor],
):
"""Initialize LLAMA LowBit MLP operation.
Args:
parameters (List[torch.Tensor]): model weights
"""
super().__init__()
self.op_parameters = parameters
self.op_id = str(uuid.uuid4())
if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
intermediate_size, _ = parameters[0][0].shape
else: # FP16 Linear
np_dtype = np.float16
intermediate_size, _ = parameters[0].shape
self.backend_cls = partial(LowBitMLP, intermediate_size=intermediate_size, dtype=np_dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Torch module forward method.
Args:
x (torch.Tensor): Input tensor
Returns:
torch.Tensor: result
"""
# Handle 3D input shape (similarly done in run_matmul)
original_shape = x.shape
if len(x.shape) > 2:
x = x.view([-1, x.shape[-1]])
output = run_model(x, self.op_parameters, self.backend_cls, self.op_id)
return output.view(original_shape)
Loading

0 comments on commit 51f7326

Please sign in to comment.