From fcfe0056f9cc4de5149aacc8e4d2d706ce99678e Mon Sep 17 00:00:00 2001 From: heheda Date: Fri, 22 Mar 2024 14:34:24 +0800 Subject: [PATCH 1/3] support branch to onnx --- .gitignore | 3 +- frontend/config.py | 1 + frontend/control_flow.py | 3 -- frontend/fx_graph.py | 72 +++++++++++++++++++++++---------------- frontend/guard_tracker.py | 15 ++++++-- test/test_model_lstm.py | 2 +- test/test_nnmodule.py | 25 ++++++++++++++ 7 files changed, 84 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 4b6d83ad904a..8d0750110b3d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ build __pycache__ *.so -test/simple.py \ No newline at end of file +test/simple.py +tmp \ No newline at end of file diff --git a/frontend/config.py b/frontend/config.py index 390ed060de8b..85894f6dc455 100644 --- a/frontend/config.py +++ b/frontend/config.py @@ -5,6 +5,7 @@ "debug": True, "miss_threshold": 3, "dynshape": False, + "model_name": "" } diff --git a/frontend/control_flow.py b/frontend/control_flow.py index 0087b5986767..82e5d3008565 100644 --- a/frontend/control_flow.py +++ b/frontend/control_flow.py @@ -296,6 +296,3 @@ def if_stmt(cond: bool, if_true: Callable[..., Any], break_at_callsite() recover() return if_run_branch() - - -torch.Tensor.__iter__ diff --git a/frontend/fx_graph.py b/frontend/fx_graph.py index 607e5d5b61b7..fcdbab2e07dc 100644 --- a/frontend/fx_graph.py +++ b/frontend/fx_graph.py @@ -34,6 +34,41 @@ NodeArgs = Union[BaseArgumentTypes, torch.fx.Node] +def fetch_attr(gm: torch.fx.GraphModule, target: str) -> Any: + target_atoms = target.split('.') + attr_itr = gm + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def generate_real_tensors( + fake_tensors: list[torch.Tensor]) -> list[torch.Tensor]: + real_tensors = [] + for x in fake_tensors: + if x.dtype == torch.float32: + real_tensors.append( + torch.rand(*x.shape, + dtype=x.dtype, + layout=x.layout, + device=x.device)) + elif x.dtype == torch.int64: + real_tensors.append( + torch.randint(0, + 2, + size=x.shape, + dtype=x.dtype, + layout=x.layout, + device=x.device)) + else: + raise NotImplementedError + return real_tensors + + def backend_compile(gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]) -> Any: backend = config.get_config('backend') @@ -43,17 +78,6 @@ def backend_compile(gm: torch.fx.GraphModule, return gm elif backend == 'inductor': - def fetch_attr(gm: torch.fx.GraphModule, target: str) -> Any: - target_atoms = target.split('.') - attr_itr = gm - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError( - f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" - ) - attr_itr = getattr(attr_itr, atom) - return attr_itr - def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool: if node.op == 'call_module': @@ -78,27 +102,15 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool: module = importlib.import_module('tmp.fx_module_' + random_number) model = module.FxModule().cuda().eval() - real_inputs = [] - for x in example_inputs: - if x.dtype == torch.float32: - real_inputs.append( - torch.rand(*x.shape, - dtype=x.dtype, - layout=x.layout, - device=x.device)) - elif x.dtype == torch.int64: - real_inputs.append( - torch.randint(0, - 2, - size=x.shape, - dtype=x.dtype, - layout=x.layout, - device=x.device)) - else: - raise NotImplementedError + real_inputs = generate_real_tensors(example_inputs) with torch.no_grad(): - script_model = torch.jit.trace(model, real_inputs) + script_model = torch.jit.script(model, real_inputs) return script_model + elif backend == 'nnf': + model_name = config.get_config('model_name') + from fx2onnx import compile_with_nnf # type: ignore[import] + real_inputs = generate_real_tensors(example_inputs) + return compile_with_nnf(model_name, gm, real_inputs) else: raise RuntimeError(f"Unknown backend: {backend}") diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index af9cd7653c50..d7e02f7c8109 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -144,7 +144,7 @@ def get_name(prefix: str, name: str) -> str: self.subparam_paths[param] = get_name(prefix, name) def add_submodule(self, module: torch.nn.Module) -> None: - new_module_name = "__external_module__" + str(len(self.submodule_paths)) + new_module_name = "external_module__" + str(len(self.submodule_paths)) self.root.add_module(new_module_name, module) self.update_subpath(module, new_module_name) # self.written = True # not mark as written as graph break may happen @@ -1238,6 +1238,15 @@ def commit(self) -> None: (name, x) for x, name in self.state.fx_graph.example_inputs ]) print("graph", self.state.fx_graph.result_graph) + from .control_flow import CondModule + for node in self.state.fx_graph.result_graph.nodes: + if node.op == 'call_module' and '.' not in node.target: + mod = getattr(self.state.root, node.target) + if isinstance(mod, CondModule): + print("CondModule:", node.target) + print("true_body:", mod.true_body.graph) + print("false_body:", mod.false_body.graph) + graph_code = graph_codegen.get_code() compiled_graph = self.state.fx_graph.compile() @@ -1760,7 +1769,9 @@ def set_if_inplace_return() -> None: inplace_ref=inplace_ref, force_new_value=(func in (float, int, min, max) or (hasattr(func, '__name__') and - func.__name__ == 'contiguous'))) + func.__name__ == 'contiguous') or + (isinstance(func, torch.nn.Module) and + hasattr(func, 'inplace') and func.inplace))) return elif self.all_scalar_arg(args, kwargs) and self.all_static_arg( args, kwargs): diff --git a/test/test_model_lstm.py b/test/test_model_lstm.py index 033e799a33b6..781d2f3d01b7 100644 --- a/test/test_model_lstm.py +++ b/test/test_model_lstm.py @@ -349,7 +349,7 @@ def test_lstm_loop(caplog): hidden_size, device='cuda') expect_result = model(inputs) - for_iter_pc = 193 + for_iter_pc = 32 mark_dynamic_pc(get_next_frame_id(), for_iter_pc, DynamicControlFlow(for_iter_pc, "FOR_ITER")) compiled = compile(model) diff --git a/test/test_nnmodule.py b/test/test_nnmodule.py index ce6716330137..92f8694a3025 100644 --- a/test/test_nnmodule.py +++ b/test/test_nnmodule.py @@ -127,6 +127,31 @@ def test_map_module(caplog): run_and_check(compiled, [HIT], 1, caplog, expect_result, x) +class InplaceRelu(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.bn = torch.nn.BatchNorm2d(3) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + 1.0 + + +def test_inplace_relu(caplog): + reset() + model = InplaceRelu().eval() + compiled = compile(model) + x = torch.randn(1, 3, 3, 3) + expect_result = model(x) + run_and_check(compiled, [MISS], 1, caplog, expect_result, x) + run_and_check(compiled, [HIT], 1, caplog, expect_result, x) + + if __name__ == "__main__": caplog = logging.getLogger(__name__) test_call_method(caplog) From fb2d7ba812a2df2b881472a6eb4829facb14e6b6 Mon Sep 17 00:00:00 2001 From: heheda Date: Sun, 24 Mar 2024 16:15:41 +0800 Subject: [PATCH 2/3] modify loop op --- frontend/control_flow.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/frontend/control_flow.py b/frontend/control_flow.py index 82e5d3008565..ad11fef7d967 100644 --- a/frontend/control_flow.py +++ b/frontend/control_flow.py @@ -40,8 +40,7 @@ def forward(self, *values: Any) -> Any: loop_carry = values[self.num_read_only_param:] while iter_num < self.num_iter: # and cond.item(): - loop_carry = self.body(torch.tensor(iter_num), *read_only, - *loop_carry) + loop_carry = self.body(iter_num, *read_only, *loop_carry) # cond, *loop_carry = self.body(iter_num, cond, *read_only, # *loop_carry) iter_num += 1 From a943b9015ec86a41ad39bfddb7081e0b6893a4ac Mon Sep 17 00:00:00 2001 From: heheda Date: Sat, 11 May 2024 22:58:10 +0800 Subject: [PATCH 3/3] remove a backend --- frontend/fx_graph.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/frontend/fx_graph.py b/frontend/fx_graph.py index fcdbab2e07dc..246eb3b730d4 100644 --- a/frontend/fx_graph.py +++ b/frontend/fx_graph.py @@ -106,11 +106,6 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool: with torch.no_grad(): script_model = torch.jit.script(model, real_inputs) return script_model - elif backend == 'nnf': - model_name = config.get_config('model_name') - from fx2onnx import compile_with_nnf # type: ignore[import] - real_inputs = generate_real_tensors(example_inputs) - return compile_with_nnf(model_name, gm, real_inputs) else: raise RuntimeError(f"Unknown backend: {backend}")