Skip to content

Commit

Permalink
test_constant_folding
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 21, 2024
1 parent 2539929 commit 7a8ab33
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 8 deletions.
16 changes: 8 additions & 8 deletions nncf/experimental/torch/fx/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
aten = torch.ops.aten


def replace_node_with_constant(
def _replace_node_with_constant(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
constant: torch.Tensor,
Expand Down Expand Up @@ -52,13 +52,13 @@ def replace_node_with_constant(
setattr(gm, qualname, constant)


def is_const_source(node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]) -> bool:
def _is_const_source(node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]) -> bool:
return node.op == "get_attr" or (
node.op == "placeholder" and lifted_constants is not None and node.name in lifted_constants
)


class ConstantFolder(torch.fx.Interpreter):
class _ConstantFolder(torch.fx.Interpreter):
def __init__(
self,
gm: torch.fx.GraphModule,
Expand Down Expand Up @@ -101,7 +101,7 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
and len(node.users) == 1
and is_woq_int8_pattern(next(iter(node.users)))
)
) and is_const_source(
) and _is_const_source(
node.args[0], self.lifted_constants # type: ignore[arg-type]
):
# Case 1: int8_weight -> dq -> bf16_weight
Expand Down Expand Up @@ -180,7 +180,7 @@ def set_env(arg: torch.fx.Node) -> None:
# TODO - more complicated strategy
if (
self.skip_constructors
and not is_const_source(node, self.lifted_constants)
and not _is_const_source(node, self.lifted_constants)
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value
Expand All @@ -193,7 +193,7 @@ def set_env(arg: torch.fx.Node) -> None:
if out == self.unknown_value:
return self.unknown_value

if not is_const_source(node, self.lifted_constants) and isinstance(out, torch.Tensor):
if not _is_const_source(node, self.lifted_constants) and isinstance(out, torch.Tensor):
if out.device.type == "meta":
return out

Expand Down Expand Up @@ -247,11 +247,11 @@ def constant_fold(
:param gm: Given graph model.
"""
with torch.utils._python_dispatch._disable_current_modes():
cf = ConstantFolder(gm, skip_constructors=True)
cf = _ConstantFolder(gm, skip_constructors=True)
cf.run()

for node, constant in cf.node_replacements.items():
replace_node_with_constant(gm, node, constant)
_replace_node_with_constant(gm, node, constant)

erased_params = []
for node in gm.graph.find_nodes(op="get_attr"):
Expand Down
15 changes: 15 additions & 0 deletions tests/torch/data/reference_graphs/fx/transformed/folded_model.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _frozen_param0" [id=1, type=get_attr];
"2 _param_constant3" [id=2, type=get_attr];
"3 _param_constant4" [id=3, type=get_attr];
"4 linear_1" [id=4, type=linear];
"5 add" [id=5, type=add];
"6 output" [id=6, type=output];
"0 arg0_1" -> "4 linear_1" [label="(1, 3, 3, 3)", style=solid];
"1 _frozen_param0" -> "5 add" [label="(3, 3)", style=solid];
"2 _param_constant3" -> "4 linear_1" [label="(3, 3)", style=solid];
"3 _param_constant4" -> "4 linear_1" [label="(3,)", style=solid];
"4 linear_1" -> "5 add" [label="(1, 3, 3, 3)", style=solid];
"5 add" -> "6 output" [label="(1, 3, 3, 3)", style=solid];
}
15 changes: 15 additions & 0 deletions tests/torch/fx/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Tuple
Expand All @@ -29,6 +30,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.experimental.torch.fx.model_transformer import FXModelTransformer
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
Expand All @@ -48,6 +50,7 @@
from nncf.torch.graph.transformations.commands import PTTargetPoint
from tests.torch.fx.test_sanity import count_q_dq
from tests.torch.test_compressed_graph import check_graph
from tests.torch.test_models.synthetic import ConstantFoldingTestModel
from tests.torch.test_models.synthetic import ConvolutionWithAllConstantInputsModel
from tests.torch.test_models.synthetic import ConvolutionWithNotTensorBiasModel
from tests.torch.test_models.synthetic import MultiBranchesConnectedModel
Expand Down Expand Up @@ -478,3 +481,15 @@ def test_update_shared_constant():
fx_node_to_check_const_value = get_tensor_constant_from_node(fx_node_to_check_const, captured_model)

assert fx_node_to_check_const_value == torch.tensor([100])


def test_constant_folding():
model = ConstantFoldingTestModel()
captured_model = _capture_model(model, torch.ones(model.INPUT_SIZE))
folded_model = deepcopy(captured_model)
constant_fold(folded_model)
ex_input = torch.ones(model.INPUT_SIZE)
assert torch.allclose(captured_model(ex_input), folded_model(ex_input))

nncf_graph = GraphConverter.create_nncf_graph(folded_model)
check_graph(nncf_graph, "folded_model.dot", TRANSFORMED_GRAPH_DIR_NAME, extended=True)
20 changes: 20 additions & 0 deletions tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,3 +572,23 @@ def forward(self, x):
x = self.relu(self.conv2(x))
x = self.bn2(x)
return x


class ConstantFoldingTestModel(nn.Module):
INPUT_SIZE = (1, 3, 3, 3)

def __init__(self):
super().__init__()
self.linear_act = nn.Linear(3, 3)
self.linear_act.weight.data = 2 * torch.ones((3, 3))

self.linear_w = nn.Linear(3, 3)
self.linear_w.weight.data = 3 * torch.ones((3, 3))

self.param = nn.Parameter(4 * torch.ones((3, 3)))

def forward(self, x):
y = self.linear_w(self.param)
y += 10
x = self.linear_act(x)
return x + y

0 comments on commit 7a8ab33

Please sign in to comment.