Skip to content

Commit

Permalink
Constant folding empty input test case
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 25, 2024
1 parent b7c9e2e commit d8b8a60
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 52 deletions.
21 changes: 11 additions & 10 deletions tests/torch/data/reference_graphs/fx/transformed/folded_model.dot
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ strict digraph {
"0 linear_act_weight" [id=0, type=get_attr];
"1 linear_act_bias" [id=1, type=get_attr];
"2 x" [id=2, type=input];
"3 _frozen_param0" [id=3, type=get_attr];
"4 linear_1" [id=4, type=linear];
"5 add" [id=5, type=add];
"6 output" [id=6, type=output];
"0 linear_act_weight" -> "4 linear_1" [label="(3, 3)", style=solid];
"1 linear_act_bias" -> "4 linear_1" [label="(3,)", style=solid];
"2 x" -> "4 linear_1" [label="(1, 3, 3, 3)", style=solid];
"3 _frozen_param0" -> "5 add" [label="(3, 3)", style=solid];
"4 linear_1" -> "5 add" [label="(1, 3, 3, 3)", style=solid];
"5 add" -> "6 output" [label="(1, 3, 3, 3)", style=solid];
"3 dummy_disconnected_input" [id=3, type=input];
"4 _frozen_param0" [id=4, type=get_attr];
"5 linear_1" [id=5, type=linear];
"6 add" [id=6, type=add];
"7 output" [id=7, type=output];
"0 linear_act_weight" -> "5 linear_1" [label="(3, 3)", style=solid];
"1 linear_act_bias" -> "5 linear_1" [label="(3,)", style=solid];
"2 x" -> "5 linear_1" [label="(1, 3, 3, 3)", style=solid];
"4 _frozen_param0" -> "6 add" [label="(3, 3)", style=solid];
"5 linear_1" -> "6 add" [label="(1, 3, 3, 3)", style=solid];
"6 add" -> "7 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@ strict digraph {
"0 linear_act_weight" [id=0, type=get_attr];
"1 linear_act_bias" [id=1, type=get_attr];
"2 x" [id=2, type=input];
"3 _frozen_param0" [id=3, type=get_attr];
"4 scale_node0" [id=4, type=get_attr];
"5 weight_node0" [id=5, type=get_attr];
"6 quantize_per_channel_default" [id=6, type=quantize_per_channel];
"7 dequantize_per_channel_default" [id=7, type=dequantize_per_channel];
"8 linear_1" [id=8, type=linear];
"9 add" [id=9, type=add];
"10 output_1" [id=10, type=output];
"0 linear_act_weight" -> "6 quantize_per_channel_default" [label="(3, 3)", style=solid];
"1 linear_act_bias" -> "8 linear_1" [label="(3,)", style=solid];
"2 x" -> "8 linear_1" [label="(1, 3, 3, 3)", style=solid];
"3 _frozen_param0" -> "9 add" [label="(3, 3)", style=solid];
"4 scale_node0" -> "6 quantize_per_channel_default" [label="(3,)", style=solid];
"4 scale_node0" -> "7 dequantize_per_channel_default" [label="(3,)", style=solid];
"5 weight_node0" -> "6 quantize_per_channel_default" [label="(3,)", style=solid];
"5 weight_node0" -> "7 dequantize_per_channel_default" [label="(3,)", style=solid];
"6 quantize_per_channel_default" -> "7 dequantize_per_channel_default" [label="(3, 3)", style=solid];
"7 dequantize_per_channel_default" -> "8 linear_1" [label="(3, 3)", style=solid];
"8 linear_1" -> "9 add" [label="(1, 3, 3, 3)", style=solid];
"9 add" -> "10 output_1" [label="(1, 3, 3, 3)", style=solid];
"3 dummy_disconnected_input" [id=3, type=input];
"4 _frozen_param0" [id=4, type=get_attr];
"5 scale_node0" [id=5, type=get_attr];
"6 weight_node0" [id=6, type=get_attr];
"7 quantize_per_channel_default" [id=7, type=quantize_per_channel];
"8 dequantize_per_channel_default" [id=8, type=dequantize_per_channel];
"9 linear_1" [id=9, type=linear];
"10 add" [id=10, type=add];
"11 output_1" [id=11, type=output];
"0 linear_act_weight" -> "7 quantize_per_channel_default" [label="(3, 3)", style=solid];
"1 linear_act_bias" -> "9 linear_1" [label="(3,)", style=solid];
"2 x" -> "9 linear_1" [label="(1, 3, 3, 3)", style=solid];
"4 _frozen_param0" -> "10 add" [label="(3, 3)", style=solid];
"5 scale_node0" -> "7 quantize_per_channel_default" [label="(3,)", style=solid];
"5 scale_node0" -> "8 dequantize_per_channel_default" [label="(3,)", style=solid];
"6 weight_node0" -> "7 quantize_per_channel_default" [label="(3,)", style=solid];
"6 weight_node0" -> "8 dequantize_per_channel_default" [label="(3,)", style=solid];
"7 quantize_per_channel_default" -> "8 dequantize_per_channel_default" [label="(3, 3)", style=solid];
"8 dequantize_per_channel_default" -> "9 linear_1" [label="(3, 3)", style=solid];
"9 linear_1" -> "10 add" [label="(1, 3, 3, 3)", style=solid];
"10 add" -> "11 output_1" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ strict digraph {
"0 linear_act_weight" [id=0, type=get_attr];
"1 linear_act_bias" [id=1, type=get_attr];
"2 x" [id=2, type=input];
"3 _frozen_param0" [id=3, type=get_attr];
"4 quantize_per_tensor_default" [id=4, type=quantize_per_tensor];
"5 dequantize_per_tensor_default" [id=5, type=dequantize_per_tensor];
"6 linear_1" [id=6, type=linear];
"7 add" [id=7, type=add];
"8 output_1" [id=8, type=output];
"0 linear_act_weight" -> "4 quantize_per_tensor_default" [label="(3, 3)", style=solid];
"1 linear_act_bias" -> "6 linear_1" [label="(3,)", style=solid];
"2 x" -> "6 linear_1" [label="(1, 3, 3, 3)", style=solid];
"3 _frozen_param0" -> "7 add" [label="(3, 3)", style=solid];
"4 quantize_per_tensor_default" -> "5 dequantize_per_tensor_default" [label="(3, 3)", style=solid];
"5 dequantize_per_tensor_default" -> "6 linear_1" [label="(3, 3)", style=solid];
"6 linear_1" -> "7 add" [label="(1, 3, 3, 3)", style=solid];
"7 add" -> "8 output_1" [label="(1, 3, 3, 3)", style=solid];
"3 dummy_disconnected_input" [id=3, type=input];
"4 _frozen_param0" [id=4, type=get_attr];
"5 quantize_per_tensor_default" [id=5, type=quantize_per_tensor];
"6 dequantize_per_tensor_default" [id=6, type=dequantize_per_tensor];
"7 linear_1" [id=7, type=linear];
"8 add" [id=8, type=add];
"9 output_1" [id=9, type=output];
"0 linear_act_weight" -> "5 quantize_per_tensor_default" [label="(3, 3)", style=solid];
"1 linear_act_bias" -> "7 linear_1" [label="(3,)", style=solid];
"2 x" -> "7 linear_1" [label="(1, 3, 3, 3)", style=solid];
"4 _frozen_param0" -> "8 add" [label="(3, 3)", style=solid];
"5 quantize_per_tensor_default" -> "6 dequantize_per_tensor_default" [label="(3, 3)", style=solid];
"6 dequantize_per_tensor_default" -> "7 linear_1" [label="(3, 3)", style=solid];
"7 linear_1" -> "8 add" [label="(1, 3, 3, 3)", style=solid];
"8 add" -> "9 output_1" [label="(1, 3, 3, 3)", style=solid];
}
15 changes: 12 additions & 3 deletions tests/torch/fx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from pathlib import Path
from typing import Tuple, Union

import torch.fx
import torch.nn.parallel
Expand Down Expand Up @@ -122,7 +123,9 @@ def visualize_fx_model(model: torch.fx.GraphModule, output_svg_path: str):
g.get_dot_graph().write_svg(output_svg_path)


def get_torch_fx_model(model: torch.nn.Module, ex_input: torch.Tensor) -> torch.fx.GraphModule:
def get_torch_fx_model(
model: torch.nn.Module, ex_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
) -> torch.fx.GraphModule:
"""
Converts given module to GraphModule.
Expand All @@ -138,11 +141,17 @@ def get_torch_fx_model(model: torch.nn.Module, ex_input: torch.Tensor) -> torch.
else:
device = named_param[1].device

ex_input = ex_input.to(device)
if isinstance(ex_input, torch.Tensor):
ex_input = (ex_input,)
device_ex_input = []
for inp in ex_input:
device_ex_input.append(inp.to(device))
device_ex_input = tuple(device_ex_input)

model.eval()
with torch.no_grad():
with disable_patching():
return torch.export.export_for_training(model, args=(ex_input,)).module()
return torch.export.export_for_training(model, args=device_ex_input).module()


def get_torch_fx_model_q_transformed(model: torch.nn.Module, ex_input: torch.Tensor) -> torch.fx.GraphModule:
Expand Down
9 changes: 5 additions & 4 deletions tests/torch/fx/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,19 +533,20 @@ def test_compress_post_quantize_transformation(is_per_channel: bool):

def test_constant_folding():
model = ConstantFoldingTestModel()
captured_model = get_torch_fx_model(model, torch.ones(model.INPUT_SIZE))
ex_inputs = (torch.ones(model.INPUT_SIZE), torch.ones((1,)))
captured_model = get_torch_fx_model(model, ex_inputs)
folded_model = deepcopy(captured_model)
constant_fold(folded_model)
ex_input = torch.ones(model.INPUT_SIZE)
assert torch.allclose(captured_model(ex_input), folded_model(ex_input))
assert torch.allclose(captured_model(*ex_inputs), folded_model(*ex_inputs))

nncf_graph = GraphConverter.create_nncf_graph(folded_model)
check_graph(nncf_graph, "folded_model.dot", TRANSFORMED_GRAPH_DIR_NAME, extended=True)


def test_constant_folding_with_constraints(is_per_channel):
model = ConstantFoldingTestModel()
model_with_correct_pattern = get_torch_fx_model(model, torch.ones(model.INPUT_SIZE))
ex_inputs = (torch.ones(model.INPUT_SIZE), torch.ones((1,)))
model_with_correct_pattern = get_torch_fx_model(model, ex_inputs)

insert_qdq_nodes(
model_with_correct_pattern,
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def __init__(self):

self.param = nn.Parameter(4 * torch.ones((3, 3)))

def forward(self, x):
def forward(self, x, dummy_disconnected_input):
y = self.linear_w(self.param)
# Inplace relu to check
# that inplace operations are
Expand Down

0 comments on commit d8b8a60

Please sign in to comment.