Skip to content

Commit

Permalink
use inplace=True mode for WOQ (#1557)
Browse files Browse the repository at this point in the history
* use inplace mode for WOQ

Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he authored Jan 25, 2024
1 parent fa8e66a commit 31743fe
Show file tree
Hide file tree
Showing 9 changed files with 594 additions and 535 deletions.
43 changes: 22 additions & 21 deletions neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,20 @@ def __init__(
dtype=self.float_type,
).to(device),
)
self.scales = self.scales.T
self.register_buffer(
"qweight",
torch.zeros(
(math.ceil(in_features / self.n_pack), out_features),
dtype=self.compression_dtype,
).to(device),
)
self.qweight = self.qweight.T
self.register_buffer(
"qzeros",
torch.zeros(
(math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)),
dtype=self.compression_dtype,
).to(device),
)
self.qzeros = self.qzeros.T
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
else:
self.compression_dtype = compression_dtype
Expand Down Expand Up @@ -193,6 +190,10 @@ def __init__(
self.bias = None

def pack(self, int_weight, scale, zp, bias):
if self.use_optimum_format:
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
int_weight = int_weight.to(self.device)
if self.use_optimum_format and zp is None:
# to avoid overflow
Expand All @@ -206,8 +207,8 @@ def pack(self, int_weight, scale, zp, bias):
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
self.scales = scale.type(self.float_type).to(self.device)
if not self.use_optimum_format and self.compression_dim == 0:
int_weight = int_weight.T
self.qweight = self.qweight.T
int_weight = int_weight.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
origin_shape = int_weight.shape
target_shape = self.qweight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
Expand All @@ -223,15 +224,15 @@ def pack(self, int_weight, scale, zp, bias):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qweight[:, j] |= tmp[:, e]
if not self.use_optimum_format and self.compression_dim == 0:
self.qweight = self.qweight.T
self.qweight = self.qweight.t_().contiguous()

if zp is not None:
zp = zp.to(self.device)
if self.use_optimum_format:
zp -= 1
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
self.qzeros = self.qzeros.T
zp = zp.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
assert hasattr(self, "qzeros"), "zp is not set when initializing."
target_shape = self.qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -243,16 +244,16 @@ def pack(self, int_weight, scale, zp, bias):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qzeros[:, j] |= tmp[:, e]
if self.use_optimum_format or self.compression_dim == 0:
self.qzeros = self.qzeros.T
self.qzeros = self.qzeros.t_().contiguous()
if self.use_optimum_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.qzeros = self.qzeros.T
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()

def recover(self):
logger.debug(f"Recovering {self} weight")
scales = self.scales.T if self.use_optimum_format else self.scales
qweight = self.qweight.T if self.use_optimum_format else self.qweight
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight

device = scales.device
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
Expand All @@ -264,8 +265,8 @@ def recover(self):
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
qweight = qweight.T
weight = weight.t_().contiguous()
qweight = qweight.t_().contiguous()
origin_shape = weight.shape
target_shape = qweight.shape
for j in range(target_shape[1]):
Expand All @@ -280,7 +281,7 @@ def recover(self):
tmp &= mask # remove sign bit
weight[:, index] = tmp.type(weight_dtype)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
weight = weight.t_().contiguous()
if "int" not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
Expand All @@ -290,10 +291,10 @@ def recover(self):
if hasattr(self, "qzeros"):
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
qzeros = qzeros.T
zp = zp.t_().contiguous()
qzeros = qzeros.t_().contiguous()
origin_shape = zp.shape
target_shape = qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -307,7 +308,7 @@ def recover(self):
tmp &= mask
zp[:, index] = tmp.type(zp_dtype)
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
zp = zp.t_().contiguous()
if self.use_optimum_format:
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
zp += 1
Expand Down
42 changes: 21 additions & 21 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ def __init__(

def pack(self, int_weight, scale, zp, bias, g_idx=None):
if self.use_optimum_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.qzeros = self.qzeros.T
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
int_weight = int_weight.to(self.device)
if self.use_optimum_format and zp is None:
# to avoid overflow
Expand All @@ -350,8 +350,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
self.scales = scale.type(self.float_type).to(self.device)
if not self.use_optimum_format and self.compression_dim == 0:
int_weight = int_weight.T
self.qweight = self.qweight.T
int_weight = int_weight.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
origin_shape = int_weight.shape
target_shape = self.qweight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
Expand All @@ -367,15 +367,15 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qweight[:, j] |= tmp[:, e]
if not self.use_optimum_format and self.compression_dim == 0:
self.qweight = self.qweight.T
self.qweight = self.qweight.t_().contiguous()

if zp is not None:
zp = zp.to(self.device)
if self.use_optimum_format:
zp -= 1
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
self.qzeros = self.qzeros.T
zp = zp.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()
assert hasattr(self, "qzeros"), "zp is not set when initializing."
target_shape = self.qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -387,16 +387,16 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
tmp[:, e] = tmp[:, e] << (self.bits * e)
self.qzeros[:, j] |= tmp[:, e]
if self.use_optimum_format or self.compression_dim == 0:
self.qzeros = self.qzeros.T
self.qzeros = self.qzeros.t_().contiguous()
if self.use_optimum_format:
self.scales = self.scales.T
self.qweight = self.qweight.T
self.qzeros = self.qzeros.T
self.scales = self.scales.t_().contiguous()
self.qweight = self.qweight.t_().contiguous()
self.qzeros = self.qzeros.t_().contiguous()

def recover(self):
logger.debug(f"Recovering {self} weight")
scales = self.scales.T if self.use_optimum_format else self.scales
qweight = self.qweight.T if self.use_optimum_format else self.qweight
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight

device = scales.device
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
Expand All @@ -411,8 +411,8 @@ def recover(self):
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
qweight = qweight.T
weight = weight.t_().contiguous()
qweight = qweight.t_().contiguous()
origin_shape = weight.shape
target_shape = qweight.shape
for j in range(target_shape[1]):
Expand All @@ -427,7 +427,7 @@ def recover(self):
tmp &= mask # remove sign bit
weight[:, index] = tmp.type(weight_dtype)
if not self.use_optimum_format and self.compression_dim == 0:
weight = weight.T
weight = weight.t_().contiguous()
if "int" not in self.dtype:
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
for k, v in self.int2float_mapping.items():
Expand All @@ -437,10 +437,10 @@ def recover(self):
if hasattr(self, "qzeros"):
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
qzeros = qzeros.T
zp = zp.t_().contiguous()
qzeros = qzeros.t_().contiguous()
origin_shape = zp.shape
target_shape = qzeros.shape
for j in range(target_shape[1]):
Expand All @@ -454,7 +454,7 @@ def recover(self):
tmp &= mask
zp[:, index] = tmp.type(zp_dtype)
if self.use_optimum_format or self.compression_dim == 0:
zp = zp.T
zp = zp.t_().contiguous()
if self.use_optimum_format:
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
zp += 1
Expand Down
40 changes: 22 additions & 18 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,16 +330,17 @@ def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", en
history = []
for i_s in range(int(max_shrink * n_grid)):
ratio = 1 - i_s / n_grid # 1, 0.805-1.0
cur_weight = quant_weight(
m.weight.data,
quant_weight(
m.weight.data, # in-place mode
num_bits=num_bits,
group_size=group_size,
scheme=scheme,
data_type=data_type,
full_range=enable_full_range,
quantile=ratio,
)
loss = (org_weight - cur_weight).float().pow(2).mean().item()
loss = (org_weight - m.weight.data).float().pow(2).mean().item()
m.weight.data.copy_(org_weight)
history.append(loss)
is_best = loss < best_error
if is_best:
Expand Down Expand Up @@ -429,14 +430,17 @@ def rtn_quantize(
if num_bits <= 0:
logger.info(f"Skip {name}")
continue
weight = m.weight.T if group_dim == 0 else m.weight
# contiguous is not an in-place op and returns Tensor instead of Parameter, so set it back to m.weight.data.
# transpose should be executed on Parameter level because Param.data.t_() is not an in-place op.
# Parameter.T is an in-place op while Tensor.T is not.
m.weight.data = m.weight.t_().data.contiguous() if group_dim == 0 else m.weight.data
if enable_mse_search:
quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range)
if return_int:
from .model_wrapper import WeightOnlyLinear

_, scale, zp = quant_weight(
weight,
m.weight.data,
num_bits,
group_size,
scheme,
Expand All @@ -446,9 +450,9 @@ def rtn_quantize(
full_range=enable_full_range,
)
if group_dim == 0:
weight.transpose_(0, 1)
scale = scale.T if group_dim == 0 else scale
zp = zp.T if group_dim == 0 and zp is not None else zp
m.weight.t_()
scale = scale.t_().contiguous() if group_dim == 0 else scale
zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp
new_module = WeightOnlyLinear(
m.in_features,
m.out_features,
Expand All @@ -463,14 +467,14 @@ def rtn_quantize(
device=device,
use_optimum_format=use_optimum_format,
)
new_module.pack(weight, scale, zp, m.bias)
new_module.pack(m.weight.data, scale, zp, m.bias)
if name == "":
return new_module
else:
set_module(model, name, new_module)
else:
quant_weight(
weight,
m.weight.data,
num_bits,
group_size,
scheme,
Expand All @@ -479,7 +483,7 @@ def rtn_quantize(
full_range=enable_full_range,
)
if group_dim == 0:
weight.transpose_(0, 1)
m.weight.t_()
if orig_dtype != torch.float:
m = m.to(orig_dtype)
return model
Expand Down Expand Up @@ -651,18 +655,18 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1):
if zp is not None:
zp = zp.to(device)
if group_size == -1:
return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp)
return weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_()
int_weight = torch.zeros(weight.shape).to(device)
leng = weight.shape[1] // group_size
tail_flag = False if weight.shape[1] % group_size == 0 else True
for i in range(leng):
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1)
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1))
if zp is not None:
int_weight_tmp += zp[:, i].unsqueeze(1)
int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp)
int_weight_tmp.add_(zp[:, i].unsqueeze(1))
int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_())
if tail_flag:
int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1)
int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1))
if zp is not None:
int_weight_tmp += zp[:, -1].unsqueeze(1)
int_weight[:, leng * group_size :] = torch.round(int_weight_tmp)
int_weight_tmp.add_(zp[:, -1].unsqueeze(1))
int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_())
return int_weight
4 changes: 4 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .utility import *
from .rtn import rtn_quantize
from .gptq import gptq_quantize
Loading

0 comments on commit 31743fe

Please sign in to comment.