Skip to content

Commit

Permalink
support blockdrop.forward_full
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 committed Oct 6, 2023
1 parent c315165 commit cd34953
Showing 5 changed files with 330 additions and 40 deletions.
90 changes: 55 additions & 35 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
import traceback
import copy
import dataclasses
import torch.fx.immutable_collections as fx_immutable
from .code import ProcessedCode
from .c_api import get_value_stack_from_top, get_value_stack_size, set_eval_frame, stack_effect, get_code_map, is_bound_method
from .instruction import Instruction, ci
@@ -103,20 +104,18 @@ def as_node_args_kwargs(
self, args: list[Any], kwargs: dict[str, Any]
) -> tuple[tuple[torch.fx.Node, ...], dict[str, torch.fx.Node]]:

def as_fx_node(var: vs.Variable) -> NodeArgs:
def as_fx_node(arg: Any) -> NodeArgs:
if isinstance(arg, (tuple, list)):
return fx_immutable.immutable_list([as_fx_node(x) for x in arg])
var = self.objects.get(arg, allow_unexist_const=True)
if isinstance(var, vs.TorchParamVar):
return self.fx_graph.create_node("get_attr",
self.subparam_paths[var.obj],
(), {})
return var.as_fx_node()

node_args = tuple(
as_fx_node(self.objects.get(arg, allow_unexist_const=True))
for arg in args)
node_kwargs = {
key: as_fx_node(self.objects.get(arg, allow_unexist_const=True))
for key, arg in kwargs.items()
}
node_args = tuple(as_fx_node(arg) for arg in args)
node_kwargs = {key: as_fx_node(arg) for key, arg in kwargs.items()}

return node_args, node_kwargs

@@ -195,9 +194,12 @@ def add_stored_globals(self, name: str) -> None:
self.written = True
self.stored_globals.add(name)

def store_pos_in_callee(self, pos: StorePos, idx: int) -> StorePos:
def store_pos_in_callee(self, pos: StorePos,
idx: int) -> Optional[StorePos]:
if idx in self.objects.objs:
var = self.objects.objs[idx]
if len(var.extract_code_at_start) == 0:
return None
return var.extract_code_at_start[0]
if isinstance(pos, StoreInLocal):
raise ValueError("unknown local in callee", pos)
@@ -206,21 +208,25 @@ def store_pos_in_callee(self, pos: StorePos, idx: int) -> StorePos:
elif isinstance(pos, (StoreInGlobal, StoreInBuiltin)):
return pos
elif isinstance(pos, StoreInAttr):
return StoreInAttr(
self.store_pos_in_callee(pos.self_pos, pos.self_id),
pos.self_id, pos.attr_name)
parent_pos = self.store_pos_in_callee(pos.self_pos, pos.self_id)
if parent_pos is None:
return None
return StoreInAttr(parent_pos, pos.self_id, pos.attr_name)
elif isinstance(pos, StoreInIndex):
return StoreInIndex(
self.store_pos_in_callee(pos.self_pos, pos.self_id),
pos.self_id, pos.self_index)
parent_pos = self.store_pos_in_callee(pos.self_pos, pos.self_id)
if parent_pos is None:
return None
return StoreInIndex(parent_pos, pos.self_id, pos.self_index)
elif isinstance(pos, ExtractFromMethod):
return ExtractFromMethod(
self.store_pos_in_callee(pos.self_pos, pos.self_id),
pos.self_id, pos.method_name)
parent_pos = self.store_pos_in_callee(pos.self_pos, pos.self_id)
if parent_pos is None:
return None
return ExtractFromMethod(parent_pos, pos.self_id, pos.method_name)
elif isinstance(pos, ExtractFromFunction):
return ExtractFromFunction(
self.store_pos_in_callee(pos.var_pos, pos.var_id), pos.var_id,
pos.func_name)
parent_pos = self.store_pos_in_callee(pos.var_pos, pos.var_id)
if parent_pos is None:
return None
return ExtractFromFunction(parent_pos, pos.var_id, pos.func_name)
else:
raise NotImplementedError

@@ -246,11 +252,13 @@ def merge_call_guard() -> None:
elif isinstance(pos, StoreInGlobal):
new_var.extract_code_at_start.append(pos)
elif isinstance(pos, StoreInAttr):
new_var.extract_code_at_start.append(
self.store_pos_in_callee(pos, idx))
self_pos = self.store_pos_in_callee(pos, idx)
assert self_pos is not None
new_var.extract_code_at_start.append(self_pos)
elif isinstance(pos, StoreInIndex):
new_var.extract_code_at_start.append(
self.store_pos_in_callee(pos, idx))
self_pos = self.store_pos_in_callee(pos, idx)
assert self_pos is not None
new_var.extract_code_at_start.append(self_pos)
else:
raise NotImplementedError(pos)

@@ -387,25 +395,29 @@ def get_or_make_var(
merged_ids.add(id(obj))
return new_var

def get_new_store_pos(old: list[StorePos],
idx: int) -> list[StorePos]:
new: list[StorePos] = []
for pos in old:
new_pos = self.store_pos_in_callee(pos, idx)
if new_pos is not None:
new.append(new_pos)
return new

for idx, var in state.objects.get_all_with_id():
if var.prev is not None:
oldest = var.get_oldest_var()
if len(oldest.extract_code_at_start) == 0:
continue
new_extract = [
self.store_pos_in_callee(pos, idx)
for pos in oldest.extract_code_at_start
]
new_extract: list[StorePos] = get_new_store_pos(
oldest.extract_code_at_start, idx)
get_or_make_var(var.obj, var.need_guard_check,
self.fx_graph, new_extract)

var = state.objects.get(return_value, allow_unexist_const=True)
new_extract = [
self.store_pos_in_callee(pos, id(var.obj))
for pos in var.extract_code_at_start
]
get_or_make_var(return_value, var.need_guard_check, self.fx_graph,
new_extract)
new_extract = get_new_store_pos(var.extract_code_at_start,
id(var.obj))
get_or_make_var(return_value, False, self.fx_graph, new_extract)

if len(state.objects.objs_no_id) > 0:
raise NotImplementedError
@@ -745,6 +757,7 @@ def call_function(
]
})
if is_user_defined_func(func) or isinstance(func, torch.nn.Sequential):
print("run into user defined function")
stack_objs = get_all_objects_in_stack(self.frame)
self.state.mark_defer_restart(f"call_function", stack_objs)
from .tracer import get_process_frame
@@ -860,6 +873,13 @@ def BINARY_SUBSCR(self, inst: Instruction) -> None:
obj2 = get_value_stack_from_top(self.frame, 0)
self.call_function(operator.getitem, [obj1, obj2], {})

def COMPARE_OP(self, inst: Instruction) -> None:
obj1 = get_value_stack_from_top(self.frame, 1)
obj2 = get_value_stack_from_top(self.frame, 0)
cmp_op = ('lt', 'le', 'eq', 'ne', 'gt', 'ge')
self.call_function(getattr(operator, cmp_op[inst.arg]), [obj1, obj2],
{})

def INPLACE_POWER(self, _inst: Instruction) -> None:
self.binary_operation(operator.ipow)

2 changes: 1 addition & 1 deletion frontend/object_table.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ def add(self, var: Variable, value: Any) -> None:
old_var.extract_code_at_start.extend(var.extract_code_at_start)
old_var.need_guard_check |= var.need_guard_check
else:
self.objs[id(value)] = var
self.add_by_id(var, id(value))
var.add_subvars_to_table(self)

def add_by_id(self, var: Variable, idx: int) -> None:
4 changes: 1 addition & 3 deletions frontend/utils.py
Original file line number Diff line number Diff line change
@@ -107,16 +107,14 @@ def get_root_module(func: Callable[..., Any]) -> str:


def is_user_defined_func(func: Callable[..., Any]) -> bool:
if func in fx_graph_functions:
return False
if hasattr(func,
'__objclass__') and func.__objclass__ == torch._C._TensorBase:
return False

root_module = get_root_module(func)
if root_module == '':
return True
if root_module in ('math', 'builtins', 'torch', 'numpy'):
if root_module in ('math', 'builtins', 'torch', 'numpy', '_operator'):
return False
return True

Loading

0 comments on commit cd34953

Please sign in to comment.