Skip to content

Commit

Permalink
with code check
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 committed Dec 6, 2023
1 parent 164500e commit 6c24c8d
Show file tree
Hide file tree
Showing 13 changed files with 128 additions and 140 deletions.
106 changes: 48 additions & 58 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def get_name(prefix: str, name: str) -> str:
self.subparam_paths[param] = get_name(prefix, name)

def add_submodule(self, module: torch.nn.Module) -> None:
new_module_name = "__external_module__" + str(len(
self.submodule_paths))
new_module_name = "__external_module__" + str(len(self.submodule_paths))
self.root.add_module(new_module_name, module)
self.update_subpath(module, new_module_name)
# self.written = True # not mark as written as graph break may happen
Expand Down Expand Up @@ -274,8 +273,8 @@ def record_function(self,
for k, v in self.submodule_paths.items():
print(id(k), v, k)
raise NotImplementedError(func)
elif (hasattr(func, '__self__')
and isinstance(func.__self__, torch.Tensor)) or (
elif (hasattr(func, '__self__') and
isinstance(func.__self__, torch.Tensor)) or (
hasattr(func, '__objclass__') and func.__objclass__
== torch._C._TensorBase) or func in scalar2tensor:
if func in scalar2tensor:
Expand Down Expand Up @@ -493,7 +492,8 @@ def get_original_node(node: torch.fx.Node) -> torch.fx.Node:
new.append(new_pos)
if len(new) == 0:
# the inputs of callee come from generated outputs in caller, should not add to graph as input
new.append(UnknownPosInCaller)
new_pos = UnknownPosInCaller()
new.append(new_pos)
new_var = vs.make_var_from_value(
var.obj, var.need_guard_check,
self.objects.helper_functions, self.fx_graph, new)
Expand Down Expand Up @@ -531,8 +531,7 @@ def get_original_node(node: torch.fx.Node) -> torch.fx.Node:
break
if module_obj is None:
raise ValueError(
f"cannot find module {node.target} in {state.root}"
)
f"cannot find module {node.target} in {state.root}")
if module_obj not in self.submodule_paths:
self.add_submodule(module_obj)
name_in_caller = self.submodule_paths[module_obj]
Expand Down Expand Up @@ -761,8 +760,7 @@ def record(
self.frame.f_lasti).opname in ('SETUP_WITH', 'FOR_ITER',
'JUMP_IF_TRUE_OR_POP',
'JUMP_IF_FALSE_OR_POP',
'SETUP_FINALLY',
'RAISE_VARARGS',
'SETUP_FINALLY', 'RAISE_VARARGS',
'SETUP_ASYNC_WITH'):
self.state.num_new_refs = -1
else:
Expand Down Expand Up @@ -932,8 +930,9 @@ def rewrite_loop_graph(self) -> None:
elif node.op == "output":
fx_graph.result_graph.inserting_before(node)
input_args = [
input_nodes[p] for p, _ in itertools.chain(pos_map.input_only_pos,
pos_map.joint_pos)
input_nodes[p]
for p, _ in itertools.chain(pos_map.input_only_pos,
pos_map.joint_pos)
if p != iter_value_str
]
output_args = []
Expand Down Expand Up @@ -1076,10 +1075,9 @@ def commit(self) -> None:
print("RUNNING PY CODE")
print(py_code)
exec(py_code, self.frame.f_globals, out)
guard_fn = out["___make_guard_fn"](
*guard_codegen.objs.values())
graph_fn = out["___make_graph_fn"](
compiled_graph, *graph_codegen.objs.values())
guard_fn = out["___make_guard_fn"](*guard_codegen.objs.values())
graph_fn = out["___make_graph_fn"](compiled_graph,
*graph_codegen.objs.values())

print("guard_fn:", guard_fn)
print("pc:", self.state.start_pc, end_pc)
Expand Down Expand Up @@ -1222,13 +1220,12 @@ def process_last_inst(self) -> None:
dyn.ScalarWithUnknownValue())
else:
print("tuple inner unknown node", sub_value,
type(sub_value))
type(sub_value))
raise NotImplementedError(type(sub_value))
elif inspect.isclass(type(value)):
pass
else:
print("partial node with unknown value", value,
type(value))
print("partial node with unknown value", value, type(value))
raise NotImplementedError
var = make_var_fn(value, partial.need_guard_check,
self.state.objects.helper_functions,
Expand Down Expand Up @@ -1260,7 +1257,7 @@ def process_last_inst(self) -> None:
self.state.callee_returns)
new_node = self.state.fx_graph.create_node(
"call_function", torch.clone,
(returns_var.as_fx_node(), ), {})
(returns_var.as_fx_node(),), {})
stack_top_var = vs.TensorVar.from_tensor_and_node(
stack_top, new_node, False, [])
self.state.add_object(stack_top_var, stack_top)
Expand Down Expand Up @@ -1293,13 +1290,12 @@ def has_tensor_arg(cls, args: List[Any], kwargs: Dict[str, Any]) -> bool:

@classmethod
def all_scalar_arg(cls, args: List[Any], kwargs: Dict[str, Any]) -> bool:
return all(
is_scalar(i) for i in itertools.chain(args, kwargs.values()))
return all(is_scalar(i) for i in itertools.chain(args, kwargs.values()))

@classmethod
def all_static_arg(cls, args: List[Any], kwargs: Dict[str, Any]) -> bool:
return all(not dyn.contains(i)
for i in itertools.chain(args, kwargs.values()))
return all(
not dyn.contains(i) for i in itertools.chain(args, kwargs.values()))

@classmethod
def has_arg_of_type(
Expand Down Expand Up @@ -1399,8 +1395,7 @@ def call_function(
DeferRestartState(stack_objs, self.get_live_objs(),
self.frame.f_lasti, f"call_function"))
from .tracer import get_process_frame
preprocess_frame, post_process_frame = get_process_frame(
func, True)
preprocess_frame, post_process_frame = get_process_frame(func, True)
prior = set_eval_frame((preprocess_frame, post_process_frame))
assert prior is None
assert self.state.written == False
Expand Down Expand Up @@ -1435,12 +1430,12 @@ def set_if_inplace_return() -> None:
(is_graph_func(func) or is_math_func(func) or
(func in (float, int, min, max, len, list, abs, sum)))):
if hasattr(func, "__name__") and (
func.__name__ in ("named_children",
"_are_functorch_transforms_active",
"finfo", "dim", "save_for_backward") or
(func.__name__ == "type" and inst.argval == 0)) or
(func.__name__ in ("size",) and
not config.get_config("dynshape")):
func.__name__
in ("named_children", "_are_functorch_transforms_active",
"finfo", "dim", "save_for_backward") or
(func.__name__ == "type" and inst is not None and
inst.argval == 0) or (func.__name__ in ("size",) and
not config.get_config("dynshape"))):
self.state.set_partial_var({
-1: [
PartialVar(node=None,
Expand Down Expand Up @@ -1757,14 +1752,14 @@ def LOAD_ATTR(self, inst: Instruction) -> None:
if inst.argval in obj_var.modified_attrs:
return
if isinstance(obj, torch.Tensor) and inst.argval == 'data':
node = obj_var.as_fx_node()
node: Optional[torch.fx.Node] = obj_var.as_fx_node()
else:
node = None
need_guard_check = obj_var.need_guard_check
if config.get_config('dynshape') and isinstance(
obj, torch.Tensor) and inst.argval == 'shape':
node: Optional[torch.fx.Node] = self.state.fx_graph.create_node(
"call_method", "size", (obj_var.as_fx_node(),), {})
node = self.state.fx_graph.create_node("call_method", "size",
(obj_var.as_fx_node(),), {})
need_guard_check = False
else:
node = None
Expand Down Expand Up @@ -1833,10 +1828,9 @@ def CALL_FUNCTION_KW(self, inst: Instruction) -> None:
for arg, kw_name in zip(args[-len(kw_names):], kw_names):
kwargs[kw_name] = arg
args = args[:-len(kw_names)]
if hasattr(
func,
'__self__') and func.__self__ is not None and not isinstance(
func.__self__, ModuleType):
if hasattr(func,
'__self__') and func.__self__ is not None and not isinstance(
func.__self__, ModuleType):
args = [func.__self__] + list(args)
# print(f"function kw: {func}, type: {type(func)},args:{args}, kwargs:{kwargs}")
self.call_function(func, args, kwargs)
Expand All @@ -1850,10 +1844,9 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
else:
kwargs = {}
# print(f"function ex: {func}, type: {type(func)},args:{(func.__self__,) + args}, kwargs:{kwargs}")
if hasattr(
func,
'__self__') and func.__self__ is not None and not isinstance(
func.__self__, ModuleType):
if hasattr(func,
'__self__') and func.__self__ is not None and not isinstance(
func.__self__, ModuleType):
args = [func.__self__] + list(args)
self.call_function(func, args, kwargs)

Expand All @@ -1869,8 +1862,8 @@ def STORE_SUBSCR(self, inst: Instruction) -> None:
value = get_value_stack_from_top(self.frame, 2)
if isinstance(target, torch.Tensor):
# still use the original node, so no need to update object table
self.state.record_function(operator.setitem,
[target, index, value], {},
self.state.record_function(operator.setitem, [target, index, value],
{},
add_partial_var=False)
else:
self.state.add_inplace_update_obj(target)
Expand Down Expand Up @@ -2045,8 +2038,7 @@ def make_iterable_fn(
value: Any, need_guard_check: bool,
_helper_functions: vs.HelperFunctions,
fx_graph: Optional[FxGraph],
extract_code_at_start: Optional[list[StorePos]]
) -> vs.Variable:
extract_code_at_start: Optional[list[StorePos]]) -> vs.Variable:
if extract_code_at_start is None:
extract_code_at_start = []
return vs.IteratorVar.from_parent_var(value, obj_var, id(obj), 0,
Expand All @@ -2067,8 +2059,7 @@ def make_iterable_fn(
})

def FOR_ITER(self, _original_inst: Instruction) -> None:
original_pc, original_inst = self.code.get_orig_inst(
self.frame.f_lasti)
original_pc, original_inst = self.code.get_orig_inst(self.frame.f_lasti)
guard_pc = self.frame.f_lasti // 2
while self.code.guard_insts[guard_pc].opname == "EXTENDED_ARG":
guard_pc += 1
Expand Down Expand Up @@ -2127,26 +2118,25 @@ def FOR_ITER(self, _original_inst: Instruction) -> None:
end_loop_pc = self.code.get_pc_by_inst(guard_target)

def make_iterable_fn(
value: Any, need_guard_check: bool,
_helper_functions: vs.HelperFunctions,
_fx_graph: Optional[FxGraph],
extract_code_at_start: Optional[list[StorePos]]
value: Any, need_guard_check: bool,
_helper_functions: vs.HelperFunctions,
_fx_graph: Optional[FxGraph],
extract_code_at_start: Optional[list[StorePos]]
) -> vs.Variable:
assert isinstance(obj_var, vs.IteratorVar)
if extract_code_at_start is None:
extract_code_at_start = []
return vs.IteratorVar.from_parent_var(value,
obj_var.parent_var,
return vs.IteratorVar.from_parent_var(value, obj_var.parent_var,
obj_var.parent_idx,
obj_var.num_iters + 1,
need_guard_check,
extract_code_at_start)

def make_dynamic_input_fn(
value: Any, need_guard_check: bool,
helper_functions: vs.HelperFunctions,
fx_graph: Optional[FxGraph],
extract_code_at_start: Optional[list[StorePos]]
value: Any, need_guard_check: bool,
helper_functions: vs.HelperFunctions,
fx_graph: Optional[FxGraph],
extract_code_at_start: Optional[list[StorePos]]
) -> vs.Variable:
assert is_scalar(value)
dyn.mark_dynamic(value, dyn.ScalarWithUnknownValue())
Expand Down
6 changes: 4 additions & 2 deletions frontend/object_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ def get(self,
return self.objs[id(value)]
elif allow_unexist_const:
if isinstance(value, get_args(CONST_TYPES)) or isinstance(
value, (list, tuple, set, dict, range, CodeType, type(Ellipsis), np.ndarray)):
value, (list, tuple, set, dict, range, CodeType,
type(Ellipsis), np.ndarray)):
return make_var_from_value(value, False, self.helper_functions,
fx_graph)
raise RuntimeError(
f"Object({id(value)}) {value} {type(value)} not found in object table")
f"Object({id(value)}) {value} {type(value)} not found in object table"
)

def get_or_none(self, value: Any) -> Optional[Variable]:
if id(value) in self.objs:
Expand Down
32 changes: 17 additions & 15 deletions frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,13 @@ def get_root_module(func: Callable[..., Any]) -> str:
return 'numpy'

module = inspect.getmodule(func)
module = str(module).split('\'')[1]
if module is None or module in ('torch.distributions.bernoulli', 'torch.distributions.distribution'):
module_str = ""
if module is not None:
module_str = str(module).split('\'')[1]
if module is None or module_str in ('torch.distributions.bernoulli',
'torch.distributions.distribution'):
return ""
root_module = module.split('.')[0]
root_module = module_str.split('.')[0]
return root_module


Expand All @@ -167,9 +170,8 @@ def get_method_defined_class(cls: type[Any],

def is_user_defined_func(func: Callable[..., Any]) -> bool:
# print([(x, getattr(func, x)) for x in dir(func)])
if hasattr(func,
'__objclass__') and func.__objclass__ in (torch._C._TensorBase,
dict, str, collections.OrderedDict):
if hasattr(func, '__objclass__') and func.__objclass__ in (
torch._C._TensorBase, dict, str, collections.OrderedDict):
return False

# NOTE: random should be called as a UDF, not handled
Expand Down Expand Up @@ -198,8 +200,8 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool:
'inspect', 'collections'):
#NOTE:self.function should be recursive-checked to find out where it's defined, but not implemented
if hasattr(func, '__self__'
) and func.__self__ is not None and is_user_defined_func(
func.__self__):
) and func.__self__ is not None and is_user_defined_func(
func.__self__):
if is_own_method(func.__name__, func.__self__):
return True
else:
Expand Down Expand Up @@ -344,18 +346,18 @@ def __exit__(self, *args: Any) -> None:
@no_type_check
def is_namedtuple(obj: Any) -> bool:
cls: type[Any] = obj if inspect.isclass(cls) else type(obj)
return (issubclass(cls, tuple)
and isinstance(getattr(cls, '_fields', None), tuple)
and all(isinstance(field, str) for field in cls._fields))
return (issubclass(cls, tuple) and
isinstance(getattr(cls, '_fields', None), tuple) and
all(isinstance(field, str) for field in cls._fields))


@no_type_check
def is_structseq(obj: Any) -> bool:
cls: type[Any] = obj if inspect.isclass(obj) else type(obj)
if (cls.__base__ is tuple
and isinstance(getattr(cls, 'n_sequence_fields', None), int)
and isinstance(getattr(cls, 'n_fields', None), int)
and isinstance(getattr(cls, 'n_unnamed_fields', None), int)):
if (cls.__base__ is tuple and
isinstance(getattr(cls, 'n_sequence_fields', None), int) and
isinstance(getattr(cls, 'n_fields', None), int) and
isinstance(getattr(cls, 'n_unnamed_fields', None), int)):
try:

class subcls(cls): # type: ignore[misc]
Expand Down
12 changes: 5 additions & 7 deletions frontend/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
np.ndarray: NdarrayVar,
}

CONST_TYPES = Union[int, float, bool, str, NullObject, None, slice, type(Ellipsis)]
CONST_TYPES = Union[int, float, bool, str, NullObject, None, slice]


def make_var_from_value(
Expand All @@ -64,9 +64,8 @@ def make_var_from_value(
return ModuleVar.from_value(value, need_guard_check, helper_functions,
fx_graph, extract_code_at_start)
elif callable(value):
return FunctionVar.from_value(value, need_guard_check,
helper_functions, fx_graph,
extract_code_at_start)
return FunctionVar.from_value(value, need_guard_check, helper_functions,
fx_graph, extract_code_at_start)
elif isinstance(value, range):
return RangeVar.from_value(value, need_guard_check, helper_functions,
fx_graph, extract_code_at_start)
Expand All @@ -89,9 +88,8 @@ def make_var_from_value(
helper_functions, fx_graph,
extract_code_at_start)
elif isinstance(value, type(Ellipsis)):
return EllipsisVar.from_value(value, need_guard_check,
helper_functions, fx_graph,
extract_code_at_start)
return EllipsisVar.from_value(value, need_guard_check, helper_functions,
fx_graph, extract_code_at_start)
else:
# NOTE: use any instead of iteartor_var to represent iterator with unknown source due to the hardness of getting iterable and num_iters
print("generate any for", value, type(value), extract_code_at_start)
Expand Down
Loading

0 comments on commit 6c24c8d

Please sign in to comment.