Skip to content

Commit

Permalink
Support Bound in pymodule and pyfunction macros
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored and LilyFoote committed Feb 25, 2024
1 parent e0e3981 commit 031f033
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 46 deletions.
2 changes: 1 addition & 1 deletion guide/src/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ fn parent_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {

fn register_child_module(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> {
let child_module = PyModule::new_bound(py, "child_module")?;
child_module.add_function(&wrap_pyfunction!(func, child_module.as_gil_ref())?.as_borrowed())?;
child_module.add_function(wrap_pyfunction_bound!(func, &child_module)?)?;
parent_module.add_submodule(child_module.as_gil_ref())?;
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion guide/src/python_from_rust.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ fn main() -> PyResult<()> {
Python::with_gil(|py| {
// Create new module
let foo_module = PyModule::new_bound(py, "foo")?;
foo_module.add_function(&wrap_pyfunction!(add_one, foo_module.as_gil_ref())?.as_borrowed())?;
foo_module.add_function(wrap_pyfunction_bound!(add_one, &foo_module)?)?;

// Import and get sys.modules
let sys = PyModule::import_bound(py, "sys")?;
Expand Down
11 changes: 5 additions & 6 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
/// module
pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream> {
let options = PyModuleOptions::from_attrs(&mut function.attrs)?;
process_functions_in_module(&options, &mut function)?;
process_functions_in_module(&mut function)?;
let krate = get_pyo3_crate(&options.krate);
let ident = &function.sig.ident;
let vis = &function.vis;
Expand All @@ -215,13 +215,13 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
use #krate::impl_::pymodule as impl_;

fn __pyo3_pymodule(module: &#krate::Bound<'_, #krate::types::PyModule>) -> #krate::PyResult<()> {
#ident(module.py(), module.as_gil_ref())
#ident(module.py(), ::std::convert::Into::into(impl_::BoundModule(module)))
}

impl #ident::MakeDef {
const fn make_def() -> impl_::ModuleDef {
const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(__pyo3_pymodule);
unsafe {
const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(__pyo3_pymodule);
impl_::ModuleDef::new(
#ident::__PYO3_NAME,
#doc,
Expand Down Expand Up @@ -260,9 +260,8 @@ fn module_initialization(options: PyModuleOptions, ident: &syn::Ident) -> TokenS
}

/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn) -> Result<()> {
fn process_functions_in_module(func: &mut syn::ItemFn) -> Result<()> {
let mut stmts: Vec<syn::Stmt> = Vec::new();
let krate = get_pyo3_crate(&options.krate);

for mut stmt in func.block.stmts.drain(..) {
if let syn::Stmt::Item(Item::Fn(func)) = &mut stmt {
Expand All @@ -272,7 +271,7 @@ fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn
let name = &func.sig.ident;
let statements: Vec<syn::Stmt> = syn::parse_quote! {
#wrapped_function
#module_name.add_function(#krate::impl_::pyfunction::_wrap_pyfunction(&#name::DEF, #module_name)?)?;
#module_name.add_function(#module_name.wrap_pyfunction(&#name::DEF)?)?;
};
stmts.extend(statements);
}
Expand Down
6 changes: 3 additions & 3 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,12 @@ pub fn impl_wrap_pyfunction(
#[doc(hidden)]
#vis mod #name {
pub(crate) struct MakeDef;
pub const DEF: #krate::impl_::pyfunction::PyMethodDef = MakeDef::DEF;
pub const DEF: #krate::impl_::pymethods::PyMethodDef = MakeDef::DEF;

pub fn add_to_module(module: &#krate::Bound<'_, #krate::types::PyModule>) -> #krate::PyResult<()> {
use #krate::prelude::PyModuleMethods;
use ::std::convert::Into;
module.add_function(&#krate::types::PyCFunction::internal_new(&DEF, module.as_gil_ref().into())?)
module.add_function(#krate::types::PyCFunction::internal_new(&DEF, module.as_gil_ref().into())?)
}
}

Expand All @@ -284,7 +284,7 @@ pub fn impl_wrap_pyfunction(
const _: () = {
use #krate as _pyo3;
impl #name::MakeDef {
const DEF: #krate::impl_::pyfunction::PyMethodDef = #methoddef;
const DEF: #krate::impl_::pymethods::PyMethodDef = #methoddef;
}

#[allow(non_snake_case)]
Expand Down
44 changes: 43 additions & 1 deletion src/derive_utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//! Functionality for the code generated by the derive backend
use crate::{types::PyModule, Python};
use crate::impl_::pymethods::PyMethodDef;
use crate::{
types::{PyCFunction, PyModule},
Bound, PyResult, Python,
};

/// Enum to abstract over the arguments of Python function wrappers.
pub enum PyFunctionArguments<'a> {
Expand All @@ -18,6 +22,10 @@ impl<'a> PyFunctionArguments<'a> {
}
}
}

pub fn wrap_pyfunction(self, method_def: &'a PyMethodDef) -> PyResult<&PyCFunction> {
Ok(PyCFunction::internal_new(method_def, self)?.into_gil_ref())
}
}

impl<'a> From<Python<'a>> for PyFunctionArguments<'a> {
Expand All @@ -31,3 +39,37 @@ impl<'a> From<&'a PyModule> for PyFunctionArguments<'a> {
PyFunctionArguments::PyModule(module)
}
}

/// Enum to abstract over the arguments of Python function wrappers.
pub enum PyFunctionArgumentsBound<'a, 'py> {
Python(Python<'py>),
PyModule(&'a Bound<'py, PyModule>),
}

impl<'a, 'py> PyFunctionArgumentsBound<'a, 'py> {
pub fn into_py_and_maybe_module(self) -> (Python<'py>, Option<&'a Bound<'py, PyModule>>) {
match self {
PyFunctionArgumentsBound::Python(py) => (py, None),
PyFunctionArgumentsBound::PyModule(module) => {
let py = module.py();
(py, Some(module))
}
}
}

pub fn wrap_pyfunction(self, method_def: &PyMethodDef) -> PyResult<Bound<'py, PyCFunction>> {
PyCFunction::internal_new_bound(method_def, self)
}
}

impl<'a, 'py> From<Python<'py>> for PyFunctionArgumentsBound<'a, 'py> {
fn from(py: Python<'py>) -> PyFunctionArgumentsBound<'a, 'py> {
PyFunctionArgumentsBound::Python(py)
}
}

impl<'a, 'py> From<&'a Bound<'py, PyModule>> for PyFunctionArgumentsBound<'a, 'py> {
fn from(module: &'a Bound<'py, PyModule>) -> PyFunctionArgumentsBound<'a, 'py> {
PyFunctionArgumentsBound::PyModule(module)
}
}
1 change: 0 additions & 1 deletion src/impl_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub(crate) mod not_send;
pub mod panic;
pub mod pycell;
pub mod pyclass;
pub mod pyfunction;
pub mod pymethods;
pub mod pymodule;
#[doc(hidden)]
Expand Down
10 changes: 0 additions & 10 deletions src/impl_/pyfunction.rs

This file was deleted.

23 changes: 23 additions & 0 deletions src/impl_/pymodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,29 @@ impl<T: PyTypeInfo> PyAddToModule for T {
}
}

pub struct BoundModule<'a, 'py>(pub &'a Bound<'py, PyModule>);

impl<'a> From<BoundModule<'a, 'a>> for &'a PyModule {
#[inline]
fn from(bound: BoundModule<'a, 'a>) -> Self {
bound.0.as_gil_ref()
}
}

impl<'a, 'py> From<BoundModule<'a, 'py>> for &'a Bound<'py, PyModule> {
#[inline]
fn from(bound: BoundModule<'a, 'py>) -> Self {
bound.0
}
}

impl From<BoundModule<'_, '_>> for Py<PyModule> {
#[inline]
fn from(bound: BoundModule<'_, '_>) -> Self {
bound.0.clone().unbind()
}
}

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
Expand Down
33 changes: 31 additions & 2 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,42 @@ macro_rules! py_run_impl {
macro_rules! wrap_pyfunction {
($function:path) => {
&|py_or_module| {
use $crate::derive_utils::PyFunctionArguments;
use $function as wrapped_pyfunction;
$crate::impl_::pyfunction::_wrap_pyfunction(&wrapped_pyfunction::DEF, py_or_module)
let function_arguments: PyFunctionArguments<'_> =
::std::convert::Into::into(py_or_module);
function_arguments.wrap_pyfunction(&wrapped_pyfunction::DEF)
}
};
($function:path, $py_or_module:expr) => {{
use $crate::derive_utils::PyFunctionArguments;
use $function as wrapped_pyfunction;
$crate::impl_::pyfunction::_wrap_pyfunction(&wrapped_pyfunction::DEF, $py_or_module)
let function_arguments: PyFunctionArguments<'_> = ::std::convert::Into::into($py_or_module);
function_arguments.wrap_pyfunction(&wrapped_pyfunction::DEF)
}};
}

/// Wraps a Rust function annotated with [`#[pyfunction]`](macro@crate::pyfunction).
///
/// This can be used with [`PyModule::add_function`](crate::types::PyModule::add_function) to add free
/// functions to a [`PyModule`](crate::types::PyModule) - see its documentation for more information.
#[macro_export]
macro_rules! wrap_pyfunction_bound {
($function:path) => {
&|py_or_module| {
use $crate::derive_utils::PyFunctionArgumentsBound;
use $function as wrapped_pyfunction;
let function_arguments: PyFunctionArgumentsBound<'_, '_> =
::std::convert::Into::into($py_or_module);
function_arguments.wrap_pyfunction(&wrapped_pyfunction::DEF)
}
};
($function:path, $py_or_module:expr) => {{
use $crate::derive_utils::PyFunctionArgumentsBound;
use $function as wrapped_pyfunction;
let function_arguments: PyFunctionArgumentsBound<'_, '_> =
::std::convert::Into::into($py_or_module);
function_arguments.wrap_pyfunction(&wrapped_pyfunction::DEF)
}};
}

Expand Down
2 changes: 1 addition & 1 deletion src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub use crate::PyNativeType;
pub use pyo3_macros::{pyclass, pyfunction, pymethods, pymodule, FromPyObject};

#[cfg(feature = "macros")]
pub use crate::wrap_pyfunction;
pub use crate::{wrap_pyfunction, wrap_pyfunction_bound};

pub use crate::types::any::PyAnyMethods;
pub use crate::types::boolobject::PyBoolMethods;
Expand Down
64 changes: 53 additions & 11 deletions src/types/function.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::derive_utils::PyFunctionArguments;
use crate::derive_utils::{PyFunctionArguments, PyFunctionArgumentsBound};
use crate::ffi_ptr_ext::FfiPtrExt;
use crate::methods::PyMethodDefDestructor;
use crate::py_result_ext::PyResultExt;
use crate::types::capsule::PyCapsuleMethods;
use crate::types::module::PyModuleMethods;
use crate::{
ffi,
impl_::pymethods::{self, PyMethodDef},
Expand Down Expand Up @@ -33,17 +34,25 @@ impl PyCFunction {
doc: &'static str,
py_or_module: PyFunctionArguments<'a>,
) -> PyResult<&'a Self> {
Self::new_with_keywords_bound(fun, name, doc, py_or_module).map(Bound::into_gil_ref)
Self::internal_new(
&PyMethodDef::cfunction_with_keywords(
name,
pymethods::PyCFunctionWithKeywords(fun),
doc,
),
py_or_module,
)
.map(Bound::into_gil_ref)
}

/// Create a new built-in function with keywords (*args and/or **kwargs).
pub fn new_with_keywords_bound<'a>(
pub fn new_with_keywords_bound<'py>(
fun: ffi::PyCFunctionWithKeywords,
name: &'static str,
doc: &'static str,
py_or_module: PyFunctionArguments<'a>,
) -> PyResult<Bound<'a, Self>> {
Self::internal_new(
py_or_module: PyFunctionArgumentsBound<'_, 'py>,
) -> PyResult<Bound<'py, Self>> {
Self::internal_new_bound(
&PyMethodDef::cfunction_with_keywords(
name,
pymethods::PyCFunctionWithKeywords(fun),
Expand All @@ -67,17 +76,21 @@ impl PyCFunction {
doc: &'static str,
py_or_module: PyFunctionArguments<'a>,
) -> PyResult<&'a Self> {
Self::new_bound(fun, name, doc, py_or_module).map(Bound::into_gil_ref)
Self::internal_new(
&PyMethodDef::noargs(name, pymethods::PyCFunction(fun), doc),
py_or_module,
)
.map(Bound::into_gil_ref)
}

/// Create a new built-in function which takes no arguments.
pub fn new_bound<'a>(
pub fn new_bound<'py>(
fun: ffi::PyCFunction,
name: &'static str,
doc: &'static str,
py_or_module: PyFunctionArguments<'a>,
) -> PyResult<Bound<'a, Self>> {
Self::internal_new(
py_or_module: PyFunctionArgumentsBound<'_, 'py>,
) -> PyResult<Bound<'py, Self>> {
Self::internal_new_bound(
&PyMethodDef::noargs(name, pymethods::PyCFunction(fun), doc),
py_or_module,
)
Expand Down Expand Up @@ -189,6 +202,35 @@ impl PyCFunction {
.downcast_into_unchecked()
}
}

#[doc(hidden)]
pub(crate) fn internal_new_bound<'py>(
method_def: &PyMethodDef,
py_or_module: PyFunctionArgumentsBound<'_, 'py>,
) -> PyResult<Bound<'py, Self>> {
let (py, module) = py_or_module.into_py_and_maybe_module();
let (mod_ptr, module_name): (_, Option<Py<PyString>>) = if let Some(m) = module {
let mod_ptr = m.as_ptr();
(mod_ptr, Some(m.name()?.into_py(py)))
} else {
(std::ptr::null_mut(), None)
};
let (def, destructor) = method_def.as_method_def()?;

// FIXME: stop leaking the def and destructor
let def = Box::into_raw(Box::new(def));
std::mem::forget(destructor);

let module_name_ptr = module_name
.as_ref()
.map_or(std::ptr::null_mut(), Py::as_ptr);

unsafe {
ffi::PyCFunction_NewEx(def, mod_ptr, module_name_ptr)
.assume_owned_or_err(py)
.downcast_into_unchecked()
}
}
}

fn closure_capsule_name() -> &'static CStr {
Expand Down
Loading

0 comments on commit 031f033

Please sign in to comment.