Skip to content

Commit

Permalink
Parity 1 (apache#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 authored Mar 1, 2024
2 parents 38df973 + 559e810 commit 7a4e62e
Show file tree
Hide file tree
Showing 13 changed files with 641 additions and 55 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ jobs:
spack load [email protected] /jb4mlxg
spack load [email protected]%gcc@=11.3.0
source ~/venv/frontend-env/bin/activate
srun --gres=gpu:1 --exclusive ./scripts/pytest_with_preload.sh -vs test
FORCE_RUN_SKIPPED_TEST=1 srun --gres=gpu:1 --exclusive ./scripts/pytest_with_preload.sh -vs test/test_model_blockdrop.py -k test_blockdrop_dyn
srun --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test
FORCE_RUN_SKIPPED_TEST=1 srun --gres=gpu:v100:1 --exclusive ./scripts/pytest_with_preload.sh -vs test/test_model_blockdrop.py -k test_blockdrop_dyn
351 changes: 310 additions & 41 deletions frontend/guard_tracker.py

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions frontend/object_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .store_pos import StorePos
from .fx_graph import FxGraph
import numpy as np
import torch


class ObjectTable:
Expand Down Expand Up @@ -65,10 +66,14 @@ def get(self,
fx_graph: Optional[FxGraph] = None) -> Variable:
if id(value) in self.objs:
return self.objs[id(value)]
elif value is None:
return make_var_from_value(value, False, self.helper_functions,
fx_graph)
elif allow_unexist_const:
if isinstance(value, get_args(CONST_TYPES)) or isinstance(
value, (list, tuple, set, dict, range, CodeType,
type(Ellipsis), np.ndarray)):
value,
(list, tuple, set, dict, range, CodeType, type(Ellipsis),
np.ndarray, frozenset, torch.nn.Parameter)):
return make_var_from_value(value, False, self.helper_functions,
fx_graph)
raise RuntimeError(
Expand Down
9 changes: 9 additions & 0 deletions frontend/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ class FnCodegen:
imports: set[str]
key: int
objs: dict[str, Any] # name -> obj
statements: set[str]

def __init__(self, key: int) -> None:
self.key = key
self.prepare_var_writer = PyCodeWriter()
self.writer = PyCodeWriter()
self.imports = set()
self.objs = {}
self.statements = set()

def add_obj(self, obj: Any, name: str = "", force: bool = False) -> str:
if force:
Expand Down Expand Up @@ -56,6 +58,9 @@ def add_stmt(self, stmt: str, is_prepare: bool = False) -> None:
else:
self.writer.wl(stmt)

def add_statements(self, stmt: str) -> None:
self.statements.add(stmt)


class GraphFnCodegen(FnCodegen):
returns: list[Tuple[str, StorePos]]
Expand Down Expand Up @@ -88,6 +93,8 @@ def get_code(self) -> str:
gen_imports(writer, self.imports)
writer.wl(f"def fn(locals):")
writer.block_start()
for stmt in self.statements:
writer.wl(stmt)
if get_config('debug'):
writer.wl(
f"print('running graph_fn (key = {self.key})', locals.keys())")
Expand Down Expand Up @@ -157,6 +164,8 @@ def get_code(self) -> str:
writer.block_start()
writer.write(f"try:")
writer.block_start()
for stmt in self.statements:
writer.write(stmt)
if get_config('debug'):
writer.wl(
f"print('running guard_fn (key = {self.key})', locals.keys())")
Expand Down
24 changes: 17 additions & 7 deletions frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,25 +134,26 @@ def is_call_bytecode(inst: 'Instruction') -> bool:


def get_root_module(func: Callable[..., Any]) -> str:
import numpy as np
if hasattr(func, '__objclass__'):
if func.__objclass__ == torch._C._TensorBase:
return 'torch'
elif func.__objclass__ in (list, tuple, set, dict, str):
return 'builtins'
elif func.__objclass__ == np.ndarray:
return 'numpy'

if hasattr(func, '__self__') and isinstance(func.__self__, torch.Tensor):
return 'torch'

import numpy as np
if hasattr(func, '__class__') and func.__class__ == np.ufunc:
return 'numpy'

module = inspect.getmodule(func)
module_str = ""
if module is not None:
module_str = str(module).split('\'')[1]
if module is None or module_str in ('torch.distributions.bernoulli',
'torch.distributions.distribution'):
if module is None or 'torch.distributions' in module_str:
return ""
root_module = module_str.split('.')[0]
return root_module
Expand All @@ -178,11 +179,14 @@ def get_method_defined_class(cls: type[Any],

def is_user_defined_func(func: Callable[..., Any]) -> bool:
# print([(x, getattr(func, x)) for x in dir(func)])
import numpy
if hasattr(func, '__objclass__') and func.__objclass__ in (
torch._C._TensorBase, dict, list, str, collections.OrderedDict):
torch._C._TensorBase, dict, list, str, collections.OrderedDict,
numpy.ndarray):
return False
if hasattr(func, '__class__') and func.__class__ in (
torch._C._TensorBase, dict, list, str, collections.OrderedDict):
torch._C._TensorBase, dict, list, str, collections.OrderedDict,
numpy.ndarray):
return False

# NOTE: random should be called as a UDF, not handled
Expand All @@ -194,11 +198,13 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool:
elif isinstance(func.__self__, torch.nn.Sequential):
return True

if hasattr(func, '__name__') and func.__name__ == '<genexpr>':
if hasattr(func, '__name__') and func.__name__ in ('<genexpr>', 'numel'):
return False
if hasattr(func, '__name__') and func.__name__ == '_conv_forward':
return True

if hasattr(func, '__name__') and func.__name__ == 'forward':
return True
if hasattr(func, '__name__') and func.__name__ == 'apply':
assert hasattr(func, '__self__')
return is_user_defined_func(func.__self__)
Expand All @@ -213,8 +219,12 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool:
return False

root_module = get_root_module(func)
if root_module == 'torch' and hasattr(
func, '__name__') and func.__name__ == '_call_impl':
return True
if root_module in ('math', 'builtins', 'torch', 'numpy', '_operator',
'inspect', 'collections'):
'inspect', 'collections', 'itertools', 'functools',
'copy'):
#NOTE:self.function should be recursive-checked to find out where it's defined, but not implemented
if hasattr(func, '__self__'
) and func.__self__ is not None and is_user_defined_func(
Expand Down
6 changes: 4 additions & 2 deletions frontend/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from collections import OrderedDict
from .base import Variable, HelperFunctions
from .scalar import ScalarVar, NumpyScalarVar
from .tensor import TensorVar, TorchParamVar, TorchSizeVar, TorchDtypeVar, TorchDeviceVar
from .tensor import TensorVar, TorchParamVar, TorchSizeVar, TorchDtypeVar, TorchDeviceVar, TorchLayoutVar
from .torch_module import TorchModuleVar, TorchSequentialVar, TorchModuleListVar
from .any_ import AnyVar
from .const import NullVar, NoneVar, SliceVar, ModuleVar, FunctionVar, RangeVar, CodeVar, EllipsisVar
from .iterator import IteratorVar, RangeIterVar
from .tuple_ import TupleVar
from .set_ import SetVar
from .set_ import SetVar, FrozensetVar
from .list_ import ListVar, NdarrayVar
from .dict_ import DictVar, OrderedDictVar
from .builtin_types import CellVar, MappingProxyVar
Expand All @@ -32,9 +32,11 @@
tuple: TupleVar,
list: ListVar,
set: SetVar,
frozenset: FrozensetVar,
torch.Size: TorchSizeVar,
torch.dtype: TorchDtypeVar,
torch.device: TorchDeviceVar,
torch.layout: TorchLayoutVar,
dict: DictVar,
CodeType: CodeVar,
OrderedDict: OrderedDictVar,
Expand Down
63 changes: 63 additions & 0 deletions frontend/variables/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,66 @@ def from_value(cls, value: range, need_guard_check: bool,

def as_fx_node(self) -> NodeArgs:
return range(self.start, self.stop, self.step)


class ClsByNamedTupleVar(Variable):
cls_name: str
cls_attr: list[str]
obj: Any
obj_class: Any
attr_value: list[Any]
attr_vars: list[Variable]
helper_functions: HelperFunctions

def __init__(self, name: str, attrs: list[str], need_guard_check: bool,
obj: Any, extract_code_at_start: list[StorePos],
helper_functions: HelperFunctions) -> None:
super().__init__(need_guard_check, obj, extract_code_at_start)
self.cls_name = name
self.cls_attr = []
for attr in attrs:
self.cls_attr.append(attr)
self.obj = None
self.obj_class = None
self.attr_value = []
self.helper_functions = helper_functions

@classmethod
def from_value(
cls, value: Any, need_guard_check: bool,
_helper_functions: HelperFunctions, _fx_graph: Optional[FxGraph],
extract_code_at_start: list[StorePos]) -> 'ClsByNamedTupleVar':
return cls(value.name, value.attrs, need_guard_check, value,
extract_code_at_start, _helper_functions)

def make_guard_inner(self, codegen: "GuardFnCodegen",
pos: StorePos) -> None:
pass

def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
codegen: "GraphFnCodegen", in_return: bool,
idx: int) -> None:
assert self.attr_value is not None and self.obj is not None and len(
self.attr_value) == len(self.cls_attr)
for i, value in enumerate(self.attr_value):
var = self.helper_functions.get_or_make_var(value,
self.need_guard_check,
None, [])
var.make_output(f"{name_in_graph_fn}_{i}", store_pos, codegen,
False, id(value))
codegen.add_import_from("collections", "namedtuple")
name = f"'{self.cls_name}'"
temps = []
for j in self.cls_attr:
temp = f"'{j}'"
temps.append(temp)
attrs = f"[{','.join(f'{j}' for j in temps)},]" if len(
self.cls_attr) > 0 else "[]"
codegen.add_statements(f'funcs = namedtuple({name}, {attrs})')
paras = f"({','.join(f'{name_in_graph_fn}_{j}' for j in range(len(self.cls_attr)))},)" if len(
self.cls_attr) > 0 else "()"
codegen.output(name_in_graph_fn, store_pos, f"funcs{paras}", in_return,
idx)

def as_fx_node(self) -> NodeArgs:
return ValueError("cannot covert a user defined class to node")
17 changes: 17 additions & 0 deletions frontend/variables/list_.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class ListVar(Variable):
vars: list[Variable]
obj_ids: list[int]
length: int
helper_functions: HelperFunctions
graph: Optional[FxGraph]

def __init__(self, value: list[Any], need_guard_check: bool,
helper_functions: HelperFunctions, fx_graph: Optional[FxGraph],
Expand All @@ -22,6 +24,8 @@ def __init__(self, value: list[Any], need_guard_check: bool,
self.length = len(value)
self.vars = []
self.obj_ids = []
self.helper_functions = helper_functions
self.graph = fx_graph
for i, obj in enumerate(value):
new_extract: list[StorePos] = [
StoreInIndex(pos, id(obj), i)
Expand All @@ -43,6 +47,19 @@ def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
codegen: "GraphFnCodegen", in_return: bool,
idx: int) -> None:
oldest = self.get_oldest_var()
if len(self.obj) != len(self.vars):
# updated list
self.vars.clear()
self.obj_ids.clear()
for i, obj in enumerate(self.obj):
new_extract: list[StorePos] = [
StoreInIndex(pos, id(obj), i)
for pos in self.extract_code_at_start
]
var = self.helper_functions.get_or_make_var(
obj, self.need_guard_check, self.graph, new_extract)
self.vars.append(var)
self.obj_ids.append(id(obj))
for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)):
var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen,
False, idx_j)
Expand Down
68 changes: 68 additions & 0 deletions frontend/variables/set_.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,74 @@ def from_value(cls, value: set[Any], need_guard_check: bool,
def as_fx_node(self) -> NodeArgs:
return self.value

def add_subvars_to_table(self, table: 'ObjectTable') -> None:
for i, (var, idx) in enumerate(zip(self.vars, self.obj_ids)):
old_var = table.get_or_none_by_id(idx)
if old_var is not None:
new_extract: list[StorePos] = [
StoreInIndex(pos, idx, i, False)
for pos in self.extract_code_at_start
]
old_var.extract_code_at_start.extend(new_extract)
old_var.need_guard_check |= self.need_guard_check
else:
table.add_by_id(var, idx)
var.add_subvars_to_table(table)


class FrozensetVar(Variable):
vars: list[Variable]
obj_ids: list[int]
length: int

def __init__(self, value: frozenset[Any], need_guard_check: bool,
helper_functions: HelperFunctions, fx_graph: Optional[FxGraph],
extract_code_at_start: list[StorePos]) -> None:
super().__init__(need_guard_check, value, extract_code_at_start)
self.value = value
self.length = len(value)
self.vars = []
self.obj_ids = []
for i, obj in enumerate(value):
new_extract: list[StorePos] = [
StoreInIndex(pos, id(obj), i, False)
for pos in self.extract_code_at_start
]
var = helper_functions.get_or_make_var(obj, need_guard_check,
fx_graph, new_extract)
self.vars.append(var)
self.obj_ids.append(id(obj))

def make_guard_inner(self, codegen: "GuardFnCodegen",
pos: StorePos) -> None:
codegen.add_check((f'isinstance({pos}, frozenset)', pos))
codegen.add_check((f"len({pos}) == {self.length}", pos))
for i, (var, obj) in enumerate(zip(self.vars, self.obj_ids)):
var.make_guard_inner(codegen, StoreInIndex(pos, obj, i, False))

def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
codegen: "GraphFnCodegen", in_return: bool,
idx: int) -> None:
for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)):
var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen,
False, idx_j)

codegen.output(
name_in_graph_fn, store_pos,
f"{{{','.join(f'{name_in_graph_fn}_{j}' for j in range(len(self.vars)))},}}"
if len(self.vars) > 0 else "frozenset()", in_return, idx)

@classmethod
def from_value(cls, value: frozenset[Any], need_guard_check: bool,
helper_functions: HelperFunctions,
fx_graph: Optional[FxGraph],
extract_code_at_start: list[StorePos]) -> "FrozensetVar":
return cls(value, need_guard_check, helper_functions, fx_graph,
extract_code_at_start)

def as_fx_node(self) -> NodeArgs:
return self.value

def add_subvars_to_table(self, table: 'ObjectTable') -> None:
for i, (var, idx) in enumerate(zip(self.vars, self.obj_ids)):
old_var = table.get_or_none_by_id(idx)
Expand Down
Loading

0 comments on commit 7a4e62e

Please sign in to comment.