Skip to content

Commit

Permalink
[TorchFX][MicroFix] Folded constants do not require grad (#3128)
Browse files Browse the repository at this point in the history
### Changes

Folded constants do not require gradient

### Reason for changes

* To unify all model constant/buffers
* To make compressed model deepcopy-able

### Related tickets

#2766 

### Tests

`test_constant_folding` is updated
  • Loading branch information
daniil-lyakhov authored Dec 4, 2024
1 parent 1e6891d commit c1cb354
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 23 deletions.
47 changes: 24 additions & 23 deletions nncf/experimental/torch/fx/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,27 +246,28 @@ def constant_fold(
:param constraint_fn: Constraint function which takes a node and returs the constraint:
should the node be constant folded or not.
"""
with torch.utils._python_dispatch._disable_current_modes():
cf = ConstantFolder(gm)
cf.run()
with torch.no_grad():
with torch.utils._python_dispatch._disable_current_modes():
cf = ConstantFolder(gm)
cf.run()

for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
_replace_node_with_constant(gm, node, constant)

erased_params = []
for node in gm.graph.find_nodes(op="get_attr"):
if len(node.users) == 0:
if hasattr(gm, node.target):
delattr(gm, node.target)
erased_params.append(node)

for node in erased_params:
gm.graph.erase_node(node)

# Custom _is_impure function allows to eliminate all layers with zero
# users including inplace ops like relu_ besides output and placeholders.
gm.graph.eliminate_dead_code(_is_impure)
gm.graph.lint()
gm.recompile()
for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
_replace_node_with_constant(gm, node, constant)

erased_params = []
for node in gm.graph.find_nodes(op="get_attr"):
if len(node.users) == 0:
if hasattr(gm, node.target):
delattr(gm, node.target)
erased_params.append(node)

for node in erased_params:
gm.graph.erase_node(node)

# Custom _is_impure function allows to eliminate all layers with zero
# users including inplace ops like relu_ besides output and placeholders.
gm.graph.eliminate_dead_code(_is_impure)
gm.graph.lint()
gm.recompile()
4 changes: 4 additions & 0 deletions tests/torch/fx/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,10 @@ def test_constant_folding():
captured_model = get_torch_fx_model(model, ex_inputs)
folded_model = deepcopy(captured_model)
constant_fold(folded_model)

# Check the folded const does not require gradient
assert not folded_model._frozen_param0.requires_grad

assert torch.allclose(captured_model(*ex_inputs), folded_model(*ex_inputs))

nncf_graph = GraphConverter.create_nncf_graph(folded_model)
Expand Down

0 comments on commit c1cb354

Please sign in to comment.