Skip to content

Commit

Permalink
[PT] Extractor for fused convolution (#2559)
Browse files Browse the repository at this point in the history
### Changes

- Get layer attributes by args, kwargs for `conv`, `batch_norm`,
`group_norm` and `linear` functions.
- Extract submodules for FBC for custom `conv` and `batch norm` modules.
- Add `model_graph_manager` module to works with models with wrapped
with `trace_parameters=True`. (model_analizer.py will be deprecated
after finished refactor PTQ to use graph with constant nodes)
- Moved nncf/quantization/algorithms/weight_compression/torch_backend.py
to model_graph_manager


### Related tickets

129581

### Tests

tests/torch/test_model_graph_manager.py
tests/torch/test_extractor.py
  • Loading branch information
AlexanderDokuchaev authored Mar 21, 2024
1 parent 12f4720 commit 81ff8b4
Show file tree
Hide file tree
Showing 13 changed files with 1,338 additions and 180 deletions.
17 changes: 10 additions & 7 deletions nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,23 @@ def __init__(
dilations: Tuple[int, ...],
groups: int,
transpose: bool,
padding_values: Tuple[int, ...],
padding_values: Union[Tuple[int, ...], int],
with_bias: bool = False,
output_padding_values: Optional[Union[Tuple[int, ...], int]] = None,
):
"""
:param weight_requires_grad: Is True if gradients need to be computed for the corresponding Tensor,
False otherwise.
:param in_channels: number of input channels in the layer's input.
:param out_channels: number of channels produced by the layer.
:param kernel_size: size of the convolving kernel.
:param stride: stride of the convolution.
:param groups: number of blocked connections from input channels to output channels.
:param in_channels: Number of input channels in the layer's input.
:param out_channels: Number of channels produced by the layer.
:param kernel_size: Size of the convolving kernel.
:param stride: Stride of the convolution.
:param groups: Number of blocked connections from input channels to output channels.
:param transpose: If set to `True`, the layer is an ordinary convolution, otherwise - transpose one.
:param padding_values: defines the amount of padding applied to the layer's input.
:param padding_values: Defines the amount of padding applied to the layer's input.
:param with_bias: Operation include bias.
:param output_padding_values: Defines the amount of output padding applied to the layer's output, for transpose.
"""
super().__init__(weight_requires_grad=weight_requires_grad, with_bias=with_bias)
self.in_channels = in_channels
Expand All @@ -175,6 +177,7 @@ def __init__(
self.groups = groups
self.transpose = transpose
self.padding_values = padding_values
self.output_padding_values = output_padding_values

def get_weight_shape(self) -> List[int]:
if not self.transpose:
Expand Down
12 changes: 11 additions & 1 deletion nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Type
from typing import List, Optional, Set, Type

import nncf
from nncf.common.graph.definitions import NNCFGraphNodeType
Expand Down Expand Up @@ -187,3 +187,13 @@ class ConstNoopMetatype(OperatorMetatype):
@classmethod
def get_all_aliases(cls) -> List[str]:
return [NNCFGraphNodeType.CONST_NODE]


def get_all_aliases(*metatypes: OperatorMetatype) -> Set[str]:
"""
Returns a set of all unique aliases from the provided metatypes.
:param *metatypes: A list of operator metatypes.
:return: A set containing all unique aliases for metatypes.
"""
return set(a for m in metatypes for a in m.get_all_aliases())
64 changes: 12 additions & 52 deletions nncf/quantization/algorithms/weight_compression/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Tuple

import torch

Expand All @@ -32,56 +32,16 @@
from nncf.torch.graph import operator_metatypes as om
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import find_const_node_in_constant_subgraph
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.model_graph_manager import get_module_by_name
from nncf.torch.model_graph_manager import split_const_name
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.layers import WeightsDecompressor
from nncf.torch.tensor_statistics.collectors import get_raw_stat_collector


def split_weight_name(weight_name: str) -> Tuple[str, str]:
index = weight_name.rfind(".")
if index == -1:
return str(), weight_name
module_name = weight_name[:index]
weight_attr_name = weight_name[index + 1 :]
return module_name, weight_attr_name


def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Module:
if not module_name:
return model
curr_module = model
for name in module_name.split("."):
for child_name, child_module in curr_module.named_children():
if child_name == name:
curr_module = child_module
break
else:
raise nncf.ModuleNotFoundError(f"Could not find the {module_name} module in the model.")
return curr_module


def find_weight_node_in_constant_subgraph(node: NNCFNode, graph: NNCFGraph) -> Union[NNCFNode, None]:
if node.metatype == om.PTNoopMetatype:
prev_nodes = graph.get_previous_nodes(node)
if len(prev_nodes) != 1:
return None
return find_weight_node_in_constant_subgraph(prev_nodes[0], graph)
if node.metatype in CONST_NOOP_METATYPES:
return node
return None


def get_weight_node(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> NNCFNode:
for prev_node in graph.get_previous_nodes(node_with_weight):
edge = graph.get_edge(prev_node, node_with_weight)
if edge.input_port_id == weight_port_id:
weight_node = find_weight_node_in_constant_subgraph(prev_node, graph)
if weight_node is None:
raise nncf.InternalError("Could not find a constant node in the model graph.")
return weight_node


class PTWeightCompressionAlgoBackend(WeightCompressionAlgoBackend):
TARGET_TYPE_TO_PT_INS_TYPE_MAP = {
TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK,
Expand Down Expand Up @@ -125,7 +85,7 @@ def is_node_with_weights(node: NNCFNode, graph: NNCFGraph) -> bool:
edge = graph.get_edge(prev_node, node)
if edge.input_port_id not in node.metatype.weight_port_ids:
continue
weight_node = find_weight_node_in_constant_subgraph(prev_node, graph)
weight_node = find_const_node_in_constant_subgraph(prev_node, graph)
if weight_node is not None:
return True
return False
Expand All @@ -134,7 +94,7 @@ def is_node_with_weights(node: NNCFNode, graph: NNCFGraph) -> bool:
def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Tuple[str, int]]:
weight_port_ids = []
for prev_node in graph.get_previous_nodes(node):
weight_node = find_weight_node_in_constant_subgraph(prev_node, graph)
weight_node = find_const_node_in_constant_subgraph(prev_node, graph)
if weight_node is None:
continue
edge = graph.get_edge(prev_node, node)
Expand All @@ -146,7 +106,7 @@ def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Tupl
def get_channel_agnostic_reduction_axes(
node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph
) -> Optional[Tuple[int]]:
weight_node = get_weight_node(node_with_weight, weight_port_id, graph)
weight_node = get_const_node(node_with_weight, weight_port_id, graph)

ndims = len(weight_node.layer_attributes.shape)
reduction_axes = None
Expand Down Expand Up @@ -200,9 +160,9 @@ def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int:
def get_weight(
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph
) -> Tensor:
weight_node = get_weight_node(node_with_weight, weight_port_id, graph)
weight_node = get_const_node(node_with_weight, weight_port_id, graph)
weight_name = weight_node.layer_attributes.name
module_name, weight_attr_name = split_weight_name(weight_name)
module_name, weight_attr_name = split_const_name(weight_name)
module = get_module_by_name(module_name, model)
weight = getattr(module, weight_attr_name)
if weight is None or not isinstance(weight, torch.nn.Parameter):
Expand All @@ -229,9 +189,9 @@ def transform_model(
]:
raise ValueError(f"{compression_config.mode.value} is not supported.")

weight_node = get_weight_node(wc_params.node_with_weight, wc_params.weight_port_id, graph)
weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph)
weight_name = weight_node.layer_attributes.name
module_name, weight_attr_name = split_weight_name(weight_name)
module_name, weight_attr_name = split_const_name(weight_name)
module = get_module_by_name(module_name, model)
weight = getattr(module, weight_attr_name)
if weight is None or not isinstance(weight, torch.nn.Parameter):
Expand Down
Loading

0 comments on commit 81ff8b4

Please sign in to comment.