Skip to content

Commit

Permalink
Support unrolling of all static control flow (apache#10)
Browse files Browse the repository at this point in the history
Dynamic control flow is working in progress.
  • Loading branch information
heheda12345 authored Oct 6, 2023
2 parents 4a4ad50 + cd34953 commit ac6409b
Show file tree
Hide file tree
Showing 33 changed files with 1,972 additions and 254 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:
spack load [email protected]%gcc@=11.3.0
source ~/venv/frontend-env/bin/activate
pip install --upgrade -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
LD_PRELOAD=~/frontend/ldlong.v3.9.12.so pytest -vs test
srun --exclusive ./scripts/pytest_with_preload.sh -vs test
26 changes: 14 additions & 12 deletions frontend/bytecode_writter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import copy
from .bytecode_analysis import stacksize_analysis
from .instruction import Instruction, convert_instruction, ci, format_insts
from .code import save_code
from .code import generate_code_map, ProcessedCode
from .cache import get_frame_cache, CachedGraph
from .store_pos import StorePos, StoreInStack, StoreInLocal

Expand Down Expand Up @@ -308,10 +308,8 @@ def add_callsite(
in_trace_insts.extend(disable_trace_insts[:-1])

start_stack_size = cached_graphs[0].start_stack_size if cached_graphs else 0
end_stack_size = cached_graphs[0].end_stack_size if cached_graphs else 0
for graph in cached_graphs:
assert graph.start_stack_size == start_stack_size
assert graph.end_stack_size == end_stack_size

prepare_stack_insts = [
ci("STORE_FAST", f"__stack__{i}") for i in range(start_stack_size)
Expand Down Expand Up @@ -407,10 +405,12 @@ def add_callsite(
*call_guard_insts,
*match_and_run_insts,
]
max_end_stack_size = max([g.start_stack_size for g in cached_graphs],
default=0)
new_names = {
"varnames": ["__graph_fn", "__case_idx"] + [
f"__stack__{i}"
for i in range(max(start_stack_size, end_stack_size))
for i in range(max(start_stack_size, max_end_stack_size))
],
"names": [
"guard_match", "enable_trace", "disable_trace", "locals",
Expand All @@ -429,15 +429,14 @@ def add_name(code_options: Dict[str, Any], varnames: List[str],


def rewrite_bytecode(code: types.CodeType, frame_id: int,
is_callee: bool) -> types.CodeType:
is_callee: bool) -> tuple[types.CodeType, ProcessedCode]:
original_instructions = get_instructions(code)
instructions = copy.deepcopy(original_instructions)
virtualize_jumps(instructions)
for original_inst, inst in zip(original_instructions, instructions):
inst.original_inst = original_inst
instructions[0].is_start = True
print(format_insts(instructions))
strip_extended_args(instructions)
frame_cache = get_frame_cache(frame_id)
# list of (start_pc, traced_instructions)
run_traced_insts: list[tuple[int, list[Instruction]]] = []
Expand Down Expand Up @@ -468,9 +467,11 @@ def rewrite_bytecode(code: types.CodeType, frame_id: int,
next_original_pc: list[tuple[Instruction, Instruction]] = []
for i, inst in enumerate(instructions):
if inst.opname == "RETURN_VALUE":
original = inst.original_inst
assert original is not None
instructions[i] = ci("JUMP_ABSOLUTE", target=final_insts[0])
instructions[i].is_end = True
next_original_pc.append((original_instructions[i], instructions[i]))
next_original_pc.append((original, instructions[i]))
in_trace_insts.append(instructions[i])
run_traced_insts.sort(key=lambda x: x[0], reverse=True)
for start_pc, traced_code in run_traced_insts:
Expand All @@ -491,17 +492,18 @@ def rewrite_bytecode(code: types.CodeType, frame_id: int,
in_trace_insts.extend(disable_trace_at_start[:-1])
instructions = disable_trace_at_start + instructions
instructions.extend(final_insts)
print("guarded code")
print(format_insts(instructions))
keys = get_code_keys()
code_options = {k: getattr(code, k) for k in keys}
add_name(code_options, list(new_names_all["varnames"]),
list(new_names_all["names"]))
strip_extended_args(instructions)
fix_instructions_for_assemble(instructions, code_options)
save_code(original_instructions, instructions, frame_id, in_trace_insts,
next_original_pc)
print("guarded code")
print(format_insts(instructions))
code_map = generate_code_map(original_instructions, instructions,
in_trace_insts, next_original_pc)
new_code = assemble_instructions(instructions, code_options)[1]
return new_code
return new_code, code_map


# test code
Expand Down
13 changes: 12 additions & 1 deletion frontend/c_api.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
from types import FrameType

if TYPE_CHECKING:
from .code import ProcessedCode


def set_eval_frame(
new_callback: Optional[Tuple[Callable[..., Any], Callable[..., Any]]]
Expand Down Expand Up @@ -62,3 +65,11 @@ def set_null_object(obj: Any) -> None:

def get_next_frame_id() -> int:
pass


def get_code_map(frame: FrameType) -> 'ProcessedCode':
pass


def is_bound_method(obj: Any, name: str) -> bool:
pass
32 changes: 13 additions & 19 deletions frontend/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,14 @@ def get_next_orig_pc(self, lasti: int) -> int:

def get_inst(self, lasti: int) -> Instruction:
pc = lasti // 2
while pc < len(self.guard_insts
) and self.guard_insts[pc].opname == "EXTENDED_ARG":
pc += 1
return self.guard_insts[pc]

def get_pc_by_inst(self, inst: Instruction) -> int:
return self.guarded_pc[inst]

def get_dependence_of_stack_var(self, original_inst: Instruction,
stack_depth: int) -> list[Instruction]:
raise NotImplementedError
Expand All @@ -148,22 +154,10 @@ def get_dependence_of_local_var(self, original_inst: Instruction,
raise NotImplementedError


processed_codes: dict[int, ProcessedCode] = {} # frame_id -> ProcessedCode


def save_code(original_insts: list[Instruction],
generated_insts: list[Instruction], frame_id: int,
inside_trace_opcodes: list[Instruction],
next_original_pc: list[tuple[Instruction, Instruction]]) -> None:
processed_codes[frame_id] = ProcessedCode(original_insts, generated_insts,
inside_trace_opcodes,
next_original_pc)


def load_code(frame_id: int) -> ProcessedCode:
return processed_codes[frame_id]


def reset() -> None:
global processed_codes
processed_codes.clear()
def generate_code_map(
original_insts: list[Instruction], generated_insts: list[Instruction],
inside_trace_opcodes: list[Instruction],
next_original_pc: list[tuple[Instruction,
Instruction]]) -> ProcessedCode:
return ProcessedCode(original_insts, generated_insts, inside_trace_opcodes,
next_original_pc)
3 changes: 0 additions & 3 deletions frontend/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
from . import tracer, utils
from .c_api import set_eval_frame, set_skip_files, guard_match, c_reset, set_null_object
from .bytecode_writter import rewrite_bytecode
from .tracer import enable_trace, disable_trace, get_trace_func, get_process_frame
from .cache import enable_cache
from .utils import null_object
Expand Down Expand Up @@ -75,8 +74,6 @@ def _fn(*args: Any, **kwargs: Any) -> Any:

def reset() -> None:
c_reset()
from . import code
code.reset()
from . import cache
cache.reset()
from . import guard_tracker
Expand Down
48 changes: 43 additions & 5 deletions frontend/csrc/frame_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "csrc.h"
#include <Python.h>
#include <frameobject.h>
#include <map>
#include <object.h>
#include <pythread.h>
#include <sstream>
Expand Down Expand Up @@ -47,6 +48,7 @@ frontend_csrc::ProgramCache program_cache;
static int frame_count = 0;

bool need_postprocess = false;
static std::map<size_t, PyObject *> frame_id_to_code_map;

static void pylog(std::string message, const char *level = "info") {
static PyObject *pModule;
Expand Down Expand Up @@ -103,8 +105,8 @@ inline static PyObject *eval_frame_default(PyThreadState *tstate,

inline static PyObject *eval_custom_code(PyThreadState *tstate,
PyFrameObject *frame,
PyCodeObject *code, int throw_flag,
bool trace_bytecode,
PyCodeObject *code, PyObject *code_map,
int throw_flag, bool trace_bytecode,
PyObject *trace_func) {
Py_ssize_t ncells = 0;
Py_ssize_t nfrees = 0;
Expand Down Expand Up @@ -143,8 +145,11 @@ inline static PyObject *eval_custom_code(PyThreadState *tstate,
fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i];
}

frame_id_to_code_map[(size_t)shadow] = code_map;
Py_INCREF(code_map);
PyObject *result = eval_frame_default(tstate, shadow, throw_flag);

frame_id_to_code_map.erase((size_t)shadow);
Py_DECREF(code_map);
Py_DECREF(shadow);
return result;
}
Expand All @@ -170,11 +175,13 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate,
PyObject_CallFunction(preprocess, "Oi", _frame, frame_id);
PyObject *new_code = PyTuple_GetItem(result_preprocess, 0);
PyObject *trace_func = PyTuple_GetItem(result_preprocess, 1);
PyObject *code_map = PyTuple_GetItem(result_preprocess, 2);
Py_INCREF(new_code);
Py_INCREF(trace_func);
need_postprocess = false;
PyObject *result = eval_custom_code(
tstate, _frame, (PyCodeObject *)new_code, false, true, trace_func);
PyObject *result =
eval_custom_code(tstate, _frame, (PyCodeObject *)new_code, code_map,
false, true, trace_func);
// _frame->
// PyObject *result = _PyEval_EvalFrameDefault(tstate, _frame, throw_flag);
/*
Expand Down Expand Up @@ -430,6 +437,35 @@ static PyObject *stack_effect_py(PyObject *self, PyObject *args) {
TO_PyBool(effect.global_effect));
}

static PyObject *get_code_map(PyObject *self, PyObject *args) {
PyFrameObject *frame = NULL;
if (!PyArg_ParseTuple(args, "O", &frame)) {
PRINT_PYERR;
PyErr_SetString(PyExc_TypeError, "invalid parameter in get_code_map");
return NULL;
}
PyObject *code_map = frame_id_to_code_map[(size_t)frame];
Py_INCREF(code_map);
return code_map;
}

static PyObject *is_bound_method(PyObject *self, PyObject *args) {
PyObject *obj;
PyObject *name;
if (!PyArg_ParseTuple(args, "OO", &obj, &name)) {
PRINT_PYERR;
PyErr_SetString(PyExc_TypeError, "invalid parameter in get_method");
return NULL;
}
PyObject *meth = NULL;
int meth_found = _PyObject_GetMethod(obj, name, &meth);
if (meth_found) {
return Py_True;
} else {
return Py_False;
}
}

static PyMethodDef _methods[] = {
{"set_eval_frame", set_eval_frame, METH_VARARGS, NULL},
{"set_skip_files", set_skip_files, METH_VARARGS, NULL},
Expand All @@ -453,6 +489,8 @@ static PyMethodDef _methods[] = {
return PyLong_FromLong(frame_count);
},
METH_VARARGS, NULL},
{"get_code_map", get_code_map, METH_VARARGS, NULL},
{"is_bound_method", is_bound_method, METH_VARARGS, NULL},
{NULL, NULL, 0, NULL}};

static struct PyModuleDef _module = {
Expand Down
28 changes: 28 additions & 0 deletions frontend/dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any


class Dynamic:
pass


class ScalarWithUnknownValue(Dynamic):
pass


dynamics = {}
dynamic_refs = {}


def mark_dynamic(obj: Any, dyn: Dynamic) -> None:
idx = id(obj)
dynamics[idx] = dyn
dynamic_refs[idx] = obj


def contains(obj: Any) -> bool:
idx = id(obj)
return idx in dynamics


def contains_by_id(idx: int) -> bool:
return idx in dynamics
8 changes: 4 additions & 4 deletions frontend/fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class FxGraph:
result_graph: torch.fx.Graph
mark_written_fn: Callable[[], None]
fake_mode: torch._subclasses.FakeTensorMode
example_inputs: list[tuple[str, torch.Tensor]]
example_inputs: list[tuple[torch.Tensor, str]]

def __init__(self, root: torch.nn.Module,
mark_written_fn: Callable[[], None]) -> None:
Expand Down Expand Up @@ -64,7 +64,7 @@ def create_input(
target: torch.fx.node.Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
name: str,
type_expr: Optional[Any] = None,
) -> torch.fx.Node:
fake_tensor = self.fake_mode.from_tensor(value, static_shapes=True)
Expand All @@ -84,8 +84,8 @@ def compile(
model = torch.fx.GraphModule(self.root, self.result_graph)
model.recompile()
with NO_LD_PRELOAD_CTX():
compiled_fn = backend_compile(model,
[x[0] for x in self.example_inputs])
compiled_fn = backend_compile(
model, [x[0].contiguous() for x in self.example_inputs])
assert callable(compiled_fn)
# TODO: add backend compiler
return compiled_fn
Expand Down
Loading

0 comments on commit ac6409b

Please sign in to comment.