Skip to content

Commit

Permalink
use func watchers API for funcs modified and destroyed
Browse files Browse the repository at this point in the history
Summary:
Register a function watcher for the JIT, and use its callback to handle func-modified and func-destroyed events.

I'm leaving func-initialization to a separate PR; it will require more involved handling, because we'll also want to visit GC objects to find functions created before the JIT was initialized.

I eliminated the call to `PyEntry_init` after a change to function defaults. This only happened if the function was not JIT-compiled. I think this dates back to Cinder 3.8 where we had many more (non-JIT) function entrypoint variants, depending on number of arguments, argument defaults, etc, and needed to ensure we set the right one. But in 3.10 we eliminated all those custom entry points, so I don't think this is needed anymore.

I had to make a fix to the func watchers tests so they would work correctly when running with another func watcher active. I also submitted this fix upstream:  python/cpython#106286

And I had to delete a C++ test that was passing only due to a series of accidents. The func-modified callbacks (both before and after this diff) are global and dispatch only to the global singleton `jit_ctx` in `pyjit.cpp`, so they can't be tested correctly by a unit test that never globally enables the JIT and only constructs its own private JIT context. The function-modified callback in this test was doing nothing, but the entrypoint of the function was getting re-set to `_PyFunction_Vectorcall` anyway due to `PyEntry_init` seeing the JIT as not enabled; this seems unlikely to be a realistic scenario the test was intended to check.

There is already a Python-level test (`test_funcattrs.FunctionPropertiesTest.test_copying___code__`) that verifies that re-assigning `__code__` changes the behavior of the function; we run this test under the JIT, and it fails if we fail to deopt the function on `__code__` reassignment. So the behavior we care about is already tested.

Reviewed By: alexmalyshev

Differential Revision: D47156535

fbshipit-source-id: ba15f93800e23b33eb12262a201d24360df39a67
  • Loading branch information
Carl Meyer authored and facebook-github-bot committed Jul 6, 2023
1 parent 3767da7 commit 89dedf5
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 76 deletions.
96 changes: 81 additions & 15 deletions Jit/pyjit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ struct JitConfig {
int hir_inliner_enabled{0};
unsigned int auto_jit_threshold{0};
int dict_watcher_id{-1};
int func_watcher_id{-1};
};
static JitConfig jit_config;

Expand Down Expand Up @@ -1712,6 +1713,24 @@ static int install_jit_audit_hook() {
return -1;
}

static int install_jit_func_watcher() {
int watcher_id = PyFunction_AddWatcher(_PyJIT_FuncWatcher);
if (watcher_id < 0) {
return -1;
}
jit_config.func_watcher_id = watcher_id;
return 0;
}

static void clear_jit_func_watcher() {
if (jit_config.func_watcher_id >= 0) {
if (PyFunction_ClearWatcher(jit_config.func_watcher_id) < 0) {
PyErr_WriteUnraisable(Py_None);
}
jit_config.func_watcher_id = -1;
}
}

static int install_jit_dict_watcher() {
int watcher_id = PyDict_AddWatcher(_PyJIT_DictWatcher);
if (watcher_id < 0) {
Expand All @@ -1721,6 +1740,15 @@ static int install_jit_dict_watcher() {
return 0;
}

static void clear_jit_dict_watcher() {
if (jit_config.dict_watcher_id >= 0) {
if (PyDict_ClearWatcher(jit_config.dict_watcher_id) < 0) {
PyErr_WriteUnraisable(Py_None);
}
jit_config.dict_watcher_id = -1;
}
}

void _PyJIT_WatchDict(PyObject* dict) {
if (PyDict_Watch(jit_config.dict_watcher_id, dict) < 0) {
PyErr_Print();
Expand All @@ -1736,26 +1764,37 @@ void _PyJIT_UnwatchDict(PyObject* dict) {
}

int _PyJIT_InitializeSubInterp() {
// HACK: for now we assume we are the only dict watcher out there, so that we
// can just keep track of a single dict watcher ID rather than one per
// interpreter.
int prev_watcher_id = jit_config.dict_watcher_id;
// HACK: for now assume we are the only watcher out there, so that we can just
// keep track of a single watcher ID rather than one per interpreter.
int prev_dict_watcher_id = jit_config.dict_watcher_id;
JIT_CHECK(
prev_watcher_id >= 0,
prev_dict_watcher_id >= 0,
"Initializing sub-interpreter without main interpreter?");
if (install_jit_dict_watcher() < 0) {
return -1;
}
JIT_CHECK(
jit_config.dict_watcher_id == prev_watcher_id,
jit_config.dict_watcher_id == prev_dict_watcher_id,
"Somebody else watching dicts?");

// dict watcher is always enabled; func watcher only if JIT is
int prev_func_watcher_id = jit_config.func_watcher_id;
if (prev_func_watcher_id >= 0) {
if (install_jit_func_watcher() < 0) {
return -1;
}
JIT_CHECK(
jit_config.func_watcher_id == prev_func_watcher_id,
"Somebody else watching functions?");
}

return 0;
}

int _PyJIT_Initialize() {
// If we have data symbols which are public but not used within CPython code,
// we need to ensure the linker doesn't GC the .data section containing them.
// We can do this by referencing at least symbol from that sourfe module.
// We can do this by referencing at least symbol from that source module.
// In future versions of clang/gcc we may be able to eliminate this with
// 'keep' and/or 'used' attributes.
//
Expand Down Expand Up @@ -1857,7 +1896,8 @@ int _PyJIT_Initialize() {
return -1;
}

if (install_jit_audit_hook() < 0 || register_fork_callback(mod) < 0) {
if (install_jit_audit_hook() < 0 || register_fork_callback(mod) < 0 ||
install_jit_func_watcher() < 0) {
return -1;
}

Expand Down Expand Up @@ -2072,6 +2112,34 @@ void _PyJIT_InstanceTypeAssigned(PyTypeObject* old_ty, PyTypeObject* new_ty) {
}
}

int _PyJIT_FuncWatcher(
PyFunction_WatchEvent event,
PyFunctionObject* func,
PyObject* new_value) {
switch (event) {
case PyFunction_EVENT_CREATE:
// TODO move PyEntry_init setting out of funcobject.c
break;
case PyFunction_EVENT_MODIFY_CODE:
_PyJIT_FuncModified(func);
// having deopted the func, we want to immediately consider recompiling.
// func_set_code will assign this again later, but we do it early so
// PyEntry_init can consider the new code object now
Py_INCREF(new_value);
Py_XSETREF(func->func_code, new_value);
PyEntry_init(func);
break;
case PyFunction_EVENT_MODIFY_DEFAULTS:
break;
case PyFunction_EVENT_MODIFY_KWDEFAULTS:
break;
case PyFunction_EVENT_DESTROY:
_PyJIT_FuncDestroyed(func);
break;
}
return 0;
}

void _PyJIT_FuncModified(PyFunctionObject* func) {
if (jit_ctx) {
_PyJITContext_FuncModified(jit_ctx, func);
Expand Down Expand Up @@ -2175,6 +2243,10 @@ int _PyJIT_Finalize() {
CodeAllocator::freeGlobalCodeAllocator();
}

// now that we've released all references to Python funcs, it's safe to shut
// down the func watcher
clear_jit_func_watcher();

#define CLEAR_STR(s) Py_CLEAR(s_str_##s);
INTERNED_STRINGS(CLEAR_STR)
#undef CLEAR_STR
Expand All @@ -2189,13 +2261,7 @@ int _PyJIT_Finalize() {
Runtime::shutdown();

// must happen after Runtime::shutdown() so that we've cleared dict caches
if (jit_config.dict_watcher_id >= 0) {
if (PyDict_ClearWatcher(jit_config.dict_watcher_id) < 0) {
PyErr_Print();
PyErr_Clear();
}
jit_config.dict_watcher_id = -1;
}
clear_jit_dict_watcher();

return 0;
}
Expand Down
8 changes: 8 additions & 0 deletions Jit/pyjit.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ PyAPI_FUNC(int) _PyJIT_DictWatcher(
PyObject* key,
PyObject* new_value);

/*
* Func watcher callback; called on func creation/modification/deallocation.
*/
PyAPI_FUNC(int) _PyJIT_FuncWatcher(
PyFunction_WatchEvent event,
PyFunctionObject* func,
PyObject* new_value);

/*
* Informs the JIT that a type, function, or code object is being created,
* modified, or destroyed.
Expand Down
23 changes: 11 additions & 12 deletions Modules/_testcapimodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -6254,9 +6254,9 @@ get_dict_watcher_events(PyObject *self, PyObject *Py_UNUSED(args))

// Test function watchers

#define NUM_FUNC_WATCHERS 2
static PyObject *pyfunc_watchers[NUM_FUNC_WATCHERS];
static int func_watcher_ids[NUM_FUNC_WATCHERS] = {-1, -1};
#define NUM_TEST_FUNC_WATCHERS 2
static PyObject *pyfunc_watchers[NUM_TEST_FUNC_WATCHERS];
static int func_watcher_ids[NUM_TEST_FUNC_WATCHERS] = {-1, -1};

static PyObject *
get_id(PyObject *obj)
Expand Down Expand Up @@ -6330,7 +6330,7 @@ second_func_watcher_callback(PyFunction_WatchEvent event,
return call_pyfunc_watcher(pyfunc_watchers[1], event, func, new_value);
}

static PyFunction_WatchCallback func_watcher_callbacks[NUM_FUNC_WATCHERS] = {
static PyFunction_WatchCallback func_watcher_callbacks[NUM_TEST_FUNC_WATCHERS] = {
first_func_watcher_callback,
second_func_watcher_callback
};
Expand All @@ -6355,26 +6355,25 @@ add_func_watcher(PyObject *self, PyObject *func)
return NULL;
}
int idx = -1;
for (int i = 0; i < NUM_FUNC_WATCHERS; i++) {
for (int i = 0; i < NUM_TEST_FUNC_WATCHERS; i++) {
if (func_watcher_ids[i] == -1) {
idx = i;
break;
}
}
if (idx == -1) {
PyErr_SetString(PyExc_RuntimeError, "no free watchers");
return NULL;
}
PyObject *result = PyLong_FromLong(idx);
if (result == NULL) {
PyErr_SetString(PyExc_RuntimeError, "no free test watchers");
return NULL;
}
func_watcher_ids[idx] = PyFunction_AddWatcher(func_watcher_callbacks[idx]);
if (func_watcher_ids[idx] < 0) {
Py_DECREF(result);
return NULL;
}
pyfunc_watchers[idx] = Py_NewRef(func);
PyObject *result = PyLong_FromLong(func_watcher_ids[idx]);
if (result == NULL) {
return NULL;
}
return result;
}

Expand All @@ -6391,7 +6390,7 @@ clear_func_watcher(PyObject *self, PyObject *watcher_id_obj)
return NULL;
}
int idx = -1;
for (int i = 0; i < NUM_FUNC_WATCHERS; i++) {
for (int i = 0; i < NUM_TEST_FUNC_WATCHERS; i++) {
if (func_watcher_ids[i] == wid) {
idx = i;
break;
Expand Down
15 changes: 0 additions & 15 deletions Objects/funcobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,6 @@ func_set_code(PyFunctionObject *op, PyObject *value, void *Py_UNUSED(ignored))
handle_func_event(PyFunction_EVENT_MODIFY_CODE, op, value);
Py_INCREF(value);
Py_XSETREF(op->func_code, value);
#ifdef ENABLE_CINDERX
_PyJIT_FuncModified(op);
PyEntry_init(op);
#endif
return 0;
}

Expand Down Expand Up @@ -554,14 +550,6 @@ func_set_defaults(PyFunctionObject *op, PyObject *value, void *Py_UNUSED(ignored
handle_func_event(PyFunction_EVENT_MODIFY_DEFAULTS, op, value);
Py_XINCREF(value);
Py_XSETREF(op->func_defaults, value);
#ifdef ENABLE_CINDERX
// JIT-compiled functions load their defaults at runtime if needed. Others
// need their entrypoint recomputed.
// TODO(T126790232): Don't load defaults at runtime and recompile as needed.
if (!_PyJIT_IsCompiled((PyObject *)op)) {
PyEntry_init(op);
}
#endif
return 0;
}

Expand Down Expand Up @@ -783,9 +771,6 @@ func_clear(PyFunctionObject *op)
static void
func_dealloc(PyFunctionObject *op)
{
#ifdef ENABLE_CINDERX
_PyJIT_FuncDestroyed(op);
#endif
assert(Py_REFCNT(op) == 0);
Py_SET_REFCNT(op, 1);
handle_func_event(PyFunction_EVENT_DESTROY, op, NULL);
Expand Down
34 changes: 0 additions & 34 deletions RuntimeTests/jit_context_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,6 @@ class PyJITContextTest : public RuntimeTest {
_PyJITContext* jit_ctx_;
};

TEST_F(PyJITContextTest, CompiledFunctionsAreDeoptimizedWhenCodeChanges) {
const char* src = R"(
def func():
return 12345
)";
Ref<PyFunctionObject> func(compileAndGet(src, "func"));
ASSERT_NE(func.get(), nullptr) << "Failed creating func";

vectorcallfunc old_entrypoint = func->vectorcall;
_PyJIT_Result st = _PyJITContext_CompileFunction(jit_ctx_, func);
ASSERT_EQ(st, PYJIT_RESULT_OK) << "Failed compiling";

// Create a new function object so that we can grab its code object and
// assign it to the original function, at which point func should be
// de-optimized
const char* src2 = R"(
def func2():
return 2
func.__code__ = func2.__code__
)";
auto globals = Ref<>::steal(PyDict_New());
ASSERT_NE(globals.get(), nullptr) << "Failed creating globals";
ASSERT_EQ(PyDict_SetItemString(globals, "func", func), 0)
<< "Failed updating globals";

auto result =
Ref<>::steal(PyRun_String(src2, Py_file_input, globals, globals));
ASSERT_NE(result.get(), nullptr) << "Failed executing code";

// After de-optimization, the entrypoint should have been restored to the
// original value
ASSERT_EQ(func->vectorcall, old_entrypoint) << "entrypoint wasn't restored";
}

TEST_F(PyJITContextTest, UnwatchableBuiltins) {
// This is a C++ test rather than in test_cinderjit so we can guarantee a
Expand Down

0 comments on commit 89dedf5

Please sign in to comment.