Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] initial support of asym_int4_rtn #12484

Merged
merged 21 commits into from
Dec 5, 2024
1 change: 1 addition & 0 deletions python/llm/src/ipex_llm/ggml/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"fp6_k": 30,
"sym_int4_rtn": 31,
"sym_int8_rtn": 32,
"asym_int4_rtn": 33,
}

# mixed precison from llama.cpp
Expand Down
16 changes: 11 additions & 5 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@
FP6_K = ggml_tensor_qtype["fp6_k"]
SYM_INT4_RTN = ggml_tensor_qtype["sym_int4_rtn"]
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
ASYM_INT4_RTN = ggml_tensor_qtype["asym_int4_rtn"]
RTN_DTYPE = {
SYM_INT4_RTN: torch.uint8,
ASYM_INT4_RTN: torch.uint8,
SYM_INT8_RTN: torch.int8,
}

Expand Down Expand Up @@ -223,12 +225,16 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
f"Last dim of input tensor must be multiple of {QK}")

dst_size = (n // QK) * block_size_in_bytes
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
device=device)
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
scale = torch.empty(n // k, dtype=torch.float32,
device=device)
if qtype == ASYM_INT4_RTN:
scale = torch.empty((n // k) * 2, dtype=torch.float32,
device=device)
else:
scale = torch.empty(n // k, dtype=torch.float32,
device=device)
elif qtype == NF4:
# Deepspeed zero3 requires unified dtype,
# thus here uses bfloat16 consistent to other layers
Expand All @@ -244,7 +250,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)()
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
if imatrix is None:
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
Expand All @@ -269,7 +275,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
n // in_features, in_features,
hist, imatrix)
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
return dst_tensor, scale.type(torch.float16)
else:
return dst_tensor
Expand Down
8 changes: 5 additions & 3 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def from_pretrained(cls, *args, **kwargs):
qtype_map = {
"sym_int4": "sym_int4_rtn",
"sym_int8": "sym_int8_rtn",
"asym_int4": "asym_int4_rtn",
}

invalidInputError(
Expand Down Expand Up @@ -154,7 +155,7 @@ def from_pretrained(cls, *args, **kwargs):
f"but got {quantization_group_size}"
)
)
_args = copy.deepcopy(args)

_kwargs = copy.deepcopy(kwargs)

try:
Expand Down Expand Up @@ -270,6 +271,7 @@ def optimize_npu_model(cls, *args, **kwargs):
with torch.no_grad():
model.config.update({"mixed_precision": mixed_precision})
model.config.update({"group_size": quantization_group_size})
model.config.update({"asym": qtype == "asym_int4_rtn"})
optimize_llm_pre(model, qtype, mixed_precision,
quantization_group_size=quantization_group_size)
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
Expand Down Expand Up @@ -416,9 +418,9 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs)
)

invalidInputError(
qtype in ["sym_int8_rtn", "sym_int4_rtn"],
qtype in ["sym_int8_rtn", "sym_int4_rtn", "asym_int4_rtn"],
f"Unknown bigdl_transformers_low_bit value: {qtype},"
f" expected: sym_int8_rtn, sym_int4_rtn. "
f" expected: sym_int8_rtn, sym_int4_rtn, asym_int4_rtn. "
)

if enable_cpp_backend:
Expand Down
28 changes: 22 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,26 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
if qtype == "sym_int4_rtn":
if qtype in ["sym_int4_rtn", "asym_int4_rtn"]:
# workaround for qwen2-7B & int4
if (layer.in_features == 3584 and layer.out_features == 152064) or \
(layer.in_features == 18944 and layer.out_features == 3584):
if (layer.in_features == 3584 and layer.out_features == 152064):
rnwang04 marked this conversation as resolved.
Show resolved Hide resolved
qtype = "sym_int8_rtn"
iqtype = ggml_tensor_qtype[qtype]
if qtype == "sym_int4_rtn":
if (layer.in_features == 18944 and layer.out_features == 3584):
qtype = "sym_int8_rtn"
iqtype = ggml_tensor_qtype[qtype]
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device,
enable_scale_search=enable_scale_search,
imatrix=imatrix)
return QuantizedLinear(qweights, scale, layer.bias,
group_size=group_size)
zero = None
# split scale to scale & zero
if qtype == "asym_int4_rtn":
scale, zero = torch.split(scale, scale.shape[0] // 2)
return QuantizedLinear(qweights, scale, zero, layer.bias,
group_size=group_size, qtype=qtype)


@module_optimization
Expand All @@ -111,12 +118,21 @@ def replace_with_DequantizedLinear(layer, qtype, device, modules_to_not_convert,
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
if qtype in ["sym_int4_rtn", "asym_int4_rtn"]:
# workaround for qwen2-7B & int4
if (layer.in_features == 3584 and layer.out_features == 152064):
qtype = "sym_int8_rtn"
iqtype = ggml_tensor_qtype[qtype]
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device,
enable_scale_search=enable_scale_search,
imatrix=imatrix)
return DequantizedLinear(qweights, scale, layer.bias)
zero = None
# split scale to scale & zero
if qtype == "asym_int4_rtn":
scale, zero = torch.split(scale, scale.shape[0] // 2)
return DequantizedLinear(qweights, scale, zero, layer.bias, qtype)


@module_optimization
Expand Down
17 changes: 11 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
from ipex_llm.transformers.npu_models.common import split_linears
if quantization_group_size == 0:
n_splits_linear = 1
if qtype == "sym_int8_rtn":
if qtype in ["sym_int8_rtn", "asym_int4_rtn"]:
# do not split mlp down_proj for Qwen2-7B & sym_int8
n_splits_down_proj = 1
else:
Expand All @@ -154,18 +154,21 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
# workaround for MiniCPM-2B
new_lm_head_0 = SlicedLMHead(model.lm_head_0.weight, split_num=split_num,
bias=model.lm_head_0.bias, use_split=True,
group_size=quantization_group_size)
group_size=quantization_group_size,
asym=(qtype == "asym_int4_rtn"))
del model.lm_head_0
model.lm_head_0 = new_lm_head_0
new_lm_head_1 = SlicedLMHead(model.lm_head_1.weight, split_num=split_num,
bias=model.lm_head_1.bias, use_split=True,
group_size=quantization_group_size)
group_size=quantization_group_size,
asym=(qtype == "asym_int4_rtn"))
del model.lm_head_1
model.lm_head_1 = new_lm_head_1
else:
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
bias=model.lm_head.bias, use_split=True,
group_size=quantization_group_size)
group_size=quantization_group_size,
asym=(qtype == "asym_int4_rtn"))
del model.lm_head
model.lm_head = new_lm_head

Expand All @@ -176,11 +179,13 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
if quantization_group_size == 0:
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
is_split = (not mixed_precision) and qtype == "sym_int4_rtn"
is_split = (not mixed_precision) and qtype in ["sym_int4_rtn", "asym_int4_rtn"]
split_num = 14 if is_split else 1
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
bias=model.lm_head.bias, use_split=True,
group_size=quantization_group_size)
group_size=quantization_group_size,
asym=((qtype == "asym_int4_rtn") and
(not mixed_precision)))
del model.lm_head
model.lm_head = new_lm_head

Expand Down
19 changes: 18 additions & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,20 @@ def __init__(
self,
weight: torch.Tensor,
scale: torch.Tensor,
zero: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
qtype: Optional[str] = "sym_int4_rtn",
group_size: int = 0,
):
"""Initialize the QuantizedLinear class.

Args:
weight (torch.Tensor): Linear operation weight
scale (torch.Tensor): Quantization scale
zero (Optional[torch.Tensor], optional): Quantization zero for asym_int4_rtn
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
Defaults to None.
qtype (Optional[str], optional): qtype of this Linear

Raises:
RuntimeError: Quantized weight must be in torch.int8 format
Expand All @@ -155,14 +159,19 @@ def __init__(
)
)
self.outC, self.inC = self.weight.shape
self.zero = None
if group_size != 0:
self.scale = Parameter(scale, requires_grad=False)
self.zero = Parameter(zero, requires_grad=False)
else:
if self.weight.dtype == torch.uint8:
# Int4 we need to double the input channels because weights are compressed
self.inC *= 2
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
if zero is not None:
self.zero = Parameter(zero * math.sqrt(self.inC), requires_grad=False)
self.bias = bias
self.qtype = qtype
self.op_id = str(uuid.uuid4())

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -195,7 +204,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
)

out = run_matmul(x, self.weight.data, self.scale.data, self.op_id)
zero_data = self.zero.data if self.zero is not None else None
out = run_matmul(x, self.weight.data, self.scale.data, zero_data, self.op_id)

if self.bias is None:
return out
Expand All @@ -209,14 +219,18 @@ def __init__(
self,
weight: torch.Tensor,
scale: torch.Tensor,
zero: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
qtype: Optional[str] = "sym_int4_rtn",
):
"""Initialize the DequantizedLinear class.
Args:
weight (torch.Tensor): Linear operation quantized weight
scale (torch.Tensor): Quantization scale
zero (Optional[torch.Tensor], optional): Quantization zero for asym_int4_rtn
bias (Optional[torch.Tensor], optional): Linear operation optional bias.
Defaults to None.
qtype (Optional[str], optional): qtype of this Linear
Raises:
RuntimeError: Quantized weight must be in torch.int8 format
"""
Expand All @@ -240,6 +254,9 @@ def __init__(
decompressed_weight = combined_weight.view(combined_weight.size(0), -1)
dequantized_weight = decompressed_weight.to(torch.float32) * \
torch.unsqueeze(scale.to(torch.float32), dim=1)
if qtype == "asym_int4_rtn" and zero is not None:
dequantized_weight = dequantized_weight + torch.unsqueeze(zero.to(torch.float32),
dim=1)
self.weight = Parameter(dequantized_weight, requires_grad=False).contiguous()
else:
dequantized_weight = weight.to(torch.float32) * \
Expand Down
35 changes: 25 additions & 10 deletions python/llm/src/ipex_llm/transformers/npu_models/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
dtype: np.dtype = np.int8,
use_split: bool = False,
group_size: int = 0,
asym: bool = False,
):
"""Initialize the LMHeadLinear class.

Expand All @@ -54,11 +55,10 @@ def __init__(
self.batch = batch

self.split_num = split_num

if use_split:
input = self.parameter((1, self.batch, self.inC))
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
scale_factor=(group_size == 0))
scale_factor=(group_size == 0), asym=asym)
else:
input = self.parameter((self.batch, self.inC))
split_size = self.inC // split_num // 2 * 2
Expand All @@ -69,7 +69,7 @@ def __init__(
input_slice = self.slice(input, begin=[0, start_idx],
end=[self.batch, end_idx])
linear_slice = self.linear(input_slice, outC, split_size, bias=False,
wt_dtype=dtype)
wt_dtype=dtype, asym=asym)
if i == 0:
res = linear_slice
else:
Expand Down Expand Up @@ -109,7 +109,7 @@ def run(


class SlicedLMHead(nn.Module):
def __init__(self, weight, bias, split_num, use_split=False, group_size=0):
def __init__(self, weight, bias, split_num, use_split=False, group_size=0, asym=False):
super().__init__()
self.split_num = split_num
self.outC, self.inC = weight.shape
Expand All @@ -128,6 +128,7 @@ def __init__(self, weight, bias, split_num, use_split=False, group_size=0):
self.lm_heads.append(new_linear)
self.bias = bias
self.use_split = use_split
self.asym = asym

def forward(self, hidden_states):
if hidden_states.size(0) * hidden_states.size(1) == 1:
Expand Down Expand Up @@ -162,19 +163,33 @@ def get_fused_lm_head(self):
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
False, "NPU", dtype=np_dtype, use_split=self.use_split,
group_size=self.group_size)
group_size=self.group_size, asym=self.asym)
if self.use_split:
weights = []
scales = []
zeros = []
for i in range(self.split_num):
weights.append(self.lm_heads[i].weight)
scales.append(self.lm_heads[i].scale)
fused_lm_head_weights = (torch.stack(weights, axis=0).numpy(),
torch.stack(scales, axis=0).numpy())
if self.lm_heads[i].zero is not None:
zeros.append(self.lm_heads[i].zero)
if len(zeros):
fused_lm_head_weights = [(torch.stack(weights, axis=0).numpy(),
torch.stack(scales, axis=0).numpy(),
torch.stack(zeros, axis=0).numpy())]
else:
fused_lm_head_weights = [(torch.stack(weights, axis=0).numpy(),
torch.stack(scales, axis=0).numpy())]
else:
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
self.lm_heads[i].scale.data.numpy())
for i in range(self.split_num)]
if self.asym:
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
self.lm_heads[i].scale.data.numpy(),
self.lm_heads[i].zero.data.numpy())
for i in range(self.split_num)]
else:
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
self.lm_heads[i].scale.data.numpy())
for i in range(self.split_num)]

self.fused_lm_head.set_weights(self.lm_heads[0].op_id,
fused_lm_head_weights)
Loading
Loading