Skip to content

Commit

Permalink
Fix WOQ Linear pack slow issue (#1828)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored Jun 3, 2024
1 parent 4dbf71e commit da1ada2
Showing 1 changed file with 53 additions and 2 deletions.
55 changes: 53 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# since the model classes inherit torch.nn.Module.
import math

import numpy as np
import torch
from torch.autograd import Function
from torch.nn import functional as F
Expand Down Expand Up @@ -270,7 +271,7 @@ def recover(self):
fp32_weight[:, idx] = weight[:, idx] * scales[:, self.g_idx[idx]]
return fp32_weight

def pack_tensor(self, raw_tensor):
def pack_tensor_with_torch(self, raw_tensor):
target_len = math.ceil(raw_tensor.shape[1] / self.n_pack)
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
Expand All @@ -285,7 +286,7 @@ def pack_tensor(self, raw_tensor):
accelerator.synchronize()
return packed_tensor

def unpack_tensor(self, packed_tensor):
def unpack_tensor_with_torch(self, packed_tensor):
target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8
target_len = packed_tensor.shape[1] * self.n_pack
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
Expand All @@ -302,6 +303,56 @@ def unpack_tensor(self, packed_tensor):
accelerator.synchronize()
return unpacked_tensor

def pack_tensor_with_numpy(self, raw_tensor):
raw_array = raw_tensor.cpu().numpy()
target_len = np.ceil(raw_array.shape[1] / self.n_pack).astype(int)
torch.int32
target_dtype = torch.tensor(0, dtype=self.compression_dtype).numpy().dtype
packed_array = np.zeros((raw_array.shape[0], target_len), dtype=target_dtype)
mask = np.uint8(2**self.bits - 1)
for j in range(packed_array.shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = raw_array[:, start:end].astype(target_dtype)
tmp &= mask
for e in range(tmp.shape[1]):
tmp[:, e] = np.left_shift(tmp[:, e], self.bits * e)
packed_array[:, j] |= tmp[:, e]
accelerator.synchronize()
packed_tensor = torch.from_numpy(packed_array).to(device=raw_tensor.device)
return packed_tensor

def unpack_tensor_with_numpy(self, packed_tensor):
packed_array = packed_tensor.cpu().numpy()
target_dtype = np.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else np.uint8
target_len = packed_array.shape[1] * self.n_pack
unpacked_array = np.zeros((packed_array.shape[0], target_len), dtype=target_dtype)
mask = np.uint8(2**self.bits - 1)
for j in range(packed_array.shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
tmp = packed_array[:, j]
tmp = np.left_shift(tmp, self.compress_bits - self.bits * (e + 1))
tmp = np.right_shift(tmp, self.compress_bits - self.bits)
if target_dtype == np.uint8:
tmp &= mask
unpacked_array[:, index] = tmp.astype(target_dtype)
accelerator.synchronize()
unpacked_tensor = torch.from_numpy(unpacked_array).to(device=packed_tensor.device)
return unpacked_tensor

def pack_tensor(self, raw_tensor):
if "cuda" in self.device:
return self.pack_tensor_with_torch(raw_tensor)
else:
return self.pack_tensor_with_numpy(raw_tensor)

def unpack_tensor(self, packed_tensor):
if "cuda" in self.device:
return self.unpack_tensor_with_torch(packed_tensor)
else:
return self.unpack_tensor_with_numpy(packed_tensor)

def forward(self, input):
if not hasattr(self, "weight"):
weight = self.recover()
Expand Down

0 comments on commit da1ada2

Please sign in to comment.