diff --git a/dlclive/benchmark.py b/dlclive/benchmark.py index 4cb4fb1..565a0ba 100644 --- a/dlclive/benchmark.py +++ b/dlclive/benchmark.py @@ -12,7 +12,7 @@ import sys import warnings import subprocess -import typing +from typing import List, Optional, Tuple, Union import pickle import colorcet as cc from PIL import ImageColor @@ -148,22 +148,22 @@ def get_system_info() -> dict: def benchmark( - model_path, - video_path, - tf_config=None, - resize=None, - pixels=None, - cropping=None, - dynamic=(False, 0.5, 10), - n_frames=1000, - print_rate=False, - display=False, - pcutoff=0.0, - display_radius=3, - cmap="bmy", - save_poses=False, - save_video=False, - output=None, + model_path: str, + video_path: str, + tf_config: Optional[tf.ConfigProto] = None, + resize: Optional[float] = None, + pixels: Optional[int] = None, + cropping: Optional[List[int]] = None, + dynamic: Tuple[bool, float, int] = (False, 0.5, 10), + n_frames: int = 1000, + print_rate: bool = False, + display: bool = False, + pcutoff: float = 0.0, + display_radius: int = 3, + cmap: str = "bmy", + save_poses: bool = False, + save_video: bool = False, + output: Optional[str] = None, ) -> typing.Tuple[np.ndarray, tuple, bool, dict]: """ Analyze DeepLabCut-live exported model on a video: Calculate inference time, @@ -516,22 +516,22 @@ def save_inf_times( def benchmark_videos( - model_path, - video_path, - output=None, - n_frames=1000, - tf_config=None, - resize=None, - pixels=None, - cropping=None, - dynamic=(False, 0.5, 10), - print_rate=False, - display=False, - pcutoff=0.5, - display_radius=3, - cmap="bmy", - save_poses=False, - save_video=False, + model_path: str, + video_path: Union[str, List[str]], + output: Optional[str] = None, + n_frames: int = 1000, + tf_config: Optional[tf.ConfigProto] = None, + resize: Optional[Union[float, List[float]]] = None, + pixels: Optional[Union[int, List[int]]] = None, + cropping: Optional[List[int]] = None, + dynamic: Tuple[bool, float, int] = (False, 0.5, 10), + print_rate: bool = False, + display: bool = False, + pcutoff: float = 0.5, + display_radius: int = 3, + cmap: str = "bmy", + save_poses: bool = False, + save_video: bool = False, ): """Analyze videos using DeepLabCut-live exported models. Analyze multiple videos and/or multiple options for the size of the video diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index 210671e..f0c8e07 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -181,7 +181,7 @@ def parameterization(self) -> dict: """ return {param: getattr(self, param) for param in self.PARAMETERS} - def process_frame(self, frame): + def process_frame(self, frame: np.ndarray) -> np.ndarray: """ Crops an image according to the object's cropping and dynamic properties. @@ -237,7 +237,7 @@ def process_frame(self, frame): return frame - def init_inference(self, frame=None, **kwargs): + def init_inference(self, frame=None, **kwargs) -> np.ndarray: """ Load model and perform inference on first frame -- the first inference is usually very slow. @@ -376,7 +376,7 @@ def init_inference(self, frame=None, **kwargs): return pose - def get_pose(self, frame=None, **kwargs): + def get_pose(self, frame=None, **kwargs) -> np.ndarray: """ Get the pose of an image