From e835ff0ec30e0e391fdfcbef017241004ed9d896 Mon Sep 17 00:00:00 2001 From: Icxolu <10486322+Icxolu@users.noreply.github.com> Date: Sat, 4 May 2024 09:39:40 +0200 Subject: [PATCH] handle `#[pyo3(from_py_with = ...)]` on dunder (`__magic__`) methods (#4117) * handle `#[pyo3(from_py_with = ...)]` on dunder (__magic__) methods * add newsfragment --- newsfragments/4117.fixed.md | 1 + pyo3-macros-backend/src/params.rs | 34 +++++++++++++++++--- pyo3-macros-backend/src/pymethod.rs | 42 ++++++++++++++---------- tests/test_class_basics.rs | 11 +++++++ tests/ui/deprecations.rs | 8 +++++ tests/ui/deprecations.stderr | 50 ++++++++++++++++------------- 6 files changed, 103 insertions(+), 43 deletions(-) create mode 100644 newsfragments/4117.fixed.md diff --git a/newsfragments/4117.fixed.md b/newsfragments/4117.fixed.md new file mode 100644 index 00000000000..c3bb4c144b6 --- /dev/null +++ b/newsfragments/4117.fixed.md @@ -0,0 +1 @@ +Correctly handle `#[pyo3(from_py_with = ...)]` attribute on dunder (`__magic__`) method arguments instead of silently ignoring it. diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index d9f77fa07bc..cab9d2a7d29 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -10,7 +10,7 @@ use syn::spanned::Spanned; pub struct Holders { holders: Vec, - gil_refs_checkers: Vec, + gil_refs_checkers: Vec, } impl Holders { @@ -32,14 +32,28 @@ impl Holders { &format!("gil_refs_checker_{}", self.gil_refs_checkers.len()), span, ); - self.gil_refs_checkers.push(gil_refs_checker.clone()); + self.gil_refs_checkers + .push(GilRefChecker::FunctionArg(gil_refs_checker.clone())); + gil_refs_checker + } + + pub fn push_from_py_with_checker(&mut self, span: Span) -> syn::Ident { + let gil_refs_checker = syn::Ident::new( + &format!("gil_refs_checker_{}", self.gil_refs_checkers.len()), + span, + ); + self.gil_refs_checkers + .push(GilRefChecker::FromPyWith(gil_refs_checker.clone())); gil_refs_checker } pub fn init_holders(&self, ctx: &Ctx) -> TokenStream { let Ctx { pyo3_path } = ctx; let holders = &self.holders; - let gil_refs_checkers = &self.gil_refs_checkers; + let gil_refs_checkers = self.gil_refs_checkers.iter().map(|checker| match checker { + GilRefChecker::FunctionArg(ident) => ident, + GilRefChecker::FromPyWith(ident) => ident, + }); quote! { #[allow(clippy::let_unit_value)] #(let mut #holders = #pyo3_path::impl_::extract_argument::FunctionArgumentHolder::INIT;)* @@ -50,11 +64,23 @@ impl Holders { pub fn check_gil_refs(&self) -> TokenStream { self.gil_refs_checkers .iter() - .map(|e| quote_spanned! { e.span() => #e.function_arg(); }) + .map(|checker| match checker { + GilRefChecker::FunctionArg(ident) => { + quote_spanned! { ident.span() => #ident.function_arg(); } + } + GilRefChecker::FromPyWith(ident) => { + quote_spanned! { ident.span() => #ident.from_py_with_arg(); } + } + }) .collect() } } +enum GilRefChecker { + FunctionArg(syn::Ident), + FromPyWith(syn::Ident), +} + /// Return true if the argument list is simply (*args, **kwds). pub fn is_forwarded_args(signature: &FunctionSignature<'_>) -> bool { matches!( diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index aac804316f8..1ef137cfcc8 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -1053,20 +1053,18 @@ impl Ty { ctx: &Ctx, ) -> TokenStream { let Ctx { pyo3_path } = ctx; - let name_str = arg.name().unraw().to_string(); match self { Ty::Object => extract_object( extract_error_mode, holders, - &name_str, + arg, quote! { #ident }, - arg.ty().span(), ctx ), Ty::MaybeNullObject => extract_object( extract_error_mode, holders, - &name_str, + arg, quote! { if #ident.is_null() { #pyo3_path::ffi::Py_None() @@ -1074,23 +1072,20 @@ impl Ty { #ident } }, - arg.ty().span(), ctx ), Ty::NonNullObject => extract_object( extract_error_mode, holders, - &name_str, + arg, quote! { #ident.as_ptr() }, - arg.ty().span(), ctx ), Ty::IPowModulo => extract_object( extract_error_mode, holders, - &name_str, + arg, quote! { #ident.as_ptr() }, - arg.ty().span(), ctx ), Ty::CompareOp => extract_error_mode.handle_error( @@ -1118,24 +1113,37 @@ impl Ty { fn extract_object( extract_error_mode: ExtractErrorMode, holders: &mut Holders, - name: &str, + arg: &FnArg<'_>, source_ptr: TokenStream, - span: Span, ctx: &Ctx, ) -> TokenStream { let Ctx { pyo3_path } = ctx; - let holder = holders.push_holder(Span::call_site()); - let gil_refs_checker = holders.push_gil_refs_checker(span); - let extracted = extract_error_mode.handle_error( + let gil_refs_checker = holders.push_gil_refs_checker(arg.ty().span()); + let name = arg.name().unraw().to_string(); + + let extract = if let Some(from_py_with) = + arg.from_py_with().map(|from_py_with| &from_py_with.value) + { + let from_py_with_checker = holders.push_from_py_with_checker(from_py_with.span()); + quote! { + #pyo3_path::impl_::extract_argument::from_py_with( + #pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &#source_ptr).0, + #name, + #pyo3_path::impl_::deprecations::inspect_fn(#from_py_with, &#from_py_with_checker) as fn(_) -> _, + ) + } + } else { + let holder = holders.push_holder(Span::call_site()); quote! { #pyo3_path::impl_::extract_argument::extract_argument( #pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &#source_ptr).0, &mut #holder, #name ) - }, - ctx, - ); + } + }; + + let extracted = extract_error_mode.handle_error(extract, ctx); quote! { #pyo3_path::impl_::deprecations::inspect_type(#extracted, &#gil_refs_checker) } diff --git a/tests/test_class_basics.rs b/tests/test_class_basics.rs index 8c6a1c04915..8ff61bd2d6b 100644 --- a/tests/test_class_basics.rs +++ b/tests/test_class_basics.rs @@ -290,6 +290,10 @@ fn get_length(obj: &Bound<'_, PyAny>) -> PyResult { Ok(length) } +fn is_even(obj: &Bound<'_, PyAny>) -> PyResult { + obj.extract::().map(|i| i % 2 == 0) +} + #[pyclass] struct ClassWithFromPyWithMethods {} @@ -319,6 +323,10 @@ impl ClassWithFromPyWithMethods { fn staticmethod(#[pyo3(from_py_with = "get_length")] argument: usize) -> usize { argument } + + fn __contains__(&self, #[pyo3(from_py_with = "is_even")] obj: bool) -> bool { + obj + } } #[test] @@ -339,6 +347,9 @@ fn test_pymethods_from_py_with() { if has_gil_refs: assert instance.classmethod_gil_ref(arg) == 2 assert instance.staticmethod(arg) == 2 + + assert 42 in instance + assert 73 not in instance "# ); }) diff --git a/tests/ui/deprecations.rs b/tests/ui/deprecations.rs index dcc9b7b1d84..96f652d9679 100644 --- a/tests/ui/deprecations.rs +++ b/tests/ui/deprecations.rs @@ -38,6 +38,14 @@ impl MyClass { #[setter] fn set_bar_bound(&self, _value: &Bound<'_, PyAny>) {} + + fn __eq__(&self, #[pyo3(from_py_with = "extract_gil_ref")] _other: i32) -> bool { + true + } + + fn __contains__(&self, #[pyo3(from_py_with = "extract_bound")] _value: i32) -> bool { + true + } } fn main() {} diff --git a/tests/ui/deprecations.stderr b/tests/ui/deprecations.stderr index e692702f23e..2b75ee23e10 100644 --- a/tests/ui/deprecations.stderr +++ b/tests/ui/deprecations.stderr @@ -16,6 +16,12 @@ error: use of deprecated struct `pyo3::PyCell`: `PyCell` was merged into `Bound` 23 | fn method_gil_ref(_slf: &PyCell) {} | ^^^^^^ +error: use of deprecated method `pyo3::deprecations::GilRefs::::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor + --> tests/ui/deprecations.rs:42:44 + | +42 | fn __eq__(&self, #[pyo3(from_py_with = "extract_gil_ref")] _other: i32) -> bool { + | ^^^^^^^^^^^^^^^^^ + error: use of deprecated method `pyo3::deprecations::GilRefs::::function_arg`: use `&Bound<'_, T>` instead for this function argument --> tests/ui/deprecations.rs:18:33 | @@ -47,69 +53,69 @@ error: use of deprecated method `pyo3::deprecations::GilRefs::::function_arg` | ^ error: use of deprecated method `pyo3::deprecations::GilRefs::::function_arg`: use `&Bound<'_, T>` instead for this function argument - --> tests/ui/deprecations.rs:53:43 + --> tests/ui/deprecations.rs:61:43 | -53 | fn pyfunction_with_module_gil_ref(module: &PyModule) -> PyResult<&str> { +61 | fn pyfunction_with_module_gil_ref(module: &PyModule) -> PyResult<&str> { | ^ error: use of deprecated method `pyo3::deprecations::GilRefs::::function_arg`: use `&Bound<'_, T>` instead for this function argument - --> tests/ui/deprecations.rs:63:19 + --> tests/ui/deprecations.rs:71:19 | -63 | fn module_gil_ref(m: &PyModule) -> PyResult<()> { +71 | fn module_gil_ref(m: &PyModule) -> PyResult<()> { | ^ error: use of deprecated method `pyo3::deprecations::GilRefs::::function_arg`: use `&Bound<'_, T>` instead for this function argument - --> tests/ui/deprecations.rs:69:57 + --> tests/ui/deprecations.rs:77:57 | -69 | fn module_gil_ref_with_explicit_py_arg(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +77 | fn module_gil_ref_with_explicit_py_arg(_py: Python<'_>, m: &PyModule) -> PyResult<()> { | ^ error: use of deprecated method `pyo3::deprecations::GilRefs::::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor - --> tests/ui/deprecations.rs:102:27 + --> tests/ui/deprecations.rs:110:27 | -102 | #[pyo3(from_py_with = "extract_gil_ref")] _gil_ref: i32, +110 | #[pyo3(from_py_with = "extract_gil_ref")] _gil_ref: i32, | ^^^^^^^^^^^^^^^^^ error: use of deprecated method `pyo3::deprecations::GilRefs::::function_arg`: use `&Bound<'_, T>` instead for this function argument - --> tests/ui/deprecations.rs:108:29 + --> tests/ui/deprecations.rs:116:29 | -108 | fn pyfunction_gil_ref(_any: &PyAny) {} +116 | fn pyfunction_gil_ref(_any: &PyAny) {} | ^ error: use of deprecated method `pyo3::deprecations::OptionGilRefs::>::function_arg`: use `Option<&Bound<'_, T>>` instead for this function argument - --> tests/ui/deprecations.rs:111:36 + --> tests/ui/deprecations.rs:119:36 | -111 | fn pyfunction_option_gil_ref(_any: Option<&PyAny>) {} +119 | fn pyfunction_option_gil_ref(_any: Option<&PyAny>) {} | ^^^^^^ error: use of deprecated method `pyo3::deprecations::GilRefs::::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor - --> tests/ui/deprecations.rs:118:27 + --> tests/ui/deprecations.rs:126:27 | -118 | #[pyo3(from_py_with = "PyAny::len", item("my_object"))] +126 | #[pyo3(from_py_with = "PyAny::len", item("my_object"))] | ^^^^^^^^^^^^ error: use of deprecated method `pyo3::deprecations::GilRefs::::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor - --> tests/ui/deprecations.rs:128:27 + --> tests/ui/deprecations.rs:136:27 | -128 | #[pyo3(from_py_with = "PyAny::len")] usize, +136 | #[pyo3(from_py_with = "PyAny::len")] usize, | ^^^^^^^^^^^^ error: use of deprecated method `pyo3::deprecations::GilRefs::::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor - --> tests/ui/deprecations.rs:134:31 + --> tests/ui/deprecations.rs:142:31 | -134 | Zip(#[pyo3(from_py_with = "extract_gil_ref")] i32), +142 | Zip(#[pyo3(from_py_with = "extract_gil_ref")] i32), | ^^^^^^^^^^^^^^^^^ error: use of deprecated method `pyo3::deprecations::GilRefs::::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor - --> tests/ui/deprecations.rs:141:27 + --> tests/ui/deprecations.rs:149:27 | -141 | #[pyo3(from_py_with = "extract_gil_ref")] +149 | #[pyo3(from_py_with = "extract_gil_ref")] | ^^^^^^^^^^^^^^^^^ error: use of deprecated method `pyo3::deprecations::GilRefs::>::is_python`: use `wrap_pyfunction_bound!` instead - --> tests/ui/deprecations.rs:154:13 + --> tests/ui/deprecations.rs:162:13 | -154 | let _ = wrap_pyfunction!(double, py); +162 | let _ = wrap_pyfunction!(double, py); | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = note: this error originates in the macro `wrap_pyfunction` (in Nightly builds, run with -Z macro-backtrace for more info)