Skip to content

Commit

Permalink
[TorchFX] Export to torch.export.export_for_training (#3075)
Browse files Browse the repository at this point in the history
### Changes

* TorchFX Unit tests are moved from
`torch._export.capture_pre_autograd_graph` to
`torch.export.export_for_training`
ALL REFERENCE GRAPHS WERE VALIDATED MANUALLY 
* BC types for `fuse_bn_node` are updated
* NNCFGraphBuilder is updated to support a batch-norm type with only one
output node (instead of three)
* Model extractor does not traverse down from constans to prevent
redundant nodes in the extracted model when the constant is shared
* `shared_constants_unification_transformation` is removed
* Tests which require `capture_pre_autograd_graph` are removed

### Reason for changes

* To migrate to the lates and recommended export method for TorchFX
backend

### Related tickets

#2766 

### Tests

test_shared_constants_unification_not_connected_const
post_training_quantization/540/ is finished successfully
  • Loading branch information
daniil-lyakhov authored Nov 14, 2024
1 parent 3d8b0e0 commit 90d15a6
Show file tree
Hide file tree
Showing 78 changed files with 27,846 additions and 28,687 deletions.
3 changes: 3 additions & 0 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def _traverse_graph(
continue

visited.add(in_node.name)
# Any constant is a stop op during the traversing procedure.
if in_node.op == "get_attr":
continue
input_nodes.extend(in_node.all_input_nodes)
input_nodes.extend(list(in_node.users))

Expand Down
4 changes: 3 additions & 1 deletion nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def get_edge_params(
if source_node.op in ("get_attr",):
tensor_shape = tuple(get_tensor_constant_from_node(source_node, model).shape)
elif "val" in source_node.meta:
if source_nncf_node.metatype is om.PTBatchNormMetatype:
if source_nncf_node.metatype is om.PTBatchNormMetatype and isinstance(
source_node.meta["val"], (tuple, list)
):
tensor = source_node.meta["val"][0]
elif source_nncf_node.metatype in [om.PTSplitMetatype, om.PTMaxMetatype, om.PTMinMetatype]:
tensor = source_node.meta["val"][output_idx]
Expand Down
2 changes: 0 additions & 2 deletions nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
from nncf.experimental.torch.fx.transformations import fq_weights_transformation
from nncf.experimental.torch.fx.transformations import revert_quantization_transformations
from nncf.experimental.torch.fx.transformations import shared_constants_unification_transformation
from nncf.parameters import BackupMode
from nncf.parameters import CompressWeightsMode
from nncf.parameters import ModelType
Expand Down Expand Up @@ -158,7 +157,6 @@ def compress_weights_impl(
backup_mode,
advanced_parameters,
)
shared_constants_unification_transformation(model)
graph = NNCFGraphFactory.create(model)
compressed_model = compression_algorithm.apply(model, graph, dataset=dataset)
compressed_model = GraphModule(compressed_model, compressed_model.graph)
Expand Down
24 changes: 1 addition & 23 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,28 +187,6 @@ def bias_update_transformation(model: torch.fx.GraphModule):
return bias_update_transformation


def shared_constants_unification_transformation(model: torch.fx.GraphModule):
"""
checks FX graph for shared constants and eliminates redundant
shared constant while keeping only the first instance of the constant node.
This unification transformation is cruicial since the current algorithms(min_max, solver, BC, etc.)
for torch fx do not utilize the is_shared attribute of nodes for shared constants.
:param model: Target Torch FX GraphModule
"""
prev_targets = {}

for source_node in model.graph.nodes:
dist_node = list(source_node.users)
if source_node.target in prev_targets and source_node.op in ("get_attr",):
dist_node[0].replace_input_with(source_node, prev_targets[source_node.target])
else:
prev_targets[source_node.target] = source_node

model.graph.eliminate_dead_code()
model.recompile()


def constant_update_transformation_builder(
node: NNCFNode, value: torch.Tensor, input_port_id: int = 1
) -> TransformationFNType:
Expand Down Expand Up @@ -541,6 +519,7 @@ def _is_supported_batch_norm_for_training(node: torch.fx.Node):
Return True if the given node refers to an aten batch norm op QAT supports.
"""
supported_ops = [
torch.ops.aten.batch_norm.default,
torch.ops.aten._native_batch_norm_legit.default,
torch.ops.aten.cudnn_batch_norm.default,
torch.ops.aten.miopen_batch_norm.default,
Expand Down Expand Up @@ -807,7 +786,6 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
fuse_conv_bn(model)
separate_conv_and_bias(model)
separate_linear_and_bias(model)
shared_constants_unification_transformation(model)


def fold_constant_except_qdq(model: torch.fx.GraphModule):
Expand Down
14 changes: 11 additions & 3 deletions tests/torch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def pytest_addoption(parser: Parser):
"reference .dot files will be regenerated "
"using the current state of the repository.",
)
parser.addoption(
"--regen-json",
action="store_true",
default=False,
help="If specified, the "
"reference .json files will be regenerated "
"using the current state of the repository.",
)
parser.addoption(
"--torch-home", type=str, default=None, help="Path to cached test models, downloaded by torchvision"
)
Expand Down Expand Up @@ -116,9 +124,9 @@ def pytest_addoption(parser: Parser):


def pytest_configure(config: Config):
regen_dot = config.getoption("--regen-dot", False)
if regen_dot:
os.environ["NNCF_TEST_REGEN_DOT"] = "1"
for regen_option in ["dot", "json"]:
if config.getoption(f"--regen-{regen_option}", False):
os.environ[f"NNCF_TEST_REGEN_{regen_option.upper()}"] = "1"


@pytest.fixture(scope="module")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"0 _conv_w" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _param_constant0" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"0 _conv_w" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"1 add" -> "3 conv2d" [label="(1, 1, 1, 1)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "4 output" [label="(1, 1, 1, 1)", style=solid];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"0 _conv_w" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _param_constant0" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"0 _conv_w" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"1 add" -> "3 conv2d" [label="(1, 1, 1, 1)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "4 output" [label="(1, 1, 3, 3)", style=solid];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
strict digraph {
"0 _param_constant2" [id=0, type=get_attr];
"1 _param_constant3" [id=1, type=get_attr];
"2 conv2d_1_input" [id=2, type=input];
"3 conv2d_1" [id=3, type=conv2d];
"4 _tensor_constant0_1" [id=4, type=get_attr];
"0 conv_b_weight" [id=0, type=get_attr];
"1 conv_b_bias" [id=1, type=get_attr];
"2 bias" [id=2, type=get_attr];
"3 conv2d_1_input" [id=3, type=input];
"4 conv2d_1" [id=4, type=conv2d];
"5 add__1" [id=5, type=add_];
"6 output" [id=6, type=output];
"0 _param_constant2" -> "3 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant3" -> "3 conv2d_1" [label="(3,)", style=solid];
"2 conv2d_1_input" -> "3 conv2d_1" [label=None, style=solid];
"3 conv2d_1" -> "5 add__1" [label="(1, 3, 3, 3)", style=solid];
"4 _tensor_constant0_1" -> "5 add__1" [label="(1,)", style=solid];
"0 conv_b_weight" -> "4 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"1 conv_b_bias" -> "4 conv2d_1" [label="(3,)", style=solid];
"2 bias" -> "5 add__1" [label="(1,)", style=solid];
"3 conv2d_1_input" -> "4 conv2d_1" [label=None, style=solid];
"4 conv2d_1" -> "5 add__1" [label="(1, 3, 3, 3)", style=solid];
"5 add__1" -> "6 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 _param_constant1" [id=1, type=get_attr];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 conv2d_1" [id=6, type=conv2d];
"7 _tensor_constant0" [id=7, type=get_attr];
"8 add_" [id=8, type=add_];
"9 _tensor_constant0_1" [id=9, type=get_attr];
"10 add__1" [id=10, type=add_];
"11 add" [id=11, type=add];
"12 _param_constant4" [id=12, type=get_attr];
"13 _param_constant5" [id=13, type=get_attr];
"14 conv2d_2" [id=14, type=conv2d];
"15 _tensor_constant0_2" [id=15, type=get_attr];
"16 add_1" [id=16, type=add];
"17 output" [id=17, type=output];
"0 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "6 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"3 conv2d" -> "8 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "6 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "6 conv2d_1" [label="(3,)", style=solid];
"6 conv2d_1" -> "10 add__1" [label="(1, 3, 3, 3)", style=solid];
"7 _tensor_constant0" -> "8 add_" [label="(1,)", style=solid];
"8 add_" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"9 _tensor_constant0_1" -> "10 add__1" [label="(1,)", style=solid];
"10 add__1" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"10 add__1" -> "17 output" [label="(1, 3, 3, 3)", style=solid];
"11 add" -> "14 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"12 _param_constant4" -> "14 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"13 _param_constant5" -> "14 conv2d_2" [label="(3,)", style=solid];
"14 conv2d_2" -> "16 add_1" [label="(1, 3, 3, 3)", style=solid];
"15 _tensor_constant0_2" -> "16 add_1" [label="(1,)", style=solid];
"0 conv_a_weight" [id=0, type=get_attr];
"1 conv_a_bias" [id=1, type=get_attr];
"2 conv_b_weight" [id=2, type=get_attr];
"3 conv_b_bias" [id=3, type=get_attr];
"4 conv_c_weight" [id=4, type=get_attr];
"5 conv_c_bias" [id=5, type=get_attr];
"6 bias" [id=6, type=get_attr];
"7 conv2d_input" [id=7, type=input];
"8 conv2d" [id=8, type=conv2d];
"9 conv2d_1" [id=9, type=conv2d];
"10 add_" [id=10, type=add_];
"11 add__1" [id=11, type=add_];
"12 add" [id=12, type=add];
"13 conv2d_2" [id=13, type=conv2d];
"14 add_1" [id=14, type=add];
"15 output" [id=15, type=output];
"0 conv_a_weight" -> "8 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 conv_a_bias" -> "8 conv2d" [label="(3,)", style=solid];
"2 conv_b_weight" -> "9 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"3 conv_b_bias" -> "9 conv2d_1" [label="(3,)", style=solid];
"4 conv_c_weight" -> "13 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"5 conv_c_bias" -> "13 conv2d_2" [label="(3,)", style=solid];
"6 bias" -> "10 add_" [label="(1,)", style=solid];
"6 bias" -> "11 add__1" [label="(1,)", style=solid];
"6 bias" -> "14 add_1" [label="(1,)", style=solid];
"7 conv2d_input" -> "8 conv2d" [label=None, style=solid];
"8 conv2d" -> "9 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"8 conv2d" -> "10 add_" [label="(1, 3, 3, 3)", style=solid];
"9 conv2d_1" -> "11 add__1" [label="(1, 3, 3, 3)", style=solid];
"10 add_" -> "12 add" [label="(1, 3, 3, 3)", style=solid];
"11 add__1" -> "12 add" [label="(1, 3, 3, 3)", style=solid];
"11 add__1" -> "15 output" [label="(1, 3, 3, 3)", style=solid];
"12 add" -> "13 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"13 conv2d_2" -> "14 add_1" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 _param_constant1" [id=1, type=get_attr];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 conv2d_1" [id=6, type=conv2d];
"7 _tensor_constant0" [id=7, type=get_attr];
"0 conv_a_weight" [id=0, type=get_attr];
"1 conv_a_bias" [id=1, type=get_attr];
"2 conv_b_weight" [id=2, type=get_attr];
"3 conv_b_bias" [id=3, type=get_attr];
"4 bias" [id=4, type=get_attr];
"5 conv2d_input" [id=5, type=input];
"6 conv2d" [id=6, type=conv2d];
"7 conv2d_1" [id=7, type=conv2d];
"8 add_" [id=8, type=add_];
"9 _tensor_constant0_1" [id=9, type=get_attr];
"10 add__1" [id=10, type=add_];
"11 add" [id=11, type=add];
"12 output" [id=12, type=output];
"0 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "6 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"3 conv2d" -> "8 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "6 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "6 conv2d_1" [label="(3,)", style=solid];
"6 conv2d_1" -> "10 add__1" [label="(1, 3, 3, 3)", style=solid];
"7 _tensor_constant0" -> "8 add_" [label="(1,)", style=solid];
"8 add_" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"8 add_" -> "12 output" [label="(1, 3, 3, 3)", style=solid];
"9 _tensor_constant0_1" -> "10 add__1" [label="(1,)", style=solid];
"10 add__1" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"11 add" -> "12 output" [label="(1, 3, 3, 3)", style=solid];
"9 add__1" [id=9, type=add_];
"10 add" [id=10, type=add];
"11 output" [id=11, type=output];
"0 conv_a_weight" -> "6 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 conv_a_bias" -> "6 conv2d" [label="(3,)", style=solid];
"2 conv_b_weight" -> "7 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"3 conv_b_bias" -> "7 conv2d_1" [label="(3,)", style=solid];
"4 bias" -> "8 add_" [label="(1,)", style=solid];
"4 bias" -> "9 add__1" [label="(1,)", style=solid];
"5 conv2d_input" -> "6 conv2d" [label=None, style=solid];
"6 conv2d" -> "7 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"6 conv2d" -> "8 add_" [label="(1, 3, 3, 3)", style=solid];
"7 conv2d_1" -> "9 add__1" [label="(1, 3, 3, 3)", style=solid];
"8 add_" -> "10 add" [label="(1, 3, 3, 3)", style=solid];
"8 add_" -> "11 output" [label="(1, 3, 3, 3)", style=solid];
"9 add__1" -> "10 add" [label="(1, 3, 3, 3)", style=solid];
"10 add" -> "11 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 _param_constant1" [id=1, type=get_attr];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 conv2d_1_input" [id=6, type=input];
"7 conv2d_1" [id=7, type=conv2d];
"8 _tensor_constant0" [id=8, type=get_attr];
"0 conv_a_weight" [id=0, type=get_attr];
"1 conv_a_bias" [id=1, type=get_attr];
"2 conv_b_weight" [id=2, type=get_attr];
"3 conv_b_bias" [id=3, type=get_attr];
"4 bias" [id=4, type=get_attr];
"5 conv2d_input" [id=5, type=input];
"6 conv2d" [id=6, type=conv2d];
"7 conv2d_1_input" [id=7, type=input];
"8 conv2d_1" [id=8, type=conv2d];
"9 add_" [id=9, type=add_];
"10 _tensor_constant0_1" [id=10, type=get_attr];
"11 add__1" [id=11, type=add_];
"12 output" [id=12, type=output];
"0 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "9 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "7 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "7 conv2d_1" [label="(3,)", style=solid];
"6 conv2d_1_input" -> "7 conv2d_1" [label=None, style=solid];
"7 conv2d_1" -> "11 add__1" [label="(1, 3, 3, 3)", style=solid];
"8 _tensor_constant0" -> "9 add_" [label="(1,)", style=solid];
"9 add_" -> "12 output" [label="(1, 3, 3, 3)", style=solid];
"10 _tensor_constant0_1" -> "11 add__1" [label="(1,)", style=solid];
"11 add__1" -> "12 output" [label="(1, 3, 3, 3)", style=solid];
"10 add__1" [id=10, type=add_];
"11 output" [id=11, type=output];
"0 conv_a_weight" -> "6 conv2d" [label="(3, 3, 1, 1)", style=solid];
"1 conv_a_bias" -> "6 conv2d" [label="(3,)", style=solid];
"2 conv_b_weight" -> "8 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"3 conv_b_bias" -> "8 conv2d_1" [label="(3,)", style=solid];
"4 bias" -> "9 add_" [label="(1,)", style=solid];
"4 bias" -> "10 add__1" [label="(1,)", style=solid];
"5 conv2d_input" -> "6 conv2d" [label=None, style=solid];
"6 conv2d" -> "9 add_" [label="(1, 3, 3, 3)", style=solid];
"7 conv2d_1_input" -> "8 conv2d_1" [label=None, style=solid];
"8 conv2d_1" -> "10 add__1" [label="(1, 3, 3, 3)", style=solid];
"9 add_" -> "11 output" [label="(1, 3, 3, 3)", style=solid];
"10 add__1" -> "11 output" [label="(1, 3, 3, 3)", style=solid];
}
Loading

0 comments on commit 90d15a6

Please sign in to comment.