Skip to content

Commit

Permalink
Update int4pack related for gguf
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Dec 9, 2024
1 parent fff956c commit 29d4ab7
Showing 1 changed file with 59 additions and 27 deletions.
86 changes: 59 additions & 27 deletions torchchat/utils/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
pack_scales_and_zeros,
)

from torchao.dtypes.utils import is_device
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6


logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -122,12 +125,20 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
input.dtype
) # cast back to input.dtype
else:
c = torch.ops.aten._weight_int4pack_mm(
input,
weight_int4pack,
groupsize,
scales_and_zeros,
)
if TORCH_VERSION_AT_LEAST_2_6:
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
input,
weight_int4pack,
groupsize,
scales_and_zeros,
)
else:
c = torch.ops.aten._weight_int4pack_mm(
input,
weight_int4pack,
groupsize,
scales_and_zeros,
)
new_shape = origin_input_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
Expand Down Expand Up @@ -178,16 +189,27 @@ def __init__(
), "must specify both weights and scales_and_zeros, or neither"

if weight is None:
weight = torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
)
if is_device(device, "cpu"):
weight = torch.empty(
(
out_features,
in_features // 2,
),
dtype=torch.uint8,
device=device,
)
else:
weight = torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
)

scales_and_zeros = torch.empty(
(in_features // groupsize, out_features, 2),
dtype=get_precision(),
Expand Down Expand Up @@ -223,12 +245,17 @@ def _prepare_weight_and_scales_and_zeros(
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(
torch.uint8
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_uint8, inner_k_tiles
)
if is_device(weight_int32.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
weight_int32, inner_k_tiles
)
else:
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(
torch.uint8
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_uint8, inner_k_tiles
)
return weight_int4pack, scales_and_zeros

@classmethod
Expand Down Expand Up @@ -608,10 +635,15 @@ def load_model_and_state_dict(
if load_state_dict:
q, s, z = Q4_0.unpack(t)
scales_and_zeros = pack_scales_and_zeros(s, z)
q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
q_uint8, inner_k_tiles
)
if is_device(q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
q, inner_k_tiles
)
else:
q_tmp = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
q_tmp, inner_k_tiles
)
state_dict[f"{fqn}.weight"] = weight_int4pack
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros

Expand All @@ -623,7 +655,7 @@ def load_model_and_state_dict(
in_features=in_features,
out_features=out_features,
bias=False,
device="meta",
device="cpu",
groupsize=Q4_0.groupsize,
inner_k_tiles=inner_k_tiles,
),
Expand Down

0 comments on commit 29d4ab7

Please sign in to comment.