Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix garbage collection in inheritance cases #4563

Merged
merged 12 commits into from
Oct 11, 2024
1 change: 1 addition & 0 deletions newsfragments/4563.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `__traverse__` functions for base classes not being called by subclasses created with `#[pyclass(extends = ...)]`.
54 changes: 52 additions & 2 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ impl PyMethodKind {
"__ior__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IOR__)),
"__getbuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETBUFFER__)),
"__releasebuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RELEASEBUFFER__)),
"__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CLEAR__)),
// Protocols implemented through traits
"__getattribute__" => {
PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GETATTRIBUTE__))
Expand Down Expand Up @@ -146,6 +145,7 @@ impl PyMethodKind {
// Some tricky protocols which don't fit the pattern of the rest
"__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call),
"__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse),
"__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Clear),
// Not a proto
_ => PyMethodKind::Fn,
}
Expand All @@ -156,6 +156,7 @@ enum PyMethodProtoKind {
Slot(&'static SlotDef),
Call,
Traverse,
Clear,
SlotFragment(&'static SlotFragmentDef),
}

Expand Down Expand Up @@ -217,6 +218,9 @@ pub fn gen_py_method(
PyMethodProtoKind::Traverse => {
GeneratedPyMethod::Proto(impl_traverse_slot(cls, spec, ctx)?)
}
PyMethodProtoKind::Clear => {
GeneratedPyMethod::Proto(impl_clear_slot(cls, spec, ctx)?)
}
PyMethodProtoKind::SlotFragment(slot_fragment_def) => {
let proto = slot_fragment_def.generate_pyproto_fragment(cls, spec, ctx)?;
GeneratedPyMethod::SlotTraitImpl(method.method_name, proto)
Expand Down Expand Up @@ -462,7 +466,7 @@ fn impl_traverse_slot(
visit: #pyo3_path::ffi::visitproc,
arg: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int {
#pyo3_path::impl_::pymethods::_call_traverse::<#cls>(slf, #cls::#rust_fn_ident, visit, arg)
#pyo3_path::impl_::pymethods::_call_traverse::<#cls>(slf, #cls::#rust_fn_ident, visit, arg, #cls::__pymethod_traverse__)
}
};
let slot_def = quote! {
Expand All @@ -477,6 +481,52 @@ fn impl_traverse_slot(
})
}

fn impl_clear_slot(cls: &syn::Type, spec: &FnSpec<'_>, ctx: &Ctx) -> syn::Result<MethodAndSlotDef> {
let Ctx { pyo3_path, .. } = ctx;
let (py_arg, args) = split_off_python_arg(&spec.signature.arguments);
let self_type = match &spec.tp {
FnType::Fn(self_type) => self_type,
_ => bail_spanned!(spec.name.span() => "expected instance method for `__clear__` function"),
};
let mut holders = Holders::new();
let slf = self_type.receiver(cls, ExtractErrorMode::Raise, &mut holders, ctx);

if let [arg, ..] = args {
bail_spanned!(arg.ty().span() => "`__clear__` function expected to have no arguments");
}

let name = &spec.name;
let holders = holders.init_holders(ctx);
let fncall = if py_arg.is_some() {
quote!(#cls::#name(#slf, py))
} else {
quote!(#cls::#name(#slf))
};

let associated_method = quote! {
pub unsafe extern "C" fn __pymethod_clear__(
_slf: *mut #pyo3_path::ffi::PyObject,
) -> ::std::os::raw::c_int {
#pyo3_path::impl_::pymethods::_call_clear(_slf, |py, _slf| {
#holders
let result = #fncall;
let result = #pyo3_path::impl_::wrap::converter(&result).wrap(result)?;
Ok(result)
}, #cls::__pymethod_clear__)
}
};
let slot_def = quote! {
#pyo3_path::ffi::PyType_Slot {
slot: #pyo3_path::ffi::Py_tp_clear,
pfunc: #cls::__pymethod_clear__ as #pyo3_path::ffi::inquiry as _
}
};
Ok(MethodAndSlotDef {
associated_method,
slot_def,
})
}

fn impl_py_class_attribute(
cls: &syn::Type,
spec: &FnSpec<'_>,
Expand Down
131 changes: 131 additions & 0 deletions src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use crate::gil::LockGIL;
use crate::impl_::callback::IntoPyCallbackOutput;
use crate::impl_::panic::PanicTrap;
use crate::impl_::pycell::{PyClassObject, PyClassObjectLayout};
use crate::internal::get_slot::{get_slot, TP_BASE, TP_CLEAR, TP_TRAVERSE};
use crate::pycell::impl_::PyClassBorrowChecker as _;
use crate::pycell::{PyBorrowError, PyBorrowMutError};
use crate::pyclass::boolean_struct::False;
use crate::types::any::PyAnyMethods;
use crate::types::PyType;
use crate::{
ffi, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject, PyRef,
PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python,
Expand All @@ -18,6 +20,8 @@ use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr::null_mut;

use super::trampoline;

/// Python 3.8 and up - __ipow__ has modulo argument correctly populated.
#[cfg(Py_3_8)]
#[repr(transparent)]
Expand Down Expand Up @@ -275,6 +279,7 @@ pub unsafe fn _call_traverse<T>(
impl_: fn(&T, PyVisit<'_>) -> Result<(), PyTraverseError>,
visit: ffi::visitproc,
arg: *mut c_void,
current_traverse: ffi::traverseproc,
) -> c_int
where
T: PyClass,
Expand All @@ -289,6 +294,11 @@ where
let trap = PanicTrap::new("uncaught panic inside __traverse__ handler");
let lock = LockGIL::during_traverse();

let super_retval = call_super_traverse(slf, visit, arg, current_traverse);
if super_retval != 0 {
return super_retval;
}

// SAFETY: `slf` is a valid Python object pointer to a class object of type T, and
// traversal is running so no mutations can occur.
let class_object: &PyClassObject<T> = &*slf.cast();
Expand Down Expand Up @@ -328,6 +338,127 @@ where
retval
}

/// Call super-type traverse method, if necessary.
///
/// Adapted from <https://github.com/cython/cython/blob/7acfb375fb54a033f021b0982a3cd40c34fb22ac/Cython/Utility/ExtensionTypes.c#L386>
///
/// TODO: There are possible optimizations over looking up the base type in this way
/// - if the base type is known in this module, can potentially look it up directly in module state
/// (when we have it)
/// - if the base type is a Python builtin, can jut call the C function directly
/// - if the base type is a PyO3 type defined in the same module, can potentially do similar to
/// tp_alloc where we solve this at compile time
unsafe fn call_super_traverse(
obj: *mut ffi::PyObject,
visit: ffi::visitproc,
arg: *mut c_void,
current_traverse: ffi::traverseproc,
) -> c_int {
// SAFETY: in this function here it's ok to work with raw type objects `ffi::Py_TYPE`
// because the GC is running and so
// - (a) we cannot do refcounting and
// - (b) the type of the object cannot change.
let mut ty = ffi::Py_TYPE(obj);
let mut traverse: Option<ffi::traverseproc>;

// First find the current type by the current_traverse function
loop {
traverse = get_slot(ty, TP_TRAVERSE);
if traverse == Some(current_traverse) {
break;
}
ty = get_slot(ty, TP_BASE);
if ty.is_null() {
// FIXME: return an error if current type not in the MRO? Should be impossible.
return 0;
}
}

// Get first base which has a different traverse function
while traverse == Some(current_traverse) {
ty = get_slot(ty, TP_BASE);
if ty.is_null() {
break;
}
traverse = get_slot(ty, TP_TRAVERSE);
}

// If we found a type with a different traverse function, call it
if let Some(traverse) = traverse {
return traverse(obj, visit, arg);
}

// FIXME same question as cython: what if the current type is not in the MRO?
0
}

/// Calls an implementation of __clear__ for tp_clear
pub unsafe fn _call_clear(
slf: *mut ffi::PyObject,
impl_: for<'py> unsafe fn(Python<'py>, *mut ffi::PyObject) -> PyResult<()>,
current_clear: ffi::inquiry,
) -> c_int {
trampoline::trampoline(move |py| {
let super_retval = call_super_clear(py, slf, current_clear);
if super_retval != 0 {
return Err(PyErr::fetch(py));
}
impl_(py, slf)?;
Ok(0)
})
}

/// Call super-type traverse method, if necessary.
///
/// Adapted from <https://github.com/cython/cython/blob/7acfb375fb54a033f021b0982a3cd40c34fb22ac/Cython/Utility/ExtensionTypes.c#L386>
///
/// TODO: There are possible optimizations over looking up the base type in this way
/// - if the base type is known in this module, can potentially look it up directly in module state
/// (when we have it)
/// - if the base type is a Python builtin, can jut call the C function directly
/// - if the base type is a PyO3 type defined in the same module, can potentially do similar to
/// tp_alloc where we solve this at compile time
unsafe fn call_super_clear(
py: Python<'_>,
obj: *mut ffi::PyObject,
current_clear: ffi::inquiry,
) -> c_int {
let mut ty = PyType::from_borrowed_type_ptr(py, ffi::Py_TYPE(obj));
let mut clear: Option<ffi::inquiry>;

// First find the current type by the current_clear function
loop {
clear = ty.get_slot(TP_CLEAR);
if clear == Some(current_clear) {
break;
}
let base = ty.get_slot(TP_BASE);
if base.is_null() {
// FIXME: return an error if current type not in the MRO? Should be impossible.
return 0;
}
ty = PyType::from_borrowed_type_ptr(py, base);
}

// Get first base which has a different clear function
while clear == Some(current_clear) {
let base = ty.get_slot(TP_BASE);
if base.is_null() {
break;
}
ty = PyType::from_borrowed_type_ptr(py, base);
clear = ty.get_slot(TP_CLEAR);
}

// If we found a type with a different clear function, call it
if let Some(clear) = clear {
return clear(obj);
}

// FIXME same question as cython: what if the current type is not in the MRO?
0
}

// Autoref-based specialization for handling `__next__` returning `Option`

pub struct IterBaseTag;
Expand Down
74 changes: 61 additions & 13 deletions src/internal/get_slot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ impl Bound<'_, PyType> {
where
Slot<S>: GetSlotImpl,
{
slot.get_slot(self.as_borrowed())
// SAFETY: `self` is a valid type object.
unsafe {
slot.get_slot(
self.as_type_ptr(),
#[cfg(all(Py_LIMITED_API, not(Py_3_10)))]
is_runtime_3_10(self.py()),
)
}
}
}

Expand All @@ -21,13 +28,50 @@ impl Borrowed<'_, '_, PyType> {
where
Slot<S>: GetSlotImpl,
{
slot.get_slot(self)
// SAFETY: `self` is a valid type object.
unsafe {
slot.get_slot(
self.as_type_ptr(),
#[cfg(all(Py_LIMITED_API, not(Py_3_10)))]
is_runtime_3_10(self.py()),
)
}
}
}

/// Gets a slot from a raw FFI pointer.
///
/// Safety:
/// - `ty` must be a valid non-null pointer to a `PyTypeObject`.
/// - The Python runtime must be initialized
pub(crate) unsafe fn get_slot<const S: c_int>(
ty: *mut ffi::PyTypeObject,
slot: Slot<S>,
) -> <Slot<S> as GetSlotImpl>::Type
where
Slot<S>: GetSlotImpl,
{
slot.get_slot(
ty,
// SAFETY: the Python runtime is initialized
#[cfg(all(Py_LIMITED_API, not(Py_3_10)))]
is_runtime_3_10(crate::Python::assume_gil_acquired()),
)
}

pub(crate) trait GetSlotImpl {
type Type;
fn get_slot(self, tp: Borrowed<'_, '_, PyType>) -> Self::Type;

/// Gets the requested slot from a type object.
///
/// Safety:
/// - `ty` must be a valid non-null pointer to a `PyTypeObject`.
/// - `is_runtime_3_10` must be `false` if the runtime is not Python 3.10 or later.
unsafe fn get_slot(
self,
ty: *mut ffi::PyTypeObject,
#[cfg(all(Py_LIMITED_API, not(Py_3_10)))] is_runtime_3_10: bool,
) -> Self::Type;
}

#[derive(Copy, Clone)]
Expand All @@ -42,12 +86,14 @@ macro_rules! impl_slots {
type Type = $tp;

#[inline]
fn get_slot(self, tp: Borrowed<'_, '_, PyType>) -> Self::Type {
let ptr = tp.as_type_ptr();

unsafe fn get_slot(
self,
ty: *mut ffi::PyTypeObject,
#[cfg(all(Py_LIMITED_API, not(Py_3_10)))] is_runtime_3_10: bool
) -> Self::Type {
#[cfg(not(Py_LIMITED_API))]
unsafe {
(*ptr).$field
{
(*ty).$field
}

#[cfg(Py_LIMITED_API)]
Expand All @@ -59,27 +105,29 @@ macro_rules! impl_slots {
// (3.7, 3.8, 3.9) and then look in the type object anyway. This is only ok
// because we know that the interpreter is not going to change the size
// of the type objects for these historical versions.
if !is_runtime_3_10(tp.py())
&& unsafe { ffi::PyType_HasFeature(ptr, ffi::Py_TPFLAGS_HEAPTYPE) } == 0
if !is_runtime_3_10 && ffi::PyType_HasFeature(ty, ffi::Py_TPFLAGS_HEAPTYPE) == 0
{
return unsafe { (*ptr.cast::<PyTypeObject39Snapshot>()).$field };
return (*ty.cast::<PyTypeObject39Snapshot>()).$field;
}
}

// SAFETY: slot type is set carefully to be valid
unsafe { std::mem::transmute(ffi::PyType_GetSlot(ptr, ffi::$slot)) }
std::mem::transmute(ffi::PyType_GetSlot(ty, ffi::$slot))
}
}
}
)*
};
}

// Slots are implemented on-demand as needed.
// Slots are implemented on-demand as needed.)
impl_slots! {
TP_ALLOC: (Py_tp_alloc, tp_alloc) -> Option<ffi::allocfunc>,
TP_BASE: (Py_tp_base, tp_base) -> *mut ffi::PyTypeObject,
TP_CLEAR: (Py_tp_clear, tp_clear) -> Option<ffi::inquiry>,
TP_DESCR_GET: (Py_tp_descr_get, tp_descr_get) -> Option<ffi::descrgetfunc>,
TP_FREE: (Py_tp_free, tp_free) -> Option<ffi::freefunc>,
TP_TRAVERSE: (Py_tp_traverse, tp_traverse) -> Option<ffi::traverseproc>,
}

#[cfg(all(Py_LIMITED_API, not(Py_3_10)))]
Expand Down
Loading
Loading