Skip to content

Commit

Permalink
Merge pull request #391 from mggg/fix-default-updaters-bug
Browse files Browse the repository at this point in the history
Fix default updaters bug in GeographicPartition
  • Loading branch information
gabeschoenbach authored Apr 14, 2022
2 parents 094a02c + 2638415 commit 18cde80
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
15 changes: 8 additions & 7 deletions gerrychain/partition/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,29 @@ class Partition:
'_cache'
)

default_updaters = {"cut_edges": cut_edges}

def __init__(
self, graph=None, assignment=None, updaters=None, parent=None, flips=None,
use_cut_edges=True
use_default_updaters=True
):
"""
:param graph: Underlying graph.
:param assignment: Dictionary assigning nodes to districts.
:param updaters: Dictionary of functions to track data about the partition.
The keys are stored as attributes on the partition class,
which the functions compute.
:param use_cut_edges: If `False`, do not include `cut_edges` updater by default
and do not calculate edge flows.
:param use_default_updaters: If `False`, do not include default updaters.
"""
if parent is None:
self._first_time(graph, assignment, updaters, use_cut_edges)
self._first_time(graph, assignment, updaters, use_default_updaters)
else:
self._from_parent(parent, flips)

self._cache = dict()
self.subgraphs = SubgraphView(self.graph, self.parts)

def _first_time(self, graph, assignment, updaters, use_cut_edges):
def _first_time(self, graph, assignment, updaters, use_default_updaters):
if isinstance(graph, Graph):
self.graph = FrozenGraph(graph)
elif isinstance(graph, networkx.Graph):
Expand All @@ -71,8 +72,8 @@ def _first_time(self, graph, assignment, updaters, use_cut_edges):
if updaters is None:
updaters = dict()

if use_cut_edges:
self.updaters = {"cut_edges": cut_edges}
if use_default_updaters:
self.updaters = self.default_updaters
else:
self.updaters = {}

Expand Down
22 changes: 22 additions & 0 deletions tests/partition/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,25 @@ def test_partition_has_default_updaters(example_partition):

def test_partition_has_keys(example_partition):
assert "cut_edges" in set(example_partition.keys())


def test_geographic_partition_has_keys(example_geographic_partition):
keys = set(example_geographic_partition.updaters.keys())

assert "perimeter" in keys
assert "exterior_boundaries" in keys
assert "interior_boundaries" in keys
assert "boundary_nodes" in keys
assert "cut_edges" in keys
assert "area" in keys
assert "cut_edges_by_part" in keys


def test_partition_has_default_updaters(example_geographic_partition):
assert hasattr(example_geographic_partition, "perimeter")
assert hasattr(example_geographic_partition, "exterior_boundaries")
assert hasattr(example_geographic_partition, "interior_boundaries")
assert hasattr(example_geographic_partition, "boundary_nodes")
assert hasattr(example_geographic_partition, "cut_edges")
assert hasattr(example_geographic_partition, "area")
assert hasattr(example_geographic_partition, "cut_edges_by_part")

0 comments on commit 18cde80

Please sign in to comment.