diff --git a/src/fashionfail/models/prediction_utils.py b/src/fashionfail/models/prediction_utils.py index f76d9e4..42ae301 100644 --- a/src/fashionfail/models/prediction_utils.py +++ b/src/fashionfail/models/prediction_utils.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple import numpy as np import pandas as pd @@ -77,7 +77,7 @@ def load_tpu_preds(path_to_preds: str, preprocess: bool = True) -> pd.DataFrame: def _filter_preds_for_classes( row: pd.Series, class_ids: List[int] -) -> tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict[str, Any]]]: +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict[str, Any]]]: """ Filter prediction attributes based on class IDs.