Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Bound for classmethod and pass_module #3831

Merged
merged 4 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ This is the equivalent of the Python decorator `@classmethod`.
#[pymethods]
impl MyClass {
#[classmethod]
fn cls_method(cls: &PyType) -> PyResult<i32> {
fn cls_method(cls: &Bound<'_, PyType>) -> PyResult<i32> {
Ok(10)
}
}
Expand Down Expand Up @@ -719,10 +719,10 @@ To create a constructor which takes a positional class argument, you can combine
impl BaseClass {
#[new]
#[classmethod]
fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult<Self> {
fn py_new(cls: &Bound<'_, PyType>) -> PyResult<Self> {
// Get an abstract attribute (presumably) declared on a subclass of this class.
let subclass_attr = cls.getattr("a_class_attr")?;
Ok(Self(subclass_attr.to_object(py)))
let subclass_attr: Bound<'_, PyAny> = cls.getattr("a_class_attr")?;
Ok(Self(subclass_attr.unbind()))
}
}
```
Expand Down Expand Up @@ -928,7 +928,7 @@ impl MyClass {
// similarly for classmethod arguments, use $cls
#[classmethod]
#[pyo3(text_signature = "($cls, e, f)")]
fn my_class_method(cls: &PyType, e: i32, f: i32) -> i32 {
fn my_class_method(cls: &Bound<'_, PyType>, e: i32, f: i32) -> i32 {
e + f
}
#[staticmethod]
Expand Down
3 changes: 2 additions & 1 deletion guide/src/function.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ The `#[pyo3]` attribute can be used to modify properties of the generated Python

```rust
use pyo3::prelude::*;
use pyo3::types::PyString;

#[pyfunction]
#[pyo3(pass_module)]
fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> {
fn pyfunction_with_module<'py>(module: &Bound<'py, PyModule>) -> PyResult<Bound<'py, PyString>> {
module.name()
}

Expand Down
14 changes: 11 additions & 3 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,21 @@ impl FnType {
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
quote_spanned! { *span =>
#[allow(clippy::useless_conversion)]
::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(#py, #slf.cast())),
::std::convert::Into::into(
_pyo3::impl_::pymethods::BoundRef::from_ref_to_ptr(#py, &#slf.cast())
.downcast_unchecked::<_pyo3::types::PyType>()
),
}
}
FnType::FnModule(span) => {
let py = syn::Ident::new("py", Span::call_site());
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
quote_spanned! { *span =>
#[allow(clippy::useless_conversion)]
::std::convert::Into::into(py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf)),
::std::convert::Into::into(
_pyo3::impl_::pymethods::BoundRef::from_ref_to_ptr(#py, &#slf.cast())
.downcast_unchecked::<_pyo3::types::PyModule>()
),
}
}
}
Expand Down Expand Up @@ -409,7 +417,7 @@ impl<'a> FnSpec<'a> {
// will error on incorrect type.
Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
sig.paren_token.span.join() => "Expected `&PyType` or `Py<PyType>` as the first argument to `#[classmethod]`"
sig.paren_token.span.join() => "Expected `&Bound<PyType>` or `Py<PyType>` as the first argument to `#[classmethod]`"
),
};
FnType::FnClass(span)
Expand Down
20 changes: 20 additions & 0 deletions pytests/src/pyclasses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ struct AssertingBaseClass;

#[pymethods]
impl AssertingBaseClass {
#[new]
#[classmethod]
fn new(cls: &Bound<'_, PyType>, expected_type: Bound<'_, PyType>) -> PyResult<Self> {
if !cls.is(&expected_type) {
return Err(PyValueError::new_err(format!(
"{:?} != {:?}",
cls, expected_type
)));
}
Ok(Self)
}
}

#[pyclass(subclass)]
#[derive(Clone, Debug)]
struct AssertingBaseClassGilRef;

#[pymethods]
impl AssertingBaseClassGilRef {
#[new]
#[classmethod]
fn new(cls: &PyType, expected_type: &PyType) -> PyResult<Self> {
Expand All @@ -65,6 +84,7 @@ pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<EmptyClass>()?;
m.add_class::<PyClassIter>()?;
m.add_class::<AssertingBaseClass>()?;
m.add_class::<AssertingBaseClassGilRef>()?;
m.add_class::<ClassWithoutConstructor>()?;
Ok(())
}
11 changes: 11 additions & 0 deletions pytests/tests/test_pyclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def test_new_classmethod():
_ = AssertingSubClass(expected_type=str)


def test_new_classmethod_gil_ref():
class AssertingSubClass(pyclasses.AssertingBaseClassGilRef):
pass

# The `AssertingBaseClass` constructor errors if it is not passed the
# relevant subclass.
_ = AssertingSubClass(expected_type=AssertingSubClass)
with pytest.raises(ValueError):
_ = AssertingSubClass(expected_type=str)


class ClassWithoutConstructorPy:
def __new__(cls):
raise TypeError("No constructor defined")
Expand Down
53 changes: 52 additions & 1 deletion src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use crate::exceptions::PyStopAsyncIteration;
use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap;
use crate::internal_tricks::extract_c_string;
use crate::types::{any::PyAnyMethods, PyModule, PyType};
use crate::{
ffi, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit, Python,
ffi, Bound, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit,
Python,
};
use std::borrow::Cow;
use std::ffi::CStr;
Expand Down Expand Up @@ -466,3 +468,52 @@ pub trait AsyncIterResultOptionKind {
}

impl<Value, Error> AsyncIterResultOptionKind for Result<Option<Value>, Error> {}

/// Used in `#[classmethod]` to pass the class object to the method
/// and also in `#[pyfunction(pass_module)]`.
///
/// This is a wrapper to avoid implementing `From<Bound>` for GIL Refs.
///
/// Once the GIL Ref API is fully removed, it should be possible to simplify
/// this to just `&'a Bound<'py, T>` and `From` implementations.
pub struct BoundRef<'a, 'py, T>(pub &'a Bound<'py, T>);

impl<'a, 'py> BoundRef<'a, 'py, PyAny> {
pub unsafe fn from_ref_to_ptr(py: Python<'py>, ptr: &'a *mut ffi::PyObject) -> Self {
BoundRef(Bound::from_ref_to_ptr(py, ptr))
}

pub unsafe fn downcast_unchecked<T>(self) -> BoundRef<'a, 'py, T> {
BoundRef(self.0.downcast_unchecked::<T>())
}
}

// GIL Ref implementations for &'a T ran into trouble with orphan rules,
// so explicit implementations are used instead for the two relevant types.
impl<'a> From<BoundRef<'a, 'a, PyType>> for &'a PyType {
#[inline]
fn from(bound: BoundRef<'a, 'a, PyType>) -> Self {
bound.0.as_gil_ref()
}
}

impl<'a> From<BoundRef<'a, 'a, PyModule>> for &'a PyModule {
#[inline]
fn from(bound: BoundRef<'a, 'a, PyModule>) -> Self {
bound.0.as_gil_ref()
}
}

impl<'a, 'py, T> From<BoundRef<'a, 'py, T>> for &'a Bound<'py, T> {
#[inline]
fn from(bound: BoundRef<'a, 'py, T>) -> Self {
bound.0
}
}

impl<T> From<BoundRef<'_, '_, T>> for Py<T> {
#[inline]
fn from(bound: BoundRef<'_, '_, T>) -> Self {
bound.0.clone().unbind()
}
}
8 changes: 8 additions & 0 deletions src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ impl<'py> Bound<'py, PyAny> {
) -> PyResult<Self> {
Py::from_owned_ptr_or_err(py, ptr).map(|obj| Self(py, ManuallyDrop::new(obj)))
}

#[inline]
pub(crate) unsafe fn from_ref_to_ptr<'a>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really get this name. How about from_borrowed_ptr?

Same on BoundRef

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, I couldn't work out a good name either. I didn't feel borrowed is quite right here because that's got a very specific meaning in the Python C API to refer to a pointer return value which doesn't carry ownership.

How about just from_raw?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think this API will remain crate-private for the foreseeable future, so the name doesn't have to be perfect 😄)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, since this is internal its not super important. I'm also fine with from_raw. I think I also get now that you were intending to reference this (unusual) construct & *mut _, but yesterday i read it like it was turning a reference into a pointer, while it was doing kind of the opposite 🙈

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see why the from_x_to_y bit reads wrong.

On reflection I think Bound::from_raw isn't quite the right name for this API. I still can't think of anything better. from_shared_ref_ptr? from_borrow_on_raw_pointer? from_ptr_shared_ref?

I'm kinda ok with a clumsy name here given this is an internal API really only intended to give a lifetime 'a in &'a Bound<'_, PyAny>.

I think of the options I wrote above, from_ptr_shared_ref might make the most sense to me. Any of those which you like?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I have a particular preference among these. If i had to pick i would probably choose between from_shared_ref_ptr and from_ptr_shared_ref. So I'm fine with your preference. Another option that came to my mind was something like ref_from_ptr or shared_from_ptr, but I dont't feel strongly about these either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quite like ref_from_ptr, let's go with that 👍

_py: Python<'py>,
ptr: &'a *mut ffi::PyObject,
) -> &'a Self {
&*(ptr as *const *mut ffi::PyObject).cast::<Bound<'py, PyAny>>()
}
}

impl<'py, T> Bound<'py, T>
Expand Down
4 changes: 2 additions & 2 deletions src/tests/hygiene/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ impl Dummy {
#[staticmethod]
fn staticmethod() {}
#[classmethod]
fn clsmethod(_: &crate::types::PyType) {}
fn clsmethod(_: &crate::Bound<'_, crate::types::PyType>) {}
#[pyo3(signature = (*_args, **_kwds))]
fn __call__(
&self,
Expand Down Expand Up @@ -770,7 +770,7 @@ impl Dummy {
#[staticmethod]
fn staticmethod() {}
#[classmethod]
fn clsmethod(_: &crate::types::PyType) {}
fn clsmethod(_: &crate::Bound<'_, crate::types::PyType>) {}
#[pyo3(signature = (*_args, **_kwds))]
fn __call__(
&self,
Expand Down
14 changes: 13 additions & 1 deletion tests/test_class_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,18 @@ impl ClassWithFromPyWithMethods {
argument
}
#[classmethod]
fn classmethod(_cls: &PyType, #[pyo3(from_py_with = "PyAny::len")] argument: usize) -> usize {
fn classmethod(
_cls: &Bound<'_, PyType>,
#[pyo3(from_py_with = "PyAny::len")] argument: usize,
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
) -> usize {
argument
}

#[classmethod]
fn classmethod_gil_ref(
_cls: &PyType,
#[pyo3(from_py_with = "PyAny::len")] argument: usize,
) -> usize {
argument
}

Expand All @@ -322,6 +333,7 @@ fn test_pymethods_from_py_with() {

assert instance.instance_method(arg) == 2
assert instance.classmethod(arg) == 2
assert instance.classmethod_gil_ref(arg) == 2
assert instance.staticmethod(arg) == 2
"#
);
Expand Down
20 changes: 15 additions & 5 deletions tests/test_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ impl ClassMethod {

#[classmethod]
/// Test class method.
fn method(cls: &PyType) -> PyResult<String> {
fn method(cls: &Bound<'_, PyType>) -> PyResult<String> {
Ok(format!("{}.method()!", cls.as_gil_ref().qualname()?))
}

#[classmethod]
/// Test class method.
fn method_gil_ref(cls: &PyType) -> PyResult<String> {
Ok(format!("{}.method()!", cls.qualname()?))
}

Expand Down Expand Up @@ -108,8 +114,12 @@ struct ClassMethodWithArgs {}
#[pymethods]
impl ClassMethodWithArgs {
#[classmethod]
fn method(cls: &PyType, input: &PyString) -> PyResult<String> {
Ok(format!("{}.method({})", cls.qualname()?, input))
fn method(cls: &Bound<'_, PyType>, input: &PyString) -> PyResult<String> {
Ok(format!(
"{}.method({})",
cls.as_gil_ref().qualname()?,
input
))
}
}

Expand Down Expand Up @@ -915,7 +925,7 @@ impl r#RawIdents {
}

#[classmethod]
pub fn r#class_method(_: &PyType, r#type: PyObject) -> PyObject {
pub fn r#class_method(_: &Bound<'_, PyType>, r#type: PyObject) -> PyObject {
r#type
}

Expand Down Expand Up @@ -1082,7 +1092,7 @@ issue_1506!(

#[classmethod]
fn issue_1506_class(
_cls: &PyType,
_cls: &Bound<'_, PyType>,
_py: Python<'_>,
_arg: &PyAny,
_args: &PyTuple,
Expand Down
Loading
Loading