From bf26a9c4a87d603109a2e5a481df741aa35f99c3 Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Wed, 25 Oct 2023 22:43:32 +0800 Subject: [PATCH 1/7] bytecode_generate_time_change --- frontend/bytecode_writter.py | 1 + frontend/cache.py | 15 +++++++++++++++ frontend/csrc/frame_evaluation.cpp | 4 ++-- frontend/tracer.py | 30 +++++++++++++++++++++++------- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/frontend/bytecode_writter.py b/frontend/bytecode_writter.py index e300a5e0b7ee..0c3dc21ee30a 100644 --- a/frontend/bytecode_writter.py +++ b/frontend/bytecode_writter.py @@ -499,6 +499,7 @@ def rewrite_bytecode(code: types.CodeType, frame_id: int, strip_extended_args(instructions) fix_instructions_for_assemble(instructions, code_options) print("guarded code") + frame_cache.pre_instructions = instructions print(format_insts(instructions)) code_map = generate_code_map(original_instructions, instructions, in_trace_insts, next_original_pc) diff --git a/frontend/cache.py b/frontend/cache.py index 6ce3c4a10419..00ce6e4fce93 100644 --- a/frontend/cache.py +++ b/frontend/cache.py @@ -1,5 +1,8 @@ +from types import CodeType from typing import Callable, Any from dataclasses import dataclass + +from frontend.code import ProcessedCode from .instruction import Instruction from .c_api import add_to_cache from .store_pos import StorePos @@ -26,11 +29,17 @@ class FrameCache: cached_graphs: dict[int, list[CachedGraph]] # start_pc -> list of cached graph callsite_id: dict[int, int] # start_pc -> callsite_id + pre_cache_size: int + pre_instructions: list[Instruction] + new_code: CodeType + code_map: ProcessedCode def __init__(self, frame_id: int) -> None: self.frame_id = frame_id self.cached_graphs = {0: []} self.callsite_id = {0: 0} + self.pre_cache_size = -1 + self.pre_instructions = [] def add(self, traced_code: CachedGraph) -> None: start_pc = traced_code.start_pc @@ -46,6 +55,7 @@ def add(self, traced_code: CachedGraph) -> None: traced_code.guard_fn, traced_code.graph_fn) global TOTAL_SIZE TOTAL_SIZE += 1 + self.pre_cache_size = TOTAL_SIZE frame_caches: dict[int, FrameCache] = {} @@ -59,6 +69,11 @@ def enable_cache(frame_id: int) -> None: if frame_id not in frame_caches: frame_caches[frame_id] = FrameCache(frame_id) +def check_cache_updated(frame_id: int) -> bool: + if frame_caches[frame_id].pre_cache_size != len(frame_caches): + frame_caches[frame_id].pre_cache_size = len(frame_caches) + return True + return False def reset() -> None: global TOTAL_SIZE diff --git a/frontend/csrc/frame_evaluation.cpp b/frontend/csrc/frame_evaluation.cpp index b6d51e8712d6..cb59500b2c84 100644 --- a/frontend/csrc/frame_evaluation.cpp +++ b/frontend/csrc/frame_evaluation.cpp @@ -178,7 +178,6 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, PyObject *code_map = PyTuple_GetItem(result_preprocess, 2); Py_INCREF(new_code); Py_INCREF(trace_func); - need_postprocess = false; PyObject *result = eval_custom_code(tstate, _frame, (PyCodeObject *)new_code, code_map, false, true, trace_func); @@ -189,7 +188,8 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, */ if (need_postprocess) { PyObject *result_postprocess = - PyObject_CallFunction(postprocess, "O", (PyObject *)_frame); + PyObject_CallFunction(postprocess, "Oi", (PyObject *)_frame, frame_id); + need_postprocess = false; } Py_DECREF(_frame); Py_DECREF(preprocess); diff --git a/frontend/tracer.py b/frontend/tracer.py index 197f9cadc3f6..be5c07e9de84 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -5,11 +5,12 @@ from typing import Any, Callable, Tuple import inspect from .guard_tracker import push_tracker, pop_tracker, record -from .cache import enable_cache +from .cache import enable_cache, check_cache_updated, get_frame_cache from .fx_graph import set_frame_root -from .c_api import set_eval_frame +from .c_api import set_eval_frame, mark_need_postprocess from .bytecode_writter import rewrite_bytecode from .code import ProcessedCode +from .instruction import format_insts def get_trace_func(frame_id: int) -> Callable[[FrameType, str, Any], None]: @@ -75,9 +76,21 @@ def preprocess_frame( print(f"preprocess frame {frame.f_code.co_filename}", frame_id, hex(id(frame)), frame.f_code.co_name) enable_cache(frame_id) - set_frame_root(frame_id, f) - new_code, code_map = rewrite_bytecode(frame.f_code, frame_id, - is_callee) + + print("old bytecode: \n") + print(format_insts(get_frame_cache(frame_id).pre_instructions)) + if len(get_frame_cache(frame_id).pre_instructions) != 0: + new_code = get_frame_cache(frame_id).new_code + code_map = get_frame_cache(frame_id).code_map + + if check_cache_updated(frame_id): + mark_need_postprocess() + print("new bytecode: \n") + set_frame_root(frame_id, f) + new_code, code_map = rewrite_bytecode(frame.f_code, frame_id, + is_callee) + get_frame_cache(frame_id).new_code = new_code + get_frame_cache(frame_id).code_map = code_map trace_func = get_trace_func(frame_id) except Exception as e: print("exception in preprocess:", e, type(e)) @@ -85,13 +98,16 @@ def preprocess_frame( raise e return (new_code, trace_func, code_map) - def postprocess_frame(frame: FrameType) -> None: + def postprocess_frame(frame: FrameType, frame_id: int) -> None: try: print(f"postprocess frame {frame.f_code.co_filename}") + set_frame_root(frame_id, f) + new_code = get_frame_cache(frame_id).new_code + code_map = get_frame_cache(frame_id).code_map except Exception as e: print("exception in postprocess:", e, type(e)) print(traceback.format_exc()) raise e - return None + return (new_code, code_map) return (preprocess_frame, postprocess_frame) \ No newline at end of file From a7d698ff5b637f1e4d416118db83d792a2a8d68f Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Thu, 26 Oct 2023 19:19:39 +0800 Subject: [PATCH 2/7] cache_size_error --- frontend/bytecode_writter.py | 1 - frontend/cache.py | 2 -- frontend/tracer.py | 30 +++++++++++++++++------------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/frontend/bytecode_writter.py b/frontend/bytecode_writter.py index 0c3dc21ee30a..e300a5e0b7ee 100644 --- a/frontend/bytecode_writter.py +++ b/frontend/bytecode_writter.py @@ -499,7 +499,6 @@ def rewrite_bytecode(code: types.CodeType, frame_id: int, strip_extended_args(instructions) fix_instructions_for_assemble(instructions, code_options) print("guarded code") - frame_cache.pre_instructions = instructions print(format_insts(instructions)) code_map = generate_code_map(original_instructions, instructions, in_trace_insts, next_original_pc) diff --git a/frontend/cache.py b/frontend/cache.py index 00ce6e4fce93..2f58bad2b520 100644 --- a/frontend/cache.py +++ b/frontend/cache.py @@ -30,7 +30,6 @@ class FrameCache: list[CachedGraph]] # start_pc -> list of cached graph callsite_id: dict[int, int] # start_pc -> callsite_id pre_cache_size: int - pre_instructions: list[Instruction] new_code: CodeType code_map: ProcessedCode @@ -39,7 +38,6 @@ def __init__(self, frame_id: int) -> None: self.cached_graphs = {0: []} self.callsite_id = {0: 0} self.pre_cache_size = -1 - self.pre_instructions = [] def add(self, traced_code: CachedGraph) -> None: start_pc = traced_code.start_pc diff --git a/frontend/tracer.py b/frontend/tracer.py index be5c07e9de84..56b4ac11e92f 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -76,22 +76,22 @@ def preprocess_frame( print(f"preprocess frame {frame.f_code.co_filename}", frame_id, hex(id(frame)), frame.f_code.co_name) enable_cache(frame_id) - - print("old bytecode: \n") - print(format_insts(get_frame_cache(frame_id).pre_instructions)) - if len(get_frame_cache(frame_id).pre_instructions) != 0: - new_code = get_frame_cache(frame_id).new_code - code_map = get_frame_cache(frame_id).code_map - if check_cache_updated(frame_id): - mark_need_postprocess() + if get_frame_cache(frame_id).pre_cache_size == -1: print("new bytecode: \n") set_frame_root(frame_id, f) new_code, code_map = rewrite_bytecode(frame.f_code, frame_id, is_callee) get_frame_cache(frame_id).new_code = new_code get_frame_cache(frame_id).code_map = code_map - trace_func = get_trace_func(frame_id) + trace_func = get_trace_func(frame_id) + else: + print("old bytecode: \n") + print(format_insts(get_frame_cache(frame_id).code_map.guard_insts)) + new_code = get_frame_cache(frame_id).new_code + code_map = get_frame_cache(frame_id).code_map + trace_func = get_trace_func(frame_id) + mark_need_postprocess() except Exception as e: print("exception in preprocess:", e, type(e)) print(traceback.format_exc()) @@ -101,13 +101,17 @@ def preprocess_frame( def postprocess_frame(frame: FrameType, frame_id: int) -> None: try: print(f"postprocess frame {frame.f_code.co_filename}") - set_frame_root(frame_id, f) - new_code = get_frame_cache(frame_id).new_code - code_map = get_frame_cache(frame_id).code_map + if check_cache_updated(frame_id): + print("new bytecode: \n") + set_frame_root(frame_id, f) + new_code, code_map = rewrite_bytecode(frame.f_code, frame_id, + is_callee) + trace_func = get_trace_func(frame_id) + except Exception as e: print("exception in postprocess:", e, type(e)) print(traceback.format_exc()) raise e - return (new_code, code_map) + return (new_code, trace_func, code_map) return (preprocess_frame, postprocess_frame) \ No newline at end of file From 465e09545ee26be1c22354ea2455f7a0ba465fbb Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Sat, 28 Oct 2023 09:30:58 +0800 Subject: [PATCH 3/7] double_gragh_generate --- frontend/tracer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/frontend/tracer.py b/frontend/tracer.py index 56b4ac11e92f..ab3b17618623 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -85,13 +85,15 @@ def preprocess_frame( get_frame_cache(frame_id).new_code = new_code get_frame_cache(frame_id).code_map = code_map trace_func = get_trace_func(frame_id) + else: print("old bytecode: \n") print(format_insts(get_frame_cache(frame_id).code_map.guard_insts)) new_code = get_frame_cache(frame_id).new_code code_map = get_frame_cache(frame_id).code_map trace_func = get_trace_func(frame_id) - mark_need_postprocess() + mark_need_postprocess() + except Exception as e: print("exception in preprocess:", e, type(e)) print(traceback.format_exc()) @@ -106,12 +108,13 @@ def postprocess_frame(frame: FrameType, frame_id: int) -> None: set_frame_root(frame_id, f) new_code, code_map = rewrite_bytecode(frame.f_code, frame_id, is_callee) - trace_func = get_trace_func(frame_id) + get_frame_cache(frame_id).new_code = new_code + get_frame_cache(frame_id).code_map = code_map except Exception as e: print("exception in postprocess:", e, type(e)) print(traceback.format_exc()) raise e - return (new_code, trace_func, code_map) + return return (preprocess_frame, postprocess_frame) \ No newline at end of file From cb54c8a5cd44ceaf9822f3e92d5f38c1e0a08bed Mon Sep 17 00:00:00 2001 From: heheda Date: Sat, 28 Oct 2023 16:21:37 +0800 Subject: [PATCH 4/7] fix not generate new code bug --- frontend/cache.py | 25 ++++++++++++++++--------- frontend/tracer.py | 22 +++++++++++----------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/frontend/cache.py b/frontend/cache.py index 2f58bad2b520..7c27f0f32e54 100644 --- a/frontend/cache.py +++ b/frontend/cache.py @@ -1,5 +1,5 @@ from types import CodeType -from typing import Callable, Any +from typing import Callable, Any, Optional from dataclasses import dataclass from frontend.code import ProcessedCode @@ -30,14 +30,16 @@ class FrameCache: list[CachedGraph]] # start_pc -> list of cached graph callsite_id: dict[int, int] # start_pc -> callsite_id pre_cache_size: int - new_code: CodeType - code_map: ProcessedCode + new_code: Optional[CodeType] + code_map: Optional[ProcessedCode] + updated: bool def __init__(self, frame_id: int) -> None: self.frame_id = frame_id self.cached_graphs = {0: []} self.callsite_id = {0: 0} - self.pre_cache_size = -1 + self.new_code = None + self.code_map = None def add(self, traced_code: CachedGraph) -> None: start_pc = traced_code.start_pc @@ -53,7 +55,12 @@ def add(self, traced_code: CachedGraph) -> None: traced_code.guard_fn, traced_code.graph_fn) global TOTAL_SIZE TOTAL_SIZE += 1 - self.pre_cache_size = TOTAL_SIZE + self.updated = True + + def set_new_code(self, new_code: CodeType, code_map: ProcessedCode) -> None: + self.new_code = new_code + self.code_map = code_map + self.updated = False frame_caches: dict[int, FrameCache] = {} @@ -67,11 +74,11 @@ def enable_cache(frame_id: int) -> None: if frame_id not in frame_caches: frame_caches[frame_id] = FrameCache(frame_id) + def check_cache_updated(frame_id: int) -> bool: - if frame_caches[frame_id].pre_cache_size != len(frame_caches): - frame_caches[frame_id].pre_cache_size = len(frame_caches) - return True - return False + assert frame_id in frame_caches + return frame_caches[frame_id].updated + def reset() -> None: global TOTAL_SIZE diff --git a/frontend/tracer.py b/frontend/tracer.py index ab3b17618623..e689e8a5c5cc 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -77,23 +77,24 @@ def preprocess_frame( hex(id(frame)), frame.f_code.co_name) enable_cache(frame_id) - if get_frame_cache(frame_id).pre_cache_size == -1: + if get_frame_cache(frame_id).new_code is None: print("new bytecode: \n") set_frame_root(frame_id, f) new_code, code_map = rewrite_bytecode(frame.f_code, frame_id, - is_callee) - get_frame_cache(frame_id).new_code = new_code - get_frame_cache(frame_id).code_map = code_map + is_callee) + get_frame_cache(frame_id).set_new_code(new_code, code_map) trace_func = get_trace_func(frame_id) - + else: print("old bytecode: \n") - print(format_insts(get_frame_cache(frame_id).code_map.guard_insts)) + print( + format_insts( + get_frame_cache(frame_id).code_map.guard_insts)) new_code = get_frame_cache(frame_id).new_code code_map = get_frame_cache(frame_id).code_map trace_func = get_trace_func(frame_id) mark_need_postprocess() - + except Exception as e: print("exception in preprocess:", e, type(e)) print(traceback.format_exc()) @@ -107,14 +108,13 @@ def postprocess_frame(frame: FrameType, frame_id: int) -> None: print("new bytecode: \n") set_frame_root(frame_id, f) new_code, code_map = rewrite_bytecode(frame.f_code, frame_id, - is_callee) - get_frame_cache(frame_id).new_code = new_code - get_frame_cache(frame_id).code_map = code_map + is_callee) + get_frame_cache(frame_id).set_new_code(new_code, code_map) except Exception as e: print("exception in postprocess:", e, type(e)) print(traceback.format_exc()) raise e - return + return return (preprocess_frame, postprocess_frame) \ No newline at end of file From 1171f8150148421820f45adfd53df9958329ddeb Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Tue, 31 Oct 2023 17:22:27 +0800 Subject: [PATCH 5/7] pass --- frontend/cache.py | 1 + frontend/csrc/frame_evaluation.cpp | 2 +- frontend/tracer.py | 3 +-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/cache.py b/frontend/cache.py index 7c27f0f32e54..fd3464297b07 100644 --- a/frontend/cache.py +++ b/frontend/cache.py @@ -40,6 +40,7 @@ def __init__(self, frame_id: int) -> None: self.callsite_id = {0: 0} self.new_code = None self.code_map = None + self.updated = False def add(self, traced_code: CachedGraph) -> None: start_pc = traced_code.start_pc diff --git a/frontend/csrc/frame_evaluation.cpp b/frontend/csrc/frame_evaluation.cpp index cb59500b2c84..59c5f1954b65 100644 --- a/frontend/csrc/frame_evaluation.cpp +++ b/frontend/csrc/frame_evaluation.cpp @@ -189,7 +189,7 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, if (need_postprocess) { PyObject *result_postprocess = PyObject_CallFunction(postprocess, "Oi", (PyObject *)_frame, frame_id); - need_postprocess = false; + //need_postprocess = false; } Py_DECREF(_frame); Py_DECREF(preprocess); diff --git a/frontend/tracer.py b/frontend/tracer.py index e689e8a5c5cc..c6491ceabf4c 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -76,8 +76,7 @@ def preprocess_frame( print(f"preprocess frame {frame.f_code.co_filename}", frame_id, hex(id(frame)), frame.f_code.co_name) enable_cache(frame_id) - - if get_frame_cache(frame_id).new_code is None: + if not get_frame_cache(frame_id).updated: print("new bytecode: \n") set_frame_root(frame_id, f) new_code, code_map = rewrite_bytecode(frame.f_code, frame_id, From f3c3929ff52130ec79b01c73a9a14343461090d9 Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Tue, 31 Oct 2023 17:40:41 +0800 Subject: [PATCH 6/7] add assert --- frontend/tracer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/frontend/tracer.py b/frontend/tracer.py index c6491ceabf4c..8d4e526431bb 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -85,6 +85,8 @@ def preprocess_frame( trace_func = get_trace_func(frame_id) else: + assert get_frame_cache(frame_id).new_code is not None + assert get_frame_cache(frame_id).code_map is not None print("old bytecode: \n") print( format_insts( From 4cb61deeea5a6bc8dcf06154ed071eebedd1928c Mon Sep 17 00:00:00 2001 From: chenjike <13552845299@163.com> Date: Tue, 31 Oct 2023 22:52:01 +0800 Subject: [PATCH 7/7] code_check_pass --- frontend/csrc/frame_evaluation.cpp | 6 +++--- frontend/tracer.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/frontend/csrc/frame_evaluation.cpp b/frontend/csrc/frame_evaluation.cpp index 59c5f1954b65..6f3aef73b2b6 100644 --- a/frontend/csrc/frame_evaluation.cpp +++ b/frontend/csrc/frame_evaluation.cpp @@ -187,9 +187,9 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, _frame->f_trace = NULL; */ if (need_postprocess) { - PyObject *result_postprocess = - PyObject_CallFunction(postprocess, "Oi", (PyObject *)_frame, frame_id); - //need_postprocess = false; + PyObject *result_postprocess = PyObject_CallFunction( + postprocess, "Oi", (PyObject *)_frame, frame_id); + // need_postprocess = false; } Py_DECREF(_frame); Py_DECREF(preprocess); diff --git a/frontend/tracer.py b/frontend/tracer.py index 8d4e526431bb..61e1fd3266ed 100644 --- a/frontend/tracer.py +++ b/frontend/tracer.py @@ -85,14 +85,15 @@ def preprocess_frame( trace_func = get_trace_func(frame_id) else: - assert get_frame_cache(frame_id).new_code is not None - assert get_frame_cache(frame_id).code_map is not None print("old bytecode: \n") - print( - format_insts( - get_frame_cache(frame_id).code_map.guard_insts)) - new_code = get_frame_cache(frame_id).new_code - code_map = get_frame_cache(frame_id).code_map + old_frame = get_frame_cache(frame_id) + assert old_frame.code_map is not None, "Code map doesn't exist for frame id {}".format( + frame_id) + assert old_frame.new_code is not None, "New code doesn't exist for frame id {}".format( + frame_id) + print(format_insts(old_frame.code_map.guard_insts)) + new_code = old_frame.new_code + code_map = old_frame.code_map trace_func = get_trace_func(frame_id) mark_need_postprocess()