Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add manual iterator implementations for custom vec iterables #1107

Merged
merged 6 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,7 @@ class _RustworkxCustomVecIter(Generic[_T_co], Sequence[_T_co], ABC):
def __ne__(self, other: object) -> bool: ...
def __setstate__(self, state: Sequence[_T_co]) -> None: ...
def __array__(self, _dt: np.dtype | None = ...) -> np.ndarray: ...
def __iter__(self) -> Iterator[_T_co]: ...

class _RustworkxCustomHashMapIter(Generic[_S, _T_co], Mapping[_S, _T_co], ABC):
def __init__(self) -> None: ...
Expand Down
75 changes: 72 additions & 3 deletions src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
// There are two useful macros to quickly define a new custom return type:
//
// :`custom_vec_iter_impl` holds a `Vec<T>` and can be used as a
// read-only sequence/list. To use it, you should specify the name of the new type,
// the name of the vector that holds the data, the type `T` and a docstring.
// read-only sequence/list. To use it, you should specify the name of the new type for the
// iterable, a name for that new type's iterator, the name of the vector that holds the data, the
// type `T` and a docstring.
//
// e.g `custom_vec_iter_impl!(MyReadOnlyType, data, (usize, f64), "Docs");`
// defines a new type named `MyReadOnlyType` that holds a vector called `data`
Expand Down Expand Up @@ -473,7 +474,7 @@ impl PyConvertToPyArray for Vec<(usize, usize, PyObject)> {
}

macro_rules! custom_vec_iter_impl {
($name:ident, $data:ident, $T:ty, $doc:literal) => {
($name:ident, $iter:ident, $data:ident, $T:ty, $doc:literal) => {
#[doc = $doc]
#[pyclass(module = "rustworkx", sequence)]
#[derive(Clone)]
Expand Down Expand Up @@ -580,6 +581,13 @@ macro_rules! custom_vec_iter_impl {
}
}

fn __iter__(self_: Py<Self>, py: Python) -> $iter {
$iter {
inner: Some(self_.clone_ref(py)),
index: 0,
}
}

fn __array__(&self, py: Python, _dt: Option<&PyArrayDescr>) -> PyResult<PyObject> {
// Note: we accept the dtype argument on the signature but
// effictively do nothing with it to let Numpy handle the conversion itself
Expand All @@ -594,11 +602,66 @@ macro_rules! custom_vec_iter_impl {
PyGCProtocol::__clear__(self)
}
}

#[doc = concat!("Custom iterator class for :class:`.", stringify!($name), "`")]
// No module because this isn't constructable from Python space, and is only exposed as an
// implementation detail.
#[pyclass]
pub struct $iter {
inner: Option<Py<$name>>,
index: usize,
}

#[pymethods]
impl $iter {
fn __next__(&mut self, py: Python) -> Option<Py<PyAny>> {
let data = self.inner.as_ref().unwrap().borrow(py);
if self.index < data.$data.len() {
let out = data.$data[self.index].clone().into_py(py);
self.index += 1;
Some(out)
} else {
None
}
}

fn __iter__(self_: Py<Self>) -> Py<Self> {
// Python iterators typically just return themselves from this, though in principle
// we could return a separate object that iterates starting from the same point.
self_
}

fn __len__(&self, py: Python) -> usize {
self.inner
.as_ref()
.unwrap()
.borrow(py)
.$data
.len()
.saturating_sub(self.index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's a good trick to comply with __length_hint__ not returning a negative. Although as a usize I guess without this in release mode we would just overflow which wouldn't result in a ValueError although the hint would be wildly inaccurate (although PEP424 says the return is "not required to be accurate") for an empty or consumed iterator.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually wrote this initially for __len__ (which is required to be accurate) and reused it for __length_hint__, but then I remembered / realised that iterators don't typically implement __len__ so just renamed the method.

}

fn __length_hint__(&self, py: Python) -> usize {
self.__len__(py)
}

fn __traverse__(&self, vis: PyVisit) -> Result<(), PyTraverseError> {
if let Some(obj) = self.inner.as_ref() {
vis.call(obj)?
}
Ok(())
}

fn __clear__(&mut self) {
self.inner = None;
}
}
};
}

custom_vec_iter_impl!(
BFSSuccessors,
BFSSuccessorsIter,
bfs_successors,
(PyObject, Vec<PyObject>),
"A custom class for the return from :func:`rustworkx.bfs_successors`
Expand Down Expand Up @@ -651,6 +714,7 @@ impl PyGCProtocol for BFSSuccessors {

custom_vec_iter_impl!(
BFSPredecessors,
BFSPredecessorsIter,
bfs_predecessors,
(PyObject, Vec<PyObject>),
"A custom class for the return from :func:`rustworkx.bfs_predecessors`
Expand Down Expand Up @@ -703,6 +767,7 @@ impl PyGCProtocol for BFSPredecessors {

custom_vec_iter_impl!(
NodeIndices,
NodeIndicesIter,
nodes,
usize,
"A custom class for the return of node indices
Expand Down Expand Up @@ -735,6 +800,7 @@ impl PyGCProtocol for NodeIndices {}

custom_vec_iter_impl!(
EdgeList,
EdgeListIter,
edges,
(usize, usize),
"A custom class for the return of edge lists
Expand Down Expand Up @@ -773,6 +839,7 @@ impl PyGCProtocol for EdgeList {}

custom_vec_iter_impl!(
WeightedEdgeList,
WeightedEdgeListIter,
edges,
(usize, usize, PyObject),
"A custom class for the return of edge lists with weights
Expand Down Expand Up @@ -823,6 +890,7 @@ impl PyGCProtocol for WeightedEdgeList {

custom_vec_iter_impl!(
EdgeIndices,
EdgeIndicesIter,
edges,
usize,
"A custom class for the return of edge indices
Expand Down Expand Up @@ -875,6 +943,7 @@ impl PyDisplay for EdgeList {

custom_vec_iter_impl!(
Chains,
ChainsIter,
chains,
EdgeList,
"A custom class for the return of a list of list of edges.
Expand Down
Loading