From e773ff1f9353dc7d11c9ae84b251371a60b6955f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 27 Feb 2024 11:42:36 +0800 Subject: [PATCH] support torchscript backend --- frontend/bytecode_writter.py | 1 - frontend/fx_graph.py | 29 +++++++++++++++++++++++++++++ frontend/guard_tracker.py | 3 +-- frontend/tracer.py | 11 +++++++---- frontend/variables/__init__.py | 1 - test/test_model_deberta.py | 1 - 6 files changed, 37 insertions(+), 9 deletions(-) diff --git a/frontend/bytecode_writter.py b/frontend/bytecode_writter.py index 859d25058991..d29ae63ac882 100644 --- a/frontend/bytecode_writter.py +++ b/frontend/bytecode_writter.py @@ -234,7 +234,6 @@ def fix_constants(instructions: List[Instruction], const_list.append(entry) const_set.add(entry) inst.arg = const_list.index(entry) - print("const_list", const_list) code_options["co_consts"] = tuple((x[1] for x in const_list)) diff --git a/frontend/fx_graph.py b/frontend/fx_graph.py index dd2f4a7edefd..607e5d5b61b7 100644 --- a/frontend/fx_graph.py +++ b/frontend/fx_graph.py @@ -70,6 +70,35 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool: elif backend == 'xla': return torch._dynamo.backends.torchxla.aot_torchxla_trace_once( gm, example_inputs) + elif backend == 'script': + import os, importlib, re, random + random_number = str(random.randint(0, 1000000)) + os.makedirs('tmp/fx_module_' + random_number, exist_ok=True) + gm.to_folder('tmp/fx_module_' + random_number) + + module = importlib.import_module('tmp.fx_module_' + random_number) + model = module.FxModule().cuda().eval() + real_inputs = [] + for x in example_inputs: + if x.dtype == torch.float32: + real_inputs.append( + torch.rand(*x.shape, + dtype=x.dtype, + layout=x.layout, + device=x.device)) + elif x.dtype == torch.int64: + real_inputs.append( + torch.randint(0, + 2, + size=x.shape, + dtype=x.dtype, + layout=x.layout, + device=x.device)) + else: + raise NotImplementedError + with torch.no_grad(): + script_model = torch.jit.trace(model, real_inputs) + return script_model else: raise RuntimeError(f"Unknown backend: {backend}") diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 19cae2dae0ae..b4f0d8125769 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -143,10 +143,9 @@ def add_submodule(self, module: torch.nn.Module) -> None: # self.written = True # not mark as written as graph break may happen def add_subparam(self, param: torch.nn.Parameter) -> None: - new_param_name = "__external_param__" + str(len(self.subparam_paths)) + new_param_name = "external_param__" + str(len(self.subparam_paths)) self.root.register_parameter(new_param_name, param) self.subparam_paths[param] = new_param_name - # self.written = True # not mark as written as graph break may happen def as_node_args_kwargs( self, args: list[Any], kwargs: dict[str, Any] diff --git a/frontend/tracer.py b/frontend/tracer.py index e9ecac95572b..4ad43340db74 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -87,8 +87,9 @@ def preprocess_frame( frame_cache = get_frame_cache(frame_id) frame_cache.update_code(frame.f_code, frame_id, is_callee) new_code, code_map = frame_cache.get_new_code(is_callee) - print("bytecode to run:") - print(format_insts(code_map.guard_insts)) + if is_debug: + print("bytecode to run:") + print(format_insts(code_map.guard_insts)) trace_func = get_trace_func(frame_id) except Exception as e: @@ -102,12 +103,14 @@ def postprocess_frame(frame: FrameType, frame_id: int) -> None: from .bytecode_writter import SHOULD_NOT_CALL_REWRITE if SHOULD_NOT_CALL_REWRITE: raise ValueError("should not call postprocess") - print(f"postprocess frame {frame.f_code.co_filename}") + if is_debug: + print(f"postprocess frame {frame.f_code.co_filename}") set_frame_root(frame_id, f) frame_cache = get_frame_cache(frame_id) frame_cache.update_code(frame.f_code, frame_id, is_callee) except Exception as e: - print("exception in postprocess:", e, type(e)) + if is_debug: + print("exception in postprocess:", e, type(e)) print(traceback.format_exc()) raise e return diff --git a/frontend/variables/__init__.py b/frontend/variables/__init__.py index 06c41b104ca0..c09f79d5ac38 100644 --- a/frontend/variables/__init__.py +++ b/frontend/variables/__init__.py @@ -92,7 +92,6 @@ def make_var_from_value( fx_graph, extract_code_at_start) else: # NOTE: use any instead of iteartor_var to represent iterator with unknown source due to the hardness of getting iterable and num_iters - print("generate any for", value, type(value), extract_code_at_start) return AnyVar.from_value(value, need_guard_check, helper_functions, fx_graph, extract_code_at_start) diff --git a/test/test_model_deberta.py b/test/test_model_deberta.py index fb4394ec0645..a5b216ef73bb 100644 --- a/test/test_model_deberta.py +++ b/test/test_model_deberta.py @@ -1157,7 +1157,6 @@ def get_model(): config.num_hidden_layers = 2 config.return_dict = False model = DebertaModel(config).to(device) - print("model type", type(model)) return model