diff --git a/.gitignore b/.gitignore index 48140f8..d46794c 100644 --- a/.gitignore +++ b/.gitignore @@ -167,4 +167,6 @@ cython_debug/ testing_wip.py # Output csv files -*.csv \ No newline at end of file +a.csv +b.csv +c.csv \ No newline at end of file diff --git a/src/neighbourhood.py b/src/neighbourhood.py index 703c57d..e957329 100644 --- a/src/neighbourhood.py +++ b/src/neighbourhood.py @@ -18,17 +18,6 @@ class ReportingNeighbourhoods: neighbourhood_id_to_reporting_unit_ids: Dict[str, Set[str]] neighbourhood_id_to_reference_group: Dict[str, ReportingUnitInfo] - def get_reference_group( - self, reporting_unit_id: str - ) -> Optional[ReportingUnitInfo]: - neighbourhood_id = self.reporting_unit_id_to_neighbourhood_id.get( - reporting_unit_id - ) - if neighbourhood_id is None: - return None - - return self.neighbourhood_id_to_reference_group[neighbourhood_id] - class NeighbourhoodData: data: pl.LazyFrame diff --git a/test/data/neighbourhood_files/invalid.csv b/test/data/neighbourhood_files/invalid.csv new file mode 100644 index 0000000..83db119 --- /dev/null +++ b/test/data/neighbourhood_files/invalid.csv @@ -0,0 +1,2 @@ +neighbourhood_code,ambiguous +123,no \ No newline at end of file diff --git a/test/data/neighbourhood_files/valid.csv b/test/data/neighbourhood_files/valid.csv new file mode 100644 index 0000000..c0f0788 --- /dev/null +++ b/test/data/neighbourhood_files/valid.csv @@ -0,0 +1,3 @@ +zip_code,neighbourhood_code,ambiguous +1234AB,WK123,no +1235AB,WK123,no \ No newline at end of file diff --git a/test/data/neighbourhood_files/valid.parquet b/test/data/neighbourhood_files/valid.parquet new file mode 100644 index 0000000..9467055 Binary files /dev/null and b/test/data/neighbourhood_files/valid.parquet differ diff --git a/test/data/neighbourhood_files/wrong_filetype.txt b/test/data/neighbourhood_files/wrong_filetype.txt new file mode 100644 index 0000000..e69de29 diff --git a/test/test_neighbourhood.py b/test/test_neighbourhood.py new file mode 100644 index 0000000..c7dac34 --- /dev/null +++ b/test/test_neighbourhood.py @@ -0,0 +1,141 @@ +import pytest +from eml_types import CandidateIdentifier, PartyIdentifier, ReportingUnitInfo +from neighbourhood import NeighbourhoodData, ReportingNeighbourhoods +from typing import Optional, Dict + +read_test_cases = [ + ("./test/data/neighbourhood_files/valid.parquet", True), + ("./test/data/neighbourhood_files/valid.csv", True), + ("./test/data/neighbourhood_files/THIS_FILE_DOES_NOT_EXIST.parquet", False), + ("./test/data/THIS_FOLDER_DOES_NOT_EXIST/valid.parquet", False), + ("./test/data/neighbourhood_files/wrong_filetype.txt", False), + ("./test/data/neighbourhood_files/invalid.csv", False), +] + + +@pytest.mark.parametrize("path_to_read, should_load", read_test_cases) +def test_read_neighbourhood_file(path_to_read: str, should_load: bool) -> None: + result = NeighbourhoodData.from_path(path_to_read) + if should_load: + assert isinstance(result, NeighbourhoodData) + else: + assert result is None + + +fetch_test_cases = [ + ("./test/data/neighbourhood_files/valid.csv", "1234AB", "WK123"), + ("./test/data/neighbourhood_files/valid.csv", "1235AB", "WK123"), + ("./test/data/neighbourhood_files/valid.csv", "1234AC", None), +] + + +@pytest.mark.parametrize("path_to_read, zip_code, expected", fetch_test_cases) +def test_fetch_neighbourhood_code( + path_to_read: str, zip_code: str, expected: str +) -> None: + data = NeighbourhoodData.from_path(path_to_read) + assert data is not None and data.fetch_neighbourhood_code(zip_code) == expected + + +reporting_neighbourhoods_test_cases = [ + ( + "./test/data/neighbourhood_files/valid.csv", + {"SB1": "1234AB", "SB2": "1235AB", "SB3": "9999XX", "SB4": None}, + { + "SB1": ReportingUnitInfo( + reporting_unit_id="SB1", + reporting_unit_name="SB1", + cast=0, + total_counted=0, + rejected_votes={}, + uncounted_votes={}, + votes_per_party={PartyIdentifier(id=1, name=None): 10}, + votes_per_candidate={ + CandidateIdentifier(PartyIdentifier(id=1, name=None), 1): 8, + CandidateIdentifier(PartyIdentifier(id=1, name=None), 2): 2, + }, + ), + "SB2": ReportingUnitInfo( + reporting_unit_id="SB2", + reporting_unit_name="SB2", + cast=0, + total_counted=0, + rejected_votes={}, + uncounted_votes={}, + votes_per_party={PartyIdentifier(id=1, name=None): 12}, + votes_per_candidate={ + CandidateIdentifier(PartyIdentifier(id=1, name=None), 1): 11, + CandidateIdentifier(PartyIdentifier(id=1, name=None), 2): 1, + }, + ), + "SB3": ReportingUnitInfo( + reporting_unit_id="SB3", + reporting_unit_name="SB3", + cast=0, + total_counted=0, + rejected_votes={}, + uncounted_votes={}, + votes_per_party={PartyIdentifier(id=1, name=None): 100}, + votes_per_candidate={ + CandidateIdentifier(PartyIdentifier(id=1, name=None), 1): 80, + CandidateIdentifier(PartyIdentifier(id=1, name=None), 2): 20, + }, + ), + "SB4": ReportingUnitInfo( + reporting_unit_id="SB4", + reporting_unit_name="SB4", + cast=0, + total_counted=0, + rejected_votes={}, + uncounted_votes={}, + votes_per_party={PartyIdentifier(id=1, name=None): 200}, + votes_per_candidate={ + CandidateIdentifier(PartyIdentifier(id=1, name=None), 1): 160, + CandidateIdentifier(PartyIdentifier(id=1, name=None), 2): 40, + }, + ), + }, + ReportingNeighbourhoods( + reporting_unit_id_to_neighbourhood_id={ + "SB1": "WK123", + "SB2": "WK123", + "SB3": None, + "SB4": None, + }, + neighbourhood_id_to_reporting_unit_ids={"WK123": set(["SB1", "SB2"])}, + neighbourhood_id_to_reference_group={ + "WK123": ReportingUnitInfo( + reporting_unit_id="WK123", + reporting_unit_name=f"Reference group for WK123", + cast=0, + total_counted=0, + rejected_votes={}, + uncounted_votes={}, + votes_per_party={PartyIdentifier(id=1, name=None): 22}, + votes_per_candidate={ + CandidateIdentifier(PartyIdentifier(id=1, name=None), 1): 19, + CandidateIdentifier(PartyIdentifier(id=1, name=None), 2): 3, + }, + ) + }, + ), + ) +] + + +@pytest.mark.parametrize( + "path_to_read, reporting_unit_zips, reporting_unit_info, expected", + reporting_neighbourhoods_test_cases, +) +def test_fetch_reporting_neighbourhoods( + path_to_read: str, + reporting_unit_zips: Dict[str, Optional[str]], + reporting_unit_info: Dict[str, ReportingUnitInfo], + expected: ReportingNeighbourhoods, +) -> None: + neighbourhood_data = NeighbourhoodData.from_path(path_to_read) + assert neighbourhood_data is not None + reporting_neighourhoods = neighbourhood_data.fetch_reporting_neighbourhoods( + reporting_unit_zips, reporting_unit_info + ) + assert reporting_neighourhoods == expected