Skip to content

Commit

Permalink
Merge pull request #1641 from 1tgr/for-each
Browse files Browse the repository at this point in the history
Simplify code generated for for_each_method_def and for_each_proto_slot
  • Loading branch information
davidhewitt authored May 29, 2021
2 parents 5446fe2 + 1ba3217 commit 8bf3ade
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 80 deletions.
41 changes: 19 additions & 22 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -757,18 +757,17 @@ impl pyo3::class::impl_::PyClassImpl for MyClass {
type BaseType = PyAny;
type ThreadChecker = pyo3::class::impl_::ThreadCheckerStub<MyClass>;

fn for_each_method_def(visitor: &mut dyn FnMut(&pyo3::class::PyMethodDefType)) {
fn for_each_method_def(visitor: &mut dyn FnMut(&[pyo3::class::PyMethodDefType])) {
use pyo3::class::impl_::*;
let collector = PyClassImplCollector::<MyClass>::new();
collector.py_methods().iter()
.chain(collector.py_class_descriptors())
.chain(collector.object_protocol_methods())
.chain(collector.async_protocol_methods())
.chain(collector.context_protocol_methods())
.chain(collector.descr_protocol_methods())
.chain(collector.mapping_protocol_methods())
.chain(collector.number_protocol_methods())
.for_each(visitor)
visitor(collector.py_methods());
visitor(collector.py_class_descriptors());
visitor(collector.object_protocol_methods());
visitor(collector.async_protocol_methods());
visitor(collector.context_protocol_methods());
visitor(collector.descr_protocol_methods());
visitor(collector.mapping_protocol_methods());
visitor(collector.number_protocol_methods());
}
fn get_new() -> Option<pyo3::ffi::newfunc> {
use pyo3::class::impl_::*;
Expand All @@ -780,21 +779,19 @@ impl pyo3::class::impl_::PyClassImpl for MyClass {
let collector = PyClassImplCollector::<Self>::new();
collector.call_impl()
}
fn for_each_proto_slot(visitor: &mut dyn FnMut(&pyo3::ffi::PyType_Slot)) {
fn for_each_proto_slot(visitor: &mut dyn FnMut(&[pyo3::ffi::PyType_Slot])) {
// Implementation which uses dtolnay specialization to load all slots.
use pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
collector.object_protocol_slots()
.iter()
.chain(collector.number_protocol_slots())
.chain(collector.iter_protocol_slots())
.chain(collector.gc_protocol_slots())
.chain(collector.descr_protocol_slots())
.chain(collector.mapping_protocol_slots())
.chain(collector.sequence_protocol_slots())
.chain(collector.async_protocol_slots())
.chain(collector.buffer_protocol_slots())
.for_each(visitor);
visitor(collector.object_protocol_slots());
visitor(collector.number_protocol_slots());
visitor(collector.iter_protocol_slots());
visitor(collector.gc_protocol_slots());
visitor(collector.descr_protocol_slots());
visitor(collector.mapping_protocol_slots());
visitor(collector.sequence_protocol_slots());
visitor(collector.async_protocol_slots());
visitor(collector.buffer_protocol_slots());
}

fn get_buffer() -> Option<&'static pyo3::class::impl_::PyBufferProcs> {
Expand Down
51 changes: 24 additions & 27 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,14 @@ fn impl_class(
quote! {}
};

let (impl_inventory, iter_py_methods) = match methods_type {
PyClassMethodsType::Specialization => (None, quote! { collector.py_methods().iter() }),
let (impl_inventory, for_each_py_method) = match methods_type {
PyClassMethodsType::Specialization => (None, quote! { visitor(collector.py_methods()); }),
PyClassMethodsType::Inventory => (
Some(impl_methods_inventory(&cls)),
quote! {
pyo3::inventory::iter::<<Self as pyo3::class::impl_::HasMethodsInventory>::Methods>
.into_iter()
.flat_map(pyo3::class::impl_::PyMethodsInventory::get)
for inventory in pyo3::inventory::iter::<<Self as pyo3::class::impl_::HasMethodsInventory>::Methods>() {
visitor(pyo3::class::impl_::PyMethodsInventory::get(inventory));
}
},
),
};
Expand Down Expand Up @@ -436,18 +436,17 @@ fn impl_class(
type BaseType = #base;
type ThreadChecker = #thread_checker;

fn for_each_method_def(visitor: &mut dyn FnMut(&pyo3::class::PyMethodDefType)) {
fn for_each_method_def(visitor: &mut dyn FnMut(&[pyo3::class::PyMethodDefType])) {
use pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
#iter_py_methods
.chain(collector.py_class_descriptors())
.chain(collector.object_protocol_methods())
.chain(collector.async_protocol_methods())
.chain(collector.context_protocol_methods())
.chain(collector.descr_protocol_methods())
.chain(collector.mapping_protocol_methods())
.chain(collector.number_protocol_methods())
.for_each(visitor)
#for_each_py_method;
visitor(collector.py_class_descriptors());
visitor(collector.object_protocol_methods());
visitor(collector.async_protocol_methods());
visitor(collector.context_protocol_methods());
visitor(collector.descr_protocol_methods());
visitor(collector.mapping_protocol_methods());
visitor(collector.number_protocol_methods());
}
fn get_new() -> Option<pyo3::ffi::newfunc> {
use pyo3::class::impl_::*;
Expand All @@ -460,21 +459,19 @@ fn impl_class(
collector.call_impl()
}

fn for_each_proto_slot(visitor: &mut dyn FnMut(&pyo3::ffi::PyType_Slot)) {
fn for_each_proto_slot(visitor: &mut dyn FnMut(&[pyo3::ffi::PyType_Slot])) {
// Implementation which uses dtolnay specialization to load all slots.
use pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
collector.object_protocol_slots()
.iter()
.chain(collector.number_protocol_slots())
.chain(collector.iter_protocol_slots())
.chain(collector.gc_protocol_slots())
.chain(collector.descr_protocol_slots())
.chain(collector.mapping_protocol_slots())
.chain(collector.sequence_protocol_slots())
.chain(collector.async_protocol_slots())
.chain(collector.buffer_protocol_slots())
.for_each(visitor);
visitor(collector.object_protocol_slots());
visitor(collector.number_protocol_slots());
visitor(collector.iter_protocol_slots());
visitor(collector.gc_protocol_slots());
visitor(collector.descr_protocol_slots());
visitor(collector.mapping_protocol_slots());
visitor(collector.sequence_protocol_slots());
visitor(collector.async_protocol_slots());
visitor(collector.buffer_protocol_slots());
}

fn get_buffer() -> Option<&'static pyo3::class::impl_::PyBufferProcs> {
Expand Down
4 changes: 2 additions & 2 deletions src/class/impl_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ pub trait PyClassImpl: Sized {
/// can be accessed by multiple threads by `threading` module.
type ThreadChecker: PyClassThreadChecker<Self>;

fn for_each_method_def(_visitor: &mut dyn FnMut(&PyMethodDefType)) {}
fn for_each_method_def(_visitor: &mut dyn FnMut(&[PyMethodDefType])) {}
fn get_new() -> Option<ffi::newfunc> {
None
}
fn get_call() -> Option<ffi::PyCFunctionWithKeywords> {
None
}
fn for_each_proto_slot(_visitor: &mut dyn FnMut(&ffi::PyType_Slot)) {}
fn for_each_proto_slot(_visitor: &mut dyn FnMut(&[ffi::PyType_Slot])) {}
fn get_buffer() -> Option<&'static PyBufferProcs> {
None
}
Expand Down
45 changes: 25 additions & 20 deletions src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,11 @@ where

// protocol methods
let mut has_gc_methods = false;
T::for_each_proto_slot(&mut |slot| {
has_gc_methods |= slot.slot == ffi::Py_tp_clear;
has_gc_methods |= slot.slot == ffi::Py_tp_traverse;
slots.0.push(*slot);
T::for_each_proto_slot(&mut |proto_slots| {
has_gc_methods |= proto_slots
.iter()
.any(|slot| slot.slot == ffi::Py_tp_clear || slot.slot == ffi::Py_tp_traverse);
slots.0.extend_from_slice(proto_slots);
});

slots.push(0, ptr::null_mut());
Expand Down Expand Up @@ -312,17 +313,17 @@ fn py_class_flags(has_gc_methods: bool, is_gc: bool, is_basetype: bool) -> c_uin
}

fn py_class_method_defs(
for_each_method_def: &dyn Fn(&mut dyn FnMut(&PyMethodDefType)),
for_each_method_def: &dyn Fn(&mut dyn FnMut(&[PyMethodDefType])),
) -> Vec<ffi::PyMethodDef> {
let mut defs = Vec::new();

for_each_method_def(&mut |def| match def {
PyMethodDefType::Method(def)
| PyMethodDefType::Class(def)
| PyMethodDefType::Static(def) => {
defs.push(def.as_method_def().unwrap());
}
_ => (),
for_each_method_def(&mut |method_defs| {
defs.extend(method_defs.iter().filter_map(|def| match def {
PyMethodDefType::Method(def)
| PyMethodDefType::Class(def)
| PyMethodDefType::Static(def) => Some(def.as_method_def().unwrap()),
_ => None,
}));
});

if !defs.is_empty() {
Expand Down Expand Up @@ -385,18 +386,22 @@ const PY_GET_SET_DEF_INIT: ffi::PyGetSetDef = ffi::PyGetSetDef {
#[allow(clippy::collapsible_if)] // for if cfg!
fn py_class_properties(
is_dummy: bool,
for_each_method_def: &dyn Fn(&mut dyn FnMut(&PyMethodDefType)),
for_each_method_def: &dyn Fn(&mut dyn FnMut(&[PyMethodDefType])),
) -> Vec<ffi::PyGetSetDef> {
let mut defs = std::collections::HashMap::new();

for_each_method_def(&mut |def| match def {
PyMethodDefType::Getter(getter) => {
getter.copy_to(defs.entry(getter.name).or_insert(PY_GET_SET_DEF_INIT));
}
PyMethodDefType::Setter(setter) => {
setter.copy_to(defs.entry(setter.name).or_insert(PY_GET_SET_DEF_INIT));
for_each_method_def(&mut |method_defs| {
for def in method_defs {
match def {
PyMethodDefType::Getter(getter) => {
getter.copy_to(defs.entry(getter.name).or_insert(PY_GET_SET_DEF_INIT));
}
PyMethodDefType::Setter(setter) => {
setter.copy_to(defs.entry(setter.name).or_insert(PY_GET_SET_DEF_INIT));
}
_ => (),
}
}
_ => (),
});

let mut props: Vec<_> = defs.values().cloned().collect();
Expand Down
22 changes: 13 additions & 9 deletions src/type_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl LazyStaticType {
py: Python,
type_object: *mut ffi::PyTypeObject,
name: &str,
for_each_method_def: &dyn Fn(&mut dyn FnMut(&PyMethodDefType)),
for_each_method_def: &dyn Fn(&mut dyn FnMut(&[PyMethodDefType])),
) {
// We might want to fill the `tp_dict` with python instances of `T`
// itself. In order to do so, we must first initialize the type object
Expand Down Expand Up @@ -147,17 +147,21 @@ impl LazyStaticType {
// means that another thread can continue the initialization in the
// meantime: at worst, we'll just make a useless computation.
let mut items = vec![];
for_each_method_def(&mut |def| {
if let PyMethodDefType::ClassAttribute(attr) = def {
items.push((
extract_cstr_or_leak_cstring(
for_each_method_def(&mut |method_defs| {
items.extend(method_defs.iter().filter_map(|def| {
if let PyMethodDefType::ClassAttribute(attr) = def {
let key = extract_cstr_or_leak_cstring(
attr.name,
"class attribute name cannot contain nul bytes",
)
.unwrap(),
(attr.meth.0)(py),
));
}
.unwrap();

let val = (attr.meth.0)(py);
Some((key, val))
} else {
None
}
}));
});

// Now we hold the GIL and we can assume it won't be released until we
Expand Down

0 comments on commit 8bf3ade

Please sign in to comment.