From 70517e919cc32ed0d49b7e4d7d0695e5b6931ce8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 22 Aug 2023 22:01:26 +0200 Subject: [PATCH] gh-108240: Add _PyCapsule_SetTraverse() interal function The _socket extension uses _PyCapsule_SetTraverse() to visit and clear the socket type in the garbage collector. So the _socket.socket type can be cleared in some corner cases when it wasn't possible before. --- Include/pycapsule.h | 4 + Modules/socketmodule.c | 37 +++++++-- Objects/capsule.c | 166 +++++++++++++++++++++++++---------------- 3 files changed, 135 insertions(+), 72 deletions(-) diff --git a/Include/pycapsule.h b/Include/pycapsule.h index 929a9a685259fb..8660e83088edd6 100644 --- a/Include/pycapsule.h +++ b/Include/pycapsule.h @@ -48,6 +48,10 @@ PyAPI_FUNC(int) PyCapsule_SetName(PyObject *capsule, const char *name); PyAPI_FUNC(int) PyCapsule_SetContext(PyObject *capsule, void *context); +#ifdef Py_BUILD_CORE +PyAPI_FUNC(int) _PyCapsule_SetTraverse(PyObject *op, traverseproc traverse_func, inquiry clear_func); +#endif + PyAPI_FUNC(void *) PyCapsule_Import( const char *name, /* UTF-8 encoded string */ int no_block); diff --git a/Modules/socketmodule.c b/Modules/socketmodule.c index bb5edc368decb3..392a56d5ab7d3e 100644 --- a/Modules/socketmodule.c +++ b/Modules/socketmodule.c @@ -7314,20 +7314,39 @@ os_init(void) } #endif +static int +sock_capi_traverse(PyObject *capsule, visitproc visit, void *arg) +{ + PySocketModule_APIObject *capi = PyCapsule_GetPointer(capsule, PySocket_CAPSULE_NAME); + assert(capi != NULL); + Py_VISIT(capi->Sock_Type); + return 0; +} + +static int +sock_capi_clear(PyObject *capsule) +{ + PySocketModule_APIObject *capi = PyCapsule_GetPointer(capsule, PySocket_CAPSULE_NAME); + assert(capi != NULL); + Py_CLEAR(capi->Sock_Type); + return 0; +} + static void -sock_free_api(PySocketModule_APIObject *capi) +sock_capi_free(PySocketModule_APIObject *capi) { - Py_DECREF(capi->Sock_Type); + Py_XDECREF(capi->Sock_Type); // sock_capi_free() can clear it Py_DECREF(capi->error); Py_DECREF(capi->timeout_error); PyMem_Free(capi); } static void -sock_destroy_api(PyObject *capsule) +sock_capi_destroy(PyObject *capsule) { void *capi = PyCapsule_GetPointer(capsule, PySocket_CAPSULE_NAME); - sock_free_api(capi); + assert(capi != NULL); + sock_capi_free(capi); } static PySocketModule_APIObject * @@ -7432,11 +7451,17 @@ socket_exec(PyObject *m) } PyObject *capsule = PyCapsule_New(capi, PySocket_CAPSULE_NAME, - sock_destroy_api); + sock_capi_destroy); if (capsule == NULL) { - sock_free_api(capi); + sock_capi_free(capi); goto error; } + if (_PyCapsule_SetTraverse(capsule, + sock_capi_traverse, sock_capi_clear) < 0) { + sock_capi_free(capi); + goto error; + } + if (PyModule_Add(m, PySocket_CAPI_NAME, capsule) < 0) { goto error; } diff --git a/Objects/capsule.c b/Objects/capsule.c index baaddb3f1f0849..aa320013c53460 100644 --- a/Objects/capsule.c +++ b/Objects/capsule.c @@ -9,18 +9,28 @@ typedef struct { const char *name; void *context; PyCapsule_Destructor destructor; + traverseproc traverse_func; + inquiry clear_func; } PyCapsule; static int -_is_legal_capsule(PyCapsule *capsule, const char *invalid_capsule) +_is_legal_capsule(PyObject *op, const char *invalid_capsule) { - if (!capsule || !PyCapsule_CheckExact(capsule) || capsule->pointer == NULL) { - PyErr_SetString(PyExc_ValueError, invalid_capsule); - return 0; + if (!op || !PyCapsule_CheckExact(op)) { + goto error; + } + PyCapsule *capsule = (PyCapsule *)op; + + if (capsule->pointer == NULL) { + goto error; } return 1; + +error: + PyErr_SetString(PyExc_ValueError, invalid_capsule); + return 0; } #define is_legal_capsule(capsule, name) \ @@ -50,7 +60,7 @@ PyCapsule_New(void *pointer, const char *name, PyCapsule_Destructor destructor) return NULL; } - capsule = PyObject_New(PyCapsule, &PyCapsule_Type); + capsule = PyObject_GC_New(PyCapsule, &PyCapsule_Type); if (capsule == NULL) { return NULL; } @@ -59,15 +69,18 @@ PyCapsule_New(void *pointer, const char *name, PyCapsule_Destructor destructor) capsule->name = name; capsule->context = NULL; capsule->destructor = destructor; + capsule->traverse_func = NULL; + capsule->clear_func = NULL; + // Only track the capsule if _PyCapsule_SetTraverse() is called return (PyObject *)capsule; } int -PyCapsule_IsValid(PyObject *o, const char *name) +PyCapsule_IsValid(PyObject *op, const char *name) { - PyCapsule *capsule = (PyCapsule *)o; + PyCapsule *capsule = (PyCapsule *)op; return (capsule != NULL && PyCapsule_CheckExact(capsule) && @@ -77,13 +90,12 @@ PyCapsule_IsValid(PyObject *o, const char *name) void * -PyCapsule_GetPointer(PyObject *o, const char *name) +PyCapsule_GetPointer(PyObject *op, const char *name) { - PyCapsule *capsule = (PyCapsule *)o; - - if (!is_legal_capsule(capsule, "PyCapsule_GetPointer")) { + if (!is_legal_capsule(op, "PyCapsule_GetPointer")) { return NULL; } + PyCapsule *capsule = (PyCapsule *)op; if (!name_matches(name, capsule->name)) { PyErr_SetString(PyExc_ValueError, "PyCapsule_GetPointer called with incorrect name"); @@ -95,52 +107,48 @@ PyCapsule_GetPointer(PyObject *o, const char *name) const char * -PyCapsule_GetName(PyObject *o) +PyCapsule_GetName(PyObject *op) { - PyCapsule *capsule = (PyCapsule *)o; - - if (!is_legal_capsule(capsule, "PyCapsule_GetName")) { + if (!is_legal_capsule(op, "PyCapsule_GetName")) { return NULL; } + PyCapsule *capsule = (PyCapsule *)op; return capsule->name; } PyCapsule_Destructor -PyCapsule_GetDestructor(PyObject *o) +PyCapsule_GetDestructor(PyObject *op) { - PyCapsule *capsule = (PyCapsule *)o; - - if (!is_legal_capsule(capsule, "PyCapsule_GetDestructor")) { + if (!is_legal_capsule(op, "PyCapsule_GetDestructor")) { return NULL; } + PyCapsule *capsule = (PyCapsule *)op; return capsule->destructor; } void * -PyCapsule_GetContext(PyObject *o) +PyCapsule_GetContext(PyObject *op) { - PyCapsule *capsule = (PyCapsule *)o; - - if (!is_legal_capsule(capsule, "PyCapsule_GetContext")) { + if (!is_legal_capsule(op, "PyCapsule_GetContext")) { return NULL; } + PyCapsule *capsule = (PyCapsule *)op; return capsule->context; } int -PyCapsule_SetPointer(PyObject *o, void *pointer) +PyCapsule_SetPointer(PyObject *op, void *pointer) { - PyCapsule *capsule = (PyCapsule *)o; - - if (!pointer) { - PyErr_SetString(PyExc_ValueError, "PyCapsule_SetPointer called with null pointer"); + if (!is_legal_capsule(op, "PyCapsule_SetPointer")) { return -1; } + PyCapsule *capsule = (PyCapsule *)op; - if (!is_legal_capsule(capsule, "PyCapsule_SetPointer")) { + if (!pointer) { + PyErr_SetString(PyExc_ValueError, "PyCapsule_SetPointer called with null pointer"); return -1; } @@ -150,13 +158,12 @@ PyCapsule_SetPointer(PyObject *o, void *pointer) int -PyCapsule_SetName(PyObject *o, const char *name) +PyCapsule_SetName(PyObject *op, const char *name) { - PyCapsule *capsule = (PyCapsule *)o; - - if (!is_legal_capsule(capsule, "PyCapsule_SetName")) { + if (!is_legal_capsule(op, "PyCapsule_SetName")) { return -1; } + PyCapsule *capsule = (PyCapsule *)op; capsule->name = name; return 0; @@ -164,13 +171,12 @@ PyCapsule_SetName(PyObject *o, const char *name) int -PyCapsule_SetDestructor(PyObject *o, PyCapsule_Destructor destructor) +PyCapsule_SetDestructor(PyObject *op, PyCapsule_Destructor destructor) { - PyCapsule *capsule = (PyCapsule *)o; - - if (!is_legal_capsule(capsule, "PyCapsule_SetDestructor")) { + if (!is_legal_capsule(op, "PyCapsule_SetDestructor")) { return -1; } + PyCapsule *capsule = (PyCapsule *)op; capsule->destructor = destructor; return 0; @@ -178,19 +184,36 @@ PyCapsule_SetDestructor(PyObject *o, PyCapsule_Destructor destructor) int -PyCapsule_SetContext(PyObject *o, void *context) +PyCapsule_SetContext(PyObject *op, void *context) { - PyCapsule *capsule = (PyCapsule *)o; - - if (!is_legal_capsule(capsule, "PyCapsule_SetContext")) { + if (!is_legal_capsule(op, "PyCapsule_SetContext")) { return -1; } + PyCapsule *capsule = (PyCapsule *)op; capsule->context = context; return 0; } +int +_PyCapsule_SetTraverse(PyObject *op, traverseproc traverse_func, inquiry clear_func) +{ + if (!is_legal_capsule(op, "_PyCapsule_SetTraverse")) { + return -1; + } + PyCapsule *capsule = (PyCapsule *)op; + + if (!PyObject_GC_IsTracked(op)) { + PyObject_GC_Track(op); + } + + capsule->traverse_func = traverse_func; + capsule->clear_func = clear_func; + return 0; +} + + void * PyCapsule_Import(const char *name, int no_block) { @@ -249,13 +272,14 @@ PyCapsule_Import(const char *name, int no_block) static void -capsule_dealloc(PyObject *o) +capsule_dealloc(PyObject *op) { - PyCapsule *capsule = (PyCapsule *)o; + PyCapsule *capsule = (PyCapsule *)op; + PyObject_GC_UnTrack(op); if (capsule->destructor) { - capsule->destructor(o); + capsule->destructor(op); } - PyObject_Free(o); + PyObject_GC_Del(op); } @@ -279,6 +303,29 @@ capsule_repr(PyObject *o) } +static int +capsule_traverse(PyCapsule *capsule, visitproc visit, void *arg) +{ + if (capsule->traverse_func) { + return capsule->traverse_func((PyObject*)capsule, visit, arg); + } + else { + return 0; + } +} + + +static int +capsule_clear(PyCapsule *capsule) +{ + if (capsule->clear_func) { + return capsule->clear_func((PyObject*)capsule); + } + else { + return 0; + } +} + PyDoc_STRVAR(PyCapsule_Type__doc__, "Capsule objects let you wrap a C \"void *\" pointer in a Python\n\ @@ -293,27 +340,14 @@ Python import mechanism to link to one another.\n\ PyTypeObject PyCapsule_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - "PyCapsule", /*tp_name*/ - sizeof(PyCapsule), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - /* methods */ - capsule_dealloc, /*tp_dealloc*/ - 0, /*tp_vectorcall_offset*/ - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - 0, /*tp_as_async*/ - capsule_repr, /*tp_repr*/ - 0, /*tp_as_number*/ - 0, /*tp_as_sequence*/ - 0, /*tp_as_mapping*/ - 0, /*tp_hash*/ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - 0, /*tp_flags*/ - PyCapsule_Type__doc__ /*tp_doc*/ + .tp_name = "PyCapsule", + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_basicsize = sizeof(PyCapsule), + .tp_dealloc = capsule_dealloc, + .tp_repr = capsule_repr, + .tp_doc = PyCapsule_Type__doc__, + .tp_traverse = (traverseproc)capsule_traverse, + .tp_clear = (inquiry)capsule_clear, };