From bacef00be322e17d9a2e2db6126aba4b781c7563 Mon Sep 17 00:00:00 2001 From: lilin Date: Fri, 30 Dec 2022 15:10:39 +0800 Subject: [PATCH 01/12] [Enhance] add inferencer, update visualizer, add alias for some configs --- configs/recognition/i3d/metafile.yml | 2 + configs/recognition/slowfast/metafile.yml | 2 + configs/recognition/tsn/metafile.yml | 2 + demo/demo.py | 36 +- mmaction/apis/__init__.py | 1 + mmaction/apis/inferencers/__init__.py | 5 + .../apis/inferencers/actionrec_inferencer.py | 24 ++ .../inferencers/base_mmaction_inferencer.py | 346 ++++++++++++++++++ mmaction/datasets/transforms/loading.py | 57 ++- mmaction/engine/hooks/visualization_hook.py | 2 +- mmaction/visualization/action_visualizer.py | 181 +++++---- mmaction/visualization/video_backend.py | 42 +-- tests/apis/test_inferencer.py | 79 ++++ tests/visualization/test_video_backend.py | 58 +-- tools/visualizations/browse_dataset.py | 14 +- 15 files changed, 677 insertions(+), 174 deletions(-) create mode 100644 mmaction/apis/inferencers/__init__.py create mode 100644 mmaction/apis/inferencers/actionrec_inferencer.py create mode 100644 mmaction/apis/inferencers/base_mmaction_inferencer.py create mode 100644 tests/apis/test_inferencer.py diff --git a/configs/recognition/i3d/metafile.yml b/configs/recognition/i3d/metafile.yml index 63ad017343..f12ba591dc 100644 --- a/configs/recognition/i3d/metafile.yml +++ b/configs/recognition/i3d/metafile.yml @@ -7,6 +7,8 @@ Collections: Models: - Name: i3d_imagenet-pretrained-r50-nl-dot-product_8xb8-32x2x1-100e_kinetics400-rgb + Alias: + - i3d Config: configs/recognition/i3d/i3d_imagenet-pretrained-r50-nl-dot-product_8xb8-32x2x1-100e_kinetics400-rgb.py In Collection: I3D Metadata: diff --git a/configs/recognition/slowfast/metafile.yml b/configs/recognition/slowfast/metafile.yml index 94423659d1..7ba12c0e63 100644 --- a/configs/recognition/slowfast/metafile.yml +++ b/configs/recognition/slowfast/metafile.yml @@ -30,6 +30,8 @@ Models: Weights: https://download.openmmlab.com/mmaction/v1.0/recognition/slowfast/slowfast_r50_8xb8-4x16x1-256e_kinetics400-rgb/slowfast_r50_8xb8-4x16x1-256e_kinetics400-rgb_20220901-701b0f6f.pth - Name: slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb + Alias: + - slowfast Config: configs/recognition/slowfast/slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb.py In Collection: SlowFast Metadata: diff --git a/configs/recognition/tsn/metafile.yml b/configs/recognition/tsn/metafile.yml index b4734c93a2..e618ed71cc 100644 --- a/configs/recognition/tsn/metafile.yml +++ b/configs/recognition/tsn/metafile.yml @@ -53,6 +53,8 @@ Models: Weights: https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x5-100e_kinetics400-rgb/tsn_imagenet-pretrained-r50_8xb32-1x1x5-100e_kinetics400-rgb_20220906-65d68713.pth - Name: tsn_imagenet-pretrained-r50_8xb32-1x1x8-100e_kinetics400-rgb + Alias: + - TSN Config: configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x8-100e_kinetics400-rgb.py In Collection: TSN Metadata: diff --git a/demo/demo.py b/demo/demo.py index 5cebcd3abe..1ac7cfd91f 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -4,7 +4,6 @@ from operator import itemgetter from typing import Optional, Tuple -import cv2 from mmengine import Config, DictAction from mmaction.apis import inference_recognizer, init_recognizer @@ -88,34 +87,9 @@ def get_output( if video_path.startswith(('http://', 'https://')): raise NotImplementedError - try: - import decord - except ImportError: - raise ImportError('Please install decord to enable output file.') - - # Channel Order is `BGR` - video = decord.VideoReader(video_path) - frames = [x.asnumpy()[..., ::-1] for x in video] - if target_resolution: - w, h = target_resolution - frame_h, frame_w, _ = frames[0].shape - if w == -1: - w = int(h / frame_h * frame_w) - if h == -1: - h = int(w / frame_w * frame_h) - frames = [cv2.resize(f, (w, h)) for f in frames] - # init visualizer out_type = 'gif' if osp.splitext(out_filename)[1] == '.gif' else 'video' - vis_backends_cfg = [ - dict( - type='LocalVisBackend', - out_type=out_type, - save_dir='demo', - fps=fps) - ] - visualizer = ActionVisualizer( - vis_backends=vis_backends_cfg, save_dir='place_holder') + visualizer = ActionVisualizer() visualizer.dataset_meta = dict(classes=labels) text_cfg = {'colors': font_color} @@ -124,11 +98,15 @@ def get_output( visualizer.add_datasample( out_filename, - frames, + video_path, data_sample, draw_pred=True, draw_gt=False, - text_cfg=text_cfg) + text_cfg=text_cfg, + fps=fps, + out_type=out_type, + out_path=osp.join('demo', out_filename), + target_resolution=target_resolution) def main(): diff --git a/mmaction/apis/__init__.py b/mmaction/apis/__init__.py index 110cbe9464..c4506d5af1 100644 --- a/mmaction/apis/__init__.py +++ b/mmaction/apis/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import (detection_inference, inference_recognizer, init_recognizer, pose_inference) +from .inferencers import * # NOQA __all__ = [ 'init_recognizer', 'inference_recognizer', 'detection_inference', diff --git a/mmaction/apis/inferencers/__init__.py b/mmaction/apis/inferencers/__init__.py new file mode 100644 index 0000000000..ed6bacbe1b --- /dev/null +++ b/mmaction/apis/inferencers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .actionrec_inferencer import ActionRecInferencer +from .base_mmaction_inferencer import BaseMMActionInferencer + +__all__ = ['ActionRecInferencer', 'BaseMMActionInferencer'] diff --git a/mmaction/apis/inferencers/actionrec_inferencer.py b/mmaction/apis/inferencers/actionrec_inferencer.py new file mode 100644 index 0000000000..744b793f3d --- /dev/null +++ b/mmaction/apis/inferencers/actionrec_inferencer.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + +from mmaction.structures import ActionDataSample +from .base_mmaction_inferencer import BaseMMActionInferencer + + +class ActionRecInferencer(BaseMMActionInferencer): + + def pred2dict(self, data_sample: ActionDataSample) -> Dict: + """Extract elements necessary to represent a prediction into a + dictionary. It's better to contain only basic data elements such as + strings and numbers in order to guarantee it's json-serializable. + + Args: + data_sample (ActionDataSample): The data sample to be converted. + + Returns: + dict: The output dictionary. + """ + result = {} + result['pred_labels'] = data_sample.pred_labels.item + result['pred_scores'] = data_sample.pred_scores.item.tolist() + return result diff --git a/mmaction/apis/inferencers/base_mmaction_inferencer.py b/mmaction/apis/inferencers/base_mmaction_inferencer.py new file mode 100644 index 0000000000..c642c0de18 --- /dev/null +++ b/mmaction/apis/inferencers/base_mmaction_inferencer.py @@ -0,0 +1,346 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import mmengine +import numpy as np +from mmengine.dataset import Compose +from mmengine.fileio import list_from_file +from mmengine.infer.infer import BaseInferencer, ModelType +from mmengine.structures import InstanceData + +from mmaction.utils import ConfigType, register_all_modules + +InstanceList = List[InstanceData] +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[InstanceData, InstanceList] +ImgType = Union[np.ndarray, Sequence[np.ndarray]] +ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] + + +class BaseMMActionInferencer(BaseInferencer): + """Base inferencer. + + Args: + model (str, optional): Path to the config file or the model name + defined in metafile. For example, it could be + "slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb" or + "configs/recognition/slowfast/slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb.py". + weights (str, optional): Path to the checkpoint. If it is not specified + and model is a model name of metafile, the weights will be loaded + from metafile. Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + label (str): label file for dataset. + input_format (str): Input video format, Choices are 'video', + 'rawframes', 'array'. 'video' means input data is a video file, + 'rawframes' means input data is a video frame folder, and 'array' + means input data is a np.ndarray. Defaults to 'video'. + pack_cfg (dict, optional): Config for `InferencerPackInput` to load + input. Defaults to empty dict. + """ + + preprocess_kwargs: set = set() + forward_kwargs: set = set() + visualize_kwargs: set = { + 'return_vis', 'show', 'wait_time', 'vid_out_dir', 'draw_pred', 'fps', + 'out_type', 'target_resolution' + } + postprocess_kwargs: set = { + 'print_result', 'pred_out_file', 'return_datasample' + } + + def __init__(self, + model: Union[ModelType, str], + weights: Optional[str] = None, + device: Optional[str] = None, + scope: Optional[str] = 'mmaction2', + label: Optional[str] = None, + input_format: str = 'video', + pack_cfg: dict = {}) -> None: + # A global counter tracking the number of videos processed, for + # naming of the output videos + self.num_visualized_vids = 0 + self.input_format = input_format + self.pack_cfg = pack_cfg.copy() + register_all_modules() + super().__init__( + model=model, weights=weights, device=device, scope=scope) + + if label is not None: + self.visualizer.dataset_meta = dict(classes=list_from_file(label)) + + def __call__(self, + inputs: InputsType, + return_datasamples: bool = False, + batch_size: int = 1, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + draw_pred: bool = True, + vid_out_dir: str = '', + out_type: str = 'video', + print_result: bool = False, + pred_out_file: str = '', + target_resolution: Optional[Tuple[int]] = None, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. + return_datasamples (bool): Whether to return results as + :obj:`BaseDataElement`. Defaults to False. + batch_size (int): Inference batch size. Defaults to 1. + show (bool): Whether to display the visualization results in a + popup window. Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + draw_pred (bool): Whether to draw predicted bounding boxes. + Defaults to True. + vid_out_dir (str): Output directory of visualization results. + If left as empty, no file will be saved. Defaults to ''. + out_type (str): Output type of visualization results. + Defaults to 'video'. + print_result (bool): Whether to print the inference result w/o + visualization to the console. Defaults to False. + pred_out_file: File to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + **kwargs: Other keyword arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + return super().__call__( + inputs, + return_datasamples, + batch_size, + return_vis=return_vis, + show=show, + wait_time=wait_time, + draw_pred=draw_pred, + vid_out_dir=vid_out_dir, + print_result=print_result, + pred_out_file=pred_out_file, + out_type=out_type, + target_resolution=target_resolution, + **kwargs) + + def _inputs_to_list(self, inputs: InputsType) -> list: + """Preprocess the inputs to a list. The main difference from mmengine + version is that we don't list a directory cause input could be a frame + folder. + + Preprocess inputs to a list according to its type: + + - list or tuple: return inputs + - str: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + + Args: + inputs (InputsType): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + return list(inputs) + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline.""" + test_pipeline = cfg.test_dataloader.dataset.pipeline + # Alter data pipelines for decode + if self.input_format == 'array': + for i in range(len(test_pipeline)): + if 'Decode' in test_pipeline[i]['type']: + test_pipeline[i] = dict(type='ArrayDecode') + test_pipeline = [ + x for x in test_pipeline if 'Init' not in x['type'] + ] + elif self.input_format == 'video': + if 'Init' not in test_pipeline[0]['type']: + test_pipeline = [dict(type='DecordInit')] + test_pipeline + else: + test_pipeline[0] = dict(type='DecordInit') + for i in range(len(test_pipeline)): + if 'Decode' in test_pipeline[i]['type']: + test_pipeline[i] = dict(type='DecordDecode') + elif self.input_format == 'rawframes': + if 'Init' in test_pipeline[0]['type']: + test_pipeline = test_pipeline[1:] + for i in range(len(test_pipeline)): + if 'Decode' in test_pipeline[i]['type']: + test_pipeline[i] = dict(type='RawFrameDecode') + # Alter data pipelines to close TTA, avoid OOM + # Use center crop instead of multiple crop + for i in range(len(test_pipeline)): + if test_pipeline[i]['type'] in ['ThreeCrop', 'TenCrop']: + test_pipeline[i]['type'] = 'CenterCrop' + # Use single clip for `Recognizer3D` + if cfg.model.type == 'Recognizer3D': + for i in range(len(test_pipeline)): + if test_pipeline[i]['type'] == 'SampleFrames': + test_pipeline[i]['num_clips'] = 1 + # Pack multiple types of input format + test_pipeline.insert( + 0, + dict( + type='InferencerPackInput', + input_format=self.input_format, + **self.pack_cfg)) + + return Compose(test_pipeline) + + def visualize( + self, + inputs: InputsType, + preds: PredType, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + draw_pred: bool = True, + fps: int = 30, + out_type: str = 'video', + target_resolution: Optional[Tuple[int]] = None, + vid_out_dir: str = '', + ) -> Union[List[np.ndarray], None]: + """Visualize predictions. + + Args: + inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. + preds (List[Dict]): Predictions of the model. + return_vis (bool): Whether to return the visualization result. + Defaults to False. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + draw_pred (bool): Whether to draw prediction labels. + Defaults to True. + fps (int): Frames per second for saving video. Defaults to 4. + out_type (str): Output format type, choose from 'img', 'gif', + 'video'. Defaults to ``'img'``. + target_resolution (Tuple[int], optional): Set to + (desired_width desired_height) to have resized frames. If + either dimension is None, the frames are resized by keeping + the existing aspect ratio. Defaults to None. + vid_out_dir (str): Output directory of visualization results. + If left as empty, no file will be saved. Defaults to ''. + + Returns: + List[np.ndarray] or None: Returns visualization results only if + applicable. + """ + if self.visualizer is None or (not show and vid_out_dir == '' + and not return_vis): + return None + + if getattr(self, 'visualizer') is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None.') + + results = [] + + for single_input, pred in zip(inputs, preds): + if isinstance(single_input, str): + frames = single_input + video_name = osp.basename(single_input) + elif isinstance(single_input, np.ndarray): + frames = single_input.copy() + video_num = str(self.num_visualized_vids).zfill(8) + video_name = f'{video_num}.mp4' + else: + raise ValueError('Unsupported input type: ' + f'{type(single_input)}') + + out_path = osp.join(vid_out_dir, video_name) if vid_out_dir != '' \ + else None + + self.visualizer.add_datasample( + video_name, + frames, + pred, + show_frames=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=draw_pred, + fps=fps, + out_type=out_type, + out_path=out_path, + target_resolution=target_resolution, + ) + results.append(frames) + self.num_visualized_vids += 1 + + return results + + def postprocess( + self, + preds: PredType, + visualization: Optional[List[np.ndarray]] = None, + return_datasample: bool = False, + print_result: bool = False, + pred_out_file: str = '', + ) -> Union[ResType, Tuple[ResType, np.ndarray]]: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Convert datasamples into a json-serializable dict if needed. + 2. Pack the predictions and visualization results and return them. + 3. Dump or log the predictions. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (Optional[np.ndarray]): Visualized predictions. + return_datasample (bool): Whether to use Datasample to store + inference results. If False, dict will be used. + print_result (bool): Whether to print the inference result w/o + visualization to the console. Defaults to False. + pred_out_file: File to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization``. + + - ``visualization`` (Any): Returned by :meth:`visualize`. + - ``predictions`` (dict or DataSample): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it usually should be a + json-serializable dict containing only basic data elements such + as strings and numbers. + """ + result_dict = {} + results = preds + if not return_datasample: + results = [] + for pred in preds: + result = self.pred2dict(pred) + results.append(result) + # Add video to the results after printing and dumping + result_dict['predictions'] = results + if print_result: + print(result_dict) + if pred_out_file != '': + mmengine.dump(result_dict, pred_out_file) + result_dict['visualization'] = visualization + return result_dict + + def pred2dict(self, data_sample: InstanceData) -> Dict: + """Extract elements necessary to represent a prediction into a + dictionary. + + It's better to contain only basic data elements such as strings and + numbers in order to guarantee it's json-serializable. + """ + raise NotImplementedError diff --git a/mmaction/datasets/transforms/loading.py b/mmaction/datasets/transforms/loading.py index 558579b87f..7bd3798c9c 100644 --- a/mmaction/datasets/transforms/loading.py +++ b/mmaction/datasets/transforms/loading.py @@ -4,7 +4,7 @@ import os import os.path as osp import shutil -from typing import Optional +from typing import Optional, Union import mmcv import numpy as np @@ -1398,6 +1398,61 @@ def __repr__(self): return repr_str +@TRANSFORMS.register_module() +class InferencerPackInput(BaseTransform): + + def __init__(self, + input_format='video', + filename_tmpl='img_{:05}.jpg', + modality='RGB', + start_index=1) -> None: + self.input_format = input_format + self.filename_tmpl = filename_tmpl + self.modality = modality + self.start_index = start_index + + def transform(self, video: Union[str, np.ndarray, dict]) -> dict: + if self.input_format == 'dict': + results = video + elif self.input_format == 'video': + results = dict( + filename=video, label=-1, start_index=0, modality='RGB') + elif self.input_format == 'rawframes': + import re + + # count the number of frames that match the format of + # `filename_tmpl` + # RGB pattern example: img_{:05}.jpg -> ^img_\d+.jpg$ + # Flow patteren example: {}_{:05d}.jpg -> ^x_\d+.jpg$ + pattern = f'^{self.filename_tmpl}$' + if self.modality == 'Flow': + pattern = pattern.replace('{}', 'x') + pattern = pattern.replace( + pattern[pattern.find('{'):pattern.find('}') + 1], '\\d+') + total_frames = len( + list( + filter(lambda x: re.match(pattern, x) is not None, + os.listdir(video)))) + results = dict( + frame_dir=video, + total_frames=total_frames, + label=-1, + start_index=self.start_index, + filename_tmpl=self.filename_tmpl, + modality=self.modality) + elif self.input_format == 'array': + modality_map = {2: 'Flow', 3: 'RGB'} + modality = modality_map.get(video.shape[-1]) + results = dict( + total_frames=video.shape[0], + label=-1, + start_index=0, + array=video, + modality=modality) + + return results + + @TRANSFORMS.register_module() class ArrayDecode(BaseTransform): """Load and decode frames with given indices from a 4D array. diff --git a/mmaction/engine/hooks/visualization_hook.py b/mmaction/engine/hooks/visualization_hook.py index e4756ca817..b1c3ac8b47 100644 --- a/mmaction/engine/hooks/visualization_hook.py +++ b/mmaction/engine/hooks/visualization_hook.py @@ -91,7 +91,7 @@ def _draw_samples(self, draw_args = self.draw_args if self.out_dir is not None: - draw_args['out_folder'] = self.file_client.join_path( + draw_args['out_path'] = self.file_client.join_path( self.out_dir, f'{sample_name}_{step}') self._visualizer.add_datasample( diff --git a/mmaction/visualization/action_visualizer.py b/mmaction/visualization/action_visualizer.py index fba9d6c600..8d8781ed9f 100644 --- a/mmaction/visualization/action_visualizer.py +++ b/mmaction/visualization/action_visualizer.py @@ -1,13 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os import os.path as osp -import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union -import matplotlib.pyplot as plt import mmcv import numpy as np from mmengine.dist import master_only +from mmengine.fileio.io import isdir, isfile, join_path, list_dir_or_file from mmengine.visualization import Visualizer from mmaction.registry import VISBACKENDS, VISUALIZERS @@ -45,11 +43,6 @@ class ActionVisualizer(Visualizer): Args: name (str): Name of the instance. Defaults to 'visualizer'. - video (Union[np.ndarray, Sequence[np.ndarray]]): - the origin video to draw. The format should be RGB. - For np.ndarray input, the video shape should be (N, H, W, C). - For Sequence[np.ndarray] input, the shape of each frame in - the sequence should be (H, W, C). vis_backends (list, optional): Visual backend config list. Defaults to None. save_dir (str, optional): Save file dir for all storage backends. @@ -89,65 +82,65 @@ class ActionVisualizer(Visualizer): def __init__( self, name='visualizer', - video: Optional[np.ndarray] = None, vis_backends: Optional[List[Dict]] = None, save_dir: Optional[str] = None, fig_save_cfg=dict(frameon=False), - fig_show_cfg=dict(frameon=False, num='show') + fig_show_cfg=dict(frameon=False) ) -> None: - self._dataset_meta = None - self._vis_backends = dict() - - if save_dir is None: - warnings.warn('`Visualizer` backend is not initialized ' - 'because save_dir is None.') - elif vis_backends is not None: - assert len(vis_backends) > 0, 'empty list' - names = [ - vis_backend.get('name', None) for vis_backend in vis_backends - ] - if None in names: - if len(set(names)) > 1: - raise RuntimeError( - 'If one of them has a name attribute, ' - 'all backends must use the name attribute') - else: - type_names = [ - vis_backend['type'] for vis_backend in vis_backends - ] - if len(set(type_names)) != len(type_names): - raise RuntimeError( - 'The same vis backend cannot exist in ' - '`vis_backend` config. ' - 'Please specify the name field.') - - if None not in names and len(set(names)) != len(names): - raise RuntimeError('The name fields cannot be the same') + super().__init__( + name=name, + image=None, + vis_backends=vis_backends, + save_dir=save_dir, + fig_save_cfg=fig_save_cfg, + fig_show_cfg=fig_show_cfg) + + def _load_video(self, + video: Union[np.ndarray, Sequence[np.ndarray], str], + target_resolution: Optional[Tuple[int]] = None): + """Load video from multiple source and convert to target resolution. - save_dir = osp.join(save_dir, 'vis_data') - - for vis_backend in vis_backends: - name = vis_backend.pop('name', vis_backend['type']) - vis_backend.setdefault('save_dir', save_dir) - self._vis_backends[name] = VISBACKENDS.build(vis_backend) - - self.is_inline = 'inline' in plt.get_backend() + Args: + video (np.ndarray, str): The video to draw. + target_resolution (Tuple[int], optional): Set to + (desired_width desired_height) to have resized frames. If + either dimension is None, the frames are resized by keeping + the existing aspect ratio. Defaults to None. + """ + if isinstance(video, np.ndarray) or isinstance(video, list): + frames = video + elif isinstance(video, str): + # video file path + if isfile(video): + try: + import decord + except ImportError: + raise ImportError( + 'Please install decord to load video file.') + video = decord.VideoReader(video) + frames = [x.asnumpy()[..., ::-1] for x in video] + # rawframes folder path + elif isdir(video): + frame_list = sorted(list_dir_or_file(video, list_dir=False)) + frames = [mmcv.imread(join_path(video, x)) for x in frame_list] + else: + raise TypeError(f'type of video {type(video)} not supported') - self.fig_save = None - self.fig_show = None - self.fig_save_num = fig_save_cfg.get('num', None) - self.fig_show_num = fig_show_cfg.get('num', None) - self.fig_save_cfg = fig_save_cfg - self.fig_show_cfg = fig_show_cfg + if target_resolution is not None: + w, h = target_resolution + frame_h, frame_w, _ = frames[0].shape + if w == -1: + w = int(h / frame_h * frame_w) + if h == -1: + h = int(w / frame_w * frame_h) + frames = [mmcv.imresize(f, (w, h)) for f in frames] - (self.fig_save_canvas, self.fig_save, - self.ax_save) = self._initialize_fig(fig_save_cfg) - self.dpi = self.fig_save.get_dpi() + return frames @master_only def add_datasample(self, name: str, - video: Union[np.ndarray, Sequence[np.ndarray]], + video: Union[np.ndarray, Sequence[np.ndarray], str], data_sample: Optional[ActionDataSample] = None, draw_gt: bool = True, draw_pred: bool = True, @@ -156,18 +149,22 @@ def add_datasample(self, show_frames: bool = False, text_cfg: dict = dict(), wait_time: float = 0.1, - out_folder: Optional[str] = None, - step: int = 0) -> None: + out_path: Optional[str] = None, + out_type: str = 'img', + target_resolution: Optional[Tuple[int]] = None, + step: int = 0, + fps: int = 4) -> None: """Draw datasample and save to all backends. - - If ``out_folder`` is specified, all storage backends are ignored - and save the videos to the ``out_folder``. + - If ``out_path`` is specified, all storage backends are ignored + and save the videos to the ``out_path``. - If ``show_frames`` is True, plot the frames in a window sequentially, please confirm you are able to access the graphical interface. Args: name (str): The frame identifier. - video (np.ndarray): The video to draw. + video (np.ndarray, str): The video to draw. supports decoded + np.ndarray, video file path, rawframes folder path. data_sample (:obj:`ActionDataSample`, optional): The annotation of the frame. Defaults to None. draw_gt (bool): Whether to draw ground truth labels. @@ -185,14 +182,21 @@ def add_datasample(self, Defaults to an empty dict. wait_time (float): Delay in seconds. 0 is the special value that means "forever". Defaults to 0.1. - out_folder (str, optional): Extra folder to save the visualization + out_path (str, optional): Extra folder to save the visualization result. If specified, the visualizer will only save the result - frame to the out_folder and ignore its storage backends. + frame to the out_path and ignore its storage backends. Defaults to None. + out_type (str): Output format type, choose from 'img', 'gif', + 'video'. Defaults to ``'img'``. + target_resolution (Tuple[int], optional): Set to + (desired_width desired_height) to have resized frames. If + either dimension is None, the frames are resized by keeping + the existing aspect ratio. Defaults to None. step (int): Global step value to record. Defaults to 0. + fps (int): Frames per second for saving video. Defaults to 4. """ classes = None - wait_time_in_milliseconds = wait_time * 10**6 + video = self._load_video(video, target_resolution) tol_video = len(video) if self.dataset_meta is not None: @@ -256,24 +260,40 @@ def add_datasample(self, drawn_img = self.get_image() resulted_video.append(drawn_img) - if show_frames: - self.show( - drawn_img, - win_name=frame_name, - wait_time=wait_time_in_milliseconds) + if show_frames: + frame_wait_time = 1. / fps + for frame_idx, drawn_img in enumerate(resulted_video): + frame_name = 'frame %d of %s' % (frame_idx + 1, name) + if frame_idx < len(resulted_video) - 1: + wait_time = frame_wait_time + else: + wait_time = wait_time + self.show(drawn_img, win_name=frame_name, wait_time=wait_time) resulted_video = np.array(resulted_video) - if out_folder is not None: - resulted_video = resulted_video[..., ::-1] - os.makedirs(out_folder, exist_ok=True) - # save the frame to the target file instead of vis_backends - for frame_idx, frame in enumerate(resulted_video): - mmcv.imwrite(frame, out_folder + '/%d.png' % frame_idx) + if out_path is not None: + save_dir, save_name = osp.split(out_path) + vis_backend_cfg = dict(type='LocalVisBackend', save_dir=save_dir) + tmp_local_vis_backend = VISBACKENDS.build(vis_backend_cfg) + tmp_local_vis_backend.add_video( + save_name, + resulted_video, + step=step, + fps=fps, + out_type=out_type) else: - self.add_video(name, resulted_video, step=step) + self.add_video( + name, resulted_video, step=step, fps=fps, out_type=out_type) @master_only - def add_video(self, name: str, image: np.ndarray, step: int = 0) -> None: + def add_video( + self, + name: str, + image: np.ndarray, + step: int = 0, + fps: int = 4, + out_type: str = 'img', + ) -> None: """Record the image. Args: @@ -281,6 +301,11 @@ def add_video(self, name: str, image: np.ndarray, step: int = 0) -> None: image (np.ndarray, optional): The image to be saved. The format should be RGB. Default to None. step (int): Global step value to record. Default to 0. + fps (int): Frames per second for saving video. Defaults to 4. + out_type (str): Output format type, choose from 'img', 'gif', + 'video'. Defaults to ``'img'``. """ for vis_backend in self._vis_backends.values(): - vis_backend.add_video(name, image, step) # type: ignore + vis_backend.add_video( + name, image, step=step, fps=fps, + out_type=out_type) # type: ignore diff --git a/mmaction/visualization/video_backend.py b/mmaction/visualization/video_backend.py index 9b6366650e..c32ee8988e 100644 --- a/mmaction/visualization/video_backend.py +++ b/mmaction/visualization/video_backend.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp +from typing import Optional import cv2 import numpy as np @@ -21,41 +22,15 @@ class LocalVisBackend(LocalVisBackend): """Local visualization backend class with video support. See mmengine.visualization.LocalVisBackend for more details. - - Args: - save_dir (str, optional): The root directory to save the files - produced by the visualizer. If it is none, it means no data - is stored. - img_save_dir (str): The directory to save images. - Defaults to ``'vis_image'``. - config_save_file (str): The file name to save config. - Defaults to ``'config.py'``. - scalar_save_file (str): The file name to save scalar values. - Defaults to ``'scalars.json'``. - out_type (str): Output format type, choose from 'img', 'gif', - 'video'. Defaults to ``'img'``. - fps (int): Frames per second for saving video. Defaults to 5. """ - def __init__( - self, - save_dir: str, - img_save_dir: str = 'vis_image', - config_save_file: str = 'config.py', - scalar_save_file: str = 'scalars.json', - out_type: str = 'img', - fps: int = 5, - ): - super().__init__(save_dir, img_save_dir, config_save_file, - scalar_save_file) - self.out_type = out_type - self.fps = fps - @force_init_env def add_video(self, name: str, frames: np.ndarray, step: int = 0, + fps: Optional[int] = 4, + out_type: Optional[int] = 'img', **kwargs) -> None: """Record the frames of a video to disk. @@ -64,10 +39,13 @@ def add_video(self, frames (np.ndarray): The frames to be saved. The format should be RGB. The shape should be (T, H, W, C). step (int): Global step value to record. Defaults to 0. + out_type (str): Output format type, choose from 'img', 'gif', + 'video'. Defaults to ``'img'``. + fps (int): Frames per second for saving video. Defaults to 4. """ assert frames.dtype == np.uint8 - if self.out_type == 'img': + if out_type == 'img': frames_dir = osp.join(self._save_dir, name, f'frames_{step}') os.makedirs(frames_dir, exist_ok=True) for idx, frame in enumerate(frames): @@ -82,12 +60,12 @@ def add_video(self, 'output file.') frames = [x[..., ::-1] for x in frames] - video_clips = ImageSequenceClip(frames, fps=self.fps) + video_clips = ImageSequenceClip(frames, fps=fps) name = osp.splitext(name)[0] - if self.out_type == 'gif': + if out_type == 'gif': out_path = osp.join(self._save_dir, name + '.gif') video_clips.write_gif(out_path, logger=None) - elif self.out_type == 'video': + elif out_type == 'video': out_path = osp.join(self._save_dir, name + '.mp4') video_clips.write_videofile( out_path, remove_temp=True, logger=None) diff --git a/tests/apis/test_inferencer.py b/tests/apis/test_inferencer.py new file mode 100644 index 0000000000..a16e941c8d --- /dev/null +++ b/tests/apis/test_inferencer.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from tempfile import TemporaryDirectory +from unittest import TestCase + +import torch +from parameterized import parameterized + +from mmaction.apis import ActionRecInferencer +from mmaction.structures import ActionDataSample +from mmaction.utils import register_all_modules + + +class TestInferencer(TestCase): + + def setUp(self): + register_all_modules() + + @parameterized.expand([ + (('tsn'), ('tools/data/kinetics/label_map_k400.txt'), ('cpu', 'cuda')) + ]) + def test_init_recognizer(self, config, lable_file, devices): + + for device in devices: + if device == 'cuda' and not torch.cuda.is_available(): + # Skip the test if cuda is required but unavailable + continue + + _ = ActionRecInferencer(config, label=lable_file, device=device) + + # test `init_recognizer` with invalid config + with self.assertRaisesRegex(ValueError, 'Cannot find model'): + _ = ActionRecInferencer( + 'slowfast_config', label=lable_file, device=device) + + @parameterized.expand([ + (('tsn'), ('tools/data/kinetics/label_map_k400.txt'), + ('demo/demo.mp4'), ('cpu', 'cuda')) + ]) + def test_inference_recognizer(self, config, label_file, video_path, + devices): + + with TemporaryDirectory() as tmp_dir: + for device in devices: + if device == 'cuda' and not torch.cuda.is_available(): + # Skip the test if cuda is required but unavailable + continue + + # test video file input and return datasample + inferencer = ActionRecInferencer( + config, label=label_file, device=device) + results = inferencer( + video_path, vid_out_dir=tmp_dir, return_datasamples=True) + self.assertIn('predictions', results) + self.assertIn('visualization', results) + self.assertIsInstance(results['predictions'][0], + ActionDataSample) + assert osp.exists(osp.join(tmp_dir, osp.basename(video_path))) + + results = inferencer( + video_path, vid_out_dir=tmp_dir, out_type='gif') + self.assertIsInstance(results['predictions'][0], dict) + assert osp.exists( + osp.join(tmp_dir, + osp.basename(video_path).replace('mp4', 'gif'))) + + # test np.ndarray input + inferencer = ActionRecInferencer( + config, + label=label_file, + device=device, + input_format='array') + import decord + import numpy as np + video = decord.VideoReader(video_path) + frames = [x.asnumpy()[..., ::-1] for x in video] + frames = np.stack(frames) + inferencer(frames, vid_out_dir=tmp_dir) + assert osp.exists(osp.join(tmp_dir, '00000000.mp4')) diff --git a/tests/visualization/test_video_backend.py b/tests/visualization/test_video_backend.py index 5f75377b83..0de82465ee 100644 --- a/tests/visualization/test_video_backend.py +++ b/tests/visualization/test_video_backend.py @@ -1,14 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. import os +import os.path as osp +import time from pathlib import Path +from tempfile import TemporaryDirectory import decord import torch from mmengine.structures import LabelData from mmaction.structures import ActionDataSample +from mmaction.utils import register_all_modules from mmaction.visualization import ActionVisualizer +register_all_modules() + def test_local_visbackend(): video = decord.VideoReader('./demo/demo.mp4') @@ -16,18 +22,18 @@ def test_local_visbackend(): data_sample = ActionDataSample() data_sample.gt_labels = LabelData(item=torch.tensor([2])) - - vis = ActionVisualizer( - save_dir='./outputs', vis_backends=[dict(type='LocalVisBackend')]) - vis.add_datasample('demo', video, data_sample) - for k in range(32): - frame_path = 'outputs/vis_data/demo/frames_0/%d.png' % k - assert Path(frame_path).exists() - - vis.add_datasample('demo', video, data_sample, step=1) - for k in range(32): - frame_path = 'outputs/vis_data/demo/frames_1/%d.png' % k - assert Path(frame_path).exists() + with TemporaryDirectory() as tmp_dir: + vis = ActionVisualizer( + save_dir=tmp_dir, vis_backends=[dict(type='LocalVisBackend')]) + vis.add_datasample('demo', video, data_sample) + for k in range(32): + frame_path = osp.join(tmp_dir, 'vis_data/demo/frames_0/%d.png' % k) + assert Path(frame_path).exists() + + vis.add_datasample('demo', video, data_sample, step=1) + for k in range(32): + frame_path = osp.join(tmp_dir, 'vis_data/demo/frames_1/%d.png' % k) + assert Path(frame_path).exists() return @@ -37,19 +43,21 @@ def test_tensorboard_visbackend(): data_sample = ActionDataSample() data_sample.gt_labels = LabelData(item=torch.tensor([2])) - - vis = ActionVisualizer( - save_dir='./outputs', - vis_backends=[dict(type='TensorboardVisBackend')]) - vis.add_datasample('demo', video, data_sample, step=1) - - assert Path('outputs/vis_data/').exists() - flag = False - for item in os.listdir('outputs/vis_data/'): - if item.startswith('events.out.tfevents.'): - flag = True - break - assert flag, 'Cannot find tensorboard file!' + with TemporaryDirectory() as tmp_dir: + vis = ActionVisualizer( + save_dir=tmp_dir, + vis_backends=[dict(type='TensorboardVisBackend')]) + vis.add_datasample('demo', video, data_sample, step=1) + + assert Path(osp.join(tmp_dir, 'vis_data')).exists() + flag = False + for item in os.listdir(osp.join(tmp_dir, 'vis_data')): + if item.startswith('events.out.tfevents.'): + flag = True + break + assert flag, 'Cannot find tensorboard file!' + # wait tensorboard store asynchronously + time.sleep(1) return diff --git a/tools/visualizations/browse_dataset.py b/tools/visualizations/browse_dataset.py index 5247db19c2..ab9b94a798 100644 --- a/tools/visualizations/browse_dataset.py +++ b/tools/visualizations/browse_dataset.py @@ -190,13 +190,10 @@ def main(): intermediate_imgs) # init visualizer - vis_backends = [ - dict( - type='LocalVisBackend', - out_type='video', - save_dir=args.output_dir, - fps=args.fps) - ] + vis_backends = [dict( + type='LocalVisBackend', + save_dir=args.output_dir, + )] visualizer = ActionVisualizer( vis_backends=vis_backends, save_dir='place_holder') @@ -233,7 +230,8 @@ def main(): file_id = f'video_{i}' video = [x[..., ::-1] for x in video] - visualizer.add_datasample(file_id, video, data_sample) + visualizer.add_datasample( + file_id, video, data_sample, fps=args.fps, out_type='video') progress_bar.update() From b057ac853c46c892628885895a573ac5baca22ce Mon Sep 17 00:00:00 2001 From: lilin Date: Wed, 11 Jan 2023 20:17:34 +0800 Subject: [PATCH 02/12] add inferencer demo --- demo/demo_inferencer.py | 68 +++++++++++++++++++ mmaction/apis/inferencers/__init__.py | 2 +- ...nferencer.py => actionrecog_inferencer.py} | 4 +- .../inferencers/base_mmaction_inferencer.py | 2 +- 4 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 demo/demo_inferencer.py rename mmaction/apis/inferencers/{actionrec_inferencer.py => actionrecog_inferencer.py} (86%) diff --git a/demo/demo_inferencer.py b/demo/demo_inferencer.py new file mode 100644 index 0000000000..6cc99a7527 --- /dev/null +++ b/demo/demo_inferencer.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +from mmaction.apis.inferencers import ActionRecInferencer + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + 'inputs', type=str, help='Input image file or folder path.') + parser.add_argument( + '--vid-out-dir', + type=str, + default='', + help='Output directory of videos.') + parser.add_argument( + '--rec', + type=str, + default=None, + help='Pretrained text recognition algorithm. It\'s the path to the ' + 'config file or the model name defined in metafile.') + parser.add_argument( + '--rec-weights', + type=str, + default=None, + help='Path to the custom checkpoint file of the selected recog model. ' + 'If it is not specified and "rec" is a model name of metafile, the ' + 'weights will be loaded from metafile.') + parser.add_argument( + '--device', + type=str, + default=None, + help='Device used for inference. ' + 'If not specified, the available device will be automatically used.') + parser.add_argument( + '--batch-size', type=int, default=1, help='Inference batch size.') + parser.add_argument( + '--show', + action='store_true', + help='Display the image in a popup window.') + parser.add_argument( + '--print-result', + action='store_true', + help='Whether to print the results.') + parser.add_argument( + '--pred-out-file', + type=str, + default='', + help='File to save the inference results.') + + call_args = vars(parser.parse_args()) + + init_kws = ['rec', 'rec_weights', 'device'] + init_args = {} + for init_kw in init_kws: + init_args[init_kw] = call_args.pop(init_kw) + + return init_args, call_args + + +def main(): + init_args, call_args = parse_args() + ocr = ActionRecInferencer(**init_args) + ocr(**call_args) + + +if __name__ == '__main__': + main() diff --git a/mmaction/apis/inferencers/__init__.py b/mmaction/apis/inferencers/__init__.py index ed6bacbe1b..8cfa627598 100644 --- a/mmaction/apis/inferencers/__init__.py +++ b/mmaction/apis/inferencers/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .actionrec_inferencer import ActionRecInferencer +from .actionrecog_inferencer import ActionRecInferencer from .base_mmaction_inferencer import BaseMMActionInferencer __all__ = ['ActionRecInferencer', 'BaseMMActionInferencer'] diff --git a/mmaction/apis/inferencers/actionrec_inferencer.py b/mmaction/apis/inferencers/actionrecog_inferencer.py similarity index 86% rename from mmaction/apis/inferencers/actionrec_inferencer.py rename to mmaction/apis/inferencers/actionrecog_inferencer.py index 744b793f3d..084095266e 100644 --- a/mmaction/apis/inferencers/actionrec_inferencer.py +++ b/mmaction/apis/inferencers/actionrecog_inferencer.py @@ -2,10 +2,10 @@ from typing import Dict from mmaction.structures import ActionDataSample -from .base_mmaction_inferencer import BaseMMActionInferencer +from .base_mmaction_inferencer import BaseMMAction2Inferencer -class ActionRecInferencer(BaseMMActionInferencer): +class ActionRecInferencer(BaseMMAction2Inferencer): def pred2dict(self, data_sample: ActionDataSample) -> Dict: """Extract elements necessary to represent a prediction into a diff --git a/mmaction/apis/inferencers/base_mmaction_inferencer.py b/mmaction/apis/inferencers/base_mmaction_inferencer.py index c642c0de18..691c9430b4 100644 --- a/mmaction/apis/inferencers/base_mmaction_inferencer.py +++ b/mmaction/apis/inferencers/base_mmaction_inferencer.py @@ -19,7 +19,7 @@ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] -class BaseMMActionInferencer(BaseInferencer): +class BaseMMAction2Inferencer(BaseInferencer): """Base inferencer. Args: From 09cf59a37563b09b78f42d0c68abab65bc2ede6f Mon Sep 17 00:00:00 2001 From: lilin Date: Thu, 12 Jan 2023 21:59:56 +0800 Subject: [PATCH 03/12] [Enhance] add mmaction_inferencer and demo_inferencer --- demo/demo_inferencer.py | 14 +- mmaction/apis/inferencers/__init__.py | 9 +- .../inferencers/actionrecog_inferencer.py | 4 +- .../inferencers/base_mmaction_inferencer.py | 17 +- .../apis/inferencers/mmaction_inferencer.py | 187 ++++++++++++++++++ mmaction/visualization/action_visualizer.py | 1 + tests/apis/test_inferencer.py | 10 +- 7 files changed, 218 insertions(+), 24 deletions(-) create mode 100644 mmaction/apis/inferencers/mmaction_inferencer.py diff --git a/demo/demo_inferencer.py b/demo/demo_inferencer.py index 6cc99a7527..508baca511 100644 --- a/demo/demo_inferencer.py +++ b/demo/demo_inferencer.py @@ -1,13 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. from argparse import ArgumentParser -from mmaction.apis.inferencers import ActionRecInferencer +from mmaction.apis.inferencers import MMAction2Inferencer def parse_args(): parser = ArgumentParser() parser.add_argument( - 'inputs', type=str, help='Input image file or folder path.') + 'inputs', type=str, help='Input video file or rawframes folder path.') parser.add_argument( '--vid-out-dir', type=str, @@ -17,7 +17,7 @@ def parse_args(): '--rec', type=str, default=None, - help='Pretrained text recognition algorithm. It\'s the path to the ' + help='Pretrained action recognition algorithm. It\'s the path to the ' 'config file or the model name defined in metafile.') parser.add_argument( '--rec-weights', @@ -26,6 +26,8 @@ def parse_args(): help='Path to the custom checkpoint file of the selected recog model. ' 'If it is not specified and "rec" is a model name of metafile, the ' 'weights will be loaded from metafile.') + parser.add_argument( + '--label-file', type=str, default=None, help='label file for dataset.') parser.add_argument( '--device', type=str, @@ -50,7 +52,7 @@ def parse_args(): call_args = vars(parser.parse_args()) - init_kws = ['rec', 'rec_weights', 'device'] + init_kws = ['rec', 'rec_weights', 'device', 'label_file'] init_args = {} for init_kw in init_kws: init_args[init_kw] = call_args.pop(init_kw) @@ -60,8 +62,8 @@ def parse_args(): def main(): init_args, call_args = parse_args() - ocr = ActionRecInferencer(**init_args) - ocr(**call_args) + mmaction2 = MMAction2Inferencer(**init_args) + mmaction2(**call_args) if __name__ == '__main__': diff --git a/mmaction/apis/inferencers/__init__.py b/mmaction/apis/inferencers/__init__.py index 8cfa627598..06677bdddc 100644 --- a/mmaction/apis/inferencers/__init__.py +++ b/mmaction/apis/inferencers/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .actionrecog_inferencer import ActionRecInferencer -from .base_mmaction_inferencer import BaseMMActionInferencer +from .actionrecog_inferencer import ActionRecogInferencer +from .base_mmaction_inferencer import BaseMMAction2Inferencer +from .mmaction_inferencer import MMAction2Inferencer -__all__ = ['ActionRecInferencer', 'BaseMMActionInferencer'] +__all__ = [ + 'ActionRecogInferencer', 'BaseMMAction2Inferencer', 'MMAction2Inferencer' +] diff --git a/mmaction/apis/inferencers/actionrecog_inferencer.py b/mmaction/apis/inferencers/actionrecog_inferencer.py index 084095266e..ab552b1be4 100644 --- a/mmaction/apis/inferencers/actionrecog_inferencer.py +++ b/mmaction/apis/inferencers/actionrecog_inferencer.py @@ -5,7 +5,7 @@ from .base_mmaction_inferencer import BaseMMAction2Inferencer -class ActionRecInferencer(BaseMMAction2Inferencer): +class ActionRecogInferencer(BaseMMAction2Inferencer): def pred2dict(self, data_sample: ActionDataSample) -> Dict: """Extract elements necessary to represent a prediction into a @@ -19,6 +19,6 @@ def pred2dict(self, data_sample: ActionDataSample) -> Dict: dict: The output dictionary. """ result = {} - result['pred_labels'] = data_sample.pred_labels.item + result['pred_labels'] = data_sample.pred_labels.item.tolist() result['pred_scores'] = data_sample.pred_scores.item.tolist() return result diff --git a/mmaction/apis/inferencers/base_mmaction_inferencer.py b/mmaction/apis/inferencers/base_mmaction_inferencer.py index 691c9430b4..b216c00039 100644 --- a/mmaction/apis/inferencers/base_mmaction_inferencer.py +++ b/mmaction/apis/inferencers/base_mmaction_inferencer.py @@ -32,7 +32,7 @@ class BaseMMAction2Inferencer(BaseInferencer): from metafile. Defaults to None. device (str, optional): Device to run inference. If None, the available device will be automatically used. Defaults to None. - label (str): label file for dataset. + label_file (str, optional): label file for dataset. input_format (str): Input video format, Choices are 'video', 'rawframes', 'array'. 'video' means input data is a video file, 'rawframes' means input data is a video frame folder, and 'array' @@ -55,10 +55,10 @@ def __init__(self, model: Union[ModelType, str], weights: Optional[str] = None, device: Optional[str] = None, - scope: Optional[str] = 'mmaction2', - label: Optional[str] = None, + label_file: Optional[str] = None, input_format: str = 'video', - pack_cfg: dict = {}) -> None: + pack_cfg: dict = {}, + scope: Optional[str] = 'mmaction2') -> None: # A global counter tracking the number of videos processed, for # naming of the output videos self.num_visualized_vids = 0 @@ -68,8 +68,9 @@ def __init__(self, super().__init__( model=model, weights=weights, device=device, scope=scope) - if label is not None: - self.visualizer.dataset_meta = dict(classes=list_from_file(label)) + if label_file is not None: + self.visualizer.dataset_meta = dict( + classes=list_from_file(label_file)) def __call__(self, inputs: InputsType, @@ -263,7 +264,7 @@ def visualize( out_path = osp.join(vid_out_dir, video_name) if vid_out_dir != '' \ else None - self.visualizer.add_datasample( + visualization = self.visualizer.add_datasample( video_name, frames, pred, @@ -276,7 +277,7 @@ def visualize( out_path=out_path, target_resolution=target_resolution, ) - results.append(frames) + results.append(visualization) self.num_visualized_vids += 1 return results diff --git a/mmaction/apis/inferencers/mmaction_inferencer.py b/mmaction/apis/inferencers/mmaction_inferencer.py new file mode 100644 index 0000000000..f23341ea4f --- /dev/null +++ b/mmaction/apis/inferencers/mmaction_inferencer.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import mmengine +import numpy as np +from mmengine.structures import InstanceData + +from .actionrecog_inferencer import ActionRecogInferencer +from .base_mmaction_inferencer import BaseMMAction2Inferencer + +InstanceList = List[InstanceData] +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[InstanceData, InstanceList] +ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] + + +class MMAction2Inferencer(BaseMMAction2Inferencer): + """MMAction2 Inferencer. It's a wrapper around base task inferenecers: + ActionRecog, and it can be used to perform end-to-end action recognition + inference. + + Args: + rec (str, optional): Pretrained action recognition + algorithm. It's the path to the config file or the model name + defined in metafile. Defaults to None. + rec_weights (str, optional): Path to the custom checkpoint file of + the selected rec model. If it is not specified and "rec" is a model + name of metafile, the weights will be loaded from metafile. + Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + label_file (str, optional): label file for dataset. + input_format (str): Input video format, Choices are 'video', + 'rawframes', 'array'. 'video' means input data is a video file, + 'rawframes' means input data is a video frame folder, and 'array' + means input data is a np.ndarray. Defaults to 'video'. + """ + + def __init__(self, + rec: Optional[str] = None, + rec_weights: Optional[str] = None, + device: Optional[str] = None, + label_file: Optional[str] = None, + input_format: str = 'video') -> None: + + if rec is None: + raise ValueError('rec algorithm should provided.') + + self.visualizer = None + self.num_visualized_imgs = 0 + + if rec is not None: + self.actionrecog_inferencer = ActionRecogInferencer( + rec, rec_weights, device, label_file, input_format) + self.mode = 'rec' + + def forward(self, inputs: InputType, batch_size: int, + **forward_kwargs) -> PredType: + """Forward the inputs to the model. + + Args: + inputs (InputsType): The inputs to be forwarded. + batch_size (int): Batch size. Defaults to 1. + + Returns: + Dict: The prediction results. Possibly with keys "rec". + """ + result = {} + if self.mode == 'rec': + predictions = self.actionrecog_inferencer( + inputs, + return_datasamples=True, + batch_size=batch_size, + **forward_kwargs)['predictions'] + result['rec'] = [[p] for p in predictions] + + return result + + def visualize(self, inputs: InputsType, preds: PredType, + **kwargs) -> List[np.ndarray]: + """Visualize predictions. + + Args: + inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. + preds (List[Dict]): Predictions of the model. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + draw_pred (bool): Whether to draw predicted bounding boxes. + Defaults to True. + fps (int): Frames per second for saving video. Defaults to 4. + out_type (str): Output format type, choose from 'img', 'gif', + 'video'. Defaults to ``'img'``. + target_resolution (Tuple[int], optional): Set to + (desired_width desired_height) to have resized frames. If + either dimension is None, the frames are resized by keeping + the existing aspect ratio. Defaults to None. + vid_out_dir (str): Output directory of visualization results. + If left as empty, no file will be saved. Defaults to ''. + """ + + if 'rec' in self.mode: + return self.actionrecog_inferencer.visualize( + inputs, preds['rec'][0], **kwargs) + + def __call__( + self, + inputs: InputsType, + batch_size: int = 1, + **kwargs, + ) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. It can be a path + to image / image directory, or an array, or a list of these. + return_datasamples (bool): Whether to return results as + :obj:`BaseDataElement`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + **kwargs: Key words arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) = self._dispatch_kwargs(**kwargs) + + ori_inputs = self._inputs_to_list(inputs) + + preds = self.forward(ori_inputs, batch_size, **forward_kwargs) + + visualization = self.visualize( + ori_inputs, preds, + **visualize_kwargs) # type: ignore # noqa: E501 + results = self.postprocess(preds, visualization, **postprocess_kwargs) + return results + + def postprocess(self, + preds: PredType, + visualization: Optional[List[np.ndarray]] = None, + print_result: bool = False, + pred_out_file: str = '' + ) -> Union[ResType, Tuple[ResType, np.ndarray]]: + """Postprocess predictions. + + Args: + preds (Dict): Predictions of the model. + visualization (Optional[np.ndarray]): Visualized predictions. + print_result (bool): Whether to print the result. + Defaults to False. + pred_out_file (str): Output file name to store predictions + without images. Supported file formats are “json”, “yaml/yml” + and “pickle/pkl”. Defaults to ''. + + Returns: + Dict or List[Dict]: Each dict contains the inference result of + each image. Possible keys are "rec_labels", "rec_scores" + """ + + result_dict = {} + pred_results = [{} for _ in range(len(next(iter(preds.values()))))] + if 'rec' in self.mode: + for i, rec_pred in enumerate(preds['rec']): + result = dict(rec_labels=[], rec_scores=[]) + for rec_pred_instance in rec_pred: + rec_dict_res = self.actionrecog_inferencer.pred2dict( + rec_pred_instance) + result['rec_labels'].append(rec_dict_res['pred_labels']) + result['rec_scores'].append(rec_dict_res['pred_scores']) + pred_results[i].update(result) + + result_dict['predictions'] = pred_results + if print_result: + print(result_dict) + if pred_out_file != '': + mmengine.dump(result_dict, pred_out_file) + result_dict['visualization'] = visualization + return result_dict diff --git a/mmaction/visualization/action_visualizer.py b/mmaction/visualization/action_visualizer.py index 8d8781ed9f..48c595fd5b 100644 --- a/mmaction/visualization/action_visualizer.py +++ b/mmaction/visualization/action_visualizer.py @@ -284,6 +284,7 @@ def add_datasample(self, else: self.add_video( name, resulted_video, step=step, fps=fps, out_type=out_type) + return resulted_video @master_only def add_video( diff --git a/tests/apis/test_inferencer.py b/tests/apis/test_inferencer.py index a16e941c8d..b35551c290 100644 --- a/tests/apis/test_inferencer.py +++ b/tests/apis/test_inferencer.py @@ -6,7 +6,7 @@ import torch from parameterized import parameterized -from mmaction.apis import ActionRecInferencer +from mmaction.apis import ActionRecogInferencer from mmaction.structures import ActionDataSample from mmaction.utils import register_all_modules @@ -26,11 +26,11 @@ def test_init_recognizer(self, config, lable_file, devices): # Skip the test if cuda is required but unavailable continue - _ = ActionRecInferencer(config, label=lable_file, device=device) + _ = ActionRecogInferencer(config, label=lable_file, device=device) # test `init_recognizer` with invalid config with self.assertRaisesRegex(ValueError, 'Cannot find model'): - _ = ActionRecInferencer( + _ = ActionRecogInferencer( 'slowfast_config', label=lable_file, device=device) @parameterized.expand([ @@ -47,7 +47,7 @@ def test_inference_recognizer(self, config, label_file, video_path, continue # test video file input and return datasample - inferencer = ActionRecInferencer( + inferencer = ActionRecogInferencer( config, label=label_file, device=device) results = inferencer( video_path, vid_out_dir=tmp_dir, return_datasamples=True) @@ -65,7 +65,7 @@ def test_inference_recognizer(self, config, label_file, video_path, osp.basename(video_path).replace('mp4', 'gif'))) # test np.ndarray input - inferencer = ActionRecInferencer( + inferencer = ActionRecogInferencer( config, label=label_file, device=device, From 5be7f1d5402a1056c1da74539e868ade91f903e4 Mon Sep 17 00:00:00 2001 From: lilin Date: Fri, 13 Jan 2023 14:55:11 +0800 Subject: [PATCH 04/12] add INFERENCER registry --- mmaction/apis/inferencers/actionrecog_inferencer.py | 3 +++ mmaction/registry.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/mmaction/apis/inferencers/actionrecog_inferencer.py b/mmaction/apis/inferencers/actionrecog_inferencer.py index ab552b1be4..dd68f8935b 100644 --- a/mmaction/apis/inferencers/actionrecog_inferencer.py +++ b/mmaction/apis/inferencers/actionrecog_inferencer.py @@ -1,10 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict +from mmaction.registry import INFERENCERS from mmaction.structures import ActionDataSample from .base_mmaction_inferencer import BaseMMAction2Inferencer +@INFERENCERS.register_module(name='action-recognition') +@INFERENCERS.register_module() class ActionRecogInferencer(BaseMMAction2Inferencer): def pred2dict(self, data_sample: ActionDataSample) -> Dict: diff --git a/mmaction/registry.py b/mmaction/registry.py index db56340ed9..4efd3b4a16 100644 --- a/mmaction/registry.py +++ b/mmaction/registry.py @@ -10,6 +10,7 @@ from mmengine.registry import DATASETS as MMENGINE_DATASETS from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS from mmengine.registry import LOOPS as MMENGINE_LOOPS from mmengine.registry import METRICS as MMENGINE_METRICS @@ -81,3 +82,6 @@ # manage logprocessor LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS) + +# manage inferencer +INFERENCERS = Registry('inferencer', parent=MMENGINE_INFERENCERS) From eebd18ab049e379f1316dbd0b599cd8070a83097 Mon Sep 17 00:00:00 2001 From: lilin Date: Fri, 13 Jan 2023 19:40:17 +0800 Subject: [PATCH 05/12] remove base_inferencer --- mmaction/apis/inferencers/__init__.py | 5 +- .../inferencers/actionrecog_inferencer.py | 337 ++++++++++++++++- .../inferencers/base_mmaction_inferencer.py | 347 ------------------ .../apis/inferencers/mmaction_inferencer.py | 47 ++- 4 files changed, 377 insertions(+), 359 deletions(-) delete mode 100644 mmaction/apis/inferencers/base_mmaction_inferencer.py diff --git a/mmaction/apis/inferencers/__init__.py b/mmaction/apis/inferencers/__init__.py index 06677bdddc..16c93e0ab0 100644 --- a/mmaction/apis/inferencers/__init__.py +++ b/mmaction/apis/inferencers/__init__.py @@ -1,8 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .actionrecog_inferencer import ActionRecogInferencer -from .base_mmaction_inferencer import BaseMMAction2Inferencer from .mmaction_inferencer import MMAction2Inferencer -__all__ = [ - 'ActionRecogInferencer', 'BaseMMAction2Inferencer', 'MMAction2Inferencer' -] +__all__ = ['ActionRecogInferencer', 'MMAction2Inferencer'] diff --git a/mmaction/apis/inferencers/actionrecog_inferencer.py b/mmaction/apis/inferencers/actionrecog_inferencer.py index dd68f8935b..3d9389c6bd 100644 --- a/mmaction/apis/inferencers/actionrecog_inferencer.py +++ b/mmaction/apis/inferencers/actionrecog_inferencer.py @@ -1,14 +1,345 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict +import os.path as osp +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import mmengine +import numpy as np +from mmengine.dataset import Compose +from mmengine.fileio import list_from_file +from mmengine.infer.infer import BaseInferencer, ModelType +from mmengine.structures import InstanceData from mmaction.registry import INFERENCERS from mmaction.structures import ActionDataSample -from .base_mmaction_inferencer import BaseMMAction2Inferencer +from mmaction.utils import ConfigType, register_all_modules + +InstanceList = List[InstanceData] +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[InstanceData, InstanceList] +ImgType = Union[np.ndarray, Sequence[np.ndarray]] +ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] @INFERENCERS.register_module(name='action-recognition') @INFERENCERS.register_module() -class ActionRecogInferencer(BaseMMAction2Inferencer): +class ActionRecogInferencer(BaseInferencer): + """The inferencer for action recognition. + + Args: + model (str, optional): Path to the config file or the model name + defined in metafile. For example, it could be + "slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb" or + "configs/recognition/slowfast/slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb.py". + weights (str, optional): Path to the checkpoint. If it is not specified + and model is a model name of metafile, the weights will be loaded + from metafile. Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + label_file (str, optional): label file for dataset. + input_format (str): Input video format, Choices are 'video', + 'rawframes', 'array'. 'video' means input data is a video file, + 'rawframes' means input data is a video frame folder, and 'array' + means input data is a np.ndarray. Defaults to 'video'. + pack_cfg (dict, optional): Config for `InferencerPackInput` to load + input. Defaults to empty dict. + """ + + preprocess_kwargs: set = set() + forward_kwargs: set = set() + visualize_kwargs: set = { + 'return_vis', 'show', 'wait_time', 'vid_out_dir', 'draw_pred', 'fps', + 'out_type', 'target_resolution' + } + postprocess_kwargs: set = { + 'print_result', 'pred_out_file', 'return_datasample' + } + + def __init__(self, + model: Union[ModelType, str], + weights: Optional[str] = None, + device: Optional[str] = None, + label_file: Optional[str] = None, + input_format: str = 'video', + pack_cfg: dict = {}, + scope: Optional[str] = 'mmaction2') -> None: + # A global counter tracking the number of videos processed, for + # naming of the output videos + self.num_visualized_vids = 0 + self.input_format = input_format + self.pack_cfg = pack_cfg.copy() + register_all_modules() + super().__init__( + model=model, weights=weights, device=device, scope=scope) + + if label_file is not None: + self.visualizer.dataset_meta = dict( + classes=list_from_file(label_file)) + + def __call__(self, + inputs: InputsType, + return_datasamples: bool = False, + batch_size: int = 1, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + draw_pred: bool = True, + vid_out_dir: str = '', + out_type: str = 'video', + print_result: bool = False, + pred_out_file: str = '', + target_resolution: Optional[Tuple[int]] = None, + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. + return_datasamples (bool): Whether to return results as + :obj:`BaseDataElement`. Defaults to False. + batch_size (int): Inference batch size. Defaults to 1. + show (bool): Whether to display the visualization results in a + popup window. Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + draw_pred (bool): Whether to draw predicted bounding boxes. + Defaults to True. + vid_out_dir (str): Output directory of visualization results. + If left as empty, no file will be saved. Defaults to ''. + out_type (str): Output type of visualization results. + Defaults to 'video'. + print_result (bool): Whether to print the inference result w/o + visualization to the console. Defaults to False. + pred_out_file: File to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + **kwargs: Other keyword arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + return super().__call__( + inputs, + return_datasamples, + batch_size, + return_vis=return_vis, + show=show, + wait_time=wait_time, + draw_pred=draw_pred, + vid_out_dir=vid_out_dir, + print_result=print_result, + pred_out_file=pred_out_file, + out_type=out_type, + target_resolution=target_resolution, + **kwargs) + + def _inputs_to_list(self, inputs: InputsType) -> list: + """Preprocess the inputs to a list. The main difference from mmengine + version is that we don't list a directory cause input could be a frame + folder. + + Preprocess inputs to a list according to its type: + + - list or tuple: return inputs + - str: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + + Args: + inputs (InputsType): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + return list(inputs) + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline.""" + test_pipeline = cfg.test_dataloader.dataset.pipeline + # Alter data pipelines for decode + if self.input_format == 'array': + for i in range(len(test_pipeline)): + if 'Decode' in test_pipeline[i]['type']: + test_pipeline[i] = dict(type='ArrayDecode') + test_pipeline = [ + x for x in test_pipeline if 'Init' not in x['type'] + ] + elif self.input_format == 'video': + if 'Init' not in test_pipeline[0]['type']: + test_pipeline = [dict(type='DecordInit')] + test_pipeline + else: + test_pipeline[0] = dict(type='DecordInit') + for i in range(len(test_pipeline)): + if 'Decode' in test_pipeline[i]['type']: + test_pipeline[i] = dict(type='DecordDecode') + elif self.input_format == 'rawframes': + if 'Init' in test_pipeline[0]['type']: + test_pipeline = test_pipeline[1:] + for i in range(len(test_pipeline)): + if 'Decode' in test_pipeline[i]['type']: + test_pipeline[i] = dict(type='RawFrameDecode') + # Alter data pipelines to close TTA, avoid OOM + # Use center crop instead of multiple crop + for i in range(len(test_pipeline)): + if test_pipeline[i]['type'] in ['ThreeCrop', 'TenCrop']: + test_pipeline[i]['type'] = 'CenterCrop' + # Use single clip for `Recognizer3D` + if cfg.model.type == 'Recognizer3D': + for i in range(len(test_pipeline)): + if test_pipeline[i]['type'] == 'SampleFrames': + test_pipeline[i]['num_clips'] = 1 + # Pack multiple types of input format + test_pipeline.insert( + 0, + dict( + type='InferencerPackInput', + input_format=self.input_format, + **self.pack_cfg)) + + return Compose(test_pipeline) + + def visualize( + self, + inputs: InputsType, + preds: PredType, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + draw_pred: bool = True, + fps: int = 30, + out_type: str = 'video', + target_resolution: Optional[Tuple[int]] = None, + vid_out_dir: str = '', + ) -> Union[List[np.ndarray], None]: + """Visualize predictions. + + Args: + inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. + preds (List[Dict]): Predictions of the model. + return_vis (bool): Whether to return the visualization result. + Defaults to False. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + draw_pred (bool): Whether to draw prediction labels. + Defaults to True. + fps (int): Frames per second for saving video. Defaults to 4. + out_type (str): Output format type, choose from 'img', 'gif', + 'video'. Defaults to ``'img'``. + target_resolution (Tuple[int], optional): Set to + (desired_width desired_height) to have resized frames. If + either dimension is None, the frames are resized by keeping + the existing aspect ratio. Defaults to None. + vid_out_dir (str): Output directory of visualization results. + If left as empty, no file will be saved. Defaults to ''. + + Returns: + List[np.ndarray] or None: Returns visualization results only if + applicable. + """ + if self.visualizer is None or (not show and vid_out_dir == '' + and not return_vis): + return None + + if getattr(self, 'visualizer') is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None.') + + results = [] + + for single_input, pred in zip(inputs, preds): + if isinstance(single_input, str): + frames = single_input + video_name = osp.basename(single_input) + elif isinstance(single_input, np.ndarray): + frames = single_input.copy() + video_num = str(self.num_visualized_vids).zfill(8) + video_name = f'{video_num}.mp4' + else: + raise ValueError('Unsupported input type: ' + f'{type(single_input)}') + + out_path = osp.join(vid_out_dir, video_name) if vid_out_dir != '' \ + else None + + visualization = self.visualizer.add_datasample( + video_name, + frames, + pred, + show_frames=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=draw_pred, + fps=fps, + out_type=out_type, + out_path=out_path, + target_resolution=target_resolution, + ) + results.append(visualization) + self.num_visualized_vids += 1 + + return results + + def postprocess( + self, + preds: PredType, + visualization: Optional[List[np.ndarray]] = None, + return_datasample: bool = False, + print_result: bool = False, + pred_out_file: str = '', + ) -> Union[ResType, Tuple[ResType, np.ndarray]]: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Convert datasamples into a json-serializable dict if needed. + 2. Pack the predictions and visualization results and return them. + 3. Dump or log the predictions. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (Optional[np.ndarray]): Visualized predictions. + return_datasample (bool): Whether to use Datasample to store + inference results. If False, dict will be used. + print_result (bool): Whether to print the inference result w/o + visualization to the console. Defaults to False. + pred_out_file: File to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization``. + + - ``visualization`` (Any): Returned by :meth:`visualize`. + - ``predictions`` (dict or DataSample): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it usually should be a + json-serializable dict containing only basic data elements such + as strings and numbers. + """ + result_dict = {} + results = preds + if not return_datasample: + results = [] + for pred in preds: + result = self.pred2dict(pred) + results.append(result) + # Add video to the results after printing and dumping + result_dict['predictions'] = results + if print_result: + print(result_dict) + if pred_out_file != '': + mmengine.dump(result_dict, pred_out_file) + result_dict['visualization'] = visualization + return result_dict def pred2dict(self, data_sample: ActionDataSample) -> Dict: """Extract elements necessary to represent a prediction into a diff --git a/mmaction/apis/inferencers/base_mmaction_inferencer.py b/mmaction/apis/inferencers/base_mmaction_inferencer.py deleted file mode 100644 index b216c00039..0000000000 --- a/mmaction/apis/inferencers/base_mmaction_inferencer.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp -from typing import Dict, List, Optional, Sequence, Tuple, Union - -import mmengine -import numpy as np -from mmengine.dataset import Compose -from mmengine.fileio import list_from_file -from mmengine.infer.infer import BaseInferencer, ModelType -from mmengine.structures import InstanceData - -from mmaction.utils import ConfigType, register_all_modules - -InstanceList = List[InstanceData] -InputType = Union[str, np.ndarray] -InputsType = Union[InputType, Sequence[InputType]] -PredType = Union[InstanceData, InstanceList] -ImgType = Union[np.ndarray, Sequence[np.ndarray]] -ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] - - -class BaseMMAction2Inferencer(BaseInferencer): - """Base inferencer. - - Args: - model (str, optional): Path to the config file or the model name - defined in metafile. For example, it could be - "slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb" or - "configs/recognition/slowfast/slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb.py". - weights (str, optional): Path to the checkpoint. If it is not specified - and model is a model name of metafile, the weights will be loaded - from metafile. Defaults to None. - device (str, optional): Device to run inference. If None, the available - device will be automatically used. Defaults to None. - label_file (str, optional): label file for dataset. - input_format (str): Input video format, Choices are 'video', - 'rawframes', 'array'. 'video' means input data is a video file, - 'rawframes' means input data is a video frame folder, and 'array' - means input data is a np.ndarray. Defaults to 'video'. - pack_cfg (dict, optional): Config for `InferencerPackInput` to load - input. Defaults to empty dict. - """ - - preprocess_kwargs: set = set() - forward_kwargs: set = set() - visualize_kwargs: set = { - 'return_vis', 'show', 'wait_time', 'vid_out_dir', 'draw_pred', 'fps', - 'out_type', 'target_resolution' - } - postprocess_kwargs: set = { - 'print_result', 'pred_out_file', 'return_datasample' - } - - def __init__(self, - model: Union[ModelType, str], - weights: Optional[str] = None, - device: Optional[str] = None, - label_file: Optional[str] = None, - input_format: str = 'video', - pack_cfg: dict = {}, - scope: Optional[str] = 'mmaction2') -> None: - # A global counter tracking the number of videos processed, for - # naming of the output videos - self.num_visualized_vids = 0 - self.input_format = input_format - self.pack_cfg = pack_cfg.copy() - register_all_modules() - super().__init__( - model=model, weights=weights, device=device, scope=scope) - - if label_file is not None: - self.visualizer.dataset_meta = dict( - classes=list_from_file(label_file)) - - def __call__(self, - inputs: InputsType, - return_datasamples: bool = False, - batch_size: int = 1, - return_vis: bool = False, - show: bool = False, - wait_time: int = 0, - draw_pred: bool = True, - vid_out_dir: str = '', - out_type: str = 'video', - print_result: bool = False, - pred_out_file: str = '', - target_resolution: Optional[Tuple[int]] = None, - **kwargs) -> dict: - """Call the inferencer. - - Args: - inputs (InputsType): Inputs for the inferencer. - return_datasamples (bool): Whether to return results as - :obj:`BaseDataElement`. Defaults to False. - batch_size (int): Inference batch size. Defaults to 1. - show (bool): Whether to display the visualization results in a - popup window. Defaults to False. - wait_time (float): The interval of show (s). Defaults to 0. - draw_pred (bool): Whether to draw predicted bounding boxes. - Defaults to True. - vid_out_dir (str): Output directory of visualization results. - If left as empty, no file will be saved. Defaults to ''. - out_type (str): Output type of visualization results. - Defaults to 'video'. - print_result (bool): Whether to print the inference result w/o - visualization to the console. Defaults to False. - pred_out_file: File to save the inference results w/o - visualization. If left as empty, no file will be saved. - Defaults to ''. - - **kwargs: Other keyword arguments passed to :meth:`preprocess`, - :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. - Each key in kwargs should be in the corresponding set of - ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` - and ``postprocess_kwargs``. - - Returns: - dict: Inference and visualization results. - """ - return super().__call__( - inputs, - return_datasamples, - batch_size, - return_vis=return_vis, - show=show, - wait_time=wait_time, - draw_pred=draw_pred, - vid_out_dir=vid_out_dir, - print_result=print_result, - pred_out_file=pred_out_file, - out_type=out_type, - target_resolution=target_resolution, - **kwargs) - - def _inputs_to_list(self, inputs: InputsType) -> list: - """Preprocess the inputs to a list. The main difference from mmengine - version is that we don't list a directory cause input could be a frame - folder. - - Preprocess inputs to a list according to its type: - - - list or tuple: return inputs - - str: return a list containing the string. The string - could be a path to file, a url or other types of string according - to the task. - - Args: - inputs (InputsType): Inputs for the inferencer. - - Returns: - list: List of input for the :meth:`preprocess`. - """ - if not isinstance(inputs, (list, tuple)): - inputs = [inputs] - - return list(inputs) - - def _init_pipeline(self, cfg: ConfigType) -> Compose: - """Initialize the test pipeline.""" - test_pipeline = cfg.test_dataloader.dataset.pipeline - # Alter data pipelines for decode - if self.input_format == 'array': - for i in range(len(test_pipeline)): - if 'Decode' in test_pipeline[i]['type']: - test_pipeline[i] = dict(type='ArrayDecode') - test_pipeline = [ - x for x in test_pipeline if 'Init' not in x['type'] - ] - elif self.input_format == 'video': - if 'Init' not in test_pipeline[0]['type']: - test_pipeline = [dict(type='DecordInit')] + test_pipeline - else: - test_pipeline[0] = dict(type='DecordInit') - for i in range(len(test_pipeline)): - if 'Decode' in test_pipeline[i]['type']: - test_pipeline[i] = dict(type='DecordDecode') - elif self.input_format == 'rawframes': - if 'Init' in test_pipeline[0]['type']: - test_pipeline = test_pipeline[1:] - for i in range(len(test_pipeline)): - if 'Decode' in test_pipeline[i]['type']: - test_pipeline[i] = dict(type='RawFrameDecode') - # Alter data pipelines to close TTA, avoid OOM - # Use center crop instead of multiple crop - for i in range(len(test_pipeline)): - if test_pipeline[i]['type'] in ['ThreeCrop', 'TenCrop']: - test_pipeline[i]['type'] = 'CenterCrop' - # Use single clip for `Recognizer3D` - if cfg.model.type == 'Recognizer3D': - for i in range(len(test_pipeline)): - if test_pipeline[i]['type'] == 'SampleFrames': - test_pipeline[i]['num_clips'] = 1 - # Pack multiple types of input format - test_pipeline.insert( - 0, - dict( - type='InferencerPackInput', - input_format=self.input_format, - **self.pack_cfg)) - - return Compose(test_pipeline) - - def visualize( - self, - inputs: InputsType, - preds: PredType, - return_vis: bool = False, - show: bool = False, - wait_time: int = 0, - draw_pred: bool = True, - fps: int = 30, - out_type: str = 'video', - target_resolution: Optional[Tuple[int]] = None, - vid_out_dir: str = '', - ) -> Union[List[np.ndarray], None]: - """Visualize predictions. - - Args: - inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. - preds (List[Dict]): Predictions of the model. - return_vis (bool): Whether to return the visualization result. - Defaults to False. - show (bool): Whether to display the image in a popup window. - Defaults to False. - wait_time (float): The interval of show (s). Defaults to 0. - draw_pred (bool): Whether to draw prediction labels. - Defaults to True. - fps (int): Frames per second for saving video. Defaults to 4. - out_type (str): Output format type, choose from 'img', 'gif', - 'video'. Defaults to ``'img'``. - target_resolution (Tuple[int], optional): Set to - (desired_width desired_height) to have resized frames. If - either dimension is None, the frames are resized by keeping - the existing aspect ratio. Defaults to None. - vid_out_dir (str): Output directory of visualization results. - If left as empty, no file will be saved. Defaults to ''. - - Returns: - List[np.ndarray] or None: Returns visualization results only if - applicable. - """ - if self.visualizer is None or (not show and vid_out_dir == '' - and not return_vis): - return None - - if getattr(self, 'visualizer') is None: - raise ValueError('Visualization needs the "visualizer" term' - 'defined in the config, but got None.') - - results = [] - - for single_input, pred in zip(inputs, preds): - if isinstance(single_input, str): - frames = single_input - video_name = osp.basename(single_input) - elif isinstance(single_input, np.ndarray): - frames = single_input.copy() - video_num = str(self.num_visualized_vids).zfill(8) - video_name = f'{video_num}.mp4' - else: - raise ValueError('Unsupported input type: ' - f'{type(single_input)}') - - out_path = osp.join(vid_out_dir, video_name) if vid_out_dir != '' \ - else None - - visualization = self.visualizer.add_datasample( - video_name, - frames, - pred, - show_frames=show, - wait_time=wait_time, - draw_gt=False, - draw_pred=draw_pred, - fps=fps, - out_type=out_type, - out_path=out_path, - target_resolution=target_resolution, - ) - results.append(visualization) - self.num_visualized_vids += 1 - - return results - - def postprocess( - self, - preds: PredType, - visualization: Optional[List[np.ndarray]] = None, - return_datasample: bool = False, - print_result: bool = False, - pred_out_file: str = '', - ) -> Union[ResType, Tuple[ResType, np.ndarray]]: - """Process the predictions and visualization results from ``forward`` - and ``visualize``. - - This method should be responsible for the following tasks: - - 1. Convert datasamples into a json-serializable dict if needed. - 2. Pack the predictions and visualization results and return them. - 3. Dump or log the predictions. - - Args: - preds (List[Dict]): Predictions of the model. - visualization (Optional[np.ndarray]): Visualized predictions. - return_datasample (bool): Whether to use Datasample to store - inference results. If False, dict will be used. - print_result (bool): Whether to print the inference result w/o - visualization to the console. Defaults to False. - pred_out_file: File to save the inference results w/o - visualization. If left as empty, no file will be saved. - Defaults to ''. - - Returns: - dict: Inference and visualization results with key ``predictions`` - and ``visualization``. - - - ``visualization`` (Any): Returned by :meth:`visualize`. - - ``predictions`` (dict or DataSample): Returned by - :meth:`forward` and processed in :meth:`postprocess`. - If ``return_datasample=False``, it usually should be a - json-serializable dict containing only basic data elements such - as strings and numbers. - """ - result_dict = {} - results = preds - if not return_datasample: - results = [] - for pred in preds: - result = self.pred2dict(pred) - results.append(result) - # Add video to the results after printing and dumping - result_dict['predictions'] = results - if print_result: - print(result_dict) - if pred_out_file != '': - mmengine.dump(result_dict, pred_out_file) - result_dict['visualization'] = visualization - return result_dict - - def pred2dict(self, data_sample: InstanceData) -> Dict: - """Extract elements necessary to represent a prediction into a - dictionary. - - It's better to contain only basic data elements such as strings and - numbers in order to guarantee it's json-serializable. - """ - raise NotImplementedError diff --git a/mmaction/apis/inferencers/mmaction_inferencer.py b/mmaction/apis/inferencers/mmaction_inferencer.py index f23341ea4f..6f25da175b 100644 --- a/mmaction/apis/inferencers/mmaction_inferencer.py +++ b/mmaction/apis/inferencers/mmaction_inferencer.py @@ -3,10 +3,11 @@ import mmengine import numpy as np +from mmengine.infer import BaseInferencer from mmengine.structures import InstanceData +from mmaction.utils import ConfigType from .actionrecog_inferencer import ActionRecogInferencer -from .base_mmaction_inferencer import BaseMMAction2Inferencer InstanceList = List[InstanceData] InputType = Union[str, np.ndarray] @@ -15,10 +16,10 @@ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] -class MMAction2Inferencer(BaseMMAction2Inferencer): - """MMAction2 Inferencer. It's a wrapper around base task inferenecers: - ActionRecog, and it can be used to perform end-to-end action recognition - inference. +class MMAction2Inferencer(BaseInferencer): + """MMAction2 Inferencer. It's a unified inferencer interface for video + analyse task, including: ActionRecog. and it can be used to perform end-to- + end action recognition inference. Args: rec (str, optional): Pretrained action recognition @@ -37,6 +38,16 @@ class MMAction2Inferencer(BaseMMAction2Inferencer): means input data is a np.ndarray. Defaults to 'video'. """ + preprocess_kwargs: set = set() + forward_kwargs: set = set() + visualize_kwargs: set = { + 'return_vis', 'show', 'wait_time', 'vid_out_dir', 'draw_pred', 'fps', + 'out_type', 'target_resolution' + } + postprocess_kwargs: set = { + 'print_result', 'pred_out_file', 'return_datasample' + } + def __init__(self, rec: Optional[str] = None, rec_weights: Optional[str] = None, @@ -55,6 +66,9 @@ def __init__(self, rec, rec_weights, device, label_file, input_format) self.mode = 'rec' + def _init_pipeline(self, cfg: ConfigType) -> None: + pass + def forward(self, inputs: InputType, batch_size: int, **forward_kwargs) -> PredType: """Forward the inputs to the model. @@ -144,6 +158,29 @@ def __call__( results = self.postprocess(preds, visualization, **postprocess_kwargs) return results + def _inputs_to_list(self, inputs: InputsType) -> list: + """Preprocess the inputs to a list. The main difference from mmengine + version is that we don't list a directory cause input could be a frame + folder. + + Preprocess inputs to a list according to its type: + + - list or tuple: return inputs + - str: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + + Args: + inputs (InputsType): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + return list(inputs) + def postprocess(self, preds: PredType, visualization: Optional[List[np.ndarray]] = None, From ab85d9da89243fd3a5e904902d2417cbc0e56346 Mon Sep 17 00:00:00 2001 From: lilin Date: Mon, 30 Jan 2023 16:27:35 +0800 Subject: [PATCH 06/12] add inferencer doc --- demo/README.md | 65 ++++++++++++++++++- demo/demo_inferencer.py | 2 +- .../inferencers/actionrecog_inferencer.py | 3 +- tests/apis/test_inferencer.py | 9 +-- 4 files changed, 71 insertions(+), 8 deletions(-) diff --git a/demo/README.md b/demo/README.md index 88e4c96bf8..5810c8f7e1 100644 --- a/demo/README.md +++ b/demo/README.md @@ -7,6 +7,7 @@ - [Video GradCAM Demo](#video-gradcam-demo): A demo script to visualize GradCAM results using a single video. - [Webcam demo](#webcam-demo): A demo script to implement real-time action recognition from a web camera. - [Skeleton-based Action Recognition Demo](#skeleton-based-action-recognition-demo): A demo script to predict the skeleton-based action recognition result using a single video. +- [Inferencer Demo](#inferencer): A demo script to implement fast predict for video analysis tasks based on unified inferencer interface. ## Modify configs through script arguments @@ -52,7 +53,7 @@ Optional arguments: Examples: Assume that you are located at `$MMACTION2` and have already downloaded the checkpoints to the directory `checkpoints/`, -or use checkpoint url from to directly load corresponding checkpoint, which will be automatically saved in `$HOME/.cache/torch/checkpoints`. +or use checkpoint url from `configs/` to directly load corresponding checkpoint, which will be automatically saved in `$HOME/.cache/torch/checkpoints`. 1. Recognize a video file as input by using a TSN model on cuda by default. @@ -183,7 +184,7 @@ Users can change: ## Skeleton-based Action Recognition Demo -MMAction2 provides an demo script to predict the skeleton-based action recognition result using a single video. +MMAction2 provides a demo script to predict the skeleton-based action recognition result using a single video. ```shell python demo/demo_skeleton.py ${VIDEO_FILE} ${OUT_FILENAME} \ @@ -247,3 +248,63 @@ python demo/demo_skeleton.py demo/demo_skeleton.mp4 demo/demo_skeleton_out.mp4 \ --pose-checkpoint https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth \ --label-map tools/data/skeleton/label_map_ntu60.txt ``` + +## Inferencer + +MMAction2 provides a demo script to implement fast predict for video analysis tasks based on unified inferencer interface, currently only supports action recognition task. + +```shell +python demo/demo.py ${INPUTS} \ + [--vid-out-dir ${VID_OUT_DIR}] \ + [--rec ${RECOG_TASK}] \ + [--rec-weights ${RECOG_WEIGHTS}] \ + [--label-file ${LABEL_FILE}] \ + [--device ${DEVICE_TYPE}] \ + [--batch-size ${BATCH_SIZE}] \ + [--print-result ${PRINT_RESULT}] \ + [--pred-out-file ${PRED_OUT_FILE} ] +``` + +Optional arguments: + +- `--show`: If specified, the demo will display the video in a popup window. +- `--print-result`: If specified, the demo will print the inference results' +- `VID_OUT_DIR`: Output directory of saved videos. Defaults to None, means not to save videos. +- `RECOG_TASK`: Type of Action Recognition algorithm. It could be the path to the config file, the model name or alias defined in metafile. +- `RECOG_WEIGHTS`: Path to the custom checkpoint file of the selected recog model. If it is not specified and "rec" is a model name of metafile, the weights will be loaded from metafile. +- `LABEL_FILE`: Label file for dataset the algorithm pretrained on. Defaults to None, means don't show label in result. +- `DEVICE_TYPE`: Type of device to run the demo. Allowed values are cuda device like `cuda:0` or `cpu`. Defaults to `cuda:0`. +- `BATCH_SIZE`: The batch size used in inference. Defaults to 1. +- `PRED_OUT_FILE`: File path to save the inference results. Defaults to None, means not to save prediction results. + +Examples: + +Assume that you are located at `$MMACTION2`. + +1. Recognize a video file as input by using a TSN model, loading checkpoint from metafile. + + ```shell + # The demo.mp4 and label_map_k400.txt are both from Kinetics-400 + python demo/demo_inferencer.py demo/demo.mp4 + --rec configs/recognition/tsn/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb.py \ + --label-file tools/data/kinetics/label_map_k400.txt + ``` + +2. Recognize a video file as input by using a TSN model, using model alias in metafile. + + ```shell + # The demo.mp4 and label_map_k400.txt are both from Kinetics-400 + python demo/demo_inferencer.py demo/demo.mp4 + --rec tsn \ + --label-file tools/data/kinetics/label_map_k400.txt + ``` + +3. Recognize a video file as input by using a TSN model, and then save visulization video. + + ```shell + # The demo.mp4 and label_map_k400.txt are both from Kinetics-400 + python demo/demo_inferencer.py demo/demo.mp4 + --vid-out-dir demo_out \ + --rec tsn \ + --label-file tools/data/kinetics/label_map_k400.txt + ``` diff --git a/demo/demo_inferencer.py b/demo/demo_inferencer.py index 508baca511..f7a7f365e9 100644 --- a/demo/demo_inferencer.py +++ b/demo/demo_inferencer.py @@ -39,7 +39,7 @@ def parse_args(): parser.add_argument( '--show', action='store_true', - help='Display the image in a popup window.') + help='Display the video in a popup window.') parser.add_argument( '--print-result', action='store_true', diff --git a/mmaction/apis/inferencers/actionrecog_inferencer.py b/mmaction/apis/inferencers/actionrecog_inferencer.py index 3d9389c6bd..98be51bfa9 100644 --- a/mmaction/apis/inferencers/actionrecog_inferencer.py +++ b/mmaction/apis/inferencers/actionrecog_inferencer.py @@ -43,6 +43,7 @@ class ActionRecogInferencer(BaseInferencer): means input data is a np.ndarray. Defaults to 'video'. pack_cfg (dict, optional): Config for `InferencerPackInput` to load input. Defaults to empty dict. + scope (str, optional): The scope of the model. Defaults to "mmaction". """ preprocess_kwargs: set = set() @@ -62,7 +63,7 @@ def __init__(self, label_file: Optional[str] = None, input_format: str = 'video', pack_cfg: dict = {}, - scope: Optional[str] = 'mmaction2') -> None: + scope: Optional[str] = 'mmaction') -> None: # A global counter tracking the number of videos processed, for # naming of the output videos self.num_visualized_vids = 0 diff --git a/tests/apis/test_inferencer.py b/tests/apis/test_inferencer.py index b35551c290..2d49f249d8 100644 --- a/tests/apis/test_inferencer.py +++ b/tests/apis/test_inferencer.py @@ -26,12 +26,13 @@ def test_init_recognizer(self, config, lable_file, devices): # Skip the test if cuda is required but unavailable continue - _ = ActionRecogInferencer(config, label=lable_file, device=device) + _ = ActionRecogInferencer( + config, label_file=lable_file, device=device) # test `init_recognizer` with invalid config with self.assertRaisesRegex(ValueError, 'Cannot find model'): _ = ActionRecogInferencer( - 'slowfast_config', label=lable_file, device=device) + 'slowfast_config', label_file=lable_file, device=device) @parameterized.expand([ (('tsn'), ('tools/data/kinetics/label_map_k400.txt'), @@ -48,7 +49,7 @@ def test_inference_recognizer(self, config, label_file, video_path, # test video file input and return datasample inferencer = ActionRecogInferencer( - config, label=label_file, device=device) + config, label_file=label_file, device=device) results = inferencer( video_path, vid_out_dir=tmp_dir, return_datasamples=True) self.assertIn('predictions', results) @@ -67,7 +68,7 @@ def test_inference_recognizer(self, config, label_file, video_path, # test np.ndarray input inferencer = ActionRecogInferencer( config, - label=label_file, + label_file=label_file, device=device, input_format='array') import decord From ba37121ea76576898f6a3788b1c7e31ac996a078 Mon Sep 17 00:00:00 2001 From: lilin Date: Mon, 6 Feb 2023 14:48:43 +0800 Subject: [PATCH 07/12] add ut --- .../apis/inferencers/mmaction_inferencer.py | 4 +- tests/apis/test_inferencer.py | 43 +++++++------------ 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/mmaction/apis/inferencers/mmaction_inferencer.py b/mmaction/apis/inferencers/mmaction_inferencer.py index 6f25da175b..fe71097af9 100644 --- a/mmaction/apis/inferencers/mmaction_inferencer.py +++ b/mmaction/apis/inferencers/mmaction_inferencer.py @@ -18,8 +18,8 @@ class MMAction2Inferencer(BaseInferencer): """MMAction2 Inferencer. It's a unified inferencer interface for video - analyse task, including: ActionRecog. and it can be used to perform end-to- - end action recognition inference. + analyse task, currently including: ActionRecog. and it can be used to + perform end-to-end action recognition inference. Args: rec (str, optional): Pretrained action recognition diff --git a/tests/apis/test_inferencer.py b/tests/apis/test_inferencer.py index 2d49f249d8..855b2984ed 100644 --- a/tests/apis/test_inferencer.py +++ b/tests/apis/test_inferencer.py @@ -6,41 +6,33 @@ import torch from parameterized import parameterized -from mmaction.apis import ActionRecogInferencer -from mmaction.structures import ActionDataSample +from mmaction.apis import MMAction2Inferencer from mmaction.utils import register_all_modules -class TestInferencer(TestCase): +class TestMMActionInferencer(TestCase): def setUp(self): register_all_modules() - @parameterized.expand([ - (('tsn'), ('tools/data/kinetics/label_map_k400.txt'), ('cpu', 'cuda')) - ]) - def test_init_recognizer(self, config, lable_file, devices): - - for device in devices: - if device == 'cuda' and not torch.cuda.is_available(): - # Skip the test if cuda is required but unavailable - continue + def test_init_recognizer(self): + # Initialzied by alias + _ = MMAction2Inferencer(rec='tsn') - _ = ActionRecogInferencer( - config, label_file=lable_file, device=device) + # Initialzied by config + _ = MMAction2Inferencer( + rec='tsn_imagenet-pretrained-r50_8xb32-1x1x8-100e_kinetics400-rgb' + ) # noqa: E501 - # test `init_recognizer` with invalid config - with self.assertRaisesRegex(ValueError, 'Cannot find model'): - _ = ActionRecogInferencer( - 'slowfast_config', label_file=lable_file, device=device) + with self.assertRaisesRegex(ValueError, + 'rec algorithm should provided.'): + _ = MMAction2Inferencer() @parameterized.expand([ (('tsn'), ('tools/data/kinetics/label_map_k400.txt'), ('demo/demo.mp4'), ('cpu', 'cuda')) ]) - def test_inference_recognizer(self, config, label_file, video_path, - devices): - + def test_infer_recognizer(self, config, label_file, video_path, devices): with TemporaryDirectory() as tmp_dir: for device in devices: if device == 'cuda' and not torch.cuda.is_available(): @@ -48,14 +40,11 @@ def test_inference_recognizer(self, config, label_file, video_path, continue # test video file input and return datasample - inferencer = ActionRecogInferencer( + inferencer = MMAction2Inferencer( config, label_file=label_file, device=device) - results = inferencer( - video_path, vid_out_dir=tmp_dir, return_datasamples=True) + results = inferencer(video_path, vid_out_dir=tmp_dir) self.assertIn('predictions', results) self.assertIn('visualization', results) - self.assertIsInstance(results['predictions'][0], - ActionDataSample) assert osp.exists(osp.join(tmp_dir, osp.basename(video_path))) results = inferencer( @@ -66,7 +55,7 @@ def test_inference_recognizer(self, config, label_file, video_path, osp.basename(video_path).replace('mp4', 'gif'))) # test np.ndarray input - inferencer = ActionRecogInferencer( + inferencer = MMAction2Inferencer( config, label_file=label_file, device=device, From 4612eeeb5a81ab752f185e668e6f76fed8319ec5 Mon Sep 17 00:00:00 2001 From: lilin Date: Mon, 6 Feb 2023 15:53:18 +0800 Subject: [PATCH 08/12] refine docstring --- demo/README.md | 2 +- mmaction/apis/inferencers/mmaction_inferencer.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/demo/README.md b/demo/README.md index 5810c8f7e1..a3ae8197e0 100644 --- a/demo/README.md +++ b/demo/README.md @@ -251,7 +251,7 @@ python demo/demo_skeleton.py demo/demo_skeleton.mp4 demo/demo_skeleton_out.mp4 \ ## Inferencer -MMAction2 provides a demo script to implement fast predict for video analysis tasks based on unified inferencer interface, currently only supports action recognition task. +MMAction2 provides a demo script to implement fast prediction for video analysis tasks based on unified inferencer interface, currently only supports action recognition task. ```shell python demo/demo.py ${INPUTS} \ diff --git a/mmaction/apis/inferencers/mmaction_inferencer.py b/mmaction/apis/inferencers/mmaction_inferencer.py index fe71097af9..68fa8b199d 100644 --- a/mmaction/apis/inferencers/mmaction_inferencer.py +++ b/mmaction/apis/inferencers/mmaction_inferencer.py @@ -22,14 +22,18 @@ class MMAction2Inferencer(BaseInferencer): perform end-to-end action recognition inference. Args: - rec (str, optional): Pretrained action recognition - algorithm. It's the path to the config file or the model name - defined in metafile. Defaults to None. + rec (str, optional): Pretrained action recognition algorithm. + It's the path to the config file or the model name + defined in metafile. For example, it could be "slowfast", + "slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb" or + "configs/recognition/slowfast/slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb.py". + Defaults to None. rec_weights (str, optional): Path to the custom checkpoint file of the selected rec model. If it is not specified and "rec" is a model name of metafile, the weights will be loaded from metafile. Defaults to None. - device (str, optional): Device to run inference. If None, the available + device (str, optional): Device to run inference. For example, + it could be 'cuda' or 'cpu'. If None, the available device will be automatically used. Defaults to None. label_file (str, optional): label file for dataset. input_format (str): Input video format, Choices are 'video', From a9d5b3f4373c1733fdc769b09216d220cb6e9b53 Mon Sep 17 00:00:00 2001 From: lilin Date: Tue, 7 Feb 2023 10:49:34 +0800 Subject: [PATCH 09/12] rename inferencer --- mmaction/apis/inferencers/__init__.py | 2 +- .../{mmaction_inferencer.py => mmaction2_inferencer.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename mmaction/apis/inferencers/{mmaction_inferencer.py => mmaction2_inferencer.py} (100%) diff --git a/mmaction/apis/inferencers/__init__.py b/mmaction/apis/inferencers/__init__.py index 16c93e0ab0..9f62b667cf 100644 --- a/mmaction/apis/inferencers/__init__.py +++ b/mmaction/apis/inferencers/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .actionrecog_inferencer import ActionRecogInferencer -from .mmaction_inferencer import MMAction2Inferencer +from .mmaction2_inferencer import MMAction2Inferencer __all__ = ['ActionRecogInferencer', 'MMAction2Inferencer'] diff --git a/mmaction/apis/inferencers/mmaction_inferencer.py b/mmaction/apis/inferencers/mmaction2_inferencer.py similarity index 100% rename from mmaction/apis/inferencers/mmaction_inferencer.py rename to mmaction/apis/inferencers/mmaction2_inferencer.py From 5eb0359636440e9b36ab1a511f2f1ff71d3c2c2c Mon Sep 17 00:00:00 2001 From: lilin Date: Tue, 7 Feb 2023 18:09:50 +0800 Subject: [PATCH 10/12] refine docstring --- mmaction/apis/inferencers/mmaction2_inferencer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mmaction/apis/inferencers/mmaction2_inferencer.py b/mmaction/apis/inferencers/mmaction2_inferencer.py index 68fa8b199d..9aac7ed9b3 100644 --- a/mmaction/apis/inferencers/mmaction2_inferencer.py +++ b/mmaction/apis/inferencers/mmaction2_inferencer.py @@ -23,11 +23,11 @@ class MMAction2Inferencer(BaseInferencer): Args: rec (str, optional): Pretrained action recognition algorithm. - It's the path to the config file or the model name - defined in metafile. For example, it could be "slowfast", - "slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb" or - "configs/recognition/slowfast/slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb.py". - Defaults to None. + It's the path to the config file or the model name defined in + metafile. For example, it could be + model alias, e.g. ``'slowfast'``, + config name, e.g. ``'slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb'``, # noqa E501 + or config path. Defaults to None. rec_weights (str, optional): Path to the custom checkpoint file of the selected rec model. If it is not specified and "rec" is a model name of metafile, the weights will be loaded from metafile. From 04a7e29ecb908e24bf16fb47de7dcb28d1cecd83 Mon Sep 17 00:00:00 2001 From: cir7 <33249023+cir7@users.noreply.github.com> Date: Tue, 7 Feb 2023 18:57:34 +0800 Subject: [PATCH 11/12] Update mmaction/apis/inferencers/mmaction2_inferencer.py Co-authored-by: Yining Li --- mmaction/apis/inferencers/mmaction2_inferencer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mmaction/apis/inferencers/mmaction2_inferencer.py b/mmaction/apis/inferencers/mmaction2_inferencer.py index 9aac7ed9b3..9cf5f9ce95 100644 --- a/mmaction/apis/inferencers/mmaction2_inferencer.py +++ b/mmaction/apis/inferencers/mmaction2_inferencer.py @@ -24,10 +24,14 @@ class MMAction2Inferencer(BaseInferencer): Args: rec (str, optional): Pretrained action recognition algorithm. It's the path to the config file or the model name defined in - metafile. For example, it could be - model alias, e.g. ``'slowfast'``, - config name, e.g. ``'slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb'``, # noqa E501 - or config path. Defaults to None. + metafile. For example, it could be: + + - model alias, e.g. ``'slowfast'``, + - config name, e.g. ``'slowfast_r50_8xb8-8x8x1-256e_kinetics400 + -rgb'``, + - config path + + Defaults to ``None``. rec_weights (str, optional): Path to the custom checkpoint file of the selected rec model. If it is not specified and "rec" is a model name of metafile, the weights will be loaded from metafile. From 47659fb075f9a89fd1ab118309b70a7b9c7296f1 Mon Sep 17 00:00:00 2001 From: lilin Date: Tue, 7 Feb 2023 19:01:06 +0800 Subject: [PATCH 12/12] refine docstring --- mmaction/apis/inferencers/mmaction2_inferencer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmaction/apis/inferencers/mmaction2_inferencer.py b/mmaction/apis/inferencers/mmaction2_inferencer.py index 9cf5f9ce95..0c1b4590de 100644 --- a/mmaction/apis/inferencers/mmaction2_inferencer.py +++ b/mmaction/apis/inferencers/mmaction2_inferencer.py @@ -25,12 +25,12 @@ class MMAction2Inferencer(BaseInferencer): rec (str, optional): Pretrained action recognition algorithm. It's the path to the config file or the model name defined in metafile. For example, it could be: - + - model alias, e.g. ``'slowfast'``, - config name, e.g. ``'slowfast_r50_8xb8-8x8x1-256e_kinetics400 -rgb'``, - config path - + Defaults to ``None``. rec_weights (str, optional): Path to the custom checkpoint file of the selected rec model. If it is not specified and "rec" is a model