Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update int4pack related in torchchat gguf #1404

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 59 additions & 27 deletions torchchat/utils/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
pack_scales_and_zeros,
)

from torchao.dtypes.utils import is_device
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchchat locks onto a specific torch version, so we don't need to check

Assume > 2.6

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CI failures seem that torchao version is not that new, because TORCH_VERSION_AT_LEAST_2_6 is a new one. And I saw you pin pytorch nightly to 20241013, which is also not new, and this nightly does not have pytorch/pytorch#139611 inside. This is my question, because the nightly used in the CI is 20241126.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, working on the bump here: #1367

We'll test your fixes on there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!



logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -122,12 +125,20 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
input.dtype
) # cast back to input.dtype
else:
c = torch.ops.aten._weight_int4pack_mm(
input,
weight_int4pack,
groupsize,
scales_and_zeros,
)
if TORCH_VERSION_AT_LEAST_2_6:
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
input,
weight_int4pack,
groupsize,
scales_and_zeros,
)
else:
c = torch.ops.aten._weight_int4pack_mm(
input,
weight_int4pack,
groupsize,
scales_and_zeros,
)
new_shape = origin_input_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
Expand Down Expand Up @@ -178,16 +189,27 @@ def __init__(
), "must specify both weights and scales_and_zeros, or neither"

if weight is None:
weight = torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
)
if is_device(device, "cpu"):
weight = torch.empty(
(
out_features,
in_features // 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

),
dtype=torch.uint8,
device=device,
)
else:
weight = torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
)

scales_and_zeros = torch.empty(
(in_features // groupsize, out_features, 2),
dtype=get_precision(),
Expand Down Expand Up @@ -223,12 +245,17 @@ def _prepare_weight_and_scales_and_zeros(
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(
torch.uint8
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_uint8, inner_k_tiles
)
if is_device(weight_int32.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
weight_int32, inner_k_tiles
)
else:
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(
torch.uint8
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_uint8, inner_k_tiles
)
return weight_int4pack, scales_and_zeros

@classmethod
Expand Down Expand Up @@ -608,10 +635,15 @@ def load_model_and_state_dict(
if load_state_dict:
q, s, z = Q4_0.unpack(t)
scales_and_zeros = pack_scales_and_zeros(s, z)
q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
q_uint8, inner_k_tiles
)
if is_device(q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
q, inner_k_tiles
)
else:
q_tmp = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
q_tmp, inner_k_tiles
)
state_dict[f"{fqn}.weight"] = weight_int4pack
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros

Expand All @@ -623,7 +655,7 @@ def load_model_and_state_dict(
in_features=in_features,
out_features=out_features,
bias=False,
device="meta",
device="cpu",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this as a meta device as long as we can

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groupsize=Q4_0.groupsize,
inner_k_tiles=inner_k_tiles,
),
Expand Down
Loading