Skip to content

Commit

Permalink
Added as_super methods to PyRef and PyRefMut. (#4219)
Browse files Browse the repository at this point in the history
* Added `PyRef::as_super` and `PyRefMut::as_super` methods, including docstrings and tests. The implementation of these methods also required adding `#[repr(transparent)]` to the `PyRef` and `PyRefMut` structs.

* Added newsfragment entry.

* Changed the `AsRef<U>`/`AsMut<U>` impls for `PyRef` and `PyRefMut` to use the new `as_super` methods. Added the `PyRefMut::downgrade` associated function for converting `&PyRefMut` to `&PyRef`. Updated tests and docstrings to better demonstrate the new functionality.

* Fixed newsfragment filename.

* Removed unnecessary re-borrows flagged by clippy.

* retrigger checks

* Updated `PyRef::as_super`, `PyRefMut::as_super`, and `PyRefMut::downgrade` to use `.cast()` instead of `as _` pointer casts. Fixed typo.

* Updated `PyRef::as_super` and `PyRefMut::downgrade` to use `ptr_from_ref` for the initial cast to `*const _` instead of `as _` casts.

* Added `pyo3::internal_tricks::ptr_from_mut` function alongside the `ptr_from_ref` added in PR #4240. Updated `PyRefMut::as_super` to use this method instead of `as *mut _`.

* Updated the user guide to recommend `as_super` for accessing the base class instead of `as_ref`, and updated the subsequent example/doctest to demonstrate this functionality.

* Improved tests for the `as_super` methods.

* Updated newsfragment to include additional changes.

* Fixed formatting.

---------

Co-authored-by: jrudolph <[email protected]>
  • Loading branch information
JRRudy1 and JRRudy1 authored Jun 9, 2024
1 parent 9c67057 commit d2dca21
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 10 deletions.
39 changes: 33 additions & 6 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,12 @@ explicitly.

To get a parent class from a child, use [`PyRef`] instead of `&self` for methods,
or [`PyRefMut`] instead of `&mut self`.
Then you can access a parent class by `self_.as_ref()` as `&Self::BaseClass`,
or by `self_.into_super()` as `PyRef<Self::BaseClass>`.
Then you can access a parent class by `self_.as_super()` as `&PyRef<Self::BaseClass>`,
or by `self_.into_super()` as `PyRef<Self::BaseClass>` (and similar for the `PyRefMut`
case). For convenience, `self_.as_ref()` can also be used to get `&Self::BaseClass`
directly; however, this approach does not let you access base clases higher in the
inheritance hierarchy, for which you would need to chain multiple `as_super` or
`into_super` calls.

```rust
# use pyo3::prelude::*;
Expand All @@ -345,7 +349,7 @@ impl BaseClass {
BaseClass { val1: 10 }
}

pub fn method(&self) -> PyResult<usize> {
pub fn method1(&self) -> PyResult<usize> {
Ok(self.val1)
}
}
Expand All @@ -363,8 +367,8 @@ impl SubClass {
}

fn method2(self_: PyRef<'_, Self>) -> PyResult<usize> {
let super_ = self_.as_ref(); // Get &BaseClass
super_.method().map(|x| x * self_.val2)
let super_ = self_.as_super(); // Get &PyRef<BaseClass>
super_.method1().map(|x| x * self_.val2)
}
}

Expand All @@ -381,11 +385,28 @@ impl SubSubClass {
}

fn method3(self_: PyRef<'_, Self>) -> PyResult<usize> {
let base = self_.as_super().as_super(); // Get &PyRef<'_, BaseClass>
base.method1().map(|x| x * self_.val3)
}

fn method4(self_: PyRef<'_, Self>) -> PyResult<usize> {
let v = self_.val3;
let super_ = self_.into_super(); // Get PyRef<'_, SubClass>
SubClass::method2(super_).map(|x| x * v)
}

fn get_values(self_: PyRef<'_, Self>) -> (usize, usize, usize) {
let val1 = self_.as_super().as_super().val1;
let val2 = self_.as_super().val2;
(val1, val2, self_.val3)
}

fn double_values(mut self_: PyRefMut<'_, Self>) {
self_.as_super().as_super().val1 *= 2;
self_.as_super().val2 *= 2;
self_.val3 *= 2;
}

#[staticmethod]
fn factory_method(py: Python<'_>, val: usize) -> PyResult<PyObject> {
let base = PyClassInitializer::from(BaseClass::new());
Expand All @@ -400,7 +421,13 @@ impl SubSubClass {
}
# Python::with_gil(|py| {
# let subsub = pyo3::Py::new(py, SubSubClass::new()).unwrap();
# pyo3::py_run!(py, subsub, "assert subsub.method3() == 3000");
# pyo3::py_run!(py, subsub, "assert subsub.method1() == 10");
# pyo3::py_run!(py, subsub, "assert subsub.method2() == 150");
# pyo3::py_run!(py, subsub, "assert subsub.method3() == 200");
# pyo3::py_run!(py, subsub, "assert subsub.method4() == 3000");
# pyo3::py_run!(py, subsub, "assert subsub.get_values() == (10, 15, 20)");
# pyo3::py_run!(py, subsub, "assert subsub.double_values() == None");
# pyo3::py_run!(py, subsub, "assert subsub.get_values() == (20, 30, 40)");
# let subsub = SubSubClass::factory_method(py, 2).unwrap();
# let subsubsub = SubSubClass::factory_method(py, 3).unwrap();
# let cls = py.get_type_bound::<SubSubClass>();
Expand Down
3 changes: 3 additions & 0 deletions newsfragments/4219.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Added `as_super` methods to `PyRef` and `PyRefMut` for accesing the base class by reference
- Updated user guide to recommend `as_super` for referencing the base class instead of `as_ref`
- Added `pyo3::internal_tricks::ptr_from_mut` function for casting `&mut T` to `*mut T`
6 changes: 6 additions & 0 deletions src/internal_tricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,9 @@ pub(crate) fn extract_c_string(
pub(crate) const fn ptr_from_ref<T>(t: &T) -> *const T {
t as *const T
}

// TODO: use ptr::from_mut on MSRV 1.76
#[inline]
pub(crate) fn ptr_from_mut<T>(t: &mut T) -> *mut T {
t as *mut T
}
168 changes: 164 additions & 4 deletions src/pycell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@
use crate::conversion::AsPyPointer;
use crate::exceptions::PyRuntimeError;
use crate::ffi_ptr_ext::FfiPtrExt;
use crate::internal_tricks::{ptr_from_mut, ptr_from_ref};
use crate::pyclass::{boolean_struct::False, PyClass};
use crate::types::any::PyAnyMethods;
#[cfg(feature = "gil-refs")]
use crate::{
conversion::ToPyObject,
impl_::pyclass::PyClassImpl,
internal_tricks::ptr_from_ref,
pyclass::boolean_struct::True,
pyclass_init::PyClassInitializer,
type_object::{PyLayout, PySizedLayout},
Expand Down Expand Up @@ -612,6 +612,7 @@ impl<T: PyClass + fmt::Debug> fmt::Debug for PyCell<T> {
/// ```
///
/// See the [module-level documentation](self) for more information.
#[repr(transparent)]
pub struct PyRef<'p, T: PyClass> {
// TODO: once the GIL Ref API is removed, consider adding a lifetime parameter to `PyRef` to
// store `Borrowed` here instead, avoiding reference counting overhead.
Expand All @@ -631,7 +632,7 @@ where
U: PyClass,
{
fn as_ref(&self) -> &T::BaseType {
unsafe { &*self.inner.get_class_object().ob_base.get_ptr() }
self.as_super()
}
}

Expand Down Expand Up @@ -743,6 +744,58 @@ where
},
}
}

/// Borrows a shared reference to `PyRef<T::BaseType>`.
///
/// With the help of this method, you can access attributes and call methods
/// on the superclass without consuming the `PyRef<T>`. This method can also
/// be chained to access the super-superclass (and so on).
///
/// # Examples
/// ```
/// # use pyo3::prelude::*;
/// #[pyclass(subclass)]
/// struct Base {
/// base_name: &'static str,
/// }
/// #[pymethods]
/// impl Base {
/// fn base_name_len(&self) -> usize {
/// self.base_name.len()
/// }
/// }
///
/// #[pyclass(extends=Base)]
/// struct Sub {
/// sub_name: &'static str,
/// }
///
/// #[pymethods]
/// impl Sub {
/// #[new]
/// fn new() -> (Self, Base) {
/// (Self { sub_name: "sub_name" }, Base { base_name: "base_name" })
/// }
/// fn sub_name_len(&self) -> usize {
/// self.sub_name.len()
/// }
/// fn format_name_lengths(slf: PyRef<'_, Self>) -> String {
/// format!("{} {}", slf.as_super().base_name_len(), slf.sub_name_len())
/// }
/// }
/// # Python::with_gil(|py| {
/// # let sub = Py::new(py, Sub::new()).unwrap();
/// # pyo3::py_run!(py, sub, "assert sub.format_name_lengths() == '9 8'")
/// # });
/// ```
pub fn as_super(&self) -> &PyRef<'p, U> {
let ptr = ptr_from_ref::<Bound<'p, T>>(&self.inner)
// `Bound<T>` has the same layout as `Bound<T::BaseType>`
.cast::<Bound<'p, T::BaseType>>()
// `Bound<T::BaseType>` has the same layout as `PyRef<T::BaseType>`
.cast::<PyRef<'p, T::BaseType>>();
unsafe { &*ptr }
}
}

impl<'p, T: PyClass> Deref for PyRef<'p, T> {
Expand Down Expand Up @@ -799,6 +852,7 @@ impl<T: PyClass + fmt::Debug> fmt::Debug for PyRef<'_, T> {
/// A wrapper type for a mutably borrowed value from a [`Bound<'py, T>`].
///
/// See the [module-level documentation](self) for more information.
#[repr(transparent)]
pub struct PyRefMut<'p, T: PyClass<Frozen = False>> {
// TODO: once the GIL Ref API is removed, consider adding a lifetime parameter to `PyRef` to
// store `Borrowed` here instead, avoiding reference counting overhead.
Expand All @@ -818,7 +872,7 @@ where
U: PyClass<Frozen = False>,
{
fn as_ref(&self) -> &T::BaseType {
unsafe { &*self.inner.get_class_object().ob_base.get_ptr() }
PyRefMut::downgrade(self).as_super()
}
}

Expand All @@ -828,7 +882,7 @@ where
U: PyClass<Frozen = False>,
{
fn as_mut(&mut self) -> &mut T::BaseType {
unsafe { &mut *self.inner.get_class_object().ob_base.get_ptr() }
self.as_super()
}
}

Expand Down Expand Up @@ -870,6 +924,11 @@ impl<'py, T: PyClass<Frozen = False>> PyRefMut<'py, T> {
.try_borrow_mut()
.map(|_| Self { inner: obj.clone() })
}

pub(crate) fn downgrade(slf: &Self) -> &PyRef<'py, T> {
// `PyRefMut<T>` and `PyRef<T>` have the same layout
unsafe { &*ptr_from_ref(slf).cast() }
}
}

impl<'p, T, U> PyRefMut<'p, T>
Expand All @@ -891,6 +950,23 @@ where
},
}
}

/// Borrows a mutable reference to `PyRefMut<T::BaseType>`.
///
/// With the help of this method, you can mutate attributes and call mutating
/// methods on the superclass without consuming the `PyRefMut<T>`. This method
/// can also be chained to access the super-superclass (and so on).
///
/// See [`PyRef::as_super`] for more.
pub fn as_super(&mut self) -> &mut PyRefMut<'p, U> {
let ptr = ptr_from_mut::<Bound<'p, T>>(&mut self.inner)
// `Bound<T>` has the same layout as `Bound<T::BaseType>`
.cast::<Bound<'p, T::BaseType>>()
// `Bound<T::BaseType>` has the same layout as `PyRefMut<T::BaseType>`,
// and the mutable borrow on `self` prevents aliasing
.cast::<PyRefMut<'p, T::BaseType>>();
unsafe { &mut *ptr }
}
}

impl<'p, T: PyClass<Frozen = False>> Deref for PyRefMut<'p, T> {
Expand Down Expand Up @@ -1140,4 +1216,88 @@ mod tests {
unsafe { ffi::Py_DECREF(ptr) };
})
}

#[crate::pyclass]
#[pyo3(crate = "crate", subclass)]
struct BaseClass {
val1: usize,
}

#[crate::pyclass]
#[pyo3(crate = "crate", extends=BaseClass, subclass)]
struct SubClass {
val2: usize,
}

#[crate::pyclass]
#[pyo3(crate = "crate", extends=SubClass)]
struct SubSubClass {
val3: usize,
}

#[crate::pymethods]
#[pyo3(crate = "crate")]
impl SubSubClass {
#[new]
fn new(py: Python<'_>) -> crate::Py<SubSubClass> {
let init = crate::PyClassInitializer::from(BaseClass { val1: 10 })
.add_subclass(SubClass { val2: 15 })
.add_subclass(SubSubClass { val3: 20 });
crate::Py::new(py, init).expect("allocation error")
}

fn get_values(self_: PyRef<'_, Self>) -> (usize, usize, usize) {
let val1 = self_.as_super().as_super().val1;
let val2 = self_.as_super().val2;
(val1, val2, self_.val3)
}

fn double_values(mut self_: PyRefMut<'_, Self>) {
self_.as_super().as_super().val1 *= 2;
self_.as_super().val2 *= 2;
self_.val3 *= 2;
}
}

#[test]
fn test_pyref_as_super() {
Python::with_gil(|py| {
let obj = SubSubClass::new(py).into_bound(py);
let pyref = obj.borrow();
assert_eq!(pyref.as_super().as_super().val1, 10);
assert_eq!(pyref.as_super().val2, 15);
assert_eq!(pyref.as_ref().val2, 15); // `as_ref` also works
assert_eq!(pyref.val3, 20);
assert_eq!(SubSubClass::get_values(pyref), (10, 15, 20));
});
}

#[test]
fn test_pyrefmut_as_super() {
Python::with_gil(|py| {
let obj = SubSubClass::new(py).into_bound(py);
assert_eq!(SubSubClass::get_values(obj.borrow()), (10, 15, 20));
{
let mut pyrefmut = obj.borrow_mut();
assert_eq!(pyrefmut.as_super().as_ref().val1, 10);
pyrefmut.as_super().as_super().val1 -= 5;
pyrefmut.as_super().val2 -= 3;
pyrefmut.as_mut().val2 -= 2; // `as_mut` also works
pyrefmut.val3 -= 5;
}
assert_eq!(SubSubClass::get_values(obj.borrow()), (5, 10, 15));
SubSubClass::double_values(obj.borrow_mut());
assert_eq!(SubSubClass::get_values(obj.borrow()), (10, 20, 30));
});
}

#[test]
fn test_pyrefs_in_python() {
Python::with_gil(|py| {
let obj = SubSubClass::new(py);
crate::py_run!(py, obj, "assert obj.get_values() == (10, 15, 20)");
crate::py_run!(py, obj, "assert obj.double_values() is None");
crate::py_run!(py, obj, "assert obj.get_values() == (20, 30, 40)");
});
}
}

0 comments on commit d2dca21

Please sign in to comment.