Skip to content

Commit

Permalink
Merge pull request #9 from mtreinish/in_out_edges_method
Browse files Browse the repository at this point in the history
Add extra methods to PyDAG class
  • Loading branch information
mtreinish authored Jan 27, 2020
2 parents 1f073e2 + 896219e commit 29418f7
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 15 deletions.
66 changes: 52 additions & 14 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,18 @@ retworkx API
.. py:class:: PyDAG
A class for creating direct acyclic graphs.

The PyDAG class is constructed using the Rust library `daggy`_ which is
itself built on the Rust library `petgraph`_. The limitations and quirks
with both libraries dictate how this operates. The biggest thing to be
aware of when using the PyDAG class is that while node and edge indexes
are used for accessing elements on the DAG, node removal can change the
index of a node `petgraph`_. The limitations and quirks
with both libraries dictate how this operates. The biggest thing to be
aware of when using the PyDAG class is that while node and edge indexes
are used for accessing elements on the DAG, node removal can change the
indexes of nodes. Basically when a node in the middle of the dag is
removed the last index is moved to fill that spot. This means either
you have to track that event, or on node removal update the indexes for
the nodes you care about.
The PyDAG class is constructed using the Rust library `petgraph`_ around
the ``StableGraph`` type. The limitations and quirks with this library and
type dictate how this operates. The biggest thing to be aware of when using
the PyDAG class is that an integer node and edge index is used for accessing
elements on the DAG, not the data/weight of nodes and edges.

.. py:method:: __init__(self):
Initialize an empty DAG.

.. py:method:: __len__(self):
Return the number of nodes in the graph. Use via ``len()`` function

.. py:method:: edges(self):
Return a list of all edge data.

Expand Down Expand Up @@ -49,6 +44,14 @@ retworkx API
:returns: A list of the node data for all the parent neighbor nodes
:rtype: list

.. py:method:: get_node_data(self, node):
Return the node data for a given node index

:param int node: The index for the node

:returns: The data object set for that node
:raises IndexError: when an invalid node index is provided

.. py:method:: get_edge_data(self, node_a, node_b):
Return the edge data for the edge between 2 nodes.

Expand Down Expand Up @@ -158,6 +161,34 @@ retworkx API
:raises NoEdgeBetweenNodes if the DAG is broken and an edge can't be
found to a neighbor node

.. py:method:: in_edges(self, node):
Get the index and edge data for all parents of a node.

This will return a list of tuples with the parent index the node index
and the edge data. This can be used to recreate add_edge() calls.

:param int node: The index of the node to get the edges for

:returns in_edges: A list of tuples of the form:
(parent_index, node_index, edge_data)
:rtype: list
:raises NoEdgeBetweenNodes if the DAG is broken and an edge can't be
found to a neighbor node

.. py:method:: out_edges(self, node):
Get the index and edge data for all children of a node.

This will return a list of tuples with the child index the node index
and the edge data. This can be used to recreate add_edge() calls.

:param int node: The index of the node to get the edges for

:returns out_edges: A list of tuples of the form:
(node_index, child_index, edge_data)
:rtype: list
:raises NoEdgeBetweenNodes if the DAG is broken and an edge can't be
found to a neighbor node

.. py:method:: in_degree(self, node):
Get the degree of a node for inbound edges.

Expand All @@ -166,6 +197,14 @@ retworkx API
:returns degree: The inbound degree for the specified node
:rtype: int

.. py:method:: out_degree(self, node):
Get the degree of a node for outbound edges.

:param int node: The index of the node to find the outbound degree of

:returns degree: The outbound degree for the specified node
:rtype: int

.. py:method:: remove_edge(self, parent, child):
Remove an edge between 2 nodes.

Expand All @@ -183,7 +222,6 @@ retworkx API

:param int edge: The index of the edge to remove

.. _daggy: https://github.com/mitchmindtree/daggy
.. _petgraph: https://github.com/bluss/petgraph

.. py:function:: dag_longest_path_length(graph):
Expand Down
71 changes: 70 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ use std::collections::HashMap;
use std::iter;
use std::ops::{Index, IndexMut};

use pyo3::class::PyMappingProtocol;
use pyo3::create_exception;
use pyo3::exceptions::Exception;
use pyo3::exceptions::{Exception, IndexError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use pyo3::wrap_pyfunction;
Expand Down Expand Up @@ -294,6 +295,15 @@ impl PyDAG {
Ok(data)
}

pub fn get_node_data(&self, node: usize) -> PyResult<&PyObject> {
let index = NodeIndex::new(node);
let node = match self.graph.node_weight(index) {
Some(node) => node,
None => return Err(IndexError::py_err("No node found for index")),
};
Ok(node)
}

pub fn get_all_edge_data(
&self,
py: Python,
Expand Down Expand Up @@ -451,6 +461,51 @@ impl PyDAG {
}
Ok(out_dict.into())
}

pub fn in_edges(&mut self, py: Python, node: usize) -> PyResult<PyObject> {
let index = NodeIndex::new(node);
let dir = petgraph::Direction::Incoming;
let neighbors = self.graph.neighbors_directed(index, dir);
let mut out_list: Vec<PyObject> = Vec::new();
for neighbor in neighbors {
let edge = match self.graph.find_edge(neighbor, index) {
Some(edge) => edge,
None => {
return Err(NoEdgeBetweenNodes::py_err(
"No edge found between nodes",
))
}
};
let edge_w = self.graph.edge_weight(edge);
let triplet =
(neighbor.index(), node, edge_w.unwrap()).to_object(py);
out_list.push(triplet)
}
Ok(PyList::new(py, out_list).into())
}

pub fn out_edges(&mut self, py: Python, node: usize) -> PyResult<PyObject> {
let index = NodeIndex::new(node);
let dir = petgraph::Direction::Outgoing;
let neighbors = self.graph.neighbors_directed(index, dir);
let mut out_list: Vec<PyObject> = Vec::new();
for neighbor in neighbors {
let edge = match self.graph.find_edge(index, neighbor) {
Some(edge) => edge,
None => {
return Err(NoEdgeBetweenNodes::py_err(
"No edge found between nodes",
))
}
};
let edge_w = self.graph.edge_weight(edge);
let triplet =
(node, neighbor.index(), edge_w.unwrap()).to_object(py);
out_list.push(triplet)
}
Ok(PyList::new(py, out_list).into())
}

// pub fn add_nodes_from(&self) -> PyResult<()> {
//
// }
Expand All @@ -466,6 +521,20 @@ impl PyDAG {
let neighbors = self.graph.neighbors_directed(index, dir);
neighbors.count()
}

pub fn out_degree(&self, node: usize) -> usize {
let index = NodeIndex::new(node);
let dir = petgraph::Direction::Outgoing;
let neighbors = self.graph.neighbors_directed(index, dir);
neighbors.count()
}
}

#[pyproto]
impl PyMappingProtocol for PyDAG {
fn __len__(&self) -> PyResult<usize> {
Ok(self.graph.node_count())
}
}

fn must_check_for_cycle(dag: &PyDAG, a: NodeIndex, b: NodeIndex) -> bool {
Expand Down
33 changes: 33 additions & 0 deletions tests/test_adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ def test_neighbor_dir_surrounded(self):
res = dag.adj_direction(node_b, True)
self.assertEqual({node_a: {'a': 1}}, res)

def test_single_neighbor_dir_out_edges(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
node_b = dag.add_child(node_a, 'b', {'a': 1})
node_c = dag.add_child(node_a, 'c', {'a': 2})
res = dag.out_edges(node_a)
self.assertEqual([(node_a, node_c, {'a': 2}),
(node_a, node_b, {'a': 1})], res)

def test_neighbor_dir_surrounded_in_out_edges(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
node_b = dag.add_child(node_a, 'b', {'a': 1})
node_c = dag.add_child(node_b, 'c', {'a': 2})
res = dag.out_edges(node_b)
self.assertEqual([(node_b, node_c, {'a': 2})], res)
res = dag.in_edges(node_b)
self.assertEqual([(node_a, node_b, {'a': 1})], res)

def test_no_neighbor(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
Expand All @@ -62,3 +81,17 @@ def test_in_direction_none(self):
for i in range(5):
dag.add_child(node_a, i, None)
self.assertEqual(0, dag.in_degree(node_a))

def test_out_direction(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
for i in range(5):
dag.add_parent(node_a, i, None)
self.assertEqual(0, dag.out_degree(node_a))

def test_out_direction_none(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
for i in range(5):
dag.add_child(node_a, i, None)
self.assertEqual(5, dag.out_degree(node_a))
22 changes: 22 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,25 @@ def test_topo_sort(self):
dag.add_parent(3, 'A parent', None)
res = retworkx.topological_sort(dag)
self.assertEqual([6, 0, 5, 4, 3, 2, 1], res)

def test_get_node_data(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
node_b = dag.add_child(node_a, 'b', "Edgy")
self.assertEqual('b', dag.get_node_data(node_b))

def test_get_node_data_bad_index(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
node_b = dag.add_child(node_a, 'b', "Edgy")
self.assertRaises(IndexError, dag.get_node_data, 42)

def test_pydag_length(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
node_b = dag.add_child(node_a, 'b', "Edgy")
self.assertEqual(2, len(dag))

def test_pydag_length_empty(self):
dag = retworkx.PyDAG()
self.assertEqual(0, len(dag))

0 comments on commit 29418f7

Please sign in to comment.