Skip to content

Commit

Permalink
Add option to specify preset colors to graph_greedy_color (#1060)
Browse files Browse the repository at this point in the history
* Add option to specify preset colors to graph_greedy_color

This commit adds a new argument to the graph_greed_color function, which
enables a user to provide a callback function that specifies a color to
use for particular nodes.

* Fix type hint stub
  • Loading branch information
mtreinish authored Jan 21, 2024
1 parent ddb0cda commit e39ecc6
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
---
features:
- |
Added a new keyword argument, ``preset_color_fn``, to :func:`.graph_greedy_color`
which is used to provide preset colors for specific nodes when computing the graph
coloring. You can optionally pass a callable to that argument which will
be passed node index from the graph and is either expected to return an
integer color to use for that node, or `None` to indicate there is no
preset color for that node. For example:
.. jupyter-execute::
import rustworkx as rx
from rustworkx.visualization import mpl_draw
graph = rx.generators.generalized_petersen_graph(5, 2)
def preset_colors(node_index):
if node_index == 0:
return 3
coloring = rx.graph_greedy_color(graph, preset_color_fn=preset_colors)
colors = [coloring[node] for node in graph.node_indices()]
layout = rx.shell_layout(graph, nlist=[[0, 1, 2, 3, 4],[6, 7, 8, 9, 5]])
mpl_draw(graph, node_color=colors, pos=layout)
- |
Added a new function ``greedy_node_color_with_preset_colors`` to the
rustworkx-core module ``coloring``. This new function is identical to the
``rustworkx_core::coloring::greedy_node_color`` except it has a second
preset parameter which is passed a callable which is used to provide preset
colors for particular node ids.
132 changes: 103 additions & 29 deletions rustworkx-core/src/coloring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// under the License.

use std::cmp::Reverse;
use std::convert::Infallible;
use std::hash::Hash;

use crate::dictmap::*;
Expand Down Expand Up @@ -96,6 +97,50 @@ where
Some(colors)
}

fn inner_greedy_node_color<G, F, E>(
graph: G,
mut preset_color_fn: F,
) -> Result<DictMap<G::NodeId, usize>, E>
where
G: NodeCount + IntoNodeIdentifiers + IntoEdges,
G::NodeId: Hash + Eq + Send + Sync,
F: FnMut(G::NodeId) -> Result<Option<usize>, E>,
{
let mut colors: DictMap<G::NodeId, usize> = DictMap::with_capacity(graph.node_count());
let mut node_vec: Vec<G::NodeId> = Vec::with_capacity(graph.node_count());
let mut sort_map: HashMap<G::NodeId, usize> = HashMap::with_capacity(graph.node_count());
for k in graph.node_identifiers() {
if let Some(color) = preset_color_fn(k)? {
colors.insert(k, color);
continue;
}
node_vec.push(k);
sort_map.insert(k, graph.edges(k).count());
}
node_vec.par_sort_by_key(|k| Reverse(sort_map.get(k)));

for node in node_vec {
let mut neighbor_colors: HashSet<usize> = HashSet::new();
for edge in graph.edges(node) {
let target = edge.target();
let existing_color = match colors.get(&target) {
Some(color) => color,
None => continue,
};
neighbor_colors.insert(*existing_color);
}
let mut current_color: usize = 0;
loop {
if !neighbor_colors.contains(&current_color) {
break;
}
current_color += 1;
}
colors.insert(node, current_color);
}
Ok(colors)
}

/// Color a graph using a greedy graph coloring algorithm.
///
/// This function uses a `largest-first` strategy as described in:
Expand Down Expand Up @@ -135,36 +180,65 @@ where
G: NodeCount + IntoNodeIdentifiers + IntoEdges,
G::NodeId: Hash + Eq + Send + Sync,
{
let mut colors: DictMap<G::NodeId, usize> = DictMap::with_capacity(graph.node_count());
let mut node_vec: Vec<G::NodeId> = graph.node_identifiers().collect();

let mut sort_map: HashMap<G::NodeId, usize> = HashMap::with_capacity(graph.node_count());
for k in node_vec.iter() {
sort_map.insert(*k, graph.edges(*k).count());
}
node_vec.par_sort_by_key(|k| Reverse(sort_map.get(k)));

for node in node_vec {
let mut neighbor_colors: HashSet<usize> = HashSet::new();
for edge in graph.edges(node) {
let target = edge.target();
let existing_color = match colors.get(&target) {
Some(color) => color,
None => continue,
};
neighbor_colors.insert(*existing_color);
}
let mut current_color: usize = 0;
loop {
if !neighbor_colors.contains(&current_color) {
break;
}
current_color += 1;
}
colors.insert(node, current_color);
}
inner_greedy_node_color(graph, |_| Ok::<Option<usize>, Infallible>(None)).unwrap()
}

colors
/// Color a graph using a greedy graph coloring algorithm with preset colors
///
/// This function uses a `largest-first` strategy as described in:
///
/// Adrian Kosowski, and Krzysztof Manuszewski, Classical Coloring of Graphs,
/// Graph Colorings, 2-19, 2004. ISBN 0-8218-3458-4.
///
/// to color the nodes with higher degree first.
///
/// The coloring problem is NP-hard and this is a heuristic algorithm
/// which may not return an optimal solution.
///
/// Arguments:
///
/// * `graph` - The graph object to run the algorithm on
/// * `preset_color_fn` - A callback function that will recieve the node identifier
/// for each node in the graph and is expected to return an `Option<usize>`
/// (wrapped in a `Result`) that is `None` if the node has no preset and
/// the usize represents the preset color.
///
/// # Example
/// ```rust
///
/// use petgraph::graph::Graph;
/// use petgraph::graph::NodeIndex;
/// use petgraph::Undirected;
/// use rustworkx_core::dictmap::*;
/// use std::convert::Infallible;
/// use rustworkx_core::coloring::greedy_node_color_with_preset_colors;
///
/// let preset_color_fn = |node_idx: NodeIndex| -> Result<Option<usize>, Infallible> {
/// if node_idx.index() == 0 {
/// Ok(Some(1))
/// } else {
/// Ok(None)
/// }
/// };
///
/// let g = Graph::<(), (), Undirected>::from_edges(&[(0, 1), (0, 2)]);
/// let colors = greedy_node_color_with_preset_colors(&g, preset_color_fn).unwrap();
/// let mut expected_colors = DictMap::new();
/// expected_colors.insert(NodeIndex::new(0), 1);
/// expected_colors.insert(NodeIndex::new(1), 0);
/// expected_colors.insert(NodeIndex::new(2), 0);
/// assert_eq!(colors, expected_colors);
/// ```
pub fn greedy_node_color_with_preset_colors<G, F, E>(
graph: G,
preset_color_fn: F,
) -> Result<DictMap<G::NodeId, usize>, E>
where
G: NodeCount + IntoNodeIdentifiers + IntoEdges,
G::NodeId: Hash + Eq + Send + Sync,
F: FnMut(G::NodeId) -> Result<Option<usize>, E>,
{
inner_greedy_node_color(graph, preset_color_fn)
}

/// Color edges of a graph using a greedy approach.
Expand Down
4 changes: 3 additions & 1 deletion rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def graph_katz_centrality(

# Coloring

def graph_greedy_color(graph: PyGraph, /) -> dict[int, int]: ...
def graph_greedy_color(
graph: PyGraph, /, preset_color_fn: Callable[[int], int | None] | None = ...
) -> dict[int, int]: ...
def graph_greedy_edge_color(graph: PyGraph, /) -> dict[int, int]: ...
def graph_is_bipartite(graph: PyGraph) -> bool: ...
def digraph_is_bipartite(graph: PyDiGraph) -> bool: ...
Expand Down
32 changes: 27 additions & 5 deletions src/coloring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
// License for the specific language governing permissions and limitations
// under the License.

use crate::{digraph, graph};
use crate::{digraph, graph, NodeIndex};

use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::Python;
use rustworkx_core::coloring::{
greedy_edge_color, greedy_node_color, misra_gries_edge_color, two_color,
greedy_edge_color, greedy_node_color, greedy_node_color_with_preset_colors,
misra_gries_edge_color, two_color,
};

/// Color a :class:`~.PyGraph` object using a greedy graph coloring algorithm.
Expand All @@ -30,6 +31,13 @@ use rustworkx_core::coloring::{
/// may not return an optimal solution.
///
/// :param PyGraph: The input PyGraph object to color
/// :param preset_color_fn: An optional callback function that is used to manually
/// specify a color to use for particular nodes in the graph. If specified
/// this takes a callable that will be passed a node index and is expected to
/// either return an integer representing a color or ``None`` to indicate there
/// is no preset. Note if you do use a callable there is no validation that
/// the preset values are valid colors. You can generate an invalid coloring
/// if you the specified function returned invalid colors for any nodes.
///
/// :returns: A dictionary where keys are node indices and the value is
/// the color
Expand All @@ -52,9 +60,23 @@ use rustworkx_core::coloring::{
/// .. [1] Adrian Kosowski, and Krzysztof Manuszewski, Classical Coloring of Graphs,
/// Graph Colorings, 2-19, 2004. ISBN 0-8218-3458-4.
#[pyfunction]
#[pyo3(text_signature = "(graph, /)")]
pub fn graph_greedy_color(py: Python, graph: &graph::PyGraph) -> PyResult<PyObject> {
let colors = greedy_node_color(&graph.graph);
#[pyo3(text_signature = "(graph, /, preset_color_fn=None)")]
pub fn graph_greedy_color(
py: Python,
graph: &graph::PyGraph,
preset_color_fn: Option<PyObject>,
) -> PyResult<PyObject> {
let colors = match preset_color_fn {
Some(preset_color_fn) => {
let callback = |node_idx: NodeIndex| -> PyResult<Option<usize>> {
preset_color_fn
.call1(py, (node_idx.index(),))
.map(|x| x.extract(py).ok())
};
greedy_node_color_with_preset_colors(&graph.graph, callback)?
}
None => greedy_node_color(&graph.graph),
};
let out_dict = PyDict::new(py);
for (node, color) in colors {
out_dict.set_item(node.index(), color)?;
Expand Down
42 changes: 42 additions & 0 deletions tests/graph/test_coloring.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,48 @@ def test_simple_graph_large_degree(self):
res = rustworkx.graph_greedy_color(graph)
self.assertEqual({0: 0, 1: 1, 2: 1}, res)

def test_simple_graph_with_preset(self):
def preset(node_idx):
if node_idx == 0:
return 1
return None

graph = rustworkx.PyGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
graph.add_edge(node_a, node_b, 1)
node_c = graph.add_node("c")
graph.add_edge(node_a, node_c, 1)
res = rustworkx.graph_greedy_color(graph, preset)
self.assertEqual({0: 1, 1: 0, 2: 0}, res)

def test_simple_graph_large_degree_with_preset(self):
def preset(node_idx):
if node_idx == 0:
return 1
return None

graph = rustworkx.PyGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
graph.add_edge(node_a, node_b, 1)
node_c = graph.add_node("c")
graph.add_edge(node_a, node_c, 1)
graph.add_edge(node_a, node_c, 1)
graph.add_edge(node_a, node_c, 1)
graph.add_edge(node_a, node_c, 1)
graph.add_edge(node_a, node_c, 1)
res = rustworkx.graph_greedy_color(graph, preset)
self.assertEqual({0: 1, 1: 0, 2: 0}, res)

def test_preset_raises_exception(self):
def preset(node_idx):
raise OverflowError("I am invalid")

graph = rustworkx.generators.path_graph(5)
with self.assertRaises(OverflowError):
rustworkx.graph_greedy_color(graph, preset)


class TestGraphEdgeColoring(unittest.TestCase):
def test_graph(self):
Expand Down

0 comments on commit e39ecc6

Please sign in to comment.