diff --git a/torchchat/utils/gguf_loader.py b/torchchat/utils/gguf_loader.py index 309ff807c..019d3b2c3 100644 --- a/torchchat/utils/gguf_loader.py +++ b/torchchat/utils/gguf_loader.py @@ -24,6 +24,9 @@ pack_scales_and_zeros, ) +from torchao.dtypes.utils import is_device +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 + logger: logging.Logger = logging.getLogger(__name__) @@ -122,12 +125,20 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz input.dtype ) # cast back to input.dtype else: - c = torch.ops.aten._weight_int4pack_mm( - input, - weight_int4pack, - groupsize, - scales_and_zeros, - ) + if TORCH_VERSION_AT_LEAST_2_6: + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + input, + weight_int4pack, + groupsize, + scales_and_zeros, + ) + else: + c = torch.ops.aten._weight_int4pack_mm( + input, + weight_int4pack, + groupsize, + scales_and_zeros, + ) new_shape = origin_input_size[:-1] + (out_features,) c = c.reshape(new_shape) return c @@ -178,16 +189,27 @@ def __init__( ), "must specify both weights and scales_and_zeros, or neither" if weight is None: - weight = torch.empty( - ( - out_features // 8, - in_features // (inner_k_tiles * 16), - 32, - inner_k_tiles // 2, - ), - dtype=torch.int32, - device=device, - ) + if is_device(device, "cpu"): + weight = torch.empty( + ( + out_features, + in_features // 2, + ), + dtype=torch.uint8, + device=device, + ) + else: + weight = torch.empty( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + device=device, + ) + scales_and_zeros = torch.empty( (in_features // groupsize, out_features, 2), dtype=get_precision(), @@ -223,12 +245,17 @@ def _prepare_weight_and_scales_and_zeros( weight_int32, scales_and_zeros = group_quantize_tensor( weight_bf16, n_bit=4, groupsize=groupsize ) - weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to( - torch.uint8 - ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - weight_uint8, inner_k_tiles - ) + if is_device(weight_int32.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + weight_int32, inner_k_tiles + ) + else: + weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to( + torch.uint8 + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weight_uint8, inner_k_tiles + ) return weight_int4pack, scales_and_zeros @classmethod @@ -608,10 +635,15 @@ def load_model_and_state_dict( if load_state_dict: q, s, z = Q4_0.unpack(t) scales_and_zeros = pack_scales_and_zeros(s, z) - q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - q_uint8, inner_k_tiles - ) + if is_device(q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + q, inner_k_tiles + ) + else: + q_tmp = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + q_tmp, inner_k_tiles + ) state_dict[f"{fqn}.weight"] = weight_int4pack state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros @@ -623,7 +655,7 @@ def load_model_and_state_dict( in_features=in_features, out_features=out_features, bias=False, - device="meta", + device="cpu", groupsize=Q4_0.groupsize, inner_k_tiles=inner_k_tiles, ),