Skip to content

Commit

Permalink
Bytecode generate time (apache#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored Nov 1, 2023
2 parents 6de5086 + 0a4f149 commit 0771dfe
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 12 deletions.
23 changes: 22 additions & 1 deletion frontend/cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Callable, Any
from types import CodeType
from typing import Callable, Any, Optional
from dataclasses import dataclass

from frontend.code import ProcessedCode
from .instruction import Instruction
from .c_api import add_to_cache
from .store_pos import StorePos
Expand All @@ -26,11 +29,18 @@ class FrameCache:
cached_graphs: dict[int,
list[CachedGraph]] # start_pc -> list of cached graph
callsite_id: dict[int, int] # start_pc -> callsite_id
pre_cache_size: int
new_code: Optional[CodeType]
code_map: Optional[ProcessedCode]
updated: bool

def __init__(self, frame_id: int) -> None:
self.frame_id = frame_id
self.cached_graphs = {0: []}
self.callsite_id = {0: 0}
self.new_code = None
self.code_map = None
self.updated = False

def add(self, traced_code: CachedGraph) -> None:
start_pc = traced_code.start_pc
Expand All @@ -46,6 +56,12 @@ def add(self, traced_code: CachedGraph) -> None:
traced_code.guard_fn, traced_code.graph_fn)
global TOTAL_SIZE
TOTAL_SIZE += 1
self.updated = True

def set_new_code(self, new_code: CodeType, code_map: ProcessedCode) -> None:
self.new_code = new_code
self.code_map = code_map
self.updated = False


frame_caches: dict[int, FrameCache] = {}
Expand All @@ -60,6 +76,11 @@ def enable_cache(frame_id: int) -> None:
frame_caches[frame_id] = FrameCache(frame_id)


def check_cache_updated(frame_id: int) -> bool:
assert frame_id in frame_caches
return frame_caches[frame_id].updated


def reset() -> None:
global TOTAL_SIZE
TOTAL_SIZE = 0
Expand Down
6 changes: 3 additions & 3 deletions frontend/csrc/frame_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate,
PyObject *code_map = PyTuple_GetItem(result_preprocess, 2);
Py_INCREF(new_code);
Py_INCREF(trace_func);
need_postprocess = false;
PyObject *result =
eval_custom_code(tstate, _frame, (PyCodeObject *)new_code, code_map,
false, true, trace_func);
Expand All @@ -188,8 +187,9 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate,
_frame->f_trace = NULL;
*/
if (need_postprocess) {
PyObject *result_postprocess =
PyObject_CallFunction(postprocess, "O", (PyObject *)_frame);
PyObject *result_postprocess = PyObject_CallFunction(
postprocess, "Oi", (PyObject *)_frame, frame_id);
// need_postprocess = false;
}
Py_DECREF(_frame);
Py_DECREF(preprocess);
Expand Down
41 changes: 33 additions & 8 deletions frontend/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from typing import Any, Callable, Tuple
import inspect
from .guard_tracker import push_tracker, pop_tracker, record
from .cache import enable_cache
from .cache import enable_cache, check_cache_updated, get_frame_cache
from .fx_graph import set_frame_root
from .c_api import set_eval_frame
from .c_api import set_eval_frame, mark_need_postprocess
from .bytecode_writter import rewrite_bytecode
from .code import ProcessedCode
from .instruction import format_insts


def get_trace_func(frame_id: int) -> Callable[[FrameType, str, Any], None]:
Expand Down Expand Up @@ -79,23 +80,47 @@ def preprocess_frame(
print(f"preprocess frame {frame.f_code.co_filename}", frame_id,
hex(id(frame)), frame.f_code.co_name)
enable_cache(frame_id)
set_frame_root(frame_id, f)
new_code, code_map = rewrite_bytecode(frame.f_code, frame_id,
is_callee)
trace_func = get_trace_func(frame_id)
if not get_frame_cache(frame_id).updated:
print("new bytecode: \n")
set_frame_root(frame_id, f)
new_code, code_map = rewrite_bytecode(frame.f_code, frame_id,
is_callee)
get_frame_cache(frame_id).set_new_code(new_code, code_map)
trace_func = get_trace_func(frame_id)

else:
print("old bytecode: \n")
old_frame = get_frame_cache(frame_id)
assert old_frame.code_map is not None, "Code map doesn't exist for frame id {}".format(
frame_id)
assert old_frame.new_code is not None, "New code doesn't exist for frame id {}".format(
frame_id)
print(format_insts(old_frame.code_map.guard_insts))
new_code = old_frame.new_code
code_map = old_frame.code_map
trace_func = get_trace_func(frame_id)
mark_need_postprocess()

except Exception as e:
print("exception in preprocess:", e, type(e))
print(traceback.format_exc())
raise e
return (new_code, trace_func, code_map)

def postprocess_frame(frame: FrameType) -> None:
def postprocess_frame(frame: FrameType, frame_id: int) -> None:
try:
print(f"postprocess frame {frame.f_code.co_filename}")
if check_cache_updated(frame_id):
print("new bytecode: \n")
set_frame_root(frame_id, f)
new_code, code_map = rewrite_bytecode(frame.f_code, frame_id,
is_callee)
get_frame_cache(frame_id).set_new_code(new_code, code_map)

except Exception as e:
print("exception in postprocess:", e, type(e))
print(traceback.format_exc())
raise e
return None
return

return (preprocess_frame, postprocess_frame)

0 comments on commit 0771dfe

Please sign in to comment.