-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create two shortest path options in own module #174
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
"""Utility functions for shortest path algorithms.""" | ||
from __future__ import annotations | ||
|
||
import heapq | ||
from collections import defaultdict | ||
from typing import Hashable | ||
|
||
import networkx as nx | ||
|
||
|
||
def tarjans_pq(G: nx.MultiDiGraph, | ||
root: int | str, | ||
weight_attr: str = 'weight') -> nx.MultiDiGraph: | ||
"""Tarjan's algorithm for a directed minimum spanning tree. | ||
|
||
Also known as a minimum spanning arborescence, this algorithm finds the | ||
minimum directed spanning tree rooted at a given vertex (root) in a directed | ||
graph. | ||
|
||
Args: | ||
G (nx.MultiDiGraph): The input graph. | ||
root (int | str): The root node (i.e., that all vertices in the graph | ||
should flow to). | ||
weight_attr (str): The name of the edge attribute containing the edge | ||
weights. Defaults to 'weight'. | ||
|
||
Returns: | ||
nx.MultiDiGraph: The directed minimum spanning tree. | ||
""" | ||
# Copy the graph and relabel the nodes | ||
G_ = G.copy() | ||
new_nodes = {node:i for i,node in enumerate(G.nodes)} | ||
node_mapping = {i:node for i,node in enumerate(G.nodes)} | ||
G_ = nx.relabel_nodes(G_, new_nodes) | ||
|
||
# Extract the new root label, edges and weights | ||
root = new_nodes[root] | ||
edges = [(u, v, d[weight_attr]) for u, v, d in G_.edges(data=True)] | ||
|
||
# Initialize data structures | ||
graph = defaultdict(list) | ||
for u, v, weight in edges: | ||
graph[v].append((u, weight)) | ||
|
||
n = len(G.nodes) | ||
parent = {} # Parent pointers for the MST | ||
in_edge_pq: list = [] # Priority queue to store incoming edges | ||
|
||
# Initialize the priority queue with edges incoming to the root | ||
for u, weight in graph[root]: | ||
heapq.heappush(in_edge_pq, (weight, u, root)) | ||
|
||
mst_edges = [] | ||
mst_weight = 0 | ||
outlets: dict = {} | ||
while in_edge_pq: | ||
weight, u, v = heapq.heappop(in_edge_pq) | ||
|
||
if u not in parent: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid too much indentation, you can check the opposite and then continue: if u in parent:
continue
# If v is not in the MST yet, add the edge (u, v) to the MST
parent[u] = v
... Visually, it reduces complexity. |
||
# If v is not in the MST yet, add the edge (u, v) to the MST | ||
parent[u] = v | ||
mst_edges.append((u, v)) | ||
mst_weight += weight | ||
|
||
if v in outlets: | ||
outlets[u] = outlets[v] | ||
|
||
elif G_.get_edge_data(u,v)[0]['edge_type'] == 'outlet': | ||
outlets[u] = node_mapping[u] | ||
|
||
# Add incoming edges to v to the priority queue | ||
for w, weight_new in graph[u]: | ||
heapq.heappush(in_edge_pq, (weight_new, w, u)) | ||
|
||
# Check if all vertices are reachable from the root | ||
if len(parent) != n - 1: | ||
Comment on lines
+75
to
+76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't follow this. Why checking the length checks that all vertices are reachable from the root? I mean, why such a length should be equal to n-1, when n is the total number of nodes? Probably is because I do not fully understand what these trees are, so don't worry if it's obvious it should this like this. |
||
raise ValueError("Graph is not connected or has multiple roots.") | ||
|
||
new_graph = nx.MultiDiGraph() | ||
|
||
for u,v in mst_edges: | ||
d= G_.get_edge_data(u,v)[0] | ||
new_graph.add_edge(u,v,**d) | ||
|
||
for u, d in G_.nodes(data=True): | ||
new_graph.nodes[u].update(d) | ||
|
||
nx.set_node_attributes(new_graph, outlets, 'outlet') | ||
new_graph = nx.relabel_nodes(new_graph, node_mapping) | ||
new_graph.graph = G.graph.copy() | ||
return new_graph | ||
|
||
def dijkstra_pq(G: nx.MultiDiGraph, | ||
outlets: list, | ||
weight_attr: str = 'weight') -> nx.MultiDiGraph: | ||
"""Dijkstra's algorithm for shortest paths to outlets. | ||
|
||
This function calculates the shortest paths from each node in the graph to | ||
the nearest outlet. The graph is modified to include the outlet | ||
and the shortest path length. | ||
|
||
Args: | ||
G (nx.MultiDiGraph): The input graph. | ||
outlets (list): A list of outlet nodes. | ||
weight_attr (str): The name of the edge attribute containing the edge | ||
weights. Defaults to 'weight'. | ||
|
||
Returns: | ||
nx.MultiDiGraph: The graph with the shortest paths to outlets. | ||
""" | ||
G = G.copy() | ||
# Initialize the dictionary with infinity for all nodes | ||
shortest_paths = {node: float('inf') for node in G.nodes} | ||
|
||
# Initialize the dictionary to store the paths | ||
paths: dict[Hashable,list] = {node: [] for node in G.nodes} | ||
|
||
# Set the shortest path length to 0 for outlets | ||
for outlet in outlets: | ||
shortest_paths[outlet] = 0 | ||
paths[outlet] = [outlet] | ||
|
||
# Initialize a min-heap with (distance, node) tuples | ||
heap = [(0, outlet) for outlet in outlets] | ||
while heap: | ||
# Pop the node with the smallest distance | ||
dist, node = heapq.heappop(heap) | ||
|
||
# For each neighbor of the current node | ||
for neighbor, _, edge_data in G.in_edges(node, data=True): | ||
# Calculate the distance through the current node | ||
alt_dist = dist + edge_data[weight_attr] | ||
# If the alternative distance is shorter | ||
|
||
if alt_dist >= shortest_paths[neighbor]: | ||
continue | ||
|
||
# Update the shortest path length | ||
shortest_paths[neighbor] = alt_dist | ||
# Update the path | ||
paths[neighbor] = paths[node] + [neighbor] | ||
# Push the neighbor to the heap | ||
heapq.heappush(heap, (alt_dist, neighbor)) | ||
|
||
# Remove nodes with no path to an outlet | ||
for node in [node for node, path in paths.items() if not path]: | ||
G.remove_node(node) | ||
del paths[node], shortest_paths[node] | ||
|
||
if len(G.nodes) == 0: | ||
raise ValueError("""No nodes with path to outlet, """) | ||
|
||
edges_to_keep: set = set() | ||
|
||
for path in paths.values(): | ||
# Assign outlet | ||
outlet = path[0] | ||
for node in path: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this exclude |
||
G.nodes[node]['outlet'] = outlet | ||
G.nodes[node]['shortest_path'] = shortest_paths[node] | ||
|
||
# Store path | ||
edges_to_keep.update(zip(path[1:], path[:-1])) | ||
|
||
# Remove edges not on paths | ||
new_graph = G.copy() | ||
for u,v in G.edges(): | ||
if (u,v) not in edges_to_keep: | ||
new_graph.remove_edge(u,v) | ||
|
||
return new_graph |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from __future__ import annotations | ||
|
||
import networkx as nx | ||
import pytest | ||
|
||
from swmmanywhere.shortest_path_utils import dijkstra_pq, tarjans_pq | ||
|
||
|
||
def test_simple_graph(): | ||
"""Test case 1: Simple connected graph.""" | ||
G = nx.MultiDiGraph() | ||
G.add_edges_from([(2, 1), (3, 1), (4, 2), (4, 3), (1,'waste')]) | ||
nx.set_edge_attributes(G, | ||
{(2, 1,0):{'weight': 1, 'edge_type':'outlet'}, | ||
(3, 1,0):{'weight': 2, 'edge_type':'street'}, | ||
(4, 2,0):{'weight': 3, 'edge_type':'street'}, | ||
(4, 3,0):{'weight': 4, 'edge_type':'street'}, | ||
(1,'waste',0):{'weight': 0, 'edge_type':'outlet'}}) | ||
root = 'waste' | ||
mst = tarjans_pq(G, root) | ||
assert set(mst.edges) == {(1,'waste',0),(2,1,0), (4,2,0), (3,1,0)} | ||
assert mst.size() == 4 | ||
|
||
djg = dijkstra_pq(G, [root]) | ||
assert set(djg.edges) == {(1,'waste',0),(2,1,0), (4,2,0), (3,1,0)} | ||
assert djg.size() == 4 | ||
|
||
def test_disconnected(): | ||
"""Test case 2: Disconnected graph.""" | ||
G = nx.MultiDiGraph() | ||
G.add_edges_from([(1,'waste'), (2,1), (3,1), (5,4), (6,5)], | ||
weight=1, | ||
edge_type = 'street') | ||
root = 'waste' | ||
with pytest.raises(ValueError): | ||
tarjans_pq(G, root) | ||
|
||
djg = dijkstra_pq(G, [root]) | ||
assert set(djg.edges) == {(1, 'waste', 0), (2, 1, 0), (3, 1, 0)} | ||
|
||
def test_parallel(): | ||
"""Test case 3: Graph with parallel edges.""" | ||
G = nx.MultiDiGraph() | ||
G.add_edges_from([(2, 1, 0), | ||
(2, 1, 1), | ||
(3, 1, 0), | ||
(4, 2, 0), | ||
(4, 3, 0)], | ||
edge_type='street', | ||
weight=1) | ||
root = 1 | ||
mst = tarjans_pq(G, root) | ||
assert set(mst.edges) == {(2, 1, 0), (4, 2, 0), (3, 1, 0)} | ||
assert mst.size() == 3 | ||
|
||
djg = dijkstra_pq(G, [root]) | ||
# Currently paths are defined as node-to-node and so ignore keys .. TODO? | ||
assert set(djg.edges) == {(2, 1, 0), (2, 1, 1), (3, 1, 0), (4, 2, 0)} | ||
|
||
def test_selfloop(): | ||
"""Test case 4: Graph with self-loops.""" | ||
G = nx.MultiDiGraph() | ||
G.add_edges_from([(2, 1, 0), | ||
(3, 1, 0), | ||
(4, 2, 0), | ||
(4, 3, 0), | ||
(2, 4, 0)], | ||
edge_type='street', | ||
weight=1) | ||
G.add_edge(3,4,weight=1,edge_type='street') | ||
root = 1 | ||
mst = tarjans_pq(G, root) | ||
assert set(mst.edges) == {(2, 1, 0), (4, 2, 0), (3, 1, 0)} | ||
assert mst.size() == 3 | ||
|
||
djg = dijkstra_pq(G, [root]) | ||
assert set(djg.edges) == {(2, 1, 0), (4, 2, 0), (3, 1, 0)} | ||
assert djg.size() == 3 | ||
|
||
def test_custom_weight(): | ||
"""Test case 5: Graph with custom weight attribute.""" | ||
G = nx.MultiDiGraph() | ||
G.add_edges_from([(2, 1, 0), | ||
(3, 1, 0), | ||
(4, 2, 0), | ||
(4, 3, 0)], | ||
edge_type='street', | ||
cost=1) | ||
root = 1 | ||
mst = tarjans_pq(G, root, weight_attr='cost') | ||
assert set(mst.edges) == {(2, 1, 0), (4, 2, 0), (3, 1, 0)} | ||
assert mst.size() == 3 | ||
|
||
djg = dijkstra_pq(G, [root], weight_attr='cost') | ||
assert set(djg.edges) == {(2, 1, 0), (4, 2, 0), (3, 1, 0)} | ||
assert djg.size() == 3 | ||
|
||
def test_empty(): | ||
"""Test case 6: Empty graph.""" | ||
G = nx.MultiDiGraph() | ||
root = 1 | ||
with pytest.raises(KeyError): | ||
tarjans_pq(G, root) | ||
with pytest.raises(nx.NetworkXError): | ||
dijkstra_pq(G, [root]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a top level explanation of what this loop is doing and how ti finds the MST.