Skip to content

Commit

Permalink
WIP: Union args.
Browse files Browse the repository at this point in the history
Add union proc-macro to derive extraction for wrapper enums.
  • Loading branch information
sebpuetz committed Jul 23, 2020
1 parent b05eb48 commit 6456f88
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 3 deletions.
16 changes: 16 additions & 0 deletions examples/union/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
authors = ["PyO3 Authors"]
name = "union"
version = "0.1.0"
description = ""
edition = "2018"

[dependencies]

[dependencies.pyo3]
path = "../../"
features = ["extension-module"]

[lib]
name = "union"
crate-type = ["cdylib"]
25 changes: 25 additions & 0 deletions examples/union/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use pyo3::prelude::*;
use pyo3::types::{PyLong, PyString};
use pyo3::wrap_pyfunction;

#[pyfunction]
pub fn foo<'a>(inp: Union<'a>) {
match inp {
Union::Str(s) => println!("{}", s.to_string_lossy()),
Union::Int(i) => println!("{}", i.repr().unwrap()),
Union::StringList(s_list) => println!("{:?}", s_list),
}
}

#[union]
pub enum Union<'a> {
Str(&'a PyString),
Int(&'a PyLong),
StringList(Vec<String>),
}

#[pymodule]
pub fn union(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(foo))?;
Ok(())
}
2 changes: 2 additions & 0 deletions pyo3-derive-backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ mod pyfunction;
mod pyimpl;
mod pymethod;
mod pyproto;
mod union;
mod utils;

pub use module::{add_fn_to_module, process_functions_in_module, py_init};
pub use pyclass::{build_py_class, PyClassArgs};
pub use pyfunction::{build_py_function, PyFunctionAttr};
pub use pyimpl::{build_py_methods, impl_methods};
pub use pyproto::build_py_proto;
pub use union::build_wrapper_enum;
pub use utils::get_doc;
70 changes: 70 additions & 0 deletions pyo3-derive-backend/src/union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::spanned::Spanned;
use syn::{Fields, ItemEnum, Type};

pub fn build_wrapper_enum(tokens: &mut ItemEnum) -> syn::Result<TokenStream> {
let ident = Ident::new(&tokens.ident.to_string(), tokens.ident.span());
let mut var_extracts = Vec::new();
let mut union_types = String::new();
for (i, var) in tokens.variants.iter().enumerate() {
// TODO allow rename for err-msg
union_types.push_str(&var.ident.to_string());
// TODO support named fields
match var.fields {
Fields::Unnamed(_) => (),
_ => {
return Err(syn::Error::new(
var.span(),
"Currently only NewType variants allowed.",
))
}
}
let var_ident = &var.ident;
// TODO allow variants with multiple fields
if var.fields.len() != 1 {
return Err(syn::Error::new(
var.span(),
"Currently only NewType variants allowed",
));
}
let ty: &Type = &var.fields.iter().next().unwrap().ty;
if let Type::Reference(ty_ref) = ty {
let elem = ty_ref.elem.as_ref();
// TODO hard-coded ob
// TODO #ident::#var_ident seems wrong
var_extracts.push(quote!(
if let Ok(ob) = #elem::try_from(ob) {
return Ok(#ident::#var_ident(ob));
}
));
} else if let Type::Path(_) = ty {
var_extracts.push(quote!(
if let Ok(ob) = ::pyo3::FromPyObject::extract(ob) {
return Ok(#ident::#var_ident(ob));
}
))
} else {
return Err(syn::Error::new(ty.span(), "Expected reference"));
}

if i != tokens.variants.len() - 1 {
union_types.push_str(", ")
}
}
let union = if tokens.variants.len() > 1 {
format!("Union[{}]", union_types)
} else {
union_types
};
Ok(quote!(
impl<'source> ::pyo3::FromPyObject<'source> for #ident<'source> {
fn extract(ob: &'source ::pyo3::PyAny) -> ::pyo3::PyResult<Self> {
#(#var_extracts);*;
let type_name = ob.get_type().name();
let err_msg = format!("Can't convert {} to {}", type_name, #union);
Err(::pyo3::exceptions::PyTypeError::py_err(err_msg))
}
}
))
}
15 changes: 13 additions & 2 deletions pyo3cls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
extern crate proc_macro;
use proc_macro::TokenStream;
use pyo3_derive_backend::{
build_py_class, build_py_function, build_py_methods, build_py_proto, get_doc,
process_functions_in_module, py_init, PyClassArgs, PyFunctionAttr,
build_py_class, build_py_function, build_py_methods, build_py_proto, build_wrapper_enum,
get_doc, process_functions_in_module, py_init, PyClassArgs, PyFunctionAttr,
};
use quote::quote;
use syn::parse_macro_input;
Expand Down Expand Up @@ -91,3 +91,14 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream {
)
.into()
}

#[proc_macro_attribute]
pub fn union(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut ast = parse_macro_input!(item as syn::ItemEnum);
let expanded = build_wrapper_enum(&mut ast).unwrap_or_else(|e| e.to_compile_error());
quote!(
#ast
#expanded
)
.into()
}
2 changes: 1 addition & 1 deletion src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ pub use crate::{FromPy, FromPyObject, IntoPy, IntoPyPointer, PyTryFrom, PyTryInt
// PyModule is only part of the prelude because we need it for the pymodule function
pub use crate::types::{PyAny, PyModule};
#[cfg(feature = "macros")]
pub use pyo3cls::{pyclass, pyfunction, pymethods, pymodule, pyproto};
pub use pyo3cls::{pyclass, pyfunction, pymethods, pymodule, pyproto, union};

0 comments on commit 6456f88

Please sign in to comment.