diff --git a/pyo3-macros-backend/src/intopyobject.rs b/pyo3-macros-backend/src/intopyobject.rs index 9cc263b769b..3b4b2d376bb 100644 --- a/pyo3-macros-backend/src/intopyobject.rs +++ b/pyo3-macros-backend/src/intopyobject.rs @@ -5,7 +5,10 @@ use quote::{format_ident, quote, quote_spanned}; use syn::ext::IdentExt; use syn::parse::{Parse, ParseStream}; use syn::spanned::Spanned as _; -use syn::{parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Index, Result, Token}; +use syn::{ + parenthesized, parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Index, Result, + Token, +}; /// Attributes for deriving `IntoPyObject` scoped on containers. enum ContainerPyO3Attribute { @@ -72,6 +75,95 @@ impl ContainerOptions { } } +#[derive(Debug, Clone)] +struct ItemOption { + field: Option, + span: Span, +} + +impl ItemOption { + fn span(&self) -> Span { + self.span + } +} + +enum FieldAttribute { + Item(ItemOption), +} + +impl Parse for FieldAttribute { + fn parse(input: ParseStream<'_>) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(attributes::kw::attribute) { + let attr: attributes::kw::attribute = input.parse()?; + bail_spanned!(attr.span => "`attribute` is not supported by `IntoPyObject`"); + } else if lookahead.peek(attributes::kw::item) { + let attr: attributes::kw::item = input.parse()?; + if input.peek(syn::token::Paren) { + let content; + let _ = parenthesized!(content in input); + let key = content.parse()?; + if !content.is_empty() { + return Err( + content.error("expected at most one argument: `item` or `item(key)`") + ); + } + Ok(FieldAttribute::Item(ItemOption { + field: Some(key), + span: attr.span, + })) + } else { + Ok(FieldAttribute::Item(ItemOption { + field: None, + span: attr.span, + })) + } + } else { + Err(lookahead.error()) + } + } +} + +#[derive(Clone, Debug, Default)] +struct FieldAttributes { + item: Option, +} + +impl FieldAttributes { + /// Extract the field attributes. + fn from_attrs(attrs: &[Attribute]) -> Result { + let mut options = FieldAttributes::default(); + + for attr in attrs { + if let Some(pyo3_attrs) = get_pyo3_options(attr)? { + pyo3_attrs + .into_iter() + .try_for_each(|opt| options.set_option(opt))?; + } + } + Ok(options) + } + + fn set_option(&mut self, option: FieldAttribute) -> syn::Result<()> { + macro_rules! set_option { + ($key:ident) => { + { + ensure_spanned!( + self.$key.is_none(), + $key.span() => concat!("`", stringify!($key), "` may only be specified once") + ); + self.$key = Some($key); + } + }; + } + + match option { + FieldAttribute::Item(item) => set_option!(item), + } + Ok(()) + } +} + struct IntoPyObjectImpl { target: TokenStream, output: TokenStream, @@ -82,6 +174,7 @@ struct IntoPyObjectImpl { struct NamedStructField<'a> { ident: &'a syn::Ident, field: &'a syn::Field, + item: Option, } struct TupleStructField<'a> { @@ -132,22 +225,28 @@ impl<'a> Container<'a> { ) -> Result { let style = match fields { Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => { - if unnamed.unnamed.iter().count() == 1 { + let mut tuple_fields = unnamed + .unnamed + .iter() + .map(|field| { + let attrs = FieldAttributes::from_attrs(&field.attrs)?; + ensure_spanned!( + attrs.item.is_none(), + attrs.item.unwrap().span() => "`item` is not permitted on tuple struct elements." + ); + Ok(TupleStructField { field }) + }) + .collect::>>()?; + if tuple_fields.len() == 1 { // Always treat a 1-length tuple struct as "transparent", even without the // explicit annotation. - let field = unnamed.unnamed.iter().next().unwrap(); + let TupleStructField { field } = tuple_fields.pop().unwrap(); ContainerType::TupleNewtype(field) } else if options.transparent.is_some() { bail_spanned!( fields.span() => "transparent structs and variants can only have 1 field" ); } else { - let tuple_fields = unnamed - .unnamed - .iter() - .map(|field| Ok(TupleStructField { field })) - .collect::>>()?; - ContainerType::Tuple(tuple_fields) } } @@ -159,6 +258,11 @@ impl<'a> Container<'a> { ); let field = named.named.iter().next().unwrap(); + let attrs = FieldAttributes::from_attrs(&field.attrs)?; + ensure_spanned!( + attrs.item.is_none(), + attrs.item.unwrap().span() => "`transparent` structs may not have `item` for the inner field" + ); ContainerType::StructNewtype(field) } else { let struct_fields = named @@ -170,7 +274,13 @@ impl<'a> Container<'a> { .as_ref() .expect("Named fields should have identifiers"); - Ok(NamedStructField { ident, field }) + let attrs = FieldAttributes::from_attrs(&field.attrs)?; + + Ok(NamedStructField { + ident, + field, + item: attrs.item, + }) }) .collect::>>()?; ContainerType::Struct(struct_fields) @@ -267,7 +377,12 @@ impl<'a> Container<'a> { .iter() .enumerate() .map(|(i, f)| { - let key = f.ident.unraw().to_string(); + let key = f + .item + .as_ref() + .and_then(|item| item.field.as_ref()) + .map(|item| item.value()) + .unwrap_or_else(|| f.ident.unraw().to_string()); let value = Ident::new(&format!("arg{i}"), f.field.ty.span()); quote! { #pyo3_path::types::PyDictMethods::set_item(&dict, #key, #value)?; diff --git a/tests/test_frompy_intopy_roundtrip.rs b/tests/test_frompy_intopy_roundtrip.rs new file mode 100644 index 00000000000..fca0088b800 --- /dev/null +++ b/tests/test_frompy_intopy_roundtrip.rs @@ -0,0 +1,180 @@ +#![cfg(feature = "macros")] + +use pyo3::types::{PyDict, PyString}; +use pyo3::{prelude::*, IntoPyObject}; +use std::collections::HashMap; +use std::hash::Hash; + +#[macro_use] +#[path = "../src/tests/common.rs"] +mod common; + +#[derive(Debug, Clone, IntoPyObject, FromPyObject)] +pub struct A<'py> { + #[pyo3(item)] + s: String, + #[pyo3(item)] + t: Bound<'py, PyString>, + #[pyo3(item("foo"))] + p: Bound<'py, PyAny>, +} + +#[test] +fn test_named_fields_struct() { + Python::with_gil(|py| { + let a = A { + s: "Hello".into(), + t: PyString::new(py, "World"), + p: 42i32.into_pyobject(py).unwrap().into_any(), + }; + let pya = a.clone().into_pyobject(py).unwrap(); + let new_a = pya.extract::>().unwrap(); + + assert_eq!(a.s, new_a.s); + assert_eq!(a.t.to_str().unwrap(), new_a.t.to_str().unwrap()); + assert_eq!( + a.p.extract::().unwrap(), + new_a.p.extract::().unwrap() + ); + }); +} + +#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)] +#[pyo3(transparent)] +pub struct B { + test: String, +} + +#[test] +fn test_transparent_named_field_struct() { + Python::with_gil(|py| { + let b = B { + test: "test".into(), + }; + let pyb = b.clone().into_pyobject(py).unwrap(); + let new_b = pyb.extract::().unwrap(); + assert_eq!(b, new_b); + }); +} + +#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)] +#[pyo3(transparent)] +pub struct D { + test: T, +} + +#[test] +fn test_generic_transparent_named_field_struct() { + Python::with_gil(|py| { + let d = D { + test: String::from("test"), + }; + let pyd = d.clone().into_pyobject(py).unwrap(); + let new_d = pyd.extract::>().unwrap(); + assert_eq!(d, new_d); + + let d = D { test: 1usize }; + let pyd = d.clone().into_pyobject(py).unwrap(); + let new_d = pyd.extract::>().unwrap(); + assert_eq!(d, new_d); + }); +} + +#[derive(Debug, IntoPyObject, FromPyObject)] +pub struct GenericWithBound(HashMap); + +#[test] +fn test_generic_with_bound() { + Python::with_gil(|py| { + let mut hash_map = HashMap::::new(); + hash_map.insert("1".into(), 1); + hash_map.insert("2".into(), 2); + let map = GenericWithBound(hash_map).into_pyobject(py).unwrap(); + assert_eq!(map.len(), 2); + assert_eq!( + map.get_item("1") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 1 + ); + assert_eq!( + map.get_item("2") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 2 + ); + assert!(map.get_item("3").unwrap().is_none()); + }); +} + +#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)] +pub struct Tuple(String, usize); + +#[test] +fn test_tuple_struct() { + Python::with_gil(|py| { + let tup = Tuple(String::from("test"), 1); + let tuple = tup.clone().into_pyobject(py).unwrap(); + let new_tup = tuple.extract::().unwrap(); + assert_eq!(tup, new_tup); + }); +} + +#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)] +pub struct TransparentTuple(String); + +#[test] +fn test_transparent_tuple_struct() { + Python::with_gil(|py| { + let tup = TransparentTuple(String::from("test")); + let tuple = tup.clone().into_pyobject(py).unwrap(); + let new_tup = tuple.extract::().unwrap(); + assert_eq!(tup, new_tup); + }); +} + +#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)] +pub enum Foo { + TupleVar(usize, String), + StructVar { + #[pyo3(item)] + test: char, + }, + #[pyo3(transparent)] + TransparentTuple(usize), + #[pyo3(transparent)] + TransparentStructVar { + a: Option, + }, +} + +#[test] +fn test_enum() { + Python::with_gil(|py| { + let tuple_var = Foo::TupleVar(1, "test".into()); + let foo = tuple_var.clone().into_pyobject(py).unwrap(); + assert_eq!(tuple_var, foo.extract::().unwrap()); + + let struct_var = Foo::StructVar { test: 'b' }; + let foo = struct_var + .clone() + .into_pyobject(py) + .unwrap() + .downcast_into::() + .unwrap(); + + assert_eq!(struct_var, foo.extract::().unwrap()); + + let transparent_tuple = Foo::TransparentTuple(1); + let foo = transparent_tuple.clone().into_pyobject(py).unwrap(); + assert_eq!(transparent_tuple, foo.extract::().unwrap()); + + let transparent_struct_var = Foo::TransparentStructVar { a: None }; + let foo = transparent_struct_var.clone().into_pyobject(py).unwrap(); + assert_eq!(transparent_struct_var, foo.extract::().unwrap()); + }); +} diff --git a/tests/ui/invalid_intopy_derive.rs b/tests/ui/invalid_intopy_derive.rs index 9bbfc9b10cc..310309992d4 100644 --- a/tests/ui/invalid_intopy_derive.rs +++ b/tests/ui/invalid_intopy_derive.rs @@ -87,4 +87,23 @@ enum UnitEnum { Unit, } +#[derive(IntoPyObject)] +struct TupleAttribute(#[pyo3(attribute)] String, usize); + +#[derive(IntoPyObject)] +struct TupleItem(#[pyo3(item)] String, usize); + +#[derive(IntoPyObject)] +struct StructAttribute { + #[pyo3(attribute)] + foo: String, +} + +#[derive(IntoPyObject)] +#[pyo3(transparent)] +struct StructTransparentItem { + #[pyo3(item)] + foo: String, +} + fn main() {} diff --git a/tests/ui/invalid_intopy_derive.stderr b/tests/ui/invalid_intopy_derive.stderr index b7fc0f5ae30..cf125d9c073 100644 --- a/tests/ui/invalid_intopy_derive.stderr +++ b/tests/ui/invalid_intopy_derive.stderr @@ -101,3 +101,27 @@ error: cannot derive `IntoPyObject` for empty variants | 87 | Unit, | ^^^^ + +error: `attribute` is not supported by `IntoPyObject` + --> tests/ui/invalid_intopy_derive.rs:91:30 + | +91 | struct TupleAttribute(#[pyo3(attribute)] String, usize); + | ^^^^^^^^^ + +error: `item` is not permitted on tuple struct elements. + --> tests/ui/invalid_intopy_derive.rs:94:25 + | +94 | struct TupleItem(#[pyo3(item)] String, usize); + | ^^^^ + +error: `attribute` is not supported by `IntoPyObject` + --> tests/ui/invalid_intopy_derive.rs:98:12 + | +98 | #[pyo3(attribute)] + | ^^^^^^^^^ + +error: `transparent` structs may not have `item` for the inner field + --> tests/ui/invalid_intopy_derive.rs:105:12 + | +105 | #[pyo3(item)] + | ^^^^