Skip to content

Commit

Permalink
[PT] add metatype for torch2 (#3107)
Browse files Browse the repository at this point in the history
### Changes

Add metatype to NNCFGraph
Add "NNCF_EXPERIMENTAL_TORCH_TRACING" env variable to disable patching,
and select experimental tracing in algorithms in feature
Add `get_reference_graph` function to collect metatype in reference
graph

### Related tickets

152996
  • Loading branch information
AlexanderDokuchaev authored Dec 5, 2024
1 parent e1802b7 commit 5189aab
Show file tree
Hide file tree
Showing 20 changed files with 14,907 additions and 14,837 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@


from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import networkx as nx # type: ignore
import torch
from torch import nn

import nncf
import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import Dtype
Expand Down Expand Up @@ -76,6 +77,20 @@ def get_dtype(dtype: torch.dtype) -> Dtype:
return Dtype.INTEGER


def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> om.PTOperatorMetatype:
"""
Converts the node type and metadata into a PTOperatorMetatype object.
:param node_type: The type of the node.
:param meta: The metadata associated with the node.
:return: The PTOperatorMetatype object.
"""
node_metatype = cast(om.PTOperatorMetatype, om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type))
node_sub_meta_type: Optional[om.PTOperatorMetatype] = None
if node_metatype.get_subtypes() and isinstance(meta, FunctionMeta):
node_sub_meta_type = node_metatype.determine_subtype(function_args=meta.args, functions_kwargs=meta.kwargs)
return node_sub_meta_type or node_metatype


def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
"""
Converts a graph to an NNCFGraph.
Expand All @@ -88,12 +103,14 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
map_nx_node_to_nncf_node: Dict[int, NNCFNode] = {}
for node, data in nx_graph.nodes(data=True):
meta: Union[ConstMeta, FunctionMeta, InOutMeta] = data["meta"]
node_name = get_name_of_node(meta)
node_type = get_node_type(data["type"], meta)
node_metatype = None # TODO(AlexanderDokuchaev): add node_metatype
meta_type = get_meta_type(node_type, meta)

nncf_node = nncf_graph.add_nncf_node(
node_name=get_name_of_node(meta),
node_name=node_name,
node_type=node_type,
node_metatype=node_metatype, # type: ignore[arg-type]
node_metatype=meta_type, # type: ignore[arg-type]
)
map_nx_node_to_nncf_node[node] = nncf_node

Expand Down
4 changes: 3 additions & 1 deletion nncf/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Base subpackage for NNCF PyTorch functionality.
"""

import os
from nncf import nncf_logger
from nncf.common.logging.logger import warn_bkc_version_mismatch

Expand Down Expand Up @@ -76,4 +77,5 @@
if torch.__version__ >= "2.5.0":
from torch._dynamo.polyfills import loader

patch_torch_operators()
if os.getenv("NNCF_EXPERIMENTAL_TORCH_TRACING") is None:
patch_torch_operators()
7 changes: 7 additions & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ class PTDepthwiseConvOperatorSubtype(PTOperatorSubtype):
def matches(
cls, layer_attributes: Optional[BaseLayerAttributes] = None, function_args=None, functions_kwargs=None
) -> bool:
if layer_attributes is None and function_args is not None and functions_kwargs is not None:
# Used for torch2
weight_meta = functions_kwargs.get("weight", function_args[0])
in_channels = weight_meta.shape[1]
groups = functions_kwargs.get("groups", function_args[6] if len(function_args) > 6 else 1)
return in_channels > 1 and groups == in_channels

if _is_called_inside_nncf_module(functions_kwargs):
return False
if not isinstance(layer_attributes, ConvolutionLayerAttributes):
Expand Down
3 changes: 3 additions & 0 deletions tests/torch2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def pytest_configure(config: Config) -> None:
if nncf_debug:
set_log_level(logging.DEBUG)

# Disable patching of torch functions
os.environ["NNCF_EXPERIMENTAL_TORCH_TRACING"] = "1"


@pytest.fixture
def regen_ref_data(request: FixtureRequest):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
strict digraph {
"0 x" [id=0, type=nncf_model_input];
"1 conv.weight" [id=1, type=nncf_model_const];
"2 conv.bias" [id=2, type=nncf_model_const];
"3 conv/conv2d/0" [id=3, type=conv2d];
"4 __nncf_hooks.post_hooks.conv/conv2d/0__0.0.w" [id=4, type=nncf_model_const];
"5 conv/post_hook__conv-conv2d-0__0[0]/add/0" [id=5, type=add];
"6 /relu/0" [id=6, type=relu];
"7 output" [id=7, type=nncf_model_output];
"0 x" -> "3 conv/conv2d/0" [label="(1, 1, 3, 3)", style=solid];
"1 conv.weight" -> "3 conv/conv2d/0" [label="(1, 1, 1, 1)", style=solid];
"2 conv.bias" -> "3 conv/conv2d/0" [label="(1,)", style=solid];
"3 conv/conv2d/0" -> "5 conv/post_hook__conv-conv2d-0__0[0]/add/0" [label="(1, 1, 3, 3)", style=solid];
"4 __nncf_hooks.post_hooks.conv/conv2d/0__0.0.w" -> "5 conv/post_hook__conv-conv2d-0__0[0]/add/0" [label="(1,)", style=solid];
"5 conv/post_hook__conv-conv2d-0__0[0]/add/0" -> "6 /relu/0" [label="(1, 1, 3, 3)", style=solid];
"6 /relu/0" -> "7 output" [label="(1, 1, 3, 3)", style=solid];
x [id=0, metatype=PTInputNoopMetatype, type=nncf_model_input];
"conv.weight" [id=1, metatype=PTConstNoopMetatype, type=nncf_model_const];
"conv.bias" [id=2, metatype=PTConstNoopMetatype, type=nncf_model_const];
"conv/conv2d/0" [id=3, metatype=PTConv2dMetatype, type=conv2d];
"__nncf_hooks.post_hooks.conv/conv2d/0__0.0.w" [id=4, metatype=PTConstNoopMetatype, type=nncf_model_const];
"conv/post_hook__conv-conv2d-0__0[0]/add/0" [id=5, metatype=PTAddMetatype, type=add];
"/relu/0" [id=6, metatype=PTRELUMetatype, type=relu];
output [id=7, metatype=PTOutputNoopMetatype, type=nncf_model_output];
x -> "conv/conv2d/0" [dtype=float, shape="(1, 1, 3, 3)"];
"conv.weight" -> "conv/conv2d/0" [dtype=float, shape="(1, 1, 1, 1)"];
"conv.bias" -> "conv/conv2d/0" [dtype=float, shape="(1,)"];
"conv/conv2d/0" -> "conv/post_hook__conv-conv2d-0__0[0]/add/0" [dtype=float, shape="(1, 1, 3, 3)"];
"__nncf_hooks.post_hooks.conv/conv2d/0__0.0.w" -> "conv/post_hook__conv-conv2d-0__0[0]/add/0" [dtype=float, shape="(1,)"];
"conv/post_hook__conv-conv2d-0__0[0]/add/0" -> "/relu/0" [dtype=float, shape="(1, 1, 3, 3)"];
"/relu/0" -> output [dtype=float, shape="(1, 1, 3, 3)"];
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
strict digraph {
"0 x" [id=0, type=nncf_model_input];
"1 /add/0" [id=1, type=add];
"2 output" [id=2, type=nncf_model_output];
"0 x" -> "1 /add/0" [label="parallel_input_port_ids [1], shape (1, 1)", style=solid];
"1 /add/0" -> "2 output" [label="(1, 1)", style=solid];
x [id=0, metatype=PTInputNoopMetatype, type=nncf_model_input];
"/add/0" [id=1, metatype=PTAddMetatype, type=add];
output [id=2, metatype=PTOutputNoopMetatype, type=nncf_model_output];
x -> "/add/0" [dtype=float, parallel_input_port_ids="[1]", shape="(1, 1)"];
"/add/0" -> output [dtype=float, shape="(1, 1)"];
}
Loading

0 comments on commit 5189aab

Please sign in to comment.