From 0b78bb851eb7302ca4eaffe9326331121d729a03 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Sun, 2 Jul 2023 17:26:31 -0400 Subject: [PATCH] Allow `#[new]` to return existing instances fixes #2384 --- guide/src/class.md | 3 ++ newsfragments/3287.added.md | 1 + src/pyclass_init.rs | 31 ++++++++++++++----- tests/test_class_new.rs | 60 +++++++++++++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 7 deletions(-) create mode 100644 newsfragments/3287.added.md diff --git a/guide/src/class.md b/guide/src/class.md index 06fb9eb8c77..6e13d20f865 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -114,6 +114,9 @@ impl Nonzero { } ``` +If you want to return an existing object (for example, because your `new` +method caches the values it returns), `new` can return `pyo3::Py`. + As you can see, the Rust method name is not important here; this way you can still, use `new()` for a Rust-level constructor. diff --git a/newsfragments/3287.added.md b/newsfragments/3287.added.md new file mode 100644 index 00000000000..bde61a4b506 --- /dev/null +++ b/newsfragments/3287.added.md @@ -0,0 +1 @@ +`#[new]` methods may now return `Py` in order to return existing instances diff --git a/src/pyclass_init.rs b/src/pyclass_init.rs index 57f86665843..534e5fedd5c 100644 --- a/src/pyclass_init.rs +++ b/src/pyclass_init.rs @@ -1,7 +1,7 @@ //! Contains initialization utilities for `#[pyclass]`. use crate::callback::IntoPyCallbackOutput; use crate::impl_::pyclass::{PyClassBaseType, PyClassDict, PyClassThreadChecker, PyClassWeakRef}; -use crate::{ffi, PyCell, PyClass, PyErr, PyResult, Python}; +use crate::{ffi, IntoPyPointer, Py, PyCell, PyClass, PyErr, PyResult, Python}; use crate::{ ffi::PyTypeObject, pycell::{ @@ -134,9 +134,14 @@ impl PyObjectInit for PyNativeTypeInitializer { /// ); /// }); /// ``` -pub struct PyClassInitializer { - init: T, - super_init: ::Initializer, +pub struct PyClassInitializer(PyClassInitializerImpl); + +enum PyClassInitializerImpl { + Existing(Py), + New { + init: T, + super_init: ::Initializer, + }, } impl PyClassInitializer { @@ -144,7 +149,7 @@ impl PyClassInitializer { /// /// It is recommended to use `add_subclass` instead of this method for most usage. pub fn new(init: T, super_init: ::Initializer) -> Self { - Self { init, super_init } + Self(PyClassInitializerImpl::New { init, super_init }) } /// Constructs a new initializer from an initializer for the base class. @@ -242,13 +247,18 @@ impl PyObjectInit for PyClassInitializer { contents: MaybeUninit>, } - let obj = self.super_init.into_new_object(py, subtype)?; + let (init, super_init) = match self.0 { + PyClassInitializerImpl::Existing(value) => return Ok(value.into_ptr()), + PyClassInitializerImpl::New { init, super_init } => (init, super_init), + }; + + let obj = super_init.into_new_object(py, subtype)?; let cell: *mut PartiallyInitializedPyCell = obj as _; std::ptr::write( (*cell).contents.as_mut_ptr(), PyCellContents { - value: ManuallyDrop::new(UnsafeCell::new(self.init)), + value: ManuallyDrop::new(UnsafeCell::new(init)), borrow_checker: ::Storage::new(), thread_checker: T::ThreadChecker::new(), dict: T::Dict::INIT, @@ -284,6 +294,13 @@ where } } +impl From> for PyClassInitializer { + #[inline] + fn from(value: Py) -> PyClassInitializer { + PyClassInitializer(PyClassInitializerImpl::Existing(value)) + } +} + // Implementation used by proc macros to allow anything convertible to PyClassInitializer to be // the return value of pyclass #[new] method (optionally wrapped in `Result`). impl IntoPyCallbackOutput> for U diff --git a/tests/test_class_new.rs b/tests/test_class_new.rs index b9b0d152086..ff159c610f8 100644 --- a/tests/test_class_new.rs +++ b/tests/test_class_new.rs @@ -2,6 +2,7 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; use pyo3::types::IntoPyDict; #[pyclass] @@ -204,3 +205,62 @@ fn new_with_custom_error() { assert_eq!(err.to_string(), "ValueError: custom error"); }); } + +#[pyclass] +struct NewExisting { + #[pyo3(get)] + num: usize, +} + +#[pymethods] +impl NewExisting { + #[new] + fn new(py: pyo3::Python<'_>, val: usize) -> pyo3::Py { + static PRE_BUILT: GILOnceCell<[pyo3::Py; 2]> = GILOnceCell::new(); + let existing = PRE_BUILT.get_or_init(py, || { + [ + pyo3::PyCell::new(py, NewExisting { num: 0 }) + .unwrap() + .into(), + pyo3::PyCell::new(py, NewExisting { num: 1 }) + .unwrap() + .into(), + ] + }); + + if val < existing.len() { + return existing[val].clone_ref(py); + } + + pyo3::PyCell::new(py, NewExisting { num: val }) + .unwrap() + .into() + } +} + +#[test] +fn test_new_existing() { + Python::with_gil(|py| { + let typeobj = py.get_type::(); + + let obj1 = typeobj.call1((0,)).unwrap(); + let obj2 = typeobj.call1((0,)).unwrap(); + let obj3 = typeobj.call1((1,)).unwrap(); + let obj4 = typeobj.call1((1,)).unwrap(); + let obj5 = typeobj.call1((2,)).unwrap(); + let obj6 = typeobj.call1((2,)).unwrap(); + + assert!(obj1.getattr("num").unwrap().extract::().unwrap() == 0); + assert!(obj2.getattr("num").unwrap().extract::().unwrap() == 0); + assert!(obj3.getattr("num").unwrap().extract::().unwrap() == 1); + assert!(obj4.getattr("num").unwrap().extract::().unwrap() == 1); + assert!(obj5.getattr("num").unwrap().extract::().unwrap() == 2); + assert!(obj6.getattr("num").unwrap().extract::().unwrap() == 2); + + assert!(obj1.is(obj2)); + assert!(obj3.is(obj4)); + assert!(!obj1.is(obj3)); + assert!(!obj1.is(obj5)); + assert!(!obj5.is(obj6)); + }); +}