From b41e1cbb7af5f28d51ed67f1f604ffd26931f569 Mon Sep 17 00:00:00 2001 From: rongchaodong <16302010007@fudan.edu.cn> Date: Fri, 15 Mar 2024 18:59:23 +0800 Subject: [PATCH 1/9] fix str format and tensor.device as output --- frontend/guard_tracker.py | 6 ++++-- frontend/variables/dict_.py | 6 +++++- frontend/variables/tensor.py | 4 ++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index af9cd7653c50..e3d8c23439ef 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -286,7 +286,8 @@ def record_function(self, if func in (min, max): scalar = None node = None - assert len(pargs) == 2 + # NOTE: when pargs < 2, it should be a dynamic operation + assert len(pargs) <= 2 for i, obj in enumerate(pargs): if isinstance(obj, (int, float)) and not dyn.contains(obj): scalar = obj @@ -1548,7 +1549,8 @@ def is_genexpr_func(self, func: Callable[..., Any]) -> bool: def is_builtin_func(self, func: Callable[..., Any]) -> bool: return func in (dict, tuple, set, list, hasattr, slice, range, len, - type, all, str.join, reversed, zip, iter, id, next) + type, all, str.join, reversed, zip, iter, id, next, + collections.OrderedDict, str.format, any) def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool: print(dir(func)) diff --git a/frontend/variables/dict_.py b/frontend/variables/dict_.py index 2c0b03c0ebaa..81d0c9ae81a8 100644 --- a/frontend/variables/dict_.py +++ b/frontend/variables/dict_.py @@ -76,7 +76,11 @@ def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, items = [] for key, j in zip(self.value.keys(), range(len(self.vars))): if isinstance(key, str): - key_part = f"'{key}'" + if "\n" not in key: + key_part = f"'{key}'" + else: + key_part = f"'{repr(key)}'" + key_part = key_part.strip("'") else: key_part = key item = f'{key_part}: {name_in_graph_fn}_{j}' diff --git a/frontend/variables/tensor.py b/frontend/variables/tensor.py index 4915636d3a1c..72c8da6253d8 100644 --- a/frontend/variables/tensor.py +++ b/frontend/variables/tensor.py @@ -238,8 +238,8 @@ def make_guard_inner(self, codegen: "GuardFnCodegen", def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, codegen: "GraphFnCodegen", in_return: bool, idx: int) -> None: - codegen.output(name_in_graph_fn, store_pos, f"{self.device}", in_return, - idx) + codegen.output(name_in_graph_fn, store_pos, f"'{self.device}'", + in_return, idx) def as_fx_node(self) -> "NodeArgs": return self.device From 50c4ab9759f25c06bf9e39ff60b048032b154944 Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Tue, 19 Mar 2024 14:24:34 +0800 Subject: [PATCH 2/9] nonzero dynamic shape & generic_jump TensorVariable() & tensor item --- frontend/guard_tracker.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 58252d177d6c..bca91499ff05 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -1873,6 +1873,13 @@ def gen_by_caller(self, caller = caller.caller return False + def generic_jump_check(self) -> None: + top_value = get_value_stack_from_top(self.frame, 0) + if torch.is_tensor(top_value): + raise ValueError("generic_jump TensorVariable() by tensor") + if dyn.contains(top_value): + raise ValueError("generic_jump TensorVariable() by dyn scalar") + def binary_operation(self, func: Callable[..., Any]) -> None: obj1 = get_value_stack_from_top(self.frame, 1) obj2 = get_value_stack_from_top(self.frame, 0) @@ -1920,6 +1927,11 @@ def BINARY_OR(self, _inst: Instruction) -> None: def BINARY_SUBSCR(self, inst: Instruction) -> None: obj1 = get_value_stack_from_top(self.frame, 1) obj2 = get_value_stack_from_top(self.frame, 0) + if torch.is_tensor(obj1): + if torch.is_tensor(obj2): + raise ValueError("dynamic shape in tensor") + if dyn.contains(obj2): + raise ValueError("dynamic shape in dyn scalar") self.call_function(operator.getitem, [obj1, obj2], {}) def unary_operation(self, func: Callable[..., Any]) -> None: @@ -2453,16 +2465,16 @@ def DUP_TOP_TWO(self, _inst: Instruction) -> None: pass def POP_JUMP_IF_FALSE(self, _inst: Instruction) -> None: - pass + self.generic_jump_check() def POP_JUMP_IF_TRUE(self, _inst: Instruction) -> None: - pass + self.generic_jump_check() def JUMP_IF_TRUE_OR_POP(self, _inst: Instruction) -> None: - pass + self.generic_jump_check() def JUMP_IF_FALSE_OR_POP(self, _inst: Instruction) -> None: - pass + self.generic_jump_check() def JUMP_FORWARD(self, inst: Instruction) -> None: pass From a8ed0caea407b558f8e122d6cef221171e37b016 Mon Sep 17 00:00:00 2001 From: rongchaodong <16302010007@fudan.edu.cn> Date: Tue, 19 Mar 2024 15:33:04 +0800 Subject: [PATCH 3/9] fix str and ndarray size --- frontend/guard_tracker.py | 3 ++- frontend/utils.py | 4 ++++ frontend/variables/__init__.py | 4 ++++ frontend/variables/list_.py | 2 +- 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index e3d8c23439ef..b7c5b4666ead 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -1550,7 +1550,8 @@ def is_genexpr_func(self, func: Callable[..., Any]) -> bool: def is_builtin_func(self, func: Callable[..., Any]) -> bool: return func in (dict, tuple, set, list, hasattr, slice, range, len, type, all, str.join, reversed, zip, iter, id, next, - collections.OrderedDict, str.format, any) + collections.OrderedDict, str.format, any, str, + str.split) def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool: print(dir(func)) diff --git a/frontend/utils.py b/frontend/utils.py index 476bad81cac4..6ed1d6124874 100644 --- a/frontend/utils.py +++ b/frontend/utils.py @@ -156,6 +156,10 @@ def get_root_module(func: Callable[..., Any]) -> str: if module is None or 'torch.distributions' in module_str: return "" root_module = module_str.split('.')[0] + #NOTE: special cases in torchvision module, need to check whether this module is safe to record in graph + if hasattr(func, '__name__') and func.__name__ in ( + 'pad', 'resize') and root_module == 'torchvision': + return 'torch' return root_module diff --git a/frontend/variables/__init__.py b/frontend/variables/__init__.py index d0f48804c463..03fbecdc4d86 100644 --- a/frontend/variables/__init__.py +++ b/frontend/variables/__init__.py @@ -54,6 +54,10 @@ def make_var_from_value( extract_code_at_start: Optional[list[StorePos]] = None) -> Variable: if extract_code_at_start is None: extract_code_at_start = [] + if type(value) == np.ndarray and value.size == 1: + return NumpyScalarVar.from_value(np.int64(value.tolist()), + need_guard_check, helper_functions, + fx_graph, extract_code_at_start) if type(value) in ty2var: return ty2var[type(value)].from_value(value, need_guard_check, helper_functions, fx_graph, diff --git a/frontend/variables/list_.py b/frontend/variables/list_.py index 9e7cc021881e..33907340843b 100644 --- a/frontend/variables/list_.py +++ b/frontend/variables/list_.py @@ -114,7 +114,7 @@ def __init__(self, value: np.ndarray[Any, Any], need_guard_check: bool, extract_code_at_start: list[StorePos]) -> None: super().__init__(need_guard_check, value, extract_code_at_start) self.value = value - self.length = len(value) + self.length = value.size self.vars = [] self.obj_ids = [] for i, obj in enumerate(value): From c4ff9a73206f23f85d31c7f0a2ed180100151e36 Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Tue, 19 Mar 2024 23:11:20 +0800 Subject: [PATCH 4/9] dynamic shape handle --- frontend/guard_tracker.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 11586b4edc73..2b35af67a27e 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -1744,7 +1744,7 @@ def set_if_inplace_return() -> None: "flatten_parameters", "numel", "children", "named_parameters", "_weights_have_changed", "check_forward_args", "permute_hidden", "_check_input_dim", - "parameters"): + "parameters", "_has_torch_function_unary"): return if hasattr(func, "__module__" ) and func.__module__ == 'torch.autograd.profiler': @@ -1945,10 +1945,8 @@ def BINARY_SUBSCR(self, inst: Instruction) -> None: obj1 = get_value_stack_from_top(self.frame, 1) obj2 = get_value_stack_from_top(self.frame, 0) if torch.is_tensor(obj1): - if torch.is_tensor(obj2): + if torch.is_tensor(obj2) and obj2.dtype == torch.bool: raise ValueError("dynamic shape in tensor") - if dyn.contains(obj2): - raise ValueError("dynamic shape in dyn scalar") self.call_function(operator.getitem, [obj1, obj2], {}) def unary_operation(self, func: Callable[..., Any]) -> None: From 5c46c5a60f8de4bb3dd99143aaadfa8e9d36deda Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Thu, 21 Mar 2024 16:44:03 +0800 Subject: [PATCH 5/9] remove some functions from graph --- frontend/guard_tracker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index ebe1fdc7663e..773d569407ae 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -1747,7 +1747,9 @@ def set_if_inplace_return() -> None: "flatten_parameters", "numel", "children", "named_parameters", "_weights_have_changed", "check_forward_args", "permute_hidden", "_check_input_dim", - "parameters", "_has_torch_function_unary"): + "parameters", "_has_torch_function_unary", "_is_tracing", + "is_tracing", "is_scripting", "get_autocast_gpu_dtype", + "is_autocast_enabled", "ndimension"): return if hasattr(func, "__module__" ) and func.__module__ == 'torch.autograd.profiler': From 051d32f40d9f3ac84c7a73473e906e6c35016d6a Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Thu, 28 Mar 2024 21:30:36 +0800 Subject: [PATCH 6/9] throw error for rest of nobugs and few new cases --- frontend/guard_tracker.py | 31 +++++++++++++++++++++++++++++-- frontend/pycode_generator.py | 2 ++ frontend/variables/tensor.py | 8 ++++++++ test/test_model_bart.py | 5 ++++- 4 files changed, 43 insertions(+), 3 deletions(-) diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 773d569407ae..c2afb165af13 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -311,7 +311,9 @@ def record_function(self, func = math2torch[func] if func == torch.from_numpy: func = torch.tensor - + if hasattr(func, '__name__') and func.__name__ == 'numpy': + if torch.is_tensor(args[0]) or dyn.contains(args[0]): + raise ValueError("numpy can't have dynamic args") self.written = True scalar2tensor: dict[Callable[..., Any], Callable[..., Any]] = { float: torch.Tensor.float, @@ -350,6 +352,9 @@ def record_function(self, func = torch.Tensor.new_empty elif func == torch.Tensor.item: assert args[0].numel() == 1 + if args[0].dtype == torch.bool: + raise ValueError( + "The .item() method was applied to a boolean tensor.") func = torch.Tensor.clone fx_node = self.fx_graph.create_node("call_method", func.__name__, @@ -840,6 +845,7 @@ class GuardTracker: caller: Optional['GuardTracker'] cf_info: Optional[ControlFlowInfo] num_breaks: int + layout_sensitive: bool def __init__(self, frame: FrameType, @@ -877,6 +883,7 @@ def __init__(self, read_stack=read_stack, frame_cf_info=cf_info ) # stack pointer is not initialized at the creation of a stack frame self.num_breaks = 0 + self.layout_sensitive = False def init_state(self, read_stack: bool = True, @@ -905,6 +912,9 @@ def record( restart_caller=False) if self.code.get_inst(self.frame.f_lasti).opname == 'RETURN_VALUE': if trackers[-1] == self: + if self.layout_sensitive == True: + if self.caller is not None: + self.caller.layout_sensitive = True pop_tracker(self.frame_id) set_eval_frame(None) return @@ -957,6 +967,8 @@ def record( def commit_loop_subgraph(self) -> None: key = new_random_key() guard_codegen = GuardFnCodegen(key=key) + if self.layout_sensitive == True: + guard_codegen.layout_sensitive = True for var in self.state.objects.get_all(): while var.prev is not None: var = var.prev @@ -1177,6 +1189,8 @@ def commit(self) -> None: if self.state.can_guard: key = new_random_key() guard_codegen = GuardFnCodegen(key=key) + if self.layout_sensitive == True: + guard_codegen.layout_sensitive = True for var in self.state.objects.get_all(): while var.prev is not None: var = var.prev @@ -1609,11 +1623,22 @@ def call_function( self.state.fx_graph, [pos]) self.state.add_object(var, obj) return + if hasattr(func, + '__name__') and func.__name__ == 'format' and isinstance( + func, type(str.format)): + for arg in args: + if torch.is_tensor(arg) or dyn.contains(arg): + raise ValueError("format can't have dynamic args") + if hasattr(func, '__name__') and (func.__name__ == 'is_contiguous' or + func.__name__ == 'stride'): + self.layout_sensitive = True if hasattr(func, '__name__') and func.__name__ == '__init__': return # a series of classes and functions defined by warnings if get_root_module(func) in ('_warnings', 'warnings'): return + if get_root_module(func) == 'random': + raise ValueError("random scalar") is_high_order_udf = is_high_order_func_with_udf(func, args, kwargs) if is_user_defined_func(func) or isinstance( func, nn.Sequential) or is_high_order_udf: @@ -1749,7 +1774,9 @@ def set_if_inplace_return() -> None: "check_forward_args", "permute_hidden", "_check_input_dim", "parameters", "_has_torch_function_unary", "_is_tracing", "is_tracing", "is_scripting", "get_autocast_gpu_dtype", - "is_autocast_enabled", "ndimension"): + "is_autocast_enabled", "ndimension", "get_enum", + "is_tensor", "is_complex", "is_contiguous", "stride", + "get_device"): return if hasattr(func, "__module__" ) and func.__module__ == 'torch.autograd.profiler': diff --git a/frontend/pycode_generator.py b/frontend/pycode_generator.py index e2ab53ba7ea1..961ba6fee5d3 100644 --- a/frontend/pycode_generator.py +++ b/frontend/pycode_generator.py @@ -141,12 +141,14 @@ class GuardFnCodegen(FnCodegen): checks: set[tuple[str, StorePos]] imports: set[str] object_refs: list[Any] # the reference to objects for id check + layout_sensitive: bool def __init__(self, key: int) -> None: super().__init__(key) self.checks = set() self.imports = set() self.object_refs = [] + self.layout_sensitive = False def add_check(self, check: tuple[str, StorePos]) -> None: self.checks.add(check) diff --git a/frontend/variables/tensor.py b/frontend/variables/tensor.py index 72c8da6253d8..4a149e9f7527 100644 --- a/frontend/variables/tensor.py +++ b/frontend/variables/tensor.py @@ -115,6 +115,10 @@ def tensor_guard_check(self, value: torch.Tensor) -> bool: # hasattr(value, 'stride') and self.stride == value.stride() and \ # hasattr(value, 'is_contiguous') and self.is_contiguous == value.is_contiguous() + def tensor_strict_guard_check(self, value: torch.Tensor) -> bool: + return hasattr(value, 'stride') and self.stride == value.stride() and \ + hasattr(value, 'is_contiguous') and self.is_contiguous == value.is_contiguous() + def make_guard_inner(self, codegen: "GuardFnCodegen", pos: StorePos) -> None: name_in_codegen = codegen.add_obj(self) @@ -124,6 +128,10 @@ def make_guard_inner(self, codegen: "GuardFnCodegen", else: codegen.add_check( (f"{name_in_codegen}.tensor_guard_check({pos})", pos)) + if codegen.layout_sensitive == True: + codegen.add_check( + (f"{name_in_codegen}.tensor_strict_guard_check({pos})", + pos)) def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, codegen: "GraphFnCodegen", in_return: bool, diff --git a/test/test_model_bart.py b/test/test_model_bart.py index 6a8754819edf..d096eee2524d 100644 --- a/test/test_model_bart.py +++ b/test/test_model_bart.py @@ -1,3 +1,4 @@ +import os import pytest from frontend.compile import compile, reset from common.checker import assert_equal, run_and_check_cache, run_and_check, HIT, MISS, ALL_MISS @@ -1391,7 +1392,9 @@ def get_input(batch_size): return (input_ids, attention_mask), {} -@pytest.mark.model +# @pytest.mark.model +@pytest.mark.skipif(os.getenv('FORCE_RUN_SKIPPED_TEST') != '1', + reason="can't pass due to the handling of module random") def test_model_bart(caplog): reset() with torch.no_grad(): From 904977bb8bf2f4fad66f4644e154329920cc5f21 Mon Sep 17 00:00:00 2001 From: rongchaodong <16302010007@fudan.edu.cn> Date: Thu, 18 Apr 2024 15:23:07 +0800 Subject: [PATCH 7/9] enable fallback from exception during compiling process --- frontend/c_api.pyi | 6 ++++++ frontend/compile.py | 2 ++ frontend/config.py | 1 + frontend/csrc/frame_evaluation.cpp | 16 +++++++++++++--- frontend/guard_tracker.py | 13 +++++-------- frontend/tracer.py | 29 +++++++++++++++++++++++++---- 6 files changed, 52 insertions(+), 15 deletions(-) diff --git a/frontend/c_api.pyi b/frontend/c_api.pyi index f23ab8c63ab1..e2c928c7847c 100644 --- a/frontend/c_api.pyi +++ b/frontend/c_api.pyi @@ -11,6 +11,12 @@ def set_eval_frame( pass +def set_fallback( + new_callback: Optional[Tuple[Callable[..., Any], Callable[..., Any]]] +) -> Optional[Tuple[Callable[..., Any], Callable[..., Any]]]: + pass + + def set_skip_files(skip_file: set[str], end_file: set[str]) -> None: pass diff --git a/frontend/compile.py b/frontend/compile.py index 6b7e38a6f365..3ce024e75a29 100644 --- a/frontend/compile.py +++ b/frontend/compile.py @@ -93,3 +93,5 @@ def reset() -> None: fx_graph.reset() from . import dynamic dynamic.reset() + from . import tracer + tracer.reset() diff --git a/frontend/config.py b/frontend/config.py index 390ed060de8b..fb259a03ed9d 100644 --- a/frontend/config.py +++ b/frontend/config.py @@ -5,6 +5,7 @@ "debug": True, "miss_threshold": 3, "dynshape": False, + "enable_fallback": False, } diff --git a/frontend/csrc/frame_evaluation.cpp b/frontend/csrc/frame_evaluation.cpp index 9d3d7e93e3e5..268bdd1c8d4e 100644 --- a/frontend/csrc/frame_evaluation.cpp +++ b/frontend/csrc/frame_evaluation.cpp @@ -242,9 +242,11 @@ inline static void enable_eval_frame_shim(PyThreadState *tstate) { inline static void enable_eval_frame_default(PyThreadState *tstate) { if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) != previous_eval_frame) { - _PyInterpreterState_SetEvalFrameFunc(tstate->interp, - previous_eval_frame); - previous_eval_frame = NULL; + if (previous_eval_frame != NULL) { + _PyInterpreterState_SetEvalFrameFunc(tstate->interp, + previous_eval_frame); + previous_eval_frame = NULL; + } } } @@ -290,6 +292,13 @@ static PyObject *set_eval_frame(PyObject *self, PyObject *args) { return old_callback; } +static PyObject *set_fallback(PyObject *self, PyObject *args) { + PyThreadState *tstate = PyThreadState_GET(); + fprintf(stderr, "Falling back\n"); + decrese_working_threads(tstate); + Py_RETURN_NONE; +} + // TODO: in a more elegant way static PyObject *set_skip_files(PyObject *self, PyObject *args) { if (skip_files != Py_None) { @@ -659,6 +668,7 @@ static PyObject *mark_need_postprocess(PyObject *self, PyObject *args) { static PyMethodDef _methods[] = { {"set_eval_frame", set_eval_frame, METH_VARARGS, NULL}, + {"set_fallback", set_fallback, METH_VARARGS, NULL}, {"set_skip_files", set_skip_files, METH_VARARGS, NULL}, {"set_null_object", set_null_object, METH_VARARGS, NULL}, {"set_miss_threshold", set_miss_threshold, METH_VARARGS, NULL}, diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index b7c5b4666ead..85a21caff799 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -31,6 +31,7 @@ from .variables.const import ClsByNamedTupleVar from .variables.base import Variable from .control_flow import ControlFlowInfo, LoopModule, ForLoopInfo, LoopPosMap, if_stmt, IfStmtInfo +from .config import get_config MAKE_VAR_FN_TYPE = Callable[[ Any, bool, vs.HelperFunctions, Optional[FxGraph], Optional[list[StorePos]] @@ -1551,7 +1552,7 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool: return func in (dict, tuple, set, list, hasattr, slice, range, len, type, all, str.join, reversed, zip, iter, id, next, collections.OrderedDict, str.format, any, str, - str.split) + str.split, sorted) def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool: print(dir(func)) @@ -2415,11 +2416,6 @@ def UNPACK_SEQUENCE(self, inst: Instruction) -> None: # ] # }) # pass - print("check data", seq, type(seq)) - if self.state.objects.contains(seq): - print("jjjjjj") - for i in seq: - print(i) raise NotImplementedError def UNPACK_EX(self, inst: Instruction) -> None: @@ -2693,8 +2689,9 @@ def pop_tracker(frame_id: int) -> None: print("before pop_tracker", [t.frame_id for t in trackers], "frame_id", frame_id) to_pop = trackers.pop() - assert to_pop.frame_id == frame_id - assert to_pop.state.is_empty + if not get_config("enable_fallback"): + assert to_pop.frame_id == frame_id + assert to_pop.state.is_empty def record(frame: FrameType, frame_id: int) -> None: diff --git a/frontend/tracer.py b/frontend/tracer.py index 4ad43340db74..e552f6d13551 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -4,18 +4,24 @@ from types import FrameType, CodeType from typing import Any, Callable, Tuple import inspect -from .guard_tracker import push_tracker, pop_tracker, record +from .guard_tracker import push_tracker, pop_tracker, record, trackers from .cache import enable_cache, check_cache_updated, get_frame_cache from .fx_graph import set_frame_root -from .c_api import set_eval_frame, mark_need_postprocess +from .c_api import set_eval_frame, mark_need_postprocess, set_fallback from .code import ProcessedCode from .instruction import format_insts from .config import get_config +run_trace_func: bool = True +fall_back_frames: list[int] = [] + def get_trace_func(frame_id: int) -> Callable[[FrameType, str, Any], None]: def trace_func(frame: FrameType, event: str, arg: Any) -> None: + global run_trace_func + if not run_trace_func and frame_id in fall_back_frames: + return None try: if event == "opcode": opcode = frame.f_code.co_code[frame.f_lasti] @@ -33,7 +39,17 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> None: except Exception as e: print("exception in trace_func:", e, type(e)) print(traceback.format_exc()) - raise e + if get_config("enable_fallback"): + run_trace_func = False + for i in trackers: + fall_back_frames.append(i.frame_id) + # if len(trackers) > 1: + # disable_trace(frame_id) + print("fallback frames", fall_back_frames) + set_fallback(None) + return None + else: + raise e return None return trace_func @@ -115,4 +131,9 @@ def postprocess_frame(frame: FrameType, frame_id: int) -> None: raise e return - return (preprocess_frame, postprocess_frame) \ No newline at end of file + return (preprocess_frame, postprocess_frame) + + +def reset() -> None: + run_trace_func = True + fall_back_frames.clear() From f083467719d888b63265f52882aad7ee8c2a6370 Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Tue, 23 Apr 2024 15:09:21 +0800 Subject: [PATCH 8/9] add random control --- frontend/guard_tracker.py | 37 ++++++++++++++++++++++++++++++++++--- frontend/store_pos.py | 19 ++++++++++++++++++- test/test_model_bart.py | 4 +--- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index c2afb165af13..842bf8c96ae2 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -19,7 +19,7 @@ from .c_api import get_value_stack_from_top, get_value_stack_size, set_eval_frame, stack_effect, get_code_map, is_bound_method, get_from_freevars, set_value_stack_from_top, parse_cell, set_local from .instruction import Instruction, ci from .cache import CachedGraph, get_frame_cache -from .store_pos import StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInIndex, ExtractFromMethod, StoreInBuiltin, ExtractFromFunction, IterValue, StoreInFreeVar, ExtractFromNew, UnknownPosInCaller +from .store_pos import StoreConstant, StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInIndex, ExtractFromMethod, StoreInBuiltin, ExtractFromFunction, IterValue, StoreInFreeVar, ExtractFromNew, UnknownPosInCaller from . import variables as vs from . import dynamic as dyn from .utils import is_scalar, new_random_key, has_force_graph_break, NullObject, is_call_bytecode, fx_graph_functions, fx_graph_inplace_functions, is_user_defined_func, UnknownTypeError, get_all_objects_in_stack, is_graph_func, get_root_module, torch_inplace_funcs, print_bytecode, get_method_defined_class, is_math_func, is_high_order_func_with_udf, is_high_order_func, math2torch @@ -471,6 +471,8 @@ def store_pos_in_caller(self, pos: StorePos, raise ValueError("cannot store in stack in callee") elif isinstance(pos, (StoreInGlobal, StoreInBuiltin, StoreInFreeVar)): return pos + elif isinstance(pos, StoreConstant): + return pos elif isinstance(pos, StoreInAttr): # print("in callee", pos, self.frame_id) parent_pos = self.store_pos_in_caller(pos.self_pos, pos.self_id) @@ -492,7 +494,12 @@ def store_pos_in_caller(self, pos: StorePos, for p, i in zip(pos.var_pos, pos.var_id): new_pos = self.store_pos_in_caller(p, i) if new_pos is None: - return None + if isinstance( + p, + StoreConstant): # allow constant function parameter + new_pos = p + else: + return None parent_poses.append(new_pos) return ExtractFromFunction(parent_poses, pos.var_id, pos.func_name, pos.func_obj, pos.need_add_to_fn) @@ -1638,7 +1645,31 @@ def call_function( if get_root_module(func) in ('_warnings', 'warnings'): return if get_root_module(func) == 'random': - raise ValueError("random scalar") + for arg in args: + if torch.is_tensor(arg) or dyn.contains(arg): + raise ValueError("random func can't have dynamic args") + if func.__name__ not in { + 'random', 'randint', 'randrange', 'uniform' + }: + raise ValueError("Not implement random func") + + name = new_name('random') + fx_node = self.state.fx_graph.create_input(torch.tensor([0]), name, + (), {}, name) + self.state.set_partial_var({ + -1: [ + PartialVar( + node=fx_node, + need_guard_check=False, + extract_code_at_start=[ + ExtractFromFunction( + [StoreConstant(arg, id(arg)) for arg in args], + [id(arg) for arg in args], func.__name__, func, + True) + ]) + ] + }) + return is_high_order_udf = is_high_order_func_with_udf(func, args, kwargs) if is_user_defined_func(func) or isinstance( func, nn.Sequential) or is_high_order_udf: diff --git a/frontend/store_pos.py b/frontend/store_pos.py index 6d46caa1d655..1c15ba2d1ce1 100644 --- a/frontend/store_pos.py +++ b/frontend/store_pos.py @@ -1,6 +1,8 @@ -from typing import Any, Optional, TYPE_CHECKING, Callable +from typing import Any, Optional, TYPE_CHECKING, Callable, Union from types import FrameType +from torch import Tensor + from .c_api import get_value_stack_from_top if TYPE_CHECKING: from .pycode_generator import FnCodegen @@ -41,6 +43,21 @@ def get_value_from_frame(self, frame: FrameType) -> Any: return frame.f_locals[self.name] +class StoreConstant(StorePos): + value: Union[int, float] + self_id: int + + def __init__(self, value: Union[int, float], self_id: int) -> None: + self.value = value + self.self_id = self_id + + def __repr__(self) -> str: + return str(self.value) + + def get_value_from_frame(self, frame: FrameType) -> Any: + return self.value + + class StoreInGlobal(StorePos): name: str diff --git a/test/test_model_bart.py b/test/test_model_bart.py index d096eee2524d..3fc4163e6c54 100644 --- a/test/test_model_bart.py +++ b/test/test_model_bart.py @@ -1392,9 +1392,7 @@ def get_input(batch_size): return (input_ids, attention_mask), {} -# @pytest.mark.model -@pytest.mark.skipif(os.getenv('FORCE_RUN_SKIPPED_TEST') != '1', - reason="can't pass due to the handling of module random") +@pytest.mark.model def test_model_bart(caplog): reset() with torch.no_grad(): From 88074bcb51d2bd6d66311d7b13e627a30c03285f Mon Sep 17 00:00:00 2001 From: rongchaodong <16302010007@fudan.edu.cn> Date: Wed, 8 May 2024 17:17:57 +0800 Subject: [PATCH 9/9] fix torch.Parameter vals if no pos --- frontend/guard_tracker.py | 12 ++++++++++++ frontend/variables/base.py | 5 ++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index b11856727d54..64e4d4ea5ec6 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -1896,6 +1896,18 @@ def set_if_inplace_return() -> None: # TODO: add map and set correct partial var return elif is_graph_func(func): + if func is operator.getitem: + obj_var = self.state.objects.get(args[0]) + assert obj_var.extract_code_at_start[0] + obj_pos = obj_var.extract_code_at_start[0] + item_pos = StoreInIndex(obj_pos, id(obj_pos), args[1]) + self.state.set_partial_var({ + -1: [ + PartialVar(node=None, + need_guard_check=False, + extract_code_at_start=[item_pos]) + ] + }) return elif len(args) > 0 and isinstance(args[0], torch.nn.ModuleList): return diff --git a/frontend/variables/base.py b/frontend/variables/base.py index 2278e0d3cc86..e4e5ba1b6beb 100644 --- a/frontend/variables/base.py +++ b/frontend/variables/base.py @@ -2,7 +2,7 @@ from abc import abstractmethod from typing import Any, TYPE_CHECKING, Optional, Tuple, Iterable, Callable from copy import copy - +import torch from frontend.utils import add_force_graph_break from ..c_api import get_miss_locals @@ -90,6 +90,9 @@ def make_output(self, name_in_graph_fn: str, store_pos: StorePos, self.make_output_inner(name_in_graph_fn, store_pos, codegen, in_return, idx) for attr, var in self.modified_attrs.items(): + if isinstance(var.obj, torch.nn.Parameter) and len( + var.extract_code_at_start) == 0: + continue var.make_output(f'{name_in_graph_fn}_dot_{attr}', StoreInAttr(store_pos, id(self.obj), attr), codegen, False, id(getattr(self.obj, attr)))