Skip to content

Commit

Permalink
parse pyo3(item) and pyo3(attribute)
Browse files Browse the repository at this point in the history
  • Loading branch information
Icxolu committed Oct 25, 2024
1 parent 9de2ff5 commit 49d9ce3
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 11 deletions.
137 changes: 126 additions & 11 deletions pyo3-macros-backend/src/intopyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use quote::{format_ident, quote, quote_spanned};
use syn::ext::IdentExt;
use syn::parse::{Parse, ParseStream};
use syn::spanned::Spanned as _;
use syn::{parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Index, Result, Token};
use syn::{
parenthesized, parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Index, Result,
Token,
};

/// Attributes for deriving `IntoPyObject` scoped on containers.
enum ContainerPyO3Attribute {
Expand Down Expand Up @@ -72,6 +75,95 @@ impl ContainerOptions {
}
}

#[derive(Debug, Clone)]
struct ItemOption {
field: Option<syn::LitStr>,
span: Span,
}

impl ItemOption {
fn span(&self) -> Span {
self.span
}
}

enum FieldAttribute {
Item(ItemOption),
}

impl Parse for FieldAttribute {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::attribute) {
let attr: attributes::kw::attribute = input.parse()?;
bail_spanned!(attr.span => "`attribute` is not supported by `IntoPyObject`");
} else if lookahead.peek(attributes::kw::item) {
let attr: attributes::kw::item = input.parse()?;
if input.peek(syn::token::Paren) {
let content;
let _ = parenthesized!(content in input);
let key = content.parse()?;
if !content.is_empty() {
return Err(
content.error("expected at most one argument: `item` or `item(key)`")
);
}
Ok(FieldAttribute::Item(ItemOption {
field: Some(key),
span: attr.span,
}))
} else {
Ok(FieldAttribute::Item(ItemOption {
field: None,
span: attr.span,
}))
}
} else {
Err(lookahead.error())
}
}
}

#[derive(Clone, Debug, Default)]
struct FieldAttributes {
item: Option<ItemOption>,
}

impl FieldAttributes {
/// Extract the field attributes.
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
let mut options = FieldAttributes::default();

for attr in attrs {
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
pyo3_attrs
.into_iter()
.try_for_each(|opt| options.set_option(opt))?;
}
}
Ok(options)
}

fn set_option(&mut self, option: FieldAttribute) -> syn::Result<()> {
macro_rules! set_option {
($key:ident) => {
{
ensure_spanned!(
self.$key.is_none(),
$key.span() => concat!("`", stringify!($key), "` may only be specified once")
);
self.$key = Some($key);
}
};
}

match option {
FieldAttribute::Item(item) => set_option!(item),
}
Ok(())
}
}

struct IntoPyObjectImpl {
target: TokenStream,
output: TokenStream,
Expand All @@ -82,6 +174,7 @@ struct IntoPyObjectImpl {
struct NamedStructField<'a> {
ident: &'a syn::Ident,
field: &'a syn::Field,
item: Option<ItemOption>,
}

struct TupleStructField<'a> {
Expand Down Expand Up @@ -132,22 +225,28 @@ impl<'a> Container<'a> {
) -> Result<Self> {
let style = match fields {
Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
if unnamed.unnamed.iter().count() == 1 {
let mut tuple_fields = unnamed
.unnamed
.iter()
.map(|field| {
let attrs = FieldAttributes::from_attrs(&field.attrs)?;
ensure_spanned!(
attrs.item.is_none(),
attrs.item.unwrap().span() => "`item` is not permitted on tuple struct elements."
);
Ok(TupleStructField { field })
})
.collect::<Result<Vec<_>>>()?;
if tuple_fields.len() == 1 {
// Always treat a 1-length tuple struct as "transparent", even without the
// explicit annotation.
let field = unnamed.unnamed.iter().next().unwrap();
let TupleStructField { field } = tuple_fields.pop().unwrap();
ContainerType::TupleNewtype(field)
} else if options.transparent.is_some() {
bail_spanned!(
fields.span() => "transparent structs and variants can only have 1 field"
);
} else {
let tuple_fields = unnamed
.unnamed
.iter()
.map(|field| Ok(TupleStructField { field }))
.collect::<Result<Vec<_>>>()?;

ContainerType::Tuple(tuple_fields)
}
}
Expand All @@ -159,6 +258,11 @@ impl<'a> Container<'a> {
);

let field = named.named.iter().next().unwrap();
let attrs = FieldAttributes::from_attrs(&field.attrs)?;
ensure_spanned!(
attrs.item.is_none(),
attrs.item.unwrap().span() => "`transparent` structs may not have `item` for the inner field"
);
ContainerType::StructNewtype(field)
} else {
let struct_fields = named
Expand All @@ -170,7 +274,13 @@ impl<'a> Container<'a> {
.as_ref()
.expect("Named fields should have identifiers");

Ok(NamedStructField { ident, field })
let attrs = FieldAttributes::from_attrs(&field.attrs)?;

Ok(NamedStructField {
ident,
field,
item: attrs.item,
})
})
.collect::<Result<Vec<_>>>()?;
ContainerType::Struct(struct_fields)
Expand Down Expand Up @@ -267,7 +377,12 @@ impl<'a> Container<'a> {
.iter()
.enumerate()
.map(|(i, f)| {
let key = f.ident.unraw().to_string();
let key = f
.item
.as_ref()
.and_then(|item| item.field.as_ref())
.map(|item| item.value())
.unwrap_or_else(|| f.ident.unraw().to_string());
let value = Ident::new(&format!("arg{i}"), f.field.ty.span());
quote! {
#pyo3_path::types::PyDictMethods::set_item(&dict, #key, #value)?;
Expand Down
180 changes: 180 additions & 0 deletions tests/test_frompy_intopy_roundtrip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#![cfg(feature = "macros")]

use pyo3::types::{PyDict, PyString};
use pyo3::{prelude::*, IntoPyObject};
use std::collections::HashMap;
use std::hash::Hash;

#[macro_use]
#[path = "../src/tests/common.rs"]
mod common;

#[derive(Debug, Clone, IntoPyObject, FromPyObject)]
pub struct A<'py> {
#[pyo3(item)]
s: String,
#[pyo3(item)]
t: Bound<'py, PyString>,
#[pyo3(item("foo"))]
p: Bound<'py, PyAny>,
}

#[test]
fn test_named_fields_struct() {
Python::with_gil(|py| {
let a = A {
s: "Hello".into(),
t: PyString::new(py, "World"),
p: 42i32.into_pyobject(py).unwrap().into_any(),
};
let pya = a.clone().into_pyobject(py).unwrap();
let new_a = pya.extract::<A<'_>>().unwrap();

assert_eq!(a.s, new_a.s);
assert_eq!(a.t.to_str().unwrap(), new_a.t.to_str().unwrap());
assert_eq!(
a.p.extract::<i32>().unwrap(),
new_a.p.extract::<i32>().unwrap()
);
});
}

#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)]
#[pyo3(transparent)]
pub struct B {
test: String,
}

#[test]
fn test_transparent_named_field_struct() {
Python::with_gil(|py| {
let b = B {
test: "test".into(),
};
let pyb = b.clone().into_pyobject(py).unwrap();
let new_b = pyb.extract::<B>().unwrap();
assert_eq!(b, new_b);
});
}

#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)]
#[pyo3(transparent)]
pub struct D<T> {
test: T,
}

#[test]
fn test_generic_transparent_named_field_struct() {
Python::with_gil(|py| {
let d = D {
test: String::from("test"),
};
let pyd = d.clone().into_pyobject(py).unwrap();
let new_d = pyd.extract::<D<String>>().unwrap();
assert_eq!(d, new_d);

let d = D { test: 1usize };
let pyd = d.clone().into_pyobject(py).unwrap();
let new_d = pyd.extract::<D<usize>>().unwrap();
assert_eq!(d, new_d);
});
}

#[derive(Debug, IntoPyObject, FromPyObject)]
pub struct GenericWithBound<K: Hash + Eq, V>(HashMap<K, V>);

#[test]
fn test_generic_with_bound() {
Python::with_gil(|py| {
let mut hash_map = HashMap::<String, i32>::new();
hash_map.insert("1".into(), 1);
hash_map.insert("2".into(), 2);
let map = GenericWithBound(hash_map).into_pyobject(py).unwrap();
assert_eq!(map.len(), 2);
assert_eq!(
map.get_item("1")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
assert_eq!(
map.get_item("2")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
2
);
assert!(map.get_item("3").unwrap().is_none());
});
}

#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)]
pub struct Tuple(String, usize);

#[test]
fn test_tuple_struct() {
Python::with_gil(|py| {
let tup = Tuple(String::from("test"), 1);
let tuple = tup.clone().into_pyobject(py).unwrap();
let new_tup = tuple.extract::<Tuple>().unwrap();
assert_eq!(tup, new_tup);
});
}

#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)]
pub struct TransparentTuple(String);

#[test]
fn test_transparent_tuple_struct() {
Python::with_gil(|py| {
let tup = TransparentTuple(String::from("test"));
let tuple = tup.clone().into_pyobject(py).unwrap();
let new_tup = tuple.extract::<TransparentTuple>().unwrap();
assert_eq!(tup, new_tup);
});
}

#[derive(Debug, Clone, PartialEq, IntoPyObject, FromPyObject)]
pub enum Foo {
TupleVar(usize, String),
StructVar {
#[pyo3(item)]
test: char,
},
#[pyo3(transparent)]
TransparentTuple(usize),
#[pyo3(transparent)]
TransparentStructVar {
a: Option<String>,
},
}

#[test]
fn test_enum() {
Python::with_gil(|py| {
let tuple_var = Foo::TupleVar(1, "test".into());
let foo = tuple_var.clone().into_pyobject(py).unwrap();
assert_eq!(tuple_var, foo.extract::<Foo>().unwrap());

let struct_var = Foo::StructVar { test: 'b' };
let foo = struct_var
.clone()
.into_pyobject(py)
.unwrap()
.downcast_into::<PyDict>()
.unwrap();

assert_eq!(struct_var, foo.extract::<Foo>().unwrap());

let transparent_tuple = Foo::TransparentTuple(1);
let foo = transparent_tuple.clone().into_pyobject(py).unwrap();
assert_eq!(transparent_tuple, foo.extract::<Foo>().unwrap());

let transparent_struct_var = Foo::TransparentStructVar { a: None };
let foo = transparent_struct_var.clone().into_pyobject(py).unwrap();
assert_eq!(transparent_struct_var, foo.extract::<Foo>().unwrap());
});
}
19 changes: 19 additions & 0 deletions tests/ui/invalid_intopy_derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,23 @@ enum UnitEnum {
Unit,
}

#[derive(IntoPyObject)]
struct TupleAttribute(#[pyo3(attribute)] String, usize);

#[derive(IntoPyObject)]
struct TupleItem(#[pyo3(item)] String, usize);

#[derive(IntoPyObject)]
struct StructAttribute {
#[pyo3(attribute)]
foo: String,
}

#[derive(IntoPyObject)]
#[pyo3(transparent)]
struct StructTransparentItem {
#[pyo3(item)]
foo: String,
}

fn main() {}
Loading

0 comments on commit 49d9ce3

Please sign in to comment.