Skip to content

Commit

Permalink
1) Added conformance metrics for scale estimation.
Browse files Browse the repository at this point in the history
2) Updated OV compression/decompression pieline.
  • Loading branch information
andreyanufr committed Mar 25, 2024
1 parent f5111a8 commit bd5c1e1
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 43 deletions.
39 changes: 14 additions & 25 deletions nncf/quantization/algorithms/weight_compression/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,25 +200,20 @@ def get_compress_decompress_pipeline(
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape
):
(
input_node_w,
input_node_s,
input_node_zp,
node_compression_clamp,
result1,
w,
s,
zp,
clamp,
) = OVWeightCompressionAlgoBackend.get_compress_pipeline(
weight_compression_parameter, w_shape, s_shape, z_p_shape, True
)

node_decompression_add = opset.subtract(node_compression_clamp, input_node_zp)
node_decompression_mul = opset.multiply(node_decompression_add, input_node_s)
result2 = opset.result(node_decompression_mul, name="q_weights")
result2.get_output_tensor(0).set_names(set(["q_weights"]))

model = ov.Model([result1, result2], [input_node_w, input_node_s, input_node_zp])
result = (clamp - zp) * s
model = ov.Model([result], [w, s, zp])

compiled_model = ov.compile_model(model)

return compiled_model
return lambda w, s, zp: compiled_model([w, s, zp])[0]

@staticmethod
def get_compress_pipeline(
Expand All @@ -232,26 +227,20 @@ def get_compress_pipeline(
level_low = 0
level_high = 2**num_bits - 1

input_node_w = opset.parameter(w_shape, name="w")
input_node_s = opset.parameter(s_shape, name="s")
input_node_zp = opset.parameter(z_p_shape, name="zp")

node_compression_div = opset.divide(input_node_w, input_node_s)
node_compression_add = opset.add(node_compression_div, input_node_zp)
node_compression_round = opset.round(node_compression_add)
node_compression_clamp = opset.clamp(node_compression_round, level_low, level_high)
w = opset.parameter(w_shape, name="w")
s = opset.parameter(s_shape, name="s")
zp = opset.parameter(z_p_shape, name="zp")

result1 = opset.result(node_compression_clamp, name="compressed_weights")
result1.get_output_tensor(0).set_names(set(["compressed_weights"]))
result = opset.clamp(opset.round(w/s + zp), level_low, level_high, name="compressed_weights")

if return_nodes:
return input_node_w, input_node_s, input_node_zp, node_compression_clamp, result1
return w, s, zp, result

model = ov.Model([result1], [input_node_w, input_node_s, input_node_zp])
model = ov.Model([result], [w, s, zp])

compiled_model = ov.compile_model(model)

return compiled_model
return lambda w, s, zp: compiled_model([w, s, zp])[0]


class OVAWQAlgoAlgoBackend(OVWeightCompressionAlgoBackend):
Expand Down
29 changes: 15 additions & 14 deletions nncf/quantization/algorithms/weight_compression/scale_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def apply(
original_weight = fns.zeros_like(weight) + weight

compressed_weighs, scale, zp = do_integer_quantization(original_weight, reduction_axis, config)
th = 5.0 * fns.max(scale)
zp = zp.astype(scale.dtype)

q_weights = do_dequantization(compressed_weighs, scale, zp, reduction_axis)
Expand Down Expand Up @@ -187,16 +188,16 @@ def apply(
min_max_scale_diffs = fns.transpose(min_max_scale_diffs, (1, 0))
ideal_scale_diffs = fns.zeros_like(min_max_scale_diffs)

k = (wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape + zp.shape
if k in compress_decompress_cashe:
compress_decompress_model = compress_decompress_cashe[k]["compress_decompress_model"]
compress_model = compress_decompress_cashe[k]["compress_model"]
key = (wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape + zp.shape
if key in compress_decompress_cashe:
compress_decompress_model = compress_decompress_cashe[key]["compress_decompress_model"]
compress_model = compress_decompress_cashe[key]["compress_model"]
else:
compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline(
wp, q_weights.shape, scale.shape, zp.shape
)
compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp.shape)
compress_decompress_cashe[k] = {
compress_decompress_cashe[key] = {
"compress_decompress_model": compress_decompress_model,
"compress_model": compress_model,
}
Expand All @@ -210,8 +211,8 @@ def apply(

near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)

out = compress_decompress_model([original_weight.data, near_to_ideal_scale.data, zp.data])
q_weights_ = fns.zeros_like(original_weight) + out["q_weights"]
out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data)
q_weights_ = fns.zeros_like(original_weight) + out
q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)

ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1)
Expand All @@ -233,8 +234,8 @@ def apply(
result_scale = near_to_ideal_scale

if i < self._initial_steps - 1:
out = compress_model([original_weight.data, near_to_ideal_scale.data, zp.data])
compressed_weights = fns.zeros_like(original_weight) + out["compressed_weights"]
out = compress_model(original_weight.data, near_to_ideal_scale.data, zp.data)
compressed_weights = fns.zeros_like(original_weight) + out
target = compressed_weights - zp
zero_mask = compressed_weights == zp
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)
Expand All @@ -243,8 +244,8 @@ def apply(
factor = 1.0 - 0.05 * scale_steps
scaled_scale = factor * scale

out = compress_model([original_weight.data, scaled_scale.data, zp.data])
compressed_weights = fns.zeros_like(original_weight) + out["compressed_weights"]
out = compress_model(original_weight.data, scaled_scale.data, zp.data)
compressed_weights = fns.zeros_like(original_weight) + out

target = compressed_weights - zp
zero_mask = compressed_weights == zp
Expand All @@ -254,8 +255,8 @@ def apply(
weighted_scale = ideal_scale * importance
near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True)

out = compress_decompress_model([original_weight.data, near_to_ideal_scale.data, zp.data])
q_weights_ = fns.zeros_like(original_weight) + out["q_weights"]
out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data)
q_weights_ = fns.zeros_like(original_weight) + out

q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)
ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1)
Expand All @@ -272,7 +273,7 @@ def apply(
else:
near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale
result_scale = near_to_ideal_scale

result_scale = fns.clip(result_scale, a_min=None, a_max=th)
wp.precomputed_scale = result_scale
return model

Expand Down
6 changes: 5 additions & 1 deletion tests/post_training/data/wc_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ tinyllama_data_aware_backend_OV:
tinyllama_data_aware_awq_backend_OV:
metric_value: 0.81237
tinyllama_data_aware_awq_stateful_backend_OV:
metric_value: 0.81237
metric_value: 0.81237
tinyllama_data_aware_awq_scale_estimation_backend_OV:
metric_value: 0.765396
tinyllama_data_aware_awq_scale_estimation_stateful_backend_OV:
metric_value: 0.765396
6 changes: 3 additions & 3 deletions tests/post_training/pipelines/lm_weight_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def prepare_preprocessor(self) -> None:
self.preprocessor = AutoTokenizer.from_pretrained(self.model_id)

def get_transform_calibration_fn(self):
def transform_fn(data):
def transform_fn(data, max_tokens=128):
tokenized_text = self.preprocessor(data["text"], return_tensors="np")
input_ids = tokenized_text["input_ids"]
attention_mask = tokenized_text["attention_mask"]
input_ids = tokenized_text["input_ids"][:max_tokens]
attention_mask = tokenized_text["attention_mask"][:max_tokens]

inputs = {}
inputs["input_ids"] = input_ids
Expand Down

0 comments on commit bd5c1e1

Please sign in to comment.