Skip to content

Commit

Permalink
[TorchFX] Weights Compression Support (#2891)
Browse files Browse the repository at this point in the history
### Changes

1. Added weights compression implementation from template weights
compression.
2. Modified graph builder for torch fx to include edge case where
embedding node's weight was not being placed on the right port.
3. Updated the torch weights compression tests to include FX embedding
metatype for reusability of some torch test functions in fx test.

### Reason for changes

To support nncf.compress_weights() for Torch Fx models.

### Tests

Added test at `tests/torch/fx/test_compress_weights.py` Reused the
models and some tests from the torch implementation and included some
extra checks such as the size of compressed model being lower than
original model.

### Performance:
tinyllama-1.1b-step-50k-105b Inference Speeds:

- Torch Fx Compressed: 0.963s
- Torch Fx Compiled with OV backend: 0.074s
- Torch Fx, Compiled with OV backend and compressed: 0.04s
- OV FP32: 0.079s
- OV int8:  0.039s

### Constraints
Currently only supports Torch FX representations extracted using the
`torch._export.capture_pre_autograd_graph()`. #2987 outlines the request
to support weights compression for FX models extracted using
`torch.export.export`
  • Loading branch information
anzr299 authored Sep 26, 2024
1 parent f274cf5 commit c93676d
Show file tree
Hide file tree
Showing 18 changed files with 951 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Weights Compression

[OpenVINO](https://github.com/openvinotoolkit/openvino) is the preferred backend to run Weights Compression with, and PyTorch is also supported.
[OpenVINO](https://github.com/openvinotoolkit/openvino) is the preferred backend to run Weights Compression with. PyTorch and Torch FX are also supported.

### The algorithm description

Expand Down Expand Up @@ -800,7 +800,7 @@ Accuracy/footprint trade-off for `microsoft/Phi-3-mini-4k-instruct`:

### Limitations

- The algorithm is supported for OpenVINO and PyTorch models.
- The algorithm is supported for OpenVINO, PyTorch and Torch FX models.
- The compression applies in-place.
- The compressed model is not trainable.
- INT4_SYM, INT4_ASYM, NF4 and E2M1 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only.
Expand Down
30 changes: 24 additions & 6 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import Counter
from typing import Tuple

import torch.fx
Expand Down Expand Up @@ -64,6 +65,22 @@ def _get_layer_attributes(
)
return None

def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype) -> om.OperatorMetatype:
"""
Attempts to retrieve correct subtype for the given node.
:param node: Given node.
:param metatype: Given node metatype.
:param model: Target GraphModule instance.
:return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise.
"""
if metatype in [om.PTEmbeddingMetatype]:
weight_node = node.args[0]
if weight_node.op == "get_attr":
return om.PTAtenEmbeddingMetatype

return metatype

@staticmethod
def _get_node_type_and_metatype(
node: torch.fx.Node, model: torch.fx.GraphModule
Expand Down Expand Up @@ -115,16 +132,18 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
:param model: torch fx GraphModule.
:return: NNCFGraph.
"""

nncf_graph = PTNNCFGraph()

const_targets_counter = Counter([node.target for node in model.graph.nodes if node.op == "get_attr"])
for source_node in model.graph.nodes:
node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node, model)
node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype)
is_shared_node = source_node.op in ("get_attr",) and (
const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1
)

nncf_graph.add_nncf_node(
node_name=source_node.name,
node_type=node_type,
node_metatype=node_metatype,
node_name=source_node.name, node_type=node_type, node_metatype=node_metatype, is_shared=is_shared_node
)

for source_node in model.graph.nodes:
Expand All @@ -134,7 +153,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params(
model, source_node, source_nncf_node, dist_node, idx
)

nncf_graph.add_edge_between_nncf_nodes(
source_nncf_node.node_id,
dist_node_id,
Expand All @@ -160,7 +178,7 @@ def get_edge_params(
:param source_node: Source node in format of torch.fx.Node.
:param source_nncf_node: Source node in format of NNCFNode.
:param dist_node: Distance node in format of torch.fx.Node.
:param output_idx: Output indes of the source_node.
:param output_idx: Output index of the source_node.
:return: Tuple of edge parameters: edge input port id, edge output port id and
edge tensor shape.
"""
Expand Down
50 changes: 49 additions & 1 deletion nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@
from nncf.data import Dataset
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
from nncf.experimental.torch.fx.transformations import revert_quantization_transformations
from nncf.experimental.torch.fx.transformations import shared_constants_unification_transformation
from nncf.parameters import CompressWeightsMode
from nncf.parameters import ModelType
from nncf.parameters import QuantizationMode
from nncf.parameters import SensitivityMetric
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression
from nncf.scopes import IgnoredScope

DEFAULT_RANGE_TYPE = "mean_min_max"
Expand All @@ -49,7 +54,7 @@ def quantize_impl(
model_type: Optional[ModelType] = None,
ignored_scope: Optional[IgnoredScope] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
) -> torch.nn.Module:
) -> torch.fx.GraphModule:
"""
Implementation of the `quantize()` method for the Torch FX backend.
"""
Expand Down Expand Up @@ -103,3 +108,46 @@ def quantize_impl(
quantized_model = _disallow_eval_train(quantized_model)

return quantized_model


def compress_weights_impl(
model: torch.fx.GraphModule,
dataset: Dataset,
mode: CompressWeightsMode,
ratio: float,
group_size: int,
ignored_scope: IgnoredScope,
all_layers: bool,
sensitivity_metric: SensitivityMetric,
awq: bool,
subset_size: int,
scale_estimation: bool,
gptq: bool,
lora_correction: bool,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
) -> torch.fx.GraphModule:
"""
Implementation of the `compress_weights()` method for the Torch Fx backend.
"""

compression_algorithm = WeightCompression(
mode,
ratio,
group_size,
ignored_scope,
all_layers,
sensitivity_metric,
awq,
subset_size,
scale_estimation,
gptq,
lora_correction,
advanced_parameters,
)
shared_constants_unification_transformation(model)
graph = NNCFGraphFactory.create(model)
compressed_model = compression_algorithm.apply(model, graph, dataset=dataset)
compressed_model = GraphModule(compressed_model, compressed_model.graph)
compressed_model = _disallow_eval_train(compressed_model)

return compressed_model
60 changes: 46 additions & 14 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
TransformationFNType = Callable[[torch.fx.GraphModule], None]


def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module):
def _set_new_node_meta(
new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module, model: torch.fx.GraphModule
):
"""
Sets correct meta \"val\" value to the new node.
Expand All @@ -37,7 +39,11 @@ def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target
New node expected to have only one input node.
:param target_module: Module which is being called by the new node.
"""
val = prev_node.meta["val"]
val = (
prev_node.meta["val"]
if prev_node.op not in ["get_attr"]
else get_tensor_constant_from_node(prev_node, model).data
)
val = val if isinstance(val, tuple) else (val,)
retval = []
for t in val:
Expand Down Expand Up @@ -71,16 +77,16 @@ def module_insertion_transformation(model: torch.fx.GraphModule):
target_node = get_graph_node_by_name(graph, target_point.target_node_name)

if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
_set_new_node_meta(new_node, target_node, module_to_insert)
_set_new_node_meta(new_node, target_node, module_to_insert, model)
with graph.inserting_after(target_node):
for user in target_node.users:
for user in list(target_node.users):
if user is new_node:
continue
user.replace_input_with(target_node, new_node)

else:
prev_node = target_node.args[target_point.input_port_id]
_set_new_node_meta(new_node, prev_node, module_to_insert)
_set_new_node_meta(new_node, prev_node, module_to_insert, model)
target_node.replace_input_with(prev_node, new_node)

return module_insertion_transformation
Expand Down Expand Up @@ -136,17 +142,42 @@ def bias_update_transformation(model: torch.fx.GraphModule):
return bias_update_transformation


def constant_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
def shared_constants_unification_transformation(model: torch.fx.GraphModule):
"""
checks FX graph for shared constants and eliminates redundant
shared constant while keeping only the first instance of the constant node.
This unification transformation is cruicial since the current algorithms(min_max, solver, BC, etc.)
for torch fx do not utilize the is_shared attribute of nodes for shared constants.
:param model: Target Torch FX GraphModule
"""
prev_targets = {}

for source_node in model.graph.nodes:
dist_node = list(source_node.users)
if source_node.target in prev_targets and source_node.op in ("get_attr",):
dist_node[0].replace_input_with(source_node, prev_targets[source_node.target])
else:
prev_targets[source_node.target] = source_node

model.graph.eliminate_dead_code()
model.recompile()


def constant_update_transformation_builder(
node: NNCFNode, value: torch.Tensor, input_port_id: int = 1
) -> TransformationFNType:
"""
Return transformation which updates constant of the given node to the given value.
:param node: Node which requires bias constant update.
:param value: New value to use as the node constant.
:param input_port_id: Port Id of the constant.
:return: Transformation which updates constant of the given node to the given value.
"""

def constant_update_transformation(model: torch.fx.GraphModule):
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=1)
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id)

return constant_update_transformation

Expand All @@ -161,9 +192,6 @@ def constant_update_fn(model: torch.fx.GraphModule, node: torch.fx.Node, value:
:param input_port_id: Target constant input port id.
"""
graph = model.graph
with graph.inserting_before(node):
new_constant = create_getattr_from_value(model, graph, node.name + "_updated_constant", value)

args = list(node.args)
# A bias node suppose to have constant on the second input port.
if args[input_port_id].op != "get_attr":
Expand All @@ -174,11 +202,14 @@ def constant_update_fn(model: torch.fx.GraphModule, node: torch.fx.Node, value:

# Update metadata of the new constant node.
previous_const = args[input_port_id]
new_constant.meta = copy(previous_const.meta)
new_constant.meta["val"] = value
consumer_nodes = list(previous_const.users)
# This list of consumer nodes will always be topologically sorted
# To ensure the updated node has the right order,
# we insert constant node before the node placed at the highest order in topological order.
with graph.inserting_before(consumer_nodes[0]):
new_constant = create_getattr_from_value(model, graph, node.name + "_updated_constant", value)

args[input_port_id] = new_constant
node.args = tuple(args)
previous_const.replace_all_uses_with(new_constant, propagate_meta=True)
graph.eliminate_dead_code()


Expand Down Expand Up @@ -509,6 +540,7 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
fuse_conv_bn(model)
separate_conv_and_bias(model)
separate_linear_and_bias(model)
shared_constants_unification_transformation(model)


def revert_quantization_transformations(model: torch.fx.GraphModule) -> None:
Expand Down
6 changes: 5 additions & 1 deletion nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(

@property
def available_backends(self) -> List[BackendType]:
return [BackendType.OPENVINO, BackendType.TORCH]
return [BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX]

def _set_backend_entity(self, model: TModel) -> None:
"""
Expand All @@ -152,6 +152,10 @@ def _set_backend_entity(self, model: TModel) -> None:
from nncf.quantization.algorithms.weight_compression.torch_backend import PTWeightCompressionAlgoBackend

self._backend_entity = PTWeightCompressionAlgoBackend()
elif model_backend == BackendType.TORCH_FX:
from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXWeightCompressionAlgoBackend

self._backend_entity = FXWeightCompressionAlgoBackend()
else:
raise nncf.UnsupportedBackendError(
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class PTWeightCompressionAlgoBackend(WeightCompressionAlgoBackend):
TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK,
}
MATMUL_METATYPES = [om.PTLinearMetatype, om.PTMatMulMetatype, om.PTAddmmMetatype]
EMBEDDING_METATYPES = [om.PTEmbeddingMetatype]
EMBEDDING_METATYPES = [om.PTEmbeddingMetatype, om.PTAtenEmbeddingMetatype]
CONVOLUTION_METATYPES = [
om.PTConv1dMetatype,
om.PTConv2dMetatype,
Expand Down
Loading

0 comments on commit c93676d

Please sign in to comment.