Skip to content

Commit

Permalink
better support of control flow operators (apache#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored May 11, 2024
2 parents 88074bc + a943b90 commit dddc10f
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 39 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
build
__pycache__
*.so
test/simple.py
test/simple.py
tmp
1 change: 1 addition & 0 deletions frontend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"debug": True,
"miss_threshold": 3,
"dynshape": False,
"model_name": "",
"enable_fallback": False,
}

Expand Down
6 changes: 1 addition & 5 deletions frontend/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def forward(self, *values: Any) -> Any:
loop_carry = values[self.num_read_only_param:]
while iter_num < self.num_iter:
# and cond.item():
loop_carry = self.body(torch.tensor(iter_num), *read_only,
*loop_carry)
loop_carry = self.body(iter_num, *read_only, *loop_carry)
# cond, *loop_carry = self.body(iter_num, cond, *read_only,
# *loop_carry)
iter_num += 1
Expand Down Expand Up @@ -296,6 +295,3 @@ def if_stmt(cond: bool, if_true: Callable[..., Any],
break_at_callsite()
recover()
return if_run_branch()


torch.Tensor.__iter__
67 changes: 37 additions & 30 deletions frontend/fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,41 @@
NodeArgs = Union[BaseArgumentTypes, torch.fx.Node]


def fetch_attr(gm: torch.fx.GraphModule, target: str) -> Any:
target_atoms = target.split('.')
attr_itr = gm
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr


def generate_real_tensors(
fake_tensors: list[torch.Tensor]) -> list[torch.Tensor]:
real_tensors = []
for x in fake_tensors:
if x.dtype == torch.float32:
real_tensors.append(
torch.rand(*x.shape,
dtype=x.dtype,
layout=x.layout,
device=x.device))
elif x.dtype == torch.int64:
real_tensors.append(
torch.randint(0,
2,
size=x.shape,
dtype=x.dtype,
layout=x.layout,
device=x.device))
else:
raise NotImplementedError
return real_tensors


def backend_compile(gm: torch.fx.GraphModule,
example_inputs: list[torch.Tensor]) -> Any:
backend = config.get_config('backend')
Expand All @@ -43,17 +78,6 @@ def backend_compile(gm: torch.fx.GraphModule,
return gm
elif backend == 'inductor':

def fetch_attr(gm: torch.fx.GraphModule, target: str) -> Any:
target_atoms = target.split('.')
attr_itr = gm
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr

def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool:

if node.op == 'call_module':
Expand All @@ -78,26 +102,9 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool:

module = importlib.import_module('tmp.fx_module_' + random_number)
model = module.FxModule().cuda().eval()
real_inputs = []
for x in example_inputs:
if x.dtype == torch.float32:
real_inputs.append(
torch.rand(*x.shape,
dtype=x.dtype,
layout=x.layout,
device=x.device))
elif x.dtype == torch.int64:
real_inputs.append(
torch.randint(0,
2,
size=x.shape,
dtype=x.dtype,
layout=x.layout,
device=x.device))
else:
raise NotImplementedError
real_inputs = generate_real_tensors(example_inputs)
with torch.no_grad():
script_model = torch.jit.trace(model, real_inputs)
script_model = torch.jit.script(model, real_inputs)
return script_model
else:
raise RuntimeError(f"Unknown backend: {backend}")
Expand Down
15 changes: 13 additions & 2 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_name(prefix: str, name: str) -> str:
self.subparam_paths[param] = get_name(prefix, name)

def add_submodule(self, module: torch.nn.Module) -> None:
new_module_name = "__external_module__" + str(len(self.submodule_paths))
new_module_name = "external_module__" + str(len(self.submodule_paths))
self.root.add_module(new_module_name, module)
self.update_subpath(module, new_module_name)
# self.written = True # not mark as written as graph break may happen
Expand Down Expand Up @@ -1261,6 +1261,15 @@ def commit(self) -> None:
(name, x) for x, name in self.state.fx_graph.example_inputs
])
print("graph", self.state.fx_graph.result_graph)
from .control_flow import CondModule
for node in self.state.fx_graph.result_graph.nodes:
if node.op == 'call_module' and '.' not in node.target:
mod = getattr(self.state.root, node.target)
if isinstance(mod, CondModule):
print("CondModule:", node.target)
print("true_body:", mod.true_body.graph)
print("false_body:", mod.false_body.graph)

graph_code = graph_codegen.get_code()
compiled_graph = self.state.fx_graph.compile()

Expand Down Expand Up @@ -1824,7 +1833,9 @@ def set_if_inplace_return() -> None:
inplace_ref=inplace_ref,
force_new_value=(func in (float, int, min, max) or
(hasattr(func, '__name__') and
func.__name__ == 'contiguous')))
func.__name__ == 'contiguous') or
(isinstance(func, torch.nn.Module) and
hasattr(func, 'inplace') and func.inplace)))
return
elif self.all_scalar_arg(args, kwargs) and self.all_static_arg(
args, kwargs):
Expand Down
2 changes: 1 addition & 1 deletion test/test_model_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def test_lstm_loop(caplog):
hidden_size,
device='cuda')
expect_result = model(inputs)
for_iter_pc = 193
for_iter_pc = 32
mark_dynamic_pc(get_next_frame_id(), for_iter_pc,
DynamicControlFlow(for_iter_pc, "FOR_ITER"))
compiled = compile(model)
Expand Down
25 changes: 25 additions & 0 deletions test/test_nnmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,31 @@ def test_map_module(caplog):
run_and_check(compiled, [HIT], 1, caplog, expect_result, x)


class InplaceRelu(torch.nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU(inplace=True)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x + 1.0


def test_inplace_relu(caplog):
reset()
model = InplaceRelu().eval()
compiled = compile(model)
x = torch.randn(1, 3, 3, 3)
expect_result = model(x)
run_and_check(compiled, [MISS], 1, caplog, expect_result, x)
run_and_check(compiled, [HIT], 1, caplog, expect_result, x)


if __name__ == "__main__":
caplog = logging.getLogger(__name__)
test_call_method(caplog)
Expand Down

0 comments on commit dddc10f

Please sign in to comment.