Skip to content

Commit

Permalink
Add type sanity checks in __eq__ method
Browse files Browse the repository at this point in the history
  • Loading branch information
andreArtelt committed Mar 17, 2024
1 parent c5d7d66 commit e5bdd96
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 0 deletions.
144 changes: 144 additions & 0 deletions epyt_flow/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ def running_mse(y_pred: np.ndarray, y: np.ndarray):
`float`
Running MSE.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")
if len(y_pred.shape) != 1:
raise ValueError("'y_pred' must be a 1d array")
if len(y.shape) != 1:
raise ValueError("'y' must be a 1d array")

e_sq = np.square(y - y_pred)
r_mse = list(esq for esq in e_sq)

Expand Down Expand Up @@ -51,6 +64,22 @@ def mape(y_pred: np.ndarray, y: np.ndarray, epsilon: float = .05) -> float:
`float`
MAPE score.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if not isinstance(epsilon, float):
raise TypeError("'epsilon' must be an instance of 'float' " +
f"but not of '{type(epsilon)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")
if len(y_pred.shape) != 1:
raise ValueError("'y_pred' must be a 1d array")
if len(y.shape) != 1:
raise ValueError("'y' must be a 1d array")

y_ = y + epsilon
y_pred_ = y_pred + epsilon
return np.mean(np.abs((y_ - y_pred_) / y_))
Expand All @@ -76,6 +105,22 @@ def smape(y_pred: np.ndarray, y: np.ndarray, epsilon: float = .05) -> float:
`float`
SMAPE score.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if not isinstance(epsilon, float):
raise TypeError("'epsilon' must be an instance of 'float' " +
f"but not of '{type(epsilon)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")
if len(y_pred.shape) != 1:
raise ValueError("'y_pred' must be a 1d array")
if len(y.shape) != 1:
raise ValueError("'y' must be a 1d array")

y_ = y + epsilon
y_pred_ = y_pred + epsilon
return 2. * np.mean(np.abs(y_ - y_pred_) / (np.abs(y_) + np.abs(y_pred_)))
Expand All @@ -101,6 +146,22 @@ def mase(y_pred: np.ndarray, y: np.ndarray, epsilon: float = .05) -> float:
`float`
MASE score.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if not isinstance(epsilon, float):
raise TypeError("'epsilon' must be an instance of 'float' " +
f"but not of '{type(epsilon)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")
if len(y_pred.shape) != 1:
raise ValueError("'y_pred' must be a 1d array")
if len(y.shape) != 1:
raise ValueError("'y' must be a 1d array")

try:
y_ = y + epsilon
y_pred_ = y_pred + epsilon
Expand Down Expand Up @@ -129,6 +190,15 @@ def f1_micro_score(y_pred: np.ndarray, y: np.ndarray) -> float:
`float`
F1 score.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")

return skelarn_f1_scpre(y, y_pred, average="micro")


Expand All @@ -148,6 +218,15 @@ def roc_auc_score(y_pred: np.ndarray, y: np.ndarray) -> float:
`float`
ROC AUC score.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")

return skelarn_roc_auc_score(y, y_pred)


Expand All @@ -167,6 +246,21 @@ def true_positive_rate(y_pred: np.ndarray, y: np.ndarray) -> float:
`float`
True positive rate.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")
if len(y_pred.shape) != 1:
raise ValueError("'y_pred' must be a 1d array")
if len(y.shape) != 1:
raise ValueError("'y' must be a 1d array")
if set(np.unique(y_pred)) != set([0, 1]):
raise ValueError("Labels must be either '0' or '1'")

tp = np.sum((y == 1) & (y_pred == 1))
fn = np.sum((y == 1) & (y_pred == 0))

Expand All @@ -189,6 +283,21 @@ def true_negative_rate(y_pred: np.ndarray, y: np.ndarray) -> float:
`float`
True negative rate.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")
if len(y_pred.shape) > 1:
raise ValueError("'y_pred' must be a 1d array")
if len(y.shape) > 1:
raise ValueError("'y' must be a 1d array")
if set(np.unique(y_pred)) != set([0, 1]):
raise ValueError("Labels must be either '0' or '1'")

tn = np.sum((y == 0) & (y_pred == 0))
fp = np.sum((y == 0) & (y_pred == 1))

Expand All @@ -211,6 +320,17 @@ def precision_score(y_pred: np.ndarray, y: np.ndarray) -> float:
`float`
Precision score.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")
if set(np.unique(y_pred)) != set([0, 1]):
raise ValueError("Labels must be either '0' or '1'")

tp = np.sum([np.all((y[i] == 1) & (y_pred[i] == 1)) for i in range(len(y))])
fp = np.sum([np.any((y[i] == 0) & (y_pred[i] == 1)) for i in range(len(y))])

Expand All @@ -233,6 +353,15 @@ def accuracy_score(y_pred: np.ndarray, y: np.ndarray) -> float:
`float`
Accuracy score.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")

tp = np.sum([np.all(y[i] == y_pred[i]) for i in range(len(y))])
return tp / len(y)

Expand All @@ -253,6 +382,21 @@ def f1_score(y_pred: np.ndarray, y: np.ndarray) -> float:
`float`
F1-score.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y_pred)}'")
if not isinstance(y, np.ndarray):
raise TypeError("'y' must be an instance of 'numpy.ndarray' " +
f"but not of '{type(y)}'")
if y_pred.shape != y.shape:
raise ValueError(f"Shape mismatch: {y_pred.shape} vs. {y.shape}")
if len(y_pred.shape) != 1:
raise ValueError("'y_pred' must be a 1d array")
if len(y.shape) != 1:
raise ValueError("'y' must be a 1d array")
if set(np.unique(y_pred)) != set([0, 1]):
raise ValueError("Labels must be either '0' or '1'")

tp = np.sum((y == 1) & (y_pred == 1))
fp = np.sum((y == 0) & (y_pred == 1))
fn = np.sum((y == 1) & (y_pred == 0))
Expand Down
3 changes: 3 additions & 0 deletions epyt_flow/simulation/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,7 @@ def __str__(self) -> str:
return f"start_time: {self.__start_time} end_time: {self.__end_time}"

def __eq__(self, other) -> bool:
if not isinstance(other, Event):
raise TypeError(f"Can not compare 'Event' instance with '{type(other)}' instance")

return self.__start_time == other.start_time and self.__end_time == other.end_time
3 changes: 3 additions & 0 deletions epyt_flow/simulation/events/leakages.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def get_attributes(self) -> dict:
if self.link_id is None else None}

def __eq__(self, other) -> bool:
if not isinstance(other, Leakage):
raise TypeError(f"Can not compare 'Leakage' instance with '{type(other)}' instance")

return super().__eq__(other) and self.__link_id == other.link_id \
and self.__diameter == other.diameter and self.__profile == other.profile \
and self.__node_id == other.node_id
Expand Down
8 changes: 8 additions & 0 deletions epyt_flow/simulation/events/sensor_reading_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def get_attributes(self) -> dict:
return super().get_attributes() | {"new_sensor_values": self.__new_sensor_values}

def __eq__(self, other) -> bool:
if not isinstance(other, SensorOverrideAttack):
raise TypeError("Can not compare 'SensorOverrideAttack' instance " +
f"with '{type(other)}' instance")

return super().__eq__(other) and self.__new_sensor_values == other.new_sensor_values

def __str__(self) -> str:
Expand Down Expand Up @@ -155,6 +159,10 @@ def get_attributes(self) -> dict:
self.__sensor_data_time_window_end}

def __eq__(self, other) -> bool:
if not isinstance(other, SensorReplayAttack):
raise TypeError("Can not compare 'SensorReplayAttack' instance " +
f"with '{type(other)}' instance")

return super().__eq__(other) and self.__new_sensor_values == other.new_sensor_values

def __str__(self) -> str:
Expand Down
4 changes: 4 additions & 0 deletions epyt_flow/simulation/events/sensor_reading_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def get_attributes(self) -> dict:
return {"sensor_id": self.__sensor_id, "sensor_type": self.__sensor_type}

def __eq__(self, other) -> bool:
if not isinstance(other, SensorReadingEvent):
raise TypeError("Can not compare 'SensorReadingEvent' instance " +
f"with '{type(other)}' instance")

return super().__eq__(other) and self.__sensor_id == other.sensor_id \
and self.__sensor_type == other.sensor_type

Expand Down
3 changes: 3 additions & 0 deletions epyt_flow/simulation/scada/scada_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,9 @@ def get_attributes(self) -> dict:
"tanks_level_data_raw": self.__tanks_level_data_raw}

def __eq__(self, other) -> bool:
if not isinstance(other, ScadaData):
raise TypeError(f"Can not compare 'ScadaData' instance to '{type(other)}' instance")

try:
return self.__sensor_config == other.sensor_config \
and self.__sensor_noise == other.sensor_noise \
Expand Down
4 changes: 4 additions & 0 deletions epyt_flow/simulation/sensor_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,10 @@ def get_attributes(self) -> dict:
"tank_level_sensors": self.__tank_level_sensors}

def __eq__(self, other) -> bool:
if not isinstance(other, SensorConfig):
raise TypeError("Can not compare 'SensorConfig' instance " +
f"with '{type(other)}' instance")

return self.__nodes == other.nodes and self.__links == other.links \
and self.__valves == other.valves and self.__pumps == other.pumps \
and self.__tanks == other.tanks \
Expand Down
4 changes: 4 additions & 0 deletions epyt_flow/uncertainty/model_uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ def get_attributes(self) -> dict:
"parameters_uncertainty": self.__parameters}

def __eq__(self, other) -> bool:
if not isinstance(other, ModelUncertainty):
raise TypeError("Can not compare 'ModelUncertainty' instance " +
f"with '{type(other)}' instance")

return self.__pipe_length == other.pipe_length \
and self.__pipe_roughness == other.pipe_roughness \
and self.__pipe_diameter == other.pipe_diameter \
Expand Down

0 comments on commit e5bdd96

Please sign in to comment.