Skip to content

Commit

Permalink
Parity 1 (apache#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored Dec 6, 2023
2 parents b8e6464 + 6c24c8d commit b3c000f
Show file tree
Hide file tree
Showing 15 changed files with 489 additions and 80 deletions.
208 changes: 174 additions & 34 deletions frontend/guard_tracker.py

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions frontend/object_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .utils import NullObject, ReadOnlyObject
from .store_pos import StorePos
from .fx_graph import FxGraph
import torch
import numpy as np


class ObjectTable:
Expand Down Expand Up @@ -35,7 +35,6 @@ def add(self, var: Variable, value: Any) -> None:
old_var.need_guard_check |= var.need_guard_check
else:
self.add_by_id(var, id(value))
var.add_subvars_to_table(self)

def add_by_id(self, var: Variable, idx: int) -> None:
assert idx not in self.objs
Expand Down Expand Up @@ -68,11 +67,13 @@ def get(self,
return self.objs[id(value)]
elif allow_unexist_const:
if isinstance(value, get_args(CONST_TYPES)) or isinstance(
value, (list, tuple, set, dict, CodeType)):
value, (list, tuple, set, dict, range, CodeType,
type(Ellipsis), np.ndarray)):
return make_var_from_value(value, False, self.helper_functions,
fx_graph)
raise RuntimeError(
f"Object({id(value)}) {value} not found in object table")
f"Object({id(value)}) {value} {type(value)} not found in object table"
)

def get_or_none(self, value: Any) -> Optional[Variable]:
if id(value) in self.objs:
Expand Down
35 changes: 26 additions & 9 deletions frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import contextlib
import torch
import torch._C
import collections
from .config import get_config, set_config

if TYPE_CHECKING:
Expand Down Expand Up @@ -93,11 +94,15 @@ def is_call_bytecode(inst: 'Instruction') -> bool:
operator.rshift,
operator.and_,
operator.or_,
operator.is_,
operator.xor,
operator.eq,
operator.lt,
operator.ne,
operator.le,
operator.gt,
operator.ge,
operator.contains,
}
fx_graph_functions = fx_graph_functions.union(fx_graph_inplace_functions)

Expand All @@ -124,7 +129,7 @@ def get_root_module(func: Callable[..., Any]) -> str:
if hasattr(func, '__objclass__'):
if func.__objclass__ == torch._C._TensorBase:
return 'torch'
elif func.__objclass__ in (list, tuple, set, dict):
elif func.__objclass__ in (list, tuple, set, dict, str):
return 'builtins'

if hasattr(func, '__self__') and isinstance(func.__self__, torch.Tensor):
Expand All @@ -135,9 +140,13 @@ def get_root_module(func: Callable[..., Any]) -> str:
return 'numpy'

module = inspect.getmodule(func)
if module is None:
module_str = ""
if module is not None:
module_str = str(module).split('\'')[1]
if module is None or module_str in ('torch.distributions.bernoulli',
'torch.distributions.distribution'):
return ""
root_module = str(module).split('\'')[1].split('.')[0]
root_module = module_str.split('.')[0]
return root_module


Expand All @@ -161,15 +170,18 @@ def get_method_defined_class(cls: type[Any],

def is_user_defined_func(func: Callable[..., Any]) -> bool:
# print([(x, getattr(func, x)) for x in dir(func)])
if hasattr(func,
'__objclass__') and func.__objclass__ in (torch._C._TensorBase,
dict):
if hasattr(func, '__objclass__') and func.__objclass__ in (
torch._C._TensorBase, dict, str, collections.OrderedDict):
return False

# NOTE: random should be called as a UDF, not handled
if hasattr(func, '__self__') and isinstance(func.__self__,
(torch.Tensor, random.Random)):
return False
if hasattr(func, '__self__'):
if isinstance(func.__self__, (torch.Tensor, random.Random)):
return False
elif isinstance(func.__self__, (list, tuple, set, dict, str)):
return False
elif isinstance(func.__self__, torch.nn.Sequential):
return True

if hasattr(func, '__name__') and func.__name__ == '<genexpr>':
return False
Expand Down Expand Up @@ -213,6 +225,11 @@ def is_graph_func(func: Callable[..., Any]) -> bool:
return root_module == 'torch'


def is_math_func(func: Callable[..., Any]) -> bool:
root_module = get_root_module(func)
return root_module == 'math'


random_state = None


Expand Down
10 changes: 7 additions & 3 deletions frontend/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from .tensor import TensorVar, TorchParamVar, TorchSizeVar, TorchDtypeVar, TorchDeviceVar
from .torch_module import TorchModuleVar, TorchSequentialVar, TorchModuleListVar
from .any_ import AnyVar
from .const import NullVar, NoneVar, SliceVar, ModuleVar, FunctionVar, RangeVar, CodeVar
from .const import NullVar, NoneVar, SliceVar, ModuleVar, FunctionVar, RangeVar, CodeVar, EllipsisVar
from .iterator import IteratorVar, RangeIterVar
from .tuple_ import TupleVar
from .set_ import SetVar
from .list_ import ListVar
from .list_ import ListVar, NdarrayVar
from .dict_ import DictVar, OrderedDictVar
from .builtin_types import CellVar, MappingProxyVar
from ..fx_graph import FxGraph
Expand All @@ -37,7 +37,8 @@
torch.device: TorchDeviceVar,
dict: DictVar,
CodeType: CodeVar,
OrderedDict: OrderedDictVar
OrderedDict: OrderedDictVar,
np.ndarray: NdarrayVar,
}

CONST_TYPES = Union[int, float, bool, str, NullObject, None, slice]
Expand Down Expand Up @@ -86,6 +87,9 @@ def make_var_from_value(
return MappingProxyVar.from_value(value, need_guard_check,
helper_functions, fx_graph,
extract_code_at_start)
elif isinstance(value, type(Ellipsis)):
return EllipsisVar.from_value(value, need_guard_check, helper_functions,
fx_graph, extract_code_at_start)
else:
# NOTE: use any instead of iteartor_var to represent iterator with unknown source due to the hardness of getting iterable and num_iters
print("generate any for", value, type(value), extract_code_at_start)
Expand Down
75 changes: 67 additions & 8 deletions frontend/variables/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from ..pycode_writer import get_float_string
from ..fx_graph import NodeArgs, FxGraph
from ..utils import NullObject, null_object
from ..store_pos import StorePos
from ..store_pos import StorePos, StoreInFreeVar, StoreInAttr
if TYPE_CHECKING:
from ..pycode_generator import GraphFnCodegen, GuardFnCodegen
from ..object_table import ObjectTable


class NoneVar(Variable):
Expand Down Expand Up @@ -132,6 +133,33 @@ def as_fx_node(self) -> NodeArgs:
return slice(self.start, self.stop, self.step)


class EllipsisVar(Variable):

def __init__(self, need_guard_check: bool, obj: Any,
extract_code_at_start: list[StorePos]) -> None:
super().__init__(need_guard_check, obj, extract_code_at_start)

def make_guard_inner(self, codegen: "GuardFnCodegen",
pos: StorePos) -> None:
codegen.add_id_check(f"id({pos}) == {id(self.obj)}", self.obj)

def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
codegen: "GraphFnCodegen", in_return: bool,
idx: int) -> None:
name = codegen.add_obj(self.obj, "Ellipsis_VAR")
codegen.output(name_in_graph_fn, store_pos, name, in_return, idx)

@classmethod
def from_value(cls, value: Any, need_guard_check: bool,
_helper_functions: HelperFunctions,
_fx_graph: Optional[FxGraph],
extract_code_at_start: list[StorePos]) -> "EllipsisVar":
return cls(need_guard_check, value, extract_code_at_start)

def as_fx_node(self) -> NodeArgs:
return Ellipsis


torch_modules = set([torch])


Expand Down Expand Up @@ -162,10 +190,24 @@ def from_value(cls, value: ModuleType, need_guard_check: bool,


class FunctionVar(Variable):
closure_vars: list[Variable]
obj_ids: list[int]

def __init__(self, func: Callable[..., Any], need_guard_check: bool,
helper_functions: HelperFunctions,
extract_code_at_start: list[StorePos]) -> None:
super().__init__(need_guard_check, func, extract_code_at_start)
self.closure_vars = []
self.obj_ids = []
if hasattr(func, "__code__") and hasattr(func, "__closure__"):
if func.__closure__ is not None:
assert len(func.__code__.co_freevars) == len(func.__closure__)
for i, x in enumerate(func.__closure__):
if x.cell_contents != func:
cell_var = helper_functions.get_or_make_var(
x, need_guard_check, None, [StoreInFreeVar(i)])
self.closure_vars.append(cell_var)
self.obj_ids.append(id(x))

def make_guard_inner(self, codegen: "GuardFnCodegen",
pos: StorePos) -> None:
Expand All @@ -187,17 +229,31 @@ def from_value(cls, value: Callable[..., Any], need_guard_check: bool,
_helper_functions: HelperFunctions,
_fx_graph: Optional[FxGraph],
extract_code_at_start: list[StorePos]) -> "FunctionVar":
return cls(value, need_guard_check, extract_code_at_start)
return cls(value, need_guard_check, _helper_functions,
extract_code_at_start)

def add_subvars_to_table(self, table: 'ObjectTable') -> None:
for i, (var, idx) in enumerate(zip(self.closure_vars, self.obj_ids)):
old_var = table.get_or_none_by_id(idx)
if old_var is not None:
new_extract: list[StorePos] = [StoreInFreeVar(i)]
old_var.extract_code_at_start.extend(new_extract)
old_var.need_guard_check |= self.need_guard_check
else:
table.add_by_id(var, idx)
var.add_subvars_to_table(table)

# def as_fx_node(self) -> NodeArgs:
# return self.obj


class RangeVar(Variable):
start: Optional[int]
stop: Optional[int]
step: Optional[int]
start: int
stop: int
step: int

def __init__(self, start: Optional[int], stop: Optional[int],
step: Optional[int], need_guard_check: bool, obj: range,
extract_code_at_start: list[StorePos]) -> None:
def __init__(self, start: int, stop: int, step: int, need_guard_check: bool,
obj: range, extract_code_at_start: list[StorePos]) -> None:
super().__init__(need_guard_check, obj, extract_code_at_start)
self.start = start
self.stop = stop
Expand All @@ -222,3 +278,6 @@ def from_value(cls, value: range, need_guard_check: bool,
extract_code_at_start: list[StorePos]) -> "RangeVar":
return cls(value.start, value.stop, value.step, need_guard_check, value,
extract_code_at_start)

def as_fx_node(self) -> NodeArgs:
return range(self.start, self.stop, self.step)
16 changes: 12 additions & 4 deletions frontend/variables/dict_.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,18 @@ def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
codegen.output(name_in_graph_fn, store_pos, str(old_store_pos),
in_return, idx)
else:
codegen.output(
name_in_graph_fn, store_pos,
f"{{{','.join(f'{key}: {name_in_graph_fn}_{j}' for key, j in zip(self.value.keys(), range(len(self.vars))))}}}"
if len(self.vars) > 0 else "{}", in_return, idx)
items = []
for key, j in zip(self.value.keys(), range(len(self.vars))):
if isinstance(key, str):
key_part = f"'{key}'"
else:
key_part = key
item = f'{key_part}: {name_in_graph_fn}_{j}'
items.append(item)
target = f"{{{', '.join(i for i in items)}}}"
codegen.output(name_in_graph_fn, store_pos,
target if len(self.vars) > 0 else "{}", in_return,
idx)

@classmethod
def from_value(cls,
Expand Down
72 changes: 71 additions & 1 deletion frontend/variables/list_.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Optional, Any, Callable
from copy import copy
import numpy as np
from .base import Variable, HelperFunctions
from ..fx_graph import NodeArgs, FxGraph
from ..store_pos import StorePos, StoreInIndex
Expand Down Expand Up @@ -83,4 +84,73 @@ def add_subvars_to_table(self, table: 'ObjectTable') -> None:
old_var.need_guard_check |= self.need_guard_check
else:
table.add_by_id(var, idx)
var.add_subvars_to_table(table)
var.add_subvars_to_table(table)


class NdarrayVar(Variable):
vars: list[Variable]
obj_ids: list[int]
length: int

def __init__(self, value: np.ndarray[Any, Any], need_guard_check: bool,
helper_functions: HelperFunctions, fx_graph: Optional[FxGraph],
extract_code_at_start: list[StorePos]) -> None:
super().__init__(need_guard_check, value, extract_code_at_start)
self.value = value
self.length = len(value)
self.vars = []
self.obj_ids = []
for i, obj in enumerate(value):
new_extract: list[StorePos] = [
StoreInIndex(pos, id(obj), i)
for pos in self.extract_code_at_start
]
var = helper_functions.get_or_make_var(obj, need_guard_check,
fx_graph, new_extract)
self.vars.append(var)
self.obj_ids.append(id(obj))

def make_guard_inner(self, codegen: "GuardFnCodegen",
pos: StorePos) -> None:
codegen.add_import("numpy")
codegen.add_check(f"isinstance({pos}, numpy.ndarray)")
codegen.add_check(f"len({pos}) == {self.length}")
for i, obj in enumerate(self.vars):
obj.make_guard_inner(codegen, StoreInIndex(pos, id(obj), i))

def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
codegen: "GraphFnCodegen", in_return: bool,
idx: int) -> None:
for j, (idx_j, var) in enumerate(zip(self.obj_ids, self.vars)):
var.make_output(f"{name_in_graph_fn}_{j}", store_pos, codegen,
False, idx_j)
list_str = f"[{','.join(f'{name_in_graph_fn}_{j}' for j in range(len(self.vars)))},]" if len(
self.vars) > 0 else "[]"
codegen.add_import("numpy")
var_str = f"numpy.array({list_str})"
codegen.output(name_in_graph_fn, store_pos, var_str, in_return, idx)

@classmethod
def from_value(cls, value: np.ndarray[Any, Any], need_guard_check: bool,
helper_functions: HelperFunctions,
fx_graph: Optional[FxGraph],
extract_code_at_start: list[StorePos]) -> "NdarrayVar":
return cls(value, need_guard_check, helper_functions, fx_graph,
extract_code_at_start)

def as_fx_node(self) -> NodeArgs:
return self.value

def add_subvars_to_table(self, table: 'ObjectTable') -> None:
for i, (var, idx) in enumerate(zip(self.vars, self.obj_ids)):
old_var = table.get_or_none_by_id(idx)
if old_var is not None:
new_extract: list[StorePos] = [
StoreInIndex(pos, idx, i)
for pos in self.extract_code_at_start
]
old_var.extract_code_at_start.extend(new_extract)
old_var.need_guard_check |= self.need_guard_check
else:
table.add_by_id(var, idx)
var.add_subvars_to_table(table)
9 changes: 5 additions & 4 deletions frontend/variables/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def __init__(self, value: ScalarType, value_fix: bool,
need_guard_check: bool, fx_node: Optional[torch.fx.Node],
extract_code_at_start: list[StorePos]) -> None:
super().__init__(need_guard_check, value, extract_code_at_start)
if isinstance(value, bool) and not value_fix:
raise NotImplementedError
# NOTE: should implement bool genererated from tensor
# if isinstance(value, bool) and not value_fix:
# raise NotImplementedError
if not value_fix:
assert fx_node is not None
self.value_fix = value_fix
Expand Down Expand Up @@ -119,8 +120,8 @@ def __init__(self, value: np.generic, value_fix: bool,

def make_guard_inner(self, codegen: "GuardFnCodegen",
pos: StorePos) -> None:
codegen.add_check(
f"isinstance({pos}.item(), {type(self.obj).__name__})")
codegen.add_import("numpy")
codegen.add_check(f"isinstance({pos}, numpy.{type(self.obj).__name__})")
if self.value_fix:
item = self.obj.item()
if type(item) == float:
Expand Down
Loading

0 comments on commit b3c000f

Please sign in to comment.