diff --git a/src/zino/planned_maintenance.py b/src/zino/planned_maintenance.py index af6543d90..252402bf1 100644 --- a/src/zino/planned_maintenance.py +++ b/src/zino/planned_maintenance.py @@ -1,10 +1,17 @@ import logging from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Dict, Optional, Protocol +from typing import TYPE_CHECKING, Dict, Optional, Protocol, Union from pydantic.main import BaseModel -from zino.statemodels import Event, EventState, MatchType, PlannedMaintenance +from zino.statemodels import ( + DeviceMaintenance, + Event, + EventState, + MatchType, + PlannedMaintenance, + PortStateMaintenance, +) if TYPE_CHECKING: from zino.state import ZinoState @@ -22,7 +29,7 @@ def __call__(self) -> None: ... class PlannedMaintenances(BaseModel): - planned_maintenances: Dict[int, PlannedMaintenance] = {} + planned_maintenances: Dict[int, Union[DeviceMaintenance, PortStateMaintenance]] = {} last_pm_id: int = 0 last_run: Optional[datetime] = datetime.fromtimestamp(0) _observers: list[PlannedMaintenanceObserver] = [] @@ -40,7 +47,7 @@ def create_planned_maintenance( pm_class: type[PlannedMaintenance], match_type: MatchType, match_expression: str, - match_device: Optional[str], + match_device: Optional[str] = None, ) -> PlannedMaintenance: """Creates a planned maintenance, adds it to the planned_maintenances dict and returns it diff --git a/src/zino/statemodels.py b/src/zino/statemodels.py index 06ab2997c..12070fa19 100644 --- a/src/zino/statemodels.py +++ b/src/zino/statemodels.py @@ -393,7 +393,7 @@ class PlannedMaintenance(BaseModel): end_time: datetime.datetime type: PmType match_type: MatchType - match_device: Optional[str] + match_device: Optional[str] = None match_expression: str log: List[LogEntry] = [] event_ids: List[int] = [] @@ -430,7 +430,7 @@ def _get_or_create_events(self, state: "ZinoState") -> list[Event]: class DeviceMaintenance(PlannedMaintenance): - type: PmType = PmType.DEVICE + type: Literal[PmType.DEVICE] = PmType.DEVICE def matches_event(self, event: Event, state: "ZinoState") -> bool: """Returns true if `event` will be affected by this planned maintenance""" @@ -470,7 +470,7 @@ def _get_or_create_events(self, state: "ZinoState") -> list[Event]: class PortStateMaintenance(PlannedMaintenance): - type: PmType = PmType.PORTSTATE + type: Literal[PmType.PORTSTATE] = PmType.PORTSTATE def matches_event(self, event: Event, state: "ZinoState") -> bool: """Returns true if `event` will be affected by this planned maintenance""" diff --git a/tests/planned_maintenance_test.py b/tests/planned_maintenance_test.py index 03306ae2e..eade9083c 100644 --- a/tests/planned_maintenance_test.py +++ b/tests/planned_maintenance_test.py @@ -148,6 +148,16 @@ def test_event_opened_after_pm_was_initiated_should_be_set_to_ignored(self, stat assert state.events.checkout(event.id).state == EventState.IGNORED +def test_pms_should_be_parsed_as_correct_subclass_when_read_from_file(tmp_path, state, active_portstate_pm, active_pm): + dumpfile = tmp_path / "dump.json" + state.dump_state_to_file(dumpfile) + read_state = ZinoState.load_state_from_file(str(dumpfile)) + read_device_pm = read_state.planned_maintenances[active_pm.id] + read_portstate_pm = read_state.planned_maintenances[active_portstate_pm.id] + assert isinstance(read_device_pm, DeviceMaintenance) + assert isinstance(read_portstate_pm, PortStateMaintenance) + + @pytest.fixture def pms(): return PlannedMaintenances()