Skip to content

Commit

Permalink
Refactor apply function
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Apr 26, 2024
1 parent 140c31f commit fb883db
Showing 1 changed file with 147 additions and 114 deletions.
261 changes: 147 additions & 114 deletions nncf/quantization/algorithms/weight_compression/scale_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,112 @@ def _set_backend_entity(self, model: TModel) -> None:
"Cannot return backend-specific AWQ entity because {} is not supported!".format(model_backend.value)
)

def _get_importance_and_x(self, stats, original_weight, zero_mask, reduction_axis, config, eps):
X = fns.stack([fns.mean(stat, axis=0) for stat in stats])
X_full = fns.transpose(X)

# prevent high memory and time consumption
if X_full.shape[1] > self._subset_size:
lens = [stat.shape[0] for stat in stats]
step = X_full.shape[1] // self._subset_size
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
X = X_full[:, idxs]
else:
X = X_full

s = fns.max(fns.abs(X_full), axis=1)
# all weight in group has importance based on corresponding input activations

s = fns.unsqueeze(s, 0)
s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, config.group_size)

importance = fns.ones_like(original_weight)
importance = importance * s
importance = fns.where(zero_mask, 0.0, importance)

# normalize importances for every group of weights to make sum of them equal to 1.0
denum = fns.sum(importance, axis=2, keepdims=True)
importance = importance / (denum + eps)
X, _ = reshape_weight_for_grouped_quantization(X, 0, config.group_size)
return importance, X

def _get_compress_decompress_model(self, wp, q_weights, scale, zp, compress_decompress_cashe):
key = (wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape + zp.shape
if key in compress_decompress_cashe:
compress_decompress_model = compress_decompress_cashe[key]["compress_decompress_model"]
compress_model = compress_decompress_cashe[key]["compress_model"]
else:
compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline(
wp, q_weights.shape, scale.shape, zp.shape
)
compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp.shape)
compress_decompress_cashe[key] = {
"compress_decompress_model": compress_decompress_model,
"compress_model": compress_model,
}
return compress_model, compress_decompress_model

def _get_min_max_scale_diffs(self, fp_outs, q_outs, original_weight, q_weights):
# metric for minimization with shape [C_OUT, N_GROUPS], N_GROUPS = C_IN / GROUP_SIZE
min_max_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1)
min_max_scale_diffs = fns.transpose(min_max_scale_diffs, (1, 0))
if self._weight_penalty > 0.0:
min_max_scale_diffs += self._weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1)
return min_max_scale_diffs

def _get_weight(self, wp, model, graph):
weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph)
if len(weight_data) != 1: # not supported by the algorithm
return None
_, weight_port_id = weight_data[0]

weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph)
weight = weight.astype(TensorDataType.float32)

def _rectification(
self,
result_scale,
best_diffs,
X,
original_weight,
target,
fp_outs,
zero_mask,
importance,
min_max_scale_diffs,
compress_decompress_model,
zp,
scale,
):
ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask)
weighted_scale = ideal_scale * importance
near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)

out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data)
q_weights_ = fns.zeros_like(original_weight) + out
q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)

ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1)
ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0))
if self._weight_penalty > 0.0:
ideal_scale_diffs += self._weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1)

if best_diffs is None:
best_diffs = min_max_scale_diffs

mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype)

best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs

mask = fns.unsqueeze(mask, axis=2)

if result_scale is None:
near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale
else:
near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale
result_scale = near_to_ideal_scale
return result_scale

def apply(
self,
model: TModel,
Expand Down Expand Up @@ -129,34 +235,17 @@ def apply(
continue

stats = self._activations[k]
reduction_axis = wp.reduction_axes[0]

cur_config = deepcopy(config)
cur_config.group_size = -1

weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph)
if len(weight_data) != 1: # not supported by the algorithm
weight = self._get_weight(wp, model, graph)
if weight is None:
continue
_, weight_port_id = weight_data[0]

X = fns.stack([fns.mean(stat, axis=0) for stat in stats])
X_full = fns.transpose(X)

# prevent high memory and time consumption
if X_full.shape[1] > self._subset_size:
lens = [stat.shape[0] for stat in stats]
step = X_full.shape[1] // self._subset_size
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
X = X_full[:, idxs]
else:
X = X_full

s = fns.max(fns.abs(X_full), axis=1)

weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph)
weight = weight.astype(TensorDataType.float32)
eps = fns.finfo(weight).eps

reduction_axis = wp.reduction_axes[0]
if reduction_axis == 0:
weight = fns.transpose(weight)
reduction_axis = 1
Expand All @@ -166,94 +255,49 @@ def apply(
compressed_weights, scale, zp = do_integer_quantization(original_weight, reduction_axis, config)
zp = zp.astype(scale.dtype)

q_weights = do_dequantization(compressed_weights, scale, zp, reduction_axis)

s = fns.unsqueeze(s, 0)
s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, config.group_size)

target = compressed_weights.astype(dtype=zp.dtype) - zp
original_weight, _ = reshape_weight_for_grouped_quantization(
original_weight, reduction_axis, config.group_size
)

# all weight in group has importance based on corresponding input activations
importance = fns.ones_like(original_weight)
importance = importance * s

target = compressed_weights.astype(dtype=zp.dtype) - zp
zero_mask = compressed_weights == zp
importance, X = self._get_importance_and_x(stats, original_weight, zero_mask, reduction_axis, config, eps)

importance = fns.where(zero_mask, 0.0, importance)

# normalize importances for every group of weights to make sum of them equal to 1.0
denum = fns.sum(importance, axis=2, keepdims=True)
importance = importance / (denum + eps)

X, _ = reshape_weight_for_grouped_quantization(X, 0, config.group_size)
q_weights = do_dequantization(compressed_weights, scale, zp, reduction_axis)
q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, config.group_size)
best_diffs = None
result_scale = None

fp_outs = fns.matmul(fns.transpose(original_weight, (1, 0, 2)), X)
q_outs = fns.matmul(fns.transpose(q_weights, (1, 0, 2)), X)
min_max_scale_diffs = self._get_min_max_scale_diffs(fp_outs, q_outs, original_weight, q_weights)

# metric for minimization with shape [C_OUT, N_GROUPS], N_GROUPS = C_IN / GROUP_SIZE
min_max_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1)
min_max_scale_diffs = fns.transpose(min_max_scale_diffs, (1, 0))
if self._weight_penalty > 0.0:
min_max_scale_diffs += self._weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1)

key = (
(wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape + zp.shape
compress_model, compress_decompress_model = self._get_compress_decompress_model(
wp, q_weights, scale, zp, compress_decompress_cashe
)
if key in compress_decompress_cashe:
compress_decompress_model = compress_decompress_cashe[key]["compress_decompress_model"]
compress_model = compress_decompress_cashe[key]["compress_model"]
else:
compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline(
wp, q_weights.shape, scale.shape, zp.shape
)
compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp.shape)
compress_decompress_cashe[key] = {
"compress_decompress_model": compress_decompress_model,
"compress_model": compress_model,
}

zero_scale = 0.001
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)

best_diffs = None
result_scale = None

# iterative rectification of initial scale
for i in range(self._initial_steps):
ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask)
weighted_scale = ideal_scale * importance

near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)

out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data)
q_weights_ = fns.zeros_like(original_weight) + out
q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)

ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1)
ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0))
if self._weight_penalty > 0.0:
ideal_scale_diffs += self._weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1)

if best_diffs is None:
best_diffs = min_max_scale_diffs

mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype)

best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs

mask = fns.unsqueeze(mask, axis=2)

if result_scale is None:
near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale
else:
near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale
result_scale = near_to_ideal_scale

result_scale = self._rectification(
result_scale,
best_diffs,
X,
original_weight,
target,
fp_outs,
zero_mask,
importance,
min_max_scale_diffs,
compress_decompress_model,
zp,
scale,
)
if i < self._initial_steps - 1:
out = compress_model(original_weight.data, near_to_ideal_scale.data, zp.data)
out = compress_model(original_weight.data, result_scale.data, zp.data)
compressed_weights = fns.zeros_like(original_weight) + out
target = compressed_weights - zp
zero_mask = compressed_weights == zp
Expand All @@ -271,31 +315,20 @@ def apply(
zero_mask = compressed_weights == zp
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)

ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask)
weighted_scale = ideal_scale * importance
near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)

out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data)
q_weights_ = fns.zeros_like(original_weight) + out

q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)
ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1)
ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0))
if self._weight_penalty > 0.0:
ideal_scale_diffs += self._weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1)

mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype)

best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs

mask = fns.unsqueeze(mask, axis=2)

if result_scale is None:
near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale
else:
near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale
result_scale = near_to_ideal_scale

result_scale = self._rectification(
result_scale,
best_diffs,
X,
original_weight,
target,
fp_outs,
zero_mask,
importance,
min_max_scale_diffs,
compress_decompress_model,
zp,
scale,
)
res[k] = result_scale

return res

0 comments on commit fb883db

Please sign in to comment.