Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into cfv3
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 committed May 11, 2024
2 parents fb2d7ba + 88074bc commit b021011
Show file tree
Hide file tree
Showing 15 changed files with 196 additions and 31 deletions.
6 changes: 6 additions & 0 deletions frontend/c_api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ def set_eval_frame(
pass


def set_fallback(
new_callback: Optional[Tuple[Callable[..., Any], Callable[..., Any]]]
) -> Optional[Tuple[Callable[..., Any], Callable[..., Any]]]:
pass


def set_skip_files(skip_file: set[str], end_file: set[str]) -> None:
pass

Expand Down
2 changes: 2 additions & 0 deletions frontend/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,5 @@ def reset() -> None:
fx_graph.reset()
from . import dynamic
dynamic.reset()
from . import tracer
tracer.reset()
3 changes: 2 additions & 1 deletion frontend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
"debug": True,
"miss_threshold": 3,
"dynshape": False,
"model_name": ""
"model_name": "",
"enable_fallback": False,
}


Expand Down
16 changes: 13 additions & 3 deletions frontend/csrc/frame_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,11 @@ inline static void enable_eval_frame_shim(PyThreadState *tstate) {
inline static void enable_eval_frame_default(PyThreadState *tstate) {
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
previous_eval_frame) {
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
previous_eval_frame);
previous_eval_frame = NULL;
if (previous_eval_frame != NULL) {
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
previous_eval_frame);
previous_eval_frame = NULL;
}
}
}

Expand Down Expand Up @@ -290,6 +292,13 @@ static PyObject *set_eval_frame(PyObject *self, PyObject *args) {
return old_callback;
}

static PyObject *set_fallback(PyObject *self, PyObject *args) {
PyThreadState *tstate = PyThreadState_GET();
fprintf(stderr, "Falling back\n");
decrese_working_threads(tstate);
Py_RETURN_NONE;
}

// TODO: in a more elegant way
static PyObject *set_skip_files(PyObject *self, PyObject *args) {
if (skip_files != Py_None) {
Expand Down Expand Up @@ -659,6 +668,7 @@ static PyObject *mark_need_postprocess(PyObject *self, PyObject *args) {

static PyMethodDef _methods[] = {
{"set_eval_frame", set_eval_frame, METH_VARARGS, NULL},
{"set_fallback", set_fallback, METH_VARARGS, NULL},
{"set_skip_files", set_skip_files, METH_VARARGS, NULL},
{"set_null_object", set_null_object, METH_VARARGS, NULL},
{"set_miss_threshold", set_miss_threshold, METH_VARARGS, NULL},
Expand Down
116 changes: 99 additions & 17 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .c_api import get_value_stack_from_top, get_value_stack_size, set_eval_frame, stack_effect, get_code_map, is_bound_method, get_from_freevars, set_value_stack_from_top, parse_cell, set_local
from .instruction import Instruction, ci
from .cache import CachedGraph, get_frame_cache
from .store_pos import StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInIndex, ExtractFromMethod, StoreInBuiltin, ExtractFromFunction, IterValue, StoreInFreeVar, ExtractFromNew, UnknownPosInCaller
from .store_pos import StoreConstant, StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInIndex, ExtractFromMethod, StoreInBuiltin, ExtractFromFunction, IterValue, StoreInFreeVar, ExtractFromNew, UnknownPosInCaller
from . import variables as vs
from . import dynamic as dyn
from .utils import is_scalar, new_random_key, has_force_graph_break, NullObject, is_call_bytecode, fx_graph_functions, fx_graph_inplace_functions, is_user_defined_func, UnknownTypeError, get_all_objects_in_stack, is_graph_func, get_root_module, torch_inplace_funcs, print_bytecode, get_method_defined_class, is_math_func, is_high_order_func_with_udf, is_high_order_func, math2torch
Expand All @@ -31,6 +31,7 @@
from .variables.const import ClsByNamedTupleVar
from .variables.base import Variable
from .control_flow import ControlFlowInfo, LoopModule, ForLoopInfo, LoopPosMap, if_stmt, IfStmtInfo
from .config import get_config

MAKE_VAR_FN_TYPE = Callable[[
Any, bool, vs.HelperFunctions, Optional[FxGraph], Optional[list[StorePos]]
Expand Down Expand Up @@ -286,7 +287,8 @@ def record_function(self,
if func in (min, max):
scalar = None
node = None
assert len(pargs) == 2
# NOTE: when pargs < 2, it should be a dynamic operation
assert len(pargs) <= 2
for i, obj in enumerate(pargs):
if isinstance(obj, (int, float)) and not dyn.contains(obj):
scalar = obj
Expand All @@ -310,7 +312,9 @@ def record_function(self,
func = math2torch[func]
if func == torch.from_numpy:
func = torch.tensor

if hasattr(func, '__name__') and func.__name__ == 'numpy':
if torch.is_tensor(args[0]) or dyn.contains(args[0]):
raise ValueError("numpy can't have dynamic args")
self.written = True
scalar2tensor: dict[Callable[..., Any], Callable[..., Any]] = {
float: torch.Tensor.float,
Expand Down Expand Up @@ -349,6 +353,9 @@ def record_function(self,
func = torch.Tensor.new_empty
elif func == torch.Tensor.item:
assert args[0].numel() == 1
if args[0].dtype == torch.bool:
raise ValueError(
"The .item() method was applied to a boolean tensor.")
func = torch.Tensor.clone

fx_node = self.fx_graph.create_node("call_method", func.__name__,
Expand Down Expand Up @@ -465,6 +472,8 @@ def store_pos_in_caller(self, pos: StorePos,
raise ValueError("cannot store in stack in callee")
elif isinstance(pos, (StoreInGlobal, StoreInBuiltin, StoreInFreeVar)):
return pos
elif isinstance(pos, StoreConstant):
return pos
elif isinstance(pos, StoreInAttr):
# print("in callee", pos, self.frame_id)
parent_pos = self.store_pos_in_caller(pos.self_pos, pos.self_id)
Expand All @@ -486,7 +495,12 @@ def store_pos_in_caller(self, pos: StorePos,
for p, i in zip(pos.var_pos, pos.var_id):
new_pos = self.store_pos_in_caller(p, i)
if new_pos is None:
return None
if isinstance(
p,
StoreConstant): # allow constant function parameter
new_pos = p
else:
return None
parent_poses.append(new_pos)
return ExtractFromFunction(parent_poses, pos.var_id, pos.func_name,
pos.func_obj, pos.need_add_to_fn)
Expand Down Expand Up @@ -839,6 +853,7 @@ class GuardTracker:
caller: Optional['GuardTracker']
cf_info: Optional[ControlFlowInfo]
num_breaks: int
layout_sensitive: bool

def __init__(self,
frame: FrameType,
Expand Down Expand Up @@ -876,6 +891,7 @@ def __init__(self,
read_stack=read_stack, frame_cf_info=cf_info
) # stack pointer is not initialized at the creation of a stack frame
self.num_breaks = 0
self.layout_sensitive = False

def init_state(self,
read_stack: bool = True,
Expand Down Expand Up @@ -904,6 +920,9 @@ def record(
restart_caller=False)
if self.code.get_inst(self.frame.f_lasti).opname == 'RETURN_VALUE':
if trackers[-1] == self:
if self.layout_sensitive == True:
if self.caller is not None:
self.caller.layout_sensitive = True
pop_tracker(self.frame_id)
set_eval_frame(None)
return
Expand Down Expand Up @@ -956,6 +975,8 @@ def record(
def commit_loop_subgraph(self) -> None:
key = new_random_key()
guard_codegen = GuardFnCodegen(key=key)
if self.layout_sensitive == True:
guard_codegen.layout_sensitive = True
for var in self.state.objects.get_all():
while var.prev is not None:
var = var.prev
Expand Down Expand Up @@ -1176,6 +1197,8 @@ def commit(self) -> None:
if self.state.can_guard:
key = new_random_key()
guard_codegen = GuardFnCodegen(key=key)
if self.layout_sensitive == True:
guard_codegen.layout_sensitive = True
for var in self.state.objects.get_all():
while var.prev is not None:
var = var.prev
Expand Down Expand Up @@ -1557,7 +1580,9 @@ def is_genexpr_func(self, func: Callable[..., Any]) -> bool:

def is_builtin_func(self, func: Callable[..., Any]) -> bool:
return func in (dict, tuple, set, list, hasattr, slice, range, len,
type, all, str.join, reversed, zip, iter, id, next)
type, all, str.join, reversed, zip, iter, id, next,
collections.OrderedDict, str.format, any, str,
str.split, sorted)

def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
print(dir(func))
Expand Down Expand Up @@ -1615,11 +1640,46 @@ def call_function(
self.state.fx_graph, [pos])
self.state.add_object(var, obj)
return
if hasattr(func,
'__name__') and func.__name__ == 'format' and isinstance(
func, type(str.format)):
for arg in args:
if torch.is_tensor(arg) or dyn.contains(arg):
raise ValueError("format can't have dynamic args")
if hasattr(func, '__name__') and (func.__name__ == 'is_contiguous' or
func.__name__ == 'stride'):
self.layout_sensitive = True
if hasattr(func, '__name__') and func.__name__ == '__init__':
return
# a series of classes and functions defined by warnings
if get_root_module(func) in ('_warnings', 'warnings'):
return
if get_root_module(func) == 'random':
for arg in args:
if torch.is_tensor(arg) or dyn.contains(arg):
raise ValueError("random func can't have dynamic args")
if func.__name__ not in {
'random', 'randint', 'randrange', 'uniform'
}:
raise ValueError("Not implement random func")

name = new_name('random')
fx_node = self.state.fx_graph.create_input(torch.tensor([0]), name,
(), {}, name)
self.state.set_partial_var({
-1: [
PartialVar(
node=fx_node,
need_guard_check=False,
extract_code_at_start=[
ExtractFromFunction(
[StoreConstant(arg, id(arg)) for arg in args],
[id(arg) for arg in args], func.__name__, func,
True)
])
]
})
return
is_high_order_udf = is_high_order_func_with_udf(func, args, kwargs)
if is_user_defined_func(func) or isinstance(
func, nn.Sequential) or is_high_order_udf:
Expand Down Expand Up @@ -1753,7 +1813,11 @@ def set_if_inplace_return() -> None:
"flatten_parameters", "numel", "children",
"named_parameters", "_weights_have_changed",
"check_forward_args", "permute_hidden", "_check_input_dim",
"parameters"):
"parameters", "_has_torch_function_unary", "_is_tracing",
"is_tracing", "is_scripting", "get_autocast_gpu_dtype",
"is_autocast_enabled", "ndimension", "get_enum",
"is_tensor", "is_complex", "is_contiguous", "stride",
"get_device"):
return
if hasattr(func, "__module__"
) and func.__module__ == 'torch.autograd.profiler':
Expand Down Expand Up @@ -1843,6 +1907,18 @@ def set_if_inplace_return() -> None:
# TODO: add map and set correct partial var
return
elif is_graph_func(func):
if func is operator.getitem:
obj_var = self.state.objects.get(args[0])
assert obj_var.extract_code_at_start[0]
obj_pos = obj_var.extract_code_at_start[0]
item_pos = StoreInIndex(obj_pos, id(obj_pos), args[1])
self.state.set_partial_var({
-1: [
PartialVar(node=None,
need_guard_check=False,
extract_code_at_start=[item_pos])
]
})
return
elif len(args) > 0 and isinstance(args[0], torch.nn.ModuleList):
return
Expand Down Expand Up @@ -1901,6 +1977,13 @@ def gen_by_caller(self,
caller = caller.caller
return False

def generic_jump_check(self) -> None:
top_value = get_value_stack_from_top(self.frame, 0)
if torch.is_tensor(top_value):
raise ValueError("generic_jump TensorVariable() by tensor")
if dyn.contains(top_value):
raise ValueError("generic_jump TensorVariable() by dyn scalar")

def binary_operation(self, func: Callable[..., Any]) -> None:
obj1 = get_value_stack_from_top(self.frame, 1)
obj2 = get_value_stack_from_top(self.frame, 0)
Expand Down Expand Up @@ -1948,6 +2031,9 @@ def BINARY_OR(self, _inst: Instruction) -> None:
def BINARY_SUBSCR(self, inst: Instruction) -> None:
obj1 = get_value_stack_from_top(self.frame, 1)
obj2 = get_value_stack_from_top(self.frame, 0)
if torch.is_tensor(obj1):
if torch.is_tensor(obj2) and obj2.dtype == torch.bool:
raise ValueError("dynamic shape in tensor")
self.call_function(operator.getitem, [obj1, obj2], {})

def unary_operation(self, func: Callable[..., Any]) -> None:
Expand Down Expand Up @@ -2423,11 +2509,6 @@ def UNPACK_SEQUENCE(self, inst: Instruction) -> None:
# ]
# })
# pass
print("check data", seq, type(seq))
if self.state.objects.contains(seq):
print("jjjjjj")
for i in seq:
print(i)
raise NotImplementedError

def UNPACK_EX(self, inst: Instruction) -> None:
Expand Down Expand Up @@ -2481,16 +2562,16 @@ def DUP_TOP_TWO(self, _inst: Instruction) -> None:
pass

def POP_JUMP_IF_FALSE(self, _inst: Instruction) -> None:
pass
self.generic_jump_check()

def POP_JUMP_IF_TRUE(self, _inst: Instruction) -> None:
pass
self.generic_jump_check()

def JUMP_IF_TRUE_OR_POP(self, _inst: Instruction) -> None:
pass
self.generic_jump_check()

def JUMP_IF_FALSE_OR_POP(self, _inst: Instruction) -> None:
pass
self.generic_jump_check()

def JUMP_FORWARD(self, inst: Instruction) -> None:
pass
Expand Down Expand Up @@ -2701,8 +2782,9 @@ def pop_tracker(frame_id: int) -> None:
print("before pop_tracker", [t.frame_id for t in trackers], "frame_id",
frame_id)
to_pop = trackers.pop()
assert to_pop.frame_id == frame_id
assert to_pop.state.is_empty
if not get_config("enable_fallback"):
assert to_pop.frame_id == frame_id
assert to_pop.state.is_empty


def record(frame: FrameType, frame_id: int) -> None:
Expand Down
2 changes: 2 additions & 0 deletions frontend/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,14 @@ class GuardFnCodegen(FnCodegen):
checks: set[tuple[str, StorePos]]
imports: set[str]
object_refs: list[Any] # the reference to objects for id check
layout_sensitive: bool

def __init__(self, key: int) -> None:
super().__init__(key)
self.checks = set()
self.imports = set()
self.object_refs = []
self.layout_sensitive = False

def add_check(self, check: tuple[str, StorePos]) -> None:
self.checks.add(check)
Expand Down
19 changes: 18 additions & 1 deletion frontend/store_pos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Optional, TYPE_CHECKING, Callable
from typing import Any, Optional, TYPE_CHECKING, Callable, Union
from types import FrameType

from torch import Tensor

from .c_api import get_value_stack_from_top
if TYPE_CHECKING:
from .pycode_generator import FnCodegen
Expand Down Expand Up @@ -41,6 +43,21 @@ def get_value_from_frame(self, frame: FrameType) -> Any:
return frame.f_locals[self.name]


class StoreConstant(StorePos):
value: Union[int, float]
self_id: int

def __init__(self, value: Union[int, float], self_id: int) -> None:
self.value = value
self.self_id = self_id

def __repr__(self) -> str:
return str(self.value)

def get_value_from_frame(self, frame: FrameType) -> Any:
return self.value


class StoreInGlobal(StorePos):
name: str

Expand Down
Loading

0 comments on commit b021011

Please sign in to comment.