Skip to content

Commit

Permalink
Add test cases for neighbourhood functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismostert committed Feb 22, 2024
1 parent 7393a46 commit bb0720c
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 12 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,6 @@ cython_debug/
testing_wip.py

# Output csv files
*.csv
a.csv
b.csv
c.csv
11 changes: 0 additions & 11 deletions src/neighbourhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,6 @@ class ReportingNeighbourhoods:
neighbourhood_id_to_reporting_unit_ids: Dict[str, Set[str]]
neighbourhood_id_to_reference_group: Dict[str, ReportingUnitInfo]

def get_reference_group(
self, reporting_unit_id: str
) -> Optional[ReportingUnitInfo]:
neighbourhood_id = self.reporting_unit_id_to_neighbourhood_id.get(
reporting_unit_id
)
if neighbourhood_id is None:
return None

return self.neighbourhood_id_to_reference_group[neighbourhood_id]


class NeighbourhoodData:
data: pl.LazyFrame
Expand Down
2 changes: 2 additions & 0 deletions test/data/neighbourhood_files/invalid.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
neighbourhood_code,ambiguous
123,no
3 changes: 3 additions & 0 deletions test/data/neighbourhood_files/valid.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
zip_code,neighbourhood_code,ambiguous
1234AB,WK123,no
1235AB,WK123,no
Binary file added test/data/neighbourhood_files/valid.parquet
Binary file not shown.
Empty file.
141 changes: 141 additions & 0 deletions test/test_neighbourhood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import pytest
from eml_types import CandidateIdentifier, PartyIdentifier, ReportingUnitInfo
from neighbourhood import NeighbourhoodData, ReportingNeighbourhoods
from typing import Optional, Dict

read_test_cases = [
("./test/data/neighbourhood_files/valid.parquet", True),
("./test/data/neighbourhood_files/valid.csv", True),
("./test/data/neighbourhood_files/THIS_FILE_DOES_NOT_EXIST.parquet", False),
("./test/data/THIS_FOLDER_DOES_NOT_EXIST/valid.parquet", False),
("./test/data/neighbourhood_files/wrong_filetype.txt", False),
("./test/data/neighbourhood_files/invalid.csv", False),
]


@pytest.mark.parametrize("path_to_read, should_load", read_test_cases)
def test_read_neighbourhood_file(path_to_read: str, should_load: bool) -> None:
result = NeighbourhoodData.from_path(path_to_read)
if should_load:
assert isinstance(result, NeighbourhoodData)
else:
assert result is None


fetch_test_cases = [
("./test/data/neighbourhood_files/valid.csv", "1234AB", "WK123"),
("./test/data/neighbourhood_files/valid.csv", "1235AB", "WK123"),
("./test/data/neighbourhood_files/valid.csv", "1234AC", None),
]


@pytest.mark.parametrize("path_to_read, zip_code, expected", fetch_test_cases)
def test_fetch_neighbourhood_code(
path_to_read: str, zip_code: str, expected: str
) -> None:
data = NeighbourhoodData.from_path(path_to_read)
assert data is not None and data.fetch_neighbourhood_code(zip_code) == expected


reporting_neighbourhoods_test_cases = [
(
"./test/data/neighbourhood_files/valid.csv",
{"SB1": "1234AB", "SB2": "1235AB", "SB3": "9999XX", "SB4": None},
{
"SB1": ReportingUnitInfo(
reporting_unit_id="SB1",
reporting_unit_name="SB1",
cast=0,
total_counted=0,
rejected_votes={},
uncounted_votes={},
votes_per_party={PartyIdentifier(id=1, name=None): 10},
votes_per_candidate={
CandidateIdentifier(PartyIdentifier(id=1, name=None), 1): 8,
CandidateIdentifier(PartyIdentifier(id=1, name=None), 2): 2,
},
),
"SB2": ReportingUnitInfo(
reporting_unit_id="SB2",
reporting_unit_name="SB2",
cast=0,
total_counted=0,
rejected_votes={},
uncounted_votes={},
votes_per_party={PartyIdentifier(id=1, name=None): 12},
votes_per_candidate={
CandidateIdentifier(PartyIdentifier(id=1, name=None), 1): 11,
CandidateIdentifier(PartyIdentifier(id=1, name=None), 2): 1,
},
),
"SB3": ReportingUnitInfo(
reporting_unit_id="SB3",
reporting_unit_name="SB3",
cast=0,
total_counted=0,
rejected_votes={},
uncounted_votes={},
votes_per_party={PartyIdentifier(id=1, name=None): 100},
votes_per_candidate={
CandidateIdentifier(PartyIdentifier(id=1, name=None), 1): 80,
CandidateIdentifier(PartyIdentifier(id=1, name=None), 2): 20,
},
),
"SB4": ReportingUnitInfo(
reporting_unit_id="SB4",
reporting_unit_name="SB4",
cast=0,
total_counted=0,
rejected_votes={},
uncounted_votes={},
votes_per_party={PartyIdentifier(id=1, name=None): 200},
votes_per_candidate={
CandidateIdentifier(PartyIdentifier(id=1, name=None), 1): 160,
CandidateIdentifier(PartyIdentifier(id=1, name=None), 2): 40,
},
),
},
ReportingNeighbourhoods(
reporting_unit_id_to_neighbourhood_id={
"SB1": "WK123",
"SB2": "WK123",
"SB3": None,
"SB4": None,
},
neighbourhood_id_to_reporting_unit_ids={"WK123": set(["SB1", "SB2"])},
neighbourhood_id_to_reference_group={
"WK123": ReportingUnitInfo(
reporting_unit_id="WK123",
reporting_unit_name=f"Reference group for WK123",
cast=0,
total_counted=0,
rejected_votes={},
uncounted_votes={},
votes_per_party={PartyIdentifier(id=1, name=None): 22},
votes_per_candidate={
CandidateIdentifier(PartyIdentifier(id=1, name=None), 1): 19,
CandidateIdentifier(PartyIdentifier(id=1, name=None), 2): 3,
},
)
},
),
)
]


@pytest.mark.parametrize(
"path_to_read, reporting_unit_zips, reporting_unit_info, expected",
reporting_neighbourhoods_test_cases,
)
def test_fetch_reporting_neighbourhoods(
path_to_read: str,
reporting_unit_zips: Dict[str, Optional[str]],
reporting_unit_info: Dict[str, ReportingUnitInfo],
expected: ReportingNeighbourhoods,
) -> None:
neighbourhood_data = NeighbourhoodData.from_path(path_to_read)
assert neighbourhood_data is not None
reporting_neighourhoods = neighbourhood_data.fetch_reporting_neighbourhoods(
reporting_unit_zips, reporting_unit_info
)
assert reporting_neighourhoods == expected

0 comments on commit bb0720c

Please sign in to comment.