Skip to content

Commit

Permalink
Compress first and last layer more correctly (#2282)
Browse files Browse the repository at this point in the history
### Changes

 - compress all embeddings to 8bit
As result, more accurate compression scheme for several models: gpt-2,
stable-diffusion-v1-5, stable-diffusion-v2-1, opt-6.7b

- more correct search for the last layer when it shares weights with
embedding
As result, faster compression scheme is available for falcon-7b,
bloomz-7b1, opt-6.7b

### Reason for changes

Token embedding and last layer should be always compressed to 8-bit in
order to preserve the accuracy.
Previous logic for searching these layers relied on the topological
sort, but on the practice the order can be changed.
As results, at least for 3 models in mixed-precision setup positional
embedding was quantized to 8 bit, but token embedding to 4 bit, which is
not expected.

Moreover, the last layer can share weight with embedding. With the old
logic this case was not correctly handled: one extra matmul was
quantized to 8-bit.

### Related tickets

125162

### Tests

test_shared_gather

The accuracy should be better, the performance is not significantly
affected.

opt-125m,   lamdada_openai | ppl | ms/token
-- | -- | --
All   embedding and last layer -  8bit | 29.12 | 11.47
Positional embedding - 8 bit <br /> Token embedding and last layer - 4
bit | 29.20 | 11.53
Positional embedding - 4 bit <br /> Token embedding and last layer - 8
bit | 29.59 | 11.57
  • Loading branch information
ljaljushkin authored Nov 22, 2023
1 parent 8a45acb commit 5eee3bc
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 339 deletions.
70 changes: 43 additions & 27 deletions nncf/quantization/algorithms/weight_compression/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from dataclasses import dataclass
from typing import List, Optional, Tuple, TypeVar

Expand Down Expand Up @@ -58,14 +59,18 @@ def do_compression(

friendly_name_to_op_map = {op.get_friendly_name(): op for op in model.get_ops()}

for nncf_node in nodes_to_compress:
is_last_layer_compressed = False
n = len(nodes_to_compress)
for i, nncf_node in enumerate(nodes_to_compress):
weight_port_ids = nncf_node.layer_attributes.get_const_port_ids()
for weight_port_id in weight_port_ids:
weight_op_friendly_name = nncf_node.layer_attributes.constant_attributes[weight_port_id]["name"]
weight_node = friendly_name_to_op_map[weight_op_friendly_name]
if weight_node is None:
continue
if id(weight_node) in quantized_nodes_ids:
if i == n - 1:
is_last_layer_compressed = True
continue
weight_output = weight_node.output(0)

Expand All @@ -87,15 +92,24 @@ def do_compression(
fq_name = f"{weight_op_friendly_name}/fq_weights_{weight_port_id}"
num_weights = np.prod(const_shape)
weight_params = WeightNodeParams(
reduction_axis, num_weights, fq_name, weight_node, original_weight_dtype
reduction_axis,
num_weights,
fq_name,
weight_node,
original_weight_dtype,
metatype=nncf_node.metatype,
)
all_weight_params.append(weight_params)
quantized_nodes_ids.add(id(weight_node))

internal_weight_params = all_weight_params
if mode != CompressWeightsMode.INT8:
internal_weight_params = list(filter(lambda wp: wp.metatype != OVEmbeddingMetatype, all_weight_params))
if not is_last_layer_compressed:
internal_weight_params = internal_weight_params[:-1]
primary_config = WeightCompressionConfig(mode=mode, group_size=group_size)
_assign_mixed_precision(all_weight_params, ratio, primary_config)

nncf_logger.info(_get_bitwidth_distribution_str(all_weight_params))
_assign_mixed_precision(internal_weight_params, ratio, primary_config)
nncf_logger.info(_get_bitwidth_distribution_str(all_weight_params, internal_weight_params))

for wp in track(all_weight_params, description="Applying Weight Compression"):
weight_node = wp.weight_node
Expand Down Expand Up @@ -180,6 +194,7 @@ class WeightNodeParams:
:param weight_node: The weight node itself.
:param original_weight_dtype: Type of elements in the weight array.
:param compression_config: Configuration of weight compression for the weight node.
:param metatype: Metatype of the corresponding operation with weight.
"""

reduction_axis: int
Expand All @@ -188,6 +203,7 @@ class WeightNodeParams:
weight_node: ov.Node
original_weight_dtype: TWeightType
compression_config = WeightCompressionConfig()
metatype: OperatorMetatype = None


def _do_integer_quantization(
Expand Down Expand Up @@ -325,29 +341,31 @@ def _proportion_str(num_weights_list: List[int], total_num_weights: int, total_n
return f"{percentage:.0f}% ({len(num_weights_list)} / {total_num_params})"


def _get_bitwidth_distribution_str(all_weight_params: List[WeightNodeParams]) -> str:
def _get_bitwidth_distribution_str(all_params: List[WeightNodeParams], internal_params: List[WeightNodeParams]) -> str:
"""
Generates a table that shows the ratio of weights quantized to different number of bits.
:param all_weight_params: List of information about each weight node.
:param all_params: List of information about each weight node.
:param internal_params: List of information about weight nodes that are considered for mixed precision.
:return: A string containing the table.
"""
total_num_weights = sum(ws.num_weights for ws in all_weight_params)
num_internal_weights = 0
num_params = len(all_weight_params)
num_internal_params = 0
if num_params > 2:
num_internal_params = num_params - 2
not_internal_params = [wp for wp in all_params if wp not in internal_params]
num_bits_vs_num_weights_map = {}
for i, data in enumerate(all_weight_params):
for data in internal_params:
num_bits = data.compression_config.num_bits
n_internal, n_internal = num_bits_vs_num_weights_map.get(num_bits, ([], []))
n_internal.append(data.num_weights)
num_bits_vs_num_weights_map[num_bits] = (n_internal, n_internal)
for data in not_internal_params:
num_bits = data.compression_config.num_bits
n_total, n_internal = num_bits_vs_num_weights_map.get(num_bits, ([], []))
if i not in (0, num_params - 1):
n_internal.append(data.num_weights)
num_internal_weights += data.num_weights
n_total.append(data.num_weights)
num_bits_vs_num_weights_map[num_bits] = (n_total, n_internal)

num_internal_weights = sum(ws.num_weights for ws in internal_params)
num_internal_params = len(internal_params)
total_num_weights = num_internal_weights + sum(ws.num_weights for ws in not_internal_params)
num_params = len(all_params)
num_bits_vs_num_weights_map = OrderedDict(sorted(num_bits_vs_num_weights_map.items(), reverse=True))
# Table creation
header = ["Num bits (N)", "% all parameters (layers)", "% internal parameters (layers)"]
rows = []
Expand All @@ -366,25 +384,25 @@ def _get_bitwidth_distribution_str(all_weight_params: List[WeightNodeParams]) ->


def _assign_mixed_precision(
all_weight_params: List[WeightNodeParams], ratio: float, primary_config: WeightCompressionConfig
internal_weight_params: List[WeightNodeParams], ratio: float, primary_config: WeightCompressionConfig
) -> None:
"""
Assigns mixed quantization scheme (e.g. uniform int8 or non-uniform nf4) for weights based on some criteria.
:param all_weight_params: List of information about each weight node. The quantization scheme is added to this info.
:param internal_weight_params: List of information about internal weight nodes. Only internal nodes are considered
for mixed precision. The quantization scheme is added to this info.
:param ratio: The ratio between primary and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4
and the rest to INT8).
:param primary_config: Information on how to compress (quantize) weights to primary precision.
:return: None.
"""
if ratio == 1:
for weight_param in all_weight_params[1:-1]:
for weight_param in internal_weight_params:
weight_param.compression_config = primary_config
return
errors = []
num_internal_weights = 0
# NOTE: first and last layers are always in 8 bit: no need to calculate error for them
for weight_param in track(all_weight_params[1:-1], description="Searching for Mixed-Precision Configuration"):
for weight_param in track(internal_weight_params, description="Searching for Mixed-Precision Configuration"):
weight = get_const_value(weight_param.weight_node)
backup_config = weight_param.compression_config
reduction_axis = weight_param.reduction_axis
Expand All @@ -393,14 +411,12 @@ def _assign_mixed_precision(
error = 1 / (backup_error + eps)
errors.append(error)
num_internal_weights += weight_param.num_weights
# NOTE: index is defined in the array of all weight params by taking into account that errors were not
# calculated for first and last layers.
indexes_of_layers_in_ascending_order_of_errors = [
i[0] + 1 for i in sorted(enumerate(errors), reverse=False, key=lambda x: x[1])
i[0] for i in sorted(enumerate(errors), reverse=False, key=lambda x: x[1])
]
num_weights_in_4bit = 0
for index in indexes_of_layers_in_ascending_order_of_errors:
weight_param = all_weight_params[index]
weight_param = internal_weight_params[index]
current_ratio = (num_weights_in_4bit + weight_param.num_weights) / num_internal_weights
if current_ratio >= ratio:
break
Expand Down
Original file line number Diff line number Diff line change
@@ -1,99 +1,35 @@
{
"matmul_2_data": {
"compressed_weight": [
[
115,
51,
154,
255,
79,
18,
139
],
[
59,
27,
174,
89,
201,
60,
255
],
[
110,
32,
189,
255,
132,
255,
150
],
[
190,
255,
255,
255,
206,
255,
223
],
[
165,
245,
129,
229,
222,
255,
36
],
[
192,
245,
255,
4,
228,
255,
253
]
],
"zero_point": [
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]
],
"scale": [
[
0.0029188350308686495
[
0.04962019622325897
]
],
[
0.0033386670984327793
[
0.05675733834505081
]
],
[
0.003329785307869315
[
0.05660634860396385
]
],
[
0.0022347758058458567
[
0.03799118846654892
]
],
[
0.003204419743269682
[
0.05447513610124588
]
],
[
0.0037901517935097218
[
0.06443258374929428
]
]
]
},
Expand Down Expand Up @@ -190,30 +126,52 @@
]
},
"gather_2_data": {
"compressed_weight": [
[
181,
77,
12,
5,
231,
255
],
[
166,
200,
149,
255,
223,
1
],
[
255,
10,
224,
54,
255,
166
]
],
"zero_point": [
[
0
],
[
0
],
[
0
]
],
"scale": [
[
[
0.039732541888952255
],
[
0.05974852666258812
]
0.0035146193113178015
],
[
[
0.012391435913741589
],
[
0.062155596911907196
]
0.003656211541965604
],
[
[
0.05492125079035759
],
[
0.04583488777279854
]
0.003253307193517685
]
]
}
Expand Down
Loading

0 comments on commit 5eee3bc

Please sign in to comment.