Skip to content

Commit

Permalink
Allow passing neighbourhood data and save zip in EmlMetadata
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismostert committed Feb 20, 2024
1 parent 1eedc7f commit a1d51ae
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 5 deletions.
Binary file added data/zip_to_neighbourhood_2023.parquet
Binary file not shown.
3 changes: 2 additions & 1 deletion src/eml.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class EML:
MINIMUM_DEVIATION_FACTOR: ClassVar[int] = 10
MINIMUM_VOTES: ClassVar[int] = 20

def run_protocol(self) -> Dict[str, CheckResult]:
def run_protocol(self, neighbourhood_data=None) -> Dict[str, CheckResult]:
protocol_results = {}

for polling_station_id, polling_station in self.reporting_units_info.items():
Expand Down Expand Up @@ -80,6 +80,7 @@ def run_protocol(self) -> Dict[str, CheckResult]:
potentially_switched_candidates=protocol_checks.get_potentially_switched_candidates(
self.main_unit_info,
polling_station,
neighbourhood_data,
amount_of_reporting_units=self.metadata.reporting_unit_amount,
minimum_reporting_units=EML.MINIMUM_REPORTING_UNITS,
minimum_deviation_factor=EML.MINIMUM_DEVIATION_FACTOR,
Expand Down
3 changes: 2 additions & 1 deletion src/eml_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class EmlMetadata:
election_date: Optional[str]
contest_identifier: Optional[str]
reporting_unit_amount: int
reporting_unit_names: Dict[Optional[str], Optional[str]]
reporting_unit_names: Dict[str, Optional[str]]
reporting_unit_zips: Dict[str, Optional[str]]


@dataclass(frozen=True, order=True)
Expand Down
15 changes: 13 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
from eml import EML
from odt import ODT
from neighbourhood import NeighbourhoodData
import csv_write


def create_csv_files(path_to_xml, dest_a, dest_b, dest_c, path_to_odt=None) -> None:
def create_csv_files(
path_to_xml,
dest_a,
dest_b,
dest_c,
path_to_odt=None,
path_to_neighbourhood_data=None,
) -> None:
# Parse the eml from the path and run all checks in the protocol
eml = EML.from_xml(path_to_xml)

check_results = eml.run_protocol()
# Load in neighbourhood data
neighbourhood_data = NeighbourhoodData.from_path(path_to_neighbourhood_data)

check_results = eml.run_protocol(neighbourhood_data=path_to_neighbourhood_data)
eml_metadata = eml.metadata

# If odt_path is specified we try to read the file and extract the relevant
Expand Down
39 changes: 39 additions & 0 deletions src/neighbourhood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import polars as pl
from pathlib import Path
from typing import Optional, List, Dict
from dataclasses import dataclass


@dataclass
class ReportingNeighbourhoods:
reporting_unit_to_neighbourhood: Dict[str, Optional[str]]
neighbourhood_to_reporting_units: Dict[str, List[str]]


class NeighbourhoodData:
data: pl.LazyFrame

def __init__(self, data) -> None:
self.data = data

@staticmethod
def from_path(str_path: Optional[str]):
if str_path == None:
return None

try:
path = Path(str_path)
data = None
if path.suffix == ".csv":
data = pl.scan_csv(path)
elif path.suffix == ".parquet":
data = pl.scan_parquet(path)
else:
return None
except Exception:
return None

if data.columns != ["zip_code", "neighbourhood_code", "ambiguous"]:
return None

return NeighbourhoodData(data=data)
2 changes: 2 additions & 0 deletions src/protocol_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from typing import Dict, Optional, TypeVar
from itertools import product as cartesian_product
from neighbourhood import NeighbourhoodData

T = TypeVar("T")
N = TypeVar("N", int, float)
Expand Down Expand Up @@ -139,6 +140,7 @@ def get_expected_candidate_votes(
def get_potentially_switched_candidates(
main_unit: ReportingUnitInfo,
reporting_unit: ReportingUnitInfo,
neighbourhood_data: Optional[NeighbourhoodData],
amount_of_reporting_units: int,
minimum_reporting_units: int,
minimum_deviation_factor: int,
Expand Down
34 changes: 33 additions & 1 deletion src/xml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CandidateIdentifier,
InvalidEmlException,
)
import re

NAMESPACE = {
"eml": "urn:oasis:names:tc:evs:schema:eml",
Expand All @@ -16,6 +17,7 @@
"xal": "urn:oasis:names:tc:ciq:xsdschema:xAL:2.0",
"xnl": "urn:oasis:names:tc:ciq:xsdschema:xNL:2.0",
}
ZIP_REGEX = re.compile(r"\(postcode: (\d{4} \w{2})\)")


def get_text(xml_element: Optional[XmlElement]) -> Optional[str]:
Expand All @@ -39,6 +41,31 @@ def get_attrib(xml_element: Optional[XmlElement], attrib_name: str) -> Optional[
return xml_element.attrib.get(attrib_name) if xml_element is not None else None


def get_mandatory_attrib(xml_element: Optional[XmlElement], attrib_name: str) -> str:
if xml_element is None:
raise ValueError("Could not find specified XML element")

attrib = xml_element.attrib.get(attrib_name)
if attrib is None:
raise AttributeError(
f"Element {xml_element} did not have attribute {attrib_name} but was mandatory"
)

return attrib


def extract_zip_from_name(reporting_unit_name: Optional[str]) -> Optional[str]:
if reporting_unit_name is None:
return None
search_result = re.search(ZIP_REGEX, reporting_unit_name)
if search_result is None:
return None
search_groups = search_result.groups()
if len(search_groups) != 1:
return None
return search_groups[0].replace(" ", "")


def parse_xml(file_name: Union[str, IO[bytes]]) -> XmlElement:
# EML should be checked so that it validates using the XSD

Expand Down Expand Up @@ -103,7 +130,11 @@ def get_metadata(root: XmlElement) -> EmlMetadata:

reporting_units = root.findall(".//eml:ReportingUnitIdentifier", NAMESPACE)
reporting_unit_names = {
get_attrib(elem, "Id"): get_text(elem) for elem in reporting_units
get_mandatory_attrib(elem, "Id"): get_text(elem) for elem in reporting_units
}
reporting_unit_zips = {
reporting_unit_id: extract_zip_from_name(reporting_unit_name)
for (reporting_unit_id, reporting_unit_name) in reporting_unit_names.items()
}

return EmlMetadata(
Expand All @@ -117,6 +148,7 @@ def get_metadata(root: XmlElement) -> EmlMetadata:
contest_identifier=contest_identifier,
reporting_unit_amount=len(reporting_units),
reporting_unit_names=reporting_unit_names,
reporting_unit_zips=reporting_unit_zips,
)


Expand Down
1 change: 1 addition & 0 deletions test/test_eml.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@
reporting_unit_names={
"0505::SB1": "Stembureau Binnenstad (postcode: 3331 DA)"
},
reporting_unit_zips={"0505::SB1": "3331DA"},
),
),
]
Expand Down

0 comments on commit a1d51ae

Please sign in to comment.