Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create two shortest path options in own module #174

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions swmmanywhere/shortest_path_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Utility functions for shortest path algorithms."""
from __future__ import annotations

import heapq
from collections import defaultdict
from typing import Hashable

import networkx as nx


def tarjans_pq(G: nx.MultiDiGraph,
root: int | str,
weight_attr: str = 'weight') -> nx.MultiDiGraph:
"""Tarjan's algorithm for a directed minimum spanning tree.

Also known as a minimum spanning arborescence, this algorithm finds the
minimum directed spanning tree rooted at a given vertex (root) in a directed
graph.

Args:
G (nx.MultiDiGraph): The input graph.
root (int | str): The root node (i.e., that all vertices in the graph
should flow to).
weight_attr (str): The name of the edge attribute containing the edge
weights. Defaults to 'weight'.

Returns:
nx.MultiDiGraph: The directed minimum spanning tree.
"""
# Copy the graph and relabel the nodes
G_ = G.copy()
new_nodes = {node:i for i,node in enumerate(G.nodes)}
node_mapping = {i:node for i,node in enumerate(G.nodes)}
G_ = nx.relabel_nodes(G_, new_nodes)

# Extract the new root label, edges and weights
root = new_nodes[root]
edges = [(u, v, d[weight_attr]) for u, v, d in G_.edges(data=True)]

# Initialize data structures
graph = defaultdict(list)
for u, v, weight in edges:
graph[v].append((u, weight))

n = len(G.nodes)
parent = {} # Parent pointers for the MST
in_edge_pq: list = [] # Priority queue to store incoming edges

# Initialize the priority queue with edges incoming to the root
for u, weight in graph[root]:
heapq.heappush(in_edge_pq, (weight, u, root))

mst_edges = []
mst_weight = 0
outlets: dict = {}
while in_edge_pq:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a top level explanation of what this loop is doing and how ti finds the MST.

weight, u, v = heapq.heappop(in_edge_pq)

if u not in parent:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid too much indentation, you can check the opposite and then continue:

if u in parent:
    continue

# If v is not in the MST yet, add the edge (u, v) to the MST
parent[u] = v
...

Visually, it reduces complexity.

# If v is not in the MST yet, add the edge (u, v) to the MST
parent[u] = v
mst_edges.append((u, v))
mst_weight += weight

if v in outlets:
outlets[u] = outlets[v]

elif G_.get_edge_data(u,v)[0]['edge_type'] == 'outlet':
outlets[u] = node_mapping[u]

# Add incoming edges to v to the priority queue
for w, weight_new in graph[u]:
heapq.heappush(in_edge_pq, (weight_new, w, u))

# Check if all vertices are reachable from the root
if len(parent) != n - 1:
Comment on lines +75 to +76
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this. Why checking the length checks that all vertices are reachable from the root? I mean, why such a length should be equal to n-1, when n is the total number of nodes? Probably is because I do not fully understand what these trees are, so don't worry if it's obvious it should this like this.

raise ValueError("Graph is not connected or has multiple roots.")

new_graph = nx.MultiDiGraph()

for u,v in mst_edges:
d= G_.get_edge_data(u,v)[0]
new_graph.add_edge(u,v,**d)

for u, d in G_.nodes(data=True):
new_graph.nodes[u].update(d)

nx.set_node_attributes(new_graph, outlets, 'outlet')
new_graph = nx.relabel_nodes(new_graph, node_mapping)
new_graph.graph = G.graph.copy()
return new_graph

def dijkstra_pq(G: nx.MultiDiGraph,
outlets: list,
weight_attr: str = 'weight') -> nx.MultiDiGraph:
"""Dijkstra's algorithm for shortest paths to outlets.

This function calculates the shortest paths from each node in the graph to
the nearest outlet. The graph is modified to include the outlet
and the shortest path length.

Args:
G (nx.MultiDiGraph): The input graph.
outlets (list): A list of outlet nodes.
weight_attr (str): The name of the edge attribute containing the edge
weights. Defaults to 'weight'.

Returns:
nx.MultiDiGraph: The graph with the shortest paths to outlets.
"""
G = G.copy()
# Initialize the dictionary with infinity for all nodes
shortest_paths = {node: float('inf') for node in G.nodes}

# Initialize the dictionary to store the paths
paths: dict[Hashable,list] = {node: [] for node in G.nodes}

# Set the shortest path length to 0 for outlets
for outlet in outlets:
shortest_paths[outlet] = 0
paths[outlet] = [outlet]

# Initialize a min-heap with (distance, node) tuples
heap = [(0, outlet) for outlet in outlets]
while heap:
# Pop the node with the smallest distance
dist, node = heapq.heappop(heap)

# For each neighbor of the current node
for neighbor, _, edge_data in G.in_edges(node, data=True):
# Calculate the distance through the current node
alt_dist = dist + edge_data[weight_attr]
# If the alternative distance is shorter

if alt_dist >= shortest_paths[neighbor]:
continue

# Update the shortest path length
shortest_paths[neighbor] = alt_dist
# Update the path
paths[neighbor] = paths[node] + [neighbor]
# Push the neighbor to the heap
heapq.heappush(heap, (alt_dist, neighbor))

# Remove nodes with no path to an outlet
for node in [node for node, path in paths.items() if not path]:
G.remove_node(node)
del paths[node], shortest_paths[node]

if len(G.nodes) == 0:
raise ValueError("""No nodes with path to outlet, """)

edges_to_keep: set = set()

for path in paths.values():
# Assign outlet
outlet = path[0]
for node in path:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this exclude path[0] or that one should have as outlet itself?

G.nodes[node]['outlet'] = outlet
G.nodes[node]['shortest_path'] = shortest_paths[node]

# Store path
edges_to_keep.update(zip(path[1:], path[:-1]))

# Remove edges not on paths
new_graph = G.copy()
for u,v in G.edges():
if (u,v) not in edges_to_keep:
new_graph.remove_edge(u,v)

return new_graph
105 changes: 105 additions & 0 deletions tests/test_shortest_path_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from __future__ import annotations

import networkx as nx
import pytest

from swmmanywhere.shortest_path_utils import dijkstra_pq, tarjans_pq


def test_simple_graph():
"""Test case 1: Simple connected graph."""
G = nx.MultiDiGraph()
G.add_edges_from([(2, 1), (3, 1), (4, 2), (4, 3), (1,'waste')])
nx.set_edge_attributes(G,
{(2, 1,0):{'weight': 1, 'edge_type':'outlet'},
(3, 1,0):{'weight': 2, 'edge_type':'street'},
(4, 2,0):{'weight': 3, 'edge_type':'street'},
(4, 3,0):{'weight': 4, 'edge_type':'street'},
(1,'waste',0):{'weight': 0, 'edge_type':'outlet'}})
root = 'waste'
mst = tarjans_pq(G, root)
assert set(mst.edges) == {(1,'waste',0),(2,1,0), (4,2,0), (3,1,0)}
assert mst.size() == 4

djg = dijkstra_pq(G, [root])
assert set(djg.edges) == {(1,'waste',0),(2,1,0), (4,2,0), (3,1,0)}
assert djg.size() == 4

def test_disconnected():
"""Test case 2: Disconnected graph."""
G = nx.MultiDiGraph()
G.add_edges_from([(1,'waste'), (2,1), (3,1), (5,4), (6,5)],
weight=1,
edge_type = 'street')
root = 'waste'
with pytest.raises(ValueError):
tarjans_pq(G, root)

djg = dijkstra_pq(G, [root])
assert set(djg.edges) == {(1, 'waste', 0), (2, 1, 0), (3, 1, 0)}

def test_parallel():
"""Test case 3: Graph with parallel edges."""
G = nx.MultiDiGraph()
G.add_edges_from([(2, 1, 0),
(2, 1, 1),
(3, 1, 0),
(4, 2, 0),
(4, 3, 0)],
edge_type='street',
weight=1)
root = 1
mst = tarjans_pq(G, root)
assert set(mst.edges) == {(2, 1, 0), (4, 2, 0), (3, 1, 0)}
assert mst.size() == 3

djg = dijkstra_pq(G, [root])
# Currently paths are defined as node-to-node and so ignore keys .. TODO?
assert set(djg.edges) == {(2, 1, 0), (2, 1, 1), (3, 1, 0), (4, 2, 0)}

def test_selfloop():
"""Test case 4: Graph with self-loops."""
G = nx.MultiDiGraph()
G.add_edges_from([(2, 1, 0),
(3, 1, 0),
(4, 2, 0),
(4, 3, 0),
(2, 4, 0)],
edge_type='street',
weight=1)
G.add_edge(3,4,weight=1,edge_type='street')
root = 1
mst = tarjans_pq(G, root)
assert set(mst.edges) == {(2, 1, 0), (4, 2, 0), (3, 1, 0)}
assert mst.size() == 3

djg = dijkstra_pq(G, [root])
assert set(djg.edges) == {(2, 1, 0), (4, 2, 0), (3, 1, 0)}
assert djg.size() == 3

def test_custom_weight():
"""Test case 5: Graph with custom weight attribute."""
G = nx.MultiDiGraph()
G.add_edges_from([(2, 1, 0),
(3, 1, 0),
(4, 2, 0),
(4, 3, 0)],
edge_type='street',
cost=1)
root = 1
mst = tarjans_pq(G, root, weight_attr='cost')
assert set(mst.edges) == {(2, 1, 0), (4, 2, 0), (3, 1, 0)}
assert mst.size() == 3

djg = dijkstra_pq(G, [root], weight_attr='cost')
assert set(djg.edges) == {(2, 1, 0), (4, 2, 0), (3, 1, 0)}
assert djg.size() == 3

def test_empty():
"""Test case 6: Empty graph."""
G = nx.MultiDiGraph()
root = 1
with pytest.raises(KeyError):
tarjans_pq(G, root)
with pytest.raises(nx.NetworkXError):
dijkstra_pq(G, [root])