diff --git a/demo/demo.gif b/demo/demo.gif deleted file mode 100644 index 3f9953cdcf..0000000000 Binary files a/demo/demo.gif and /dev/null differ diff --git a/demo/demo.py b/demo/demo.py index 32805d08e0..5112d64707 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import os.path as osp +from operator import itemgetter from typing import Optional, Tuple import cv2 @@ -141,12 +142,16 @@ def main(): # Build the recognizer from a config file and checkpoint file/url model = init_recognizer(cfg, args.checkpoint, device=args.device) + result = inference_recognizer(model, args.video) - results = inference_recognizer(model, args.video) + pred_scores = result.pred_scores.item.tolist() + score_tuples = tuple(zip(range(len(pred_scores)), pred_scores)) + score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True) + top5_label = score_sorted[:5] labels = open(args.label).readlines() labels = [x.strip() for x in labels] - results = [(labels[k[0]], k[1]) for k in results] + results = [(labels[k[0]], k[1]) for k in top5_label] print('The top-5 labels with corresponding scores are:') for result in results: diff --git a/demo/demo_out.mp4 b/demo/demo_out.mp4 deleted file mode 100644 index f689f60f70..0000000000 Binary files a/demo/demo_out.mp4 and /dev/null differ diff --git a/demo/demo_skeleton.mp4 b/demo/demo_skeleton.mp4 new file mode 100644 index 0000000000..fa8a76b8ce Binary files /dev/null and b/demo/demo_skeleton.mp4 differ diff --git a/demo/demo_skeleton.py b/demo/demo_skeleton.py new file mode 100644 index 0000000000..33973b3930 --- /dev/null +++ b/demo/demo_skeleton.py @@ -0,0 +1,192 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +import shutil + +import cv2 +import mmcv +import mmengine +import numpy as np +import torch +from mmengine import DictAction +from mmengine.utils import track_iter_progress + +from mmaction.apis import (detection_inference, inference_recognizer, + init_recognizer, pose_inference) +from mmaction.registry import VISUALIZERS +from mmaction.utils import frame_extract, register_all_modules + +try: + import moviepy.editor as mpy +except ImportError: + raise ImportError('Please install moviepy to enable output file') + +FONTFACE = cv2.FONT_HERSHEY_DUPLEX +FONTSCALE = 0.75 +FONTCOLOR = (255, 255, 255) # BGR, white +THICKNESS = 1 +LINETYPE = 1 + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMAction2 demo') + parser.add_argument('video', help='video file/url') + parser.add_argument('out_filename', help='output filename') + parser.add_argument( + '--config', + default=('configs/skeleton/posec3d/' + 'slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py'), + help='skeleton model config file path') + parser.add_argument( + '--checkpoint', + default=('https://download.openmmlab.com/mmaction/skeleton/posec3d/' + 'slowonly_r50_u48_240e_ntu60_xsub_keypoint/' + 'slowonly_r50_u48_240e_ntu60_xsub_keypoint-f3adabf1.pth'), + help='skeleton model checkpoint file/url') + parser.add_argument( + '--det-config', + default='demo/skeleton_demo_cfg/faster-rcnn_r50_fpn_2x_coco_infer.py', + help='human detection config file path (from mmdet)') + parser.add_argument( + '--det-checkpoint', + default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/' + 'faster_rcnn_r50_fpn_2x_coco/' + 'faster_rcnn_r50_fpn_2x_coco_' + 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'), + help='human detection checkpoint file/url') + parser.add_argument( + '--det-score-thr', + type=float, + default=0.9, + help='the threshold of human detection score') + parser.add_argument( + '--det-cat-id', + type=int, + default=0, + help='the category id for human detection') + parser.add_argument( + '--pose-config', + default='demo/skeleton_demo_cfg/' + 'td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py', + help='human pose estimation config file path (from mmpose)') + parser.add_argument( + '--pose-checkpoint', + default=('https://download.openmmlab.com/mmpose/top_down/hrnet/' + 'hrnet_w32_coco_256x192-c78dce93_20200708.pth'), + help='human pose estimation checkpoint file/url') + parser.add_argument( + '--label-map', + default='tools/data/skeleton/label_map_ntu60.txt', + help='label map file') + parser.add_argument( + '--device', type=str, default='cuda:0', help='CPU/CUDA device option') + parser.add_argument( + '--short-side', + type=int, + default=480, + help='specify the short-side length of the image') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. For example, ' + "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") + args = parser.parse_args() + return args + + +def visualize(args, frames, data_samples, action_label): + pose_config = mmengine.Config.fromfile(args.pose_config) + visualizer = VISUALIZERS.build(pose_config.visualizer) + visualizer.set_dataset_meta(data_samples[0].dataset_meta) + + vis_frames = [] + print('Drawing skeleton for each frame') + for d, f in track_iter_progress(list(zip(data_samples, frames))): + f = mmcv.imconvert(f, 'bgr', 'rgb') + visualizer.add_datasample( + 'result', + f, + data_sample=d, + draw_gt=False, + draw_heatmap=False, + draw_bbox=True, + show=False, + wait_time=0, + out_file=None, + kpt_score_thr=0.3) + vis_frame = visualizer.get_image() + cv2.putText(vis_frame, action_label, (10, 30), FONTFACE, FONTSCALE, + FONTCOLOR, THICKNESS, LINETYPE) + vis_frames.append(vis_frame) + + vid = mpy.ImageSequenceClip(vis_frames, fps=24) + vid.write_videofile(args.out_filename, remove_temp=True) + + +def main(): + args = parse_args() + frame_paths, frames = frame_extract(args.video, args.short_side) + + num_frame = len(frame_paths) + h, w, _ = frames[0].shape + + # Get Human detection results. + det_results, _ = detection_inference(args.det_config, args.det_checkpoint, + frame_paths, args.det_score_thr, + args.det_cat_id, args.device) + torch.cuda.empty_cache() + + # Get Pose estimation results. + pose_results, pose_data_samples = pose_inference(args.pose_config, + args.pose_checkpoint, + frame_paths, det_results, + args.device) + torch.cuda.empty_cache() + + fake_anno = dict( + frame_dir='', + label=-1, + img_shape=(h, w), + original_shape=(h, w), + start_index=0, + modality='Pose', + total_frames=num_frame) + num_person = max([len(x['keypoints']) for x in pose_results]) + + num_keypoint = 17 + keypoint = np.zeros((num_frame, num_person, num_keypoint, 2), + dtype=np.float16) + keypoint_score = np.zeros((num_frame, num_person, num_keypoint), + dtype=np.float16) + for i, poses in enumerate(pose_results): + keypoint[i] = poses['keypoints'] + keypoint_score[i] = poses['keypoint_scores'] + + fake_anno['keypoint'] = keypoint.transpose((1, 0, 2, 3)) + fake_anno['keypoint_score'] = keypoint_score.transpose((1, 0, 2)) + + register_all_modules() + config = mmengine.Config.fromfile(args.config) + config.merge_from_dict(args.cfg_options) + if 'data_preprocessor' in config.model: + config.model.data_preprocessor['mean'] = (w // 2, h // 2, .5) + config.model.data_preprocessor['std'] = (w, h, 1.) + + model = init_recognizer(config, args.checkpoint, args.device) + result = inference_recognizer(model, fake_anno) + + max_pred_index = result.pred_scores.item.argmax().item() + label_map = [x.strip() for x in open(args.label_map).readlines()] + action_label = label_map[max_pred_index] + + visualize(args, frames, pose_data_samples, action_label) + + tmp_frame_dir = osp.dirname(frame_paths[0]) + shutil.rmtree(tmp_frame_dir) + + +if __name__ == '__main__': + main() diff --git a/demo/skeleton_demo_cfg/faster-rcnn_r50_fpn_2x_coco_infer.py b/demo/skeleton_demo_cfg/faster-rcnn_r50_fpn_2x_coco_infer.py new file mode 100644 index 0000000000..c26a25e048 --- /dev/null +++ b/demo/skeleton_demo_cfg/faster-rcnn_r50_fpn_2x_coco_infer.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# model settings +model = dict( + type='FasterRCNN', + _scope_='mmdet', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100))) + +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +file_client_args = dict(backend='disk') +test_pipeline = [ + dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args), + dict(type='mmdet.Resize', scale=(1333, 800), keep_ratio=True), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +test_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline)) diff --git a/demo/skeleton_demo_cfg/td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py b/demo/skeleton_demo_cfg/td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py new file mode 100644 index 0000000000..fb4256ff9b --- /dev/null +++ b/demo/skeleton_demo_cfg/td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# codec settings +codec = dict( + type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2) + +# model settings +model = dict( + type='TopdownPoseEstimator', + _scope_='mmpose', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='HRNet', + in_channels=3, + extra=dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(32, 64)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(32, 64, 128)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(32, 64, 128, 256))), + init_cfg=dict( + type='Pretrained', + checkpoint='https://download.openmmlab.com/mmpose' + '/pretrain_models/hrnet_w32-36af842e.pth'), + ), + head=dict( + type='HeatmapHead', + in_channels=32, + out_channels=17, + deconv_out_channels=None, + loss=dict(type='KeypointMSELoss', use_target_weight=True), + decoder=codec), + test_cfg=dict( + flip_test=True, + flip_mode='heatmap', + shift_heatmap=True, + )) + +# dataset settings +dataset_type = 'CocoDataset' +data_mode = 'topdown' +data_root = 'data/coco/' + +file_client_args = dict(backend='disk') +test_pipeline = [ + dict(type='mmpose.LoadImage', file_client_args=file_client_args), + dict(type='mmpose.GetBBoxCenterScale'), + dict(type='mmpose.TopdownAffine', input_size=codec['input_size']), + dict(type='mmpose.PackPoseInputs') +] +test_dataloader = dict( + batch_size=32, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_val2017.json', + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + )) + +# visualizer +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='mmpose.PoseLocalVisualizer', + vis_backends=vis_backends, + name='visualizer') diff --git a/docs/en/user_guides/3_inference.md b/docs/en/user_guides/3_inference.md index 934c539bf6..8a603b66ee 100644 --- a/docs/en/user_guides/3_inference.md +++ b/docs/en/user_guides/3_inference.md @@ -1,8 +1,8 @@ # Tutorial 3: Inference with existing models -## Inference with Action Recognition Models +## Inference with RGB-based Action Recognition Models -MMAction2 provides an inference script to predict the recognition result using a single video. In order to get predict results in range `[0, 1]`, make sure to set `model['cls_head']['average_clips'] = 'prob'` in config file. +MMAction2 provides an inference script to predict the recognition result using a single video. In order to get predict results in range `[0, 1]`, make sure to set `model.cls_head.average_clips = 'prob'` in config file. ```shell python demo/demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} ${VIDEO_FILE} ${LABEL_FILE} \ @@ -12,10 +12,10 @@ python demo/demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} ${VIDEO_FILE} ${LABEL_FILE Optional arguments: -- `DEVICE_TYPE`: Type of device to run the demo. Allowed values are cuda device like `cuda:0` or `cpu`. If not specified, it will be set to `cuda:0`. -- `FPS`: FPS value of the output video. If not specified, it will be set to 30. -- `FONT_SCALE`: Font scale of the label added in the video. If not specified, it will be 0.5. -- `FONT_COLOR`: Font color of the label added in the video. If not specified, it will be `white`. +- `DEVICE_TYPE`: Type of device to run the demo. Allowed values are cuda device like `'cuda:0'` or `'cpu'`. Defaults to `'cuda:0'`. +- `FPS`: FPS value of the output video. Defaults to 30. +- `FONT_SCALE`: Font scale of the label added in the video. Defaults to 0.5. +- `FONT_COLOR`: Font color of the label added in the video. Defaults to `'white'`. - `TARGET_RESOLUTION`: Resolution(desired_width, desired_height) for resizing the frames before output when using a video as input. If not specified, it will be None and the frames are resized by keeping the existing aspect ratio. - `OUT_FILE`: Path to the output file which can be a video format or gif format. If not specified, it will be set to `None` and does not generate the output file. @@ -41,3 +41,79 @@ or use checkpoint url from to directly load corresponding checkpoint, which will https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth \ demo/demo.mp4 tools/data/kinetics/label_map_k400.txt ``` + +3. Recognize a video file as input by using a TSN model and then generate an mp4 file. + + ```shell + # The demo.mp4 and label_map_k400.txt are both from Kinetics-400 + python demo/demo.py configs/recognition/tsn/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb.py \ + checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth \ + demo/demo.mp4 tools/data/kinetics/label_map_k400.txt --out-filename demo/demo_out.mp4 + ``` + +## Inference with Skeleton-based Action Recognition Models + +MMAction2 provides an inference script to predict the skeleton-based action recognition result using a single video. + +```shell +python demo/demo_skeleton.py ${VIDEO_FILE} ${OUT_FILENAME} \ + [--config ${SKELETON_BASED_ACTION_RECOGNITION_CONFIG_FILE}] \ + [--checkpoint ${SKELETON_BASED_ACTION_RECOGNITION_CHECKPOINT}] \ + [--det-config ${HUMAN_DETECTION_CONFIG_FILE}] \ + [--det-checkpoint ${HUMAN_DETECTION_CHECKPOINT}] \ + [--det-score-thr ${HUMAN_DETECTION_SCORE_THRESHOLD}] \ + [--det-cat-id ${HUMAN_DETECTION_CATEGORY_ID}] \ + [--pose-config ${HUMAN_POSE_ESTIMATION_CONFIG_FILE}] \ + [--pose-checkpoint ${HUMAN_POSE_ESTIMATION_CHECKPOINT}] \ + [--label-map ${LABEL_MAP}] \ + [--device ${DEVICE}] \ + [--short-side] ${SHORT_SIDE} +``` + +Optional arguments: + +- `SKELETON_BASED_ACTION_RECOGNITION_CONFIG_FILE`: The skeleton-based action recognition config file path. +- `SKELETON_BASED_ACTION_RECOGNITION_CHECKPOINT`: The skeleton-based action recognition checkpoint path or url. +- `HUMAN_DETECTION_CONFIG_FILE`: The human detection config file path. +- `HUMAN_DETECTION_CHECKPOINT`: The human detection checkpoint path or url. +- `HUMAN_DETECTION_SCORE_THRE`: The score threshold for human detection. Defaults to 0.9. +- `HUMAN_DETECTION_CATEGORY_ID`: The category id for human detection. Defaults to 0. +- `HUMAN_POSE_ESTIMATION_CONFIG_FILE`: The human pose estimation config file path (trained on COCO-Keypoint). +- `HUMAN_POSE_ESTIMATION_CHECKPOINT`: The human pose estimation checkpoint path or url (trained on COCO-Keypoint). +- `LABEL_MAP`: The label map used. Defaults to `'tools/data/skeleton/label_map_ntu60.txt'`. +- `DEVICE`: Type of device to run the demo. Allowed values are cuda device like `'cuda:0'` or `'cpu'`. Defaults to `'cuda:0'`. +- `SHORT_SIDE`: The short side used for frame extraction. Defaults to 480. + +Examples: + +Assume that you are located at `$MMACTION2` . + +1. Use the Faster-RCNN as the human detector, HRNetw32 as the pose estimator, PoseC3D-NTURGB+D-60-XSub-Keypoint as the skeleton-based action recognizer. + +```shell +python demo/demo_skeleton.py demo/demo_skeleton.mp4 demo/demo_skeleton_out.mp4 \ + --config configs/skeleton/posec3d/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py \ + --checkpoint https://download.openmmlab.com/mmaction/skeleton/posec3d/slowonly_r50_u48_240e_ntu60_xsub_keypoint/slowonly_r50_u48_240e_ntu60_xsub_keypoint-f3adabf1.pth \ + --det-config demo/skeleton_demo_cfg/faster-rcnn_r50_fpn_2x_coco_infer.py \ + --det-checkpoint http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth \ + --det-score-thr 0.9 \ + --det-cat-id 0 \ + --pose-config demo/skeleton_demo_cfg/td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py \ + --pose-checkpoint https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth \ + --label-map tools/data/skeleton/label_map_ntu60.txt +``` + +2. Use the Faster-RCNN as the human detector, HRNetw32 as the pose estimator, STGCN-NTURGB+D-60-XSub-Keypoint as the skeleton-based action recognizer. + +```shell +python demo/demo_skeleton.py demo/demo_skeleton.mp4 demo/demo_skeleton_out.mp4 \ + --config configs/skeleton/stgcn/stgcn_1xb16-80e_ntu60-xsub-keypoint.py \ + --checkpoint https://download.openmmlab.com/mmaction/skeleton/stgcn/stgcn_80e_ntu60_xsub_keypoint/stgcn_80e_ntu60_xsub_keypoint-e7bb9653.pth \ + --det-config demo/skeleton_demo_cfg/faster-rcnn_r50_fpn_2x_coco_infer.py \ + --det-checkpoint http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth \ + --det-score-thr 0.9 \ + --det-cat-id 0 \ + --pose-config demo/skeleton_demo_cfg/td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py \ + --pose-checkpoint https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth \ + --label-map tools/data/skeleton/label_map_ntu60.txt +``` diff --git a/mmaction/apis/__init__.py b/mmaction/apis/__init__.py index 5b2c434d00..110cbe9464 100644 --- a/mmaction/apis/__init__.py +++ b/mmaction/apis/__init__.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .inference import inference_recognizer, init_recognizer +from .inference import (detection_inference, inference_recognizer, + init_recognizer, pose_inference) -__all__ = ['init_recognizer', 'inference_recognizer'] +__all__ = [ + 'init_recognizer', 'inference_recognizer', 'detection_inference', + 'pose_inference' +] diff --git a/mmaction/apis/inference.py b/mmaction/apis/inference.py index 5fb58b8a68..0e3ab2ddac 100644 --- a/mmaction/apis/inference.py +++ b/mmaction/apis/inference.py @@ -1,39 +1,41 @@ # Copyright (c) OpenMMLab. All rights reserved. -from operator import itemgetter -from typing import List, Optional, Tuple, Union +from pathlib import Path +from typing import List, Optional, Union import mmengine +import numpy as np import torch import torch.nn as nn from mmengine.dataset import Compose, pseudo_collate from mmengine.runner import load_checkpoint +from mmengine.utils import track_iter_progress from mmaction.registry import MODELS +from mmaction.structures import ActionDataSample -def init_recognizer(config: Union[str, mmengine.Config], +def init_recognizer(config: Union[str, Path, mmengine.Config], checkpoint: Optional[str] = None, device: Union[str, torch.device] = 'cuda:0') -> nn.Module: """Initialize a recognizer from config file. Args: - config (Union[str, mmengine.Config]): Config file path or the config - object. + config (Union[str, :obj:`Path`, :obj:`mmengine.Config`]): Config file + path, :obj:`Path` or the config object. checkpoint (str, optional): Checkpoint path/url. If set to None, the model will not load any weights. Defaults to None. - device (Union[str, ``torch.device``]): The desired device of returned - tensor. Defaults to ``cuda:0``. + device (Union[str, torch.device]): The desired device of returned + tensor. Defaults to ``'cuda:0'``. Returns: nn.Module: The constructed recognizer. """ - if isinstance(config, str): + if isinstance(config, (str, Path)): config = mmengine.Config.fromfile(config) elif not isinstance(config, mmengine.Config): raise TypeError('config must be a filename or Config object, ' f'but got {type(config)}') - # pretrained model is unnecessary since we directly load checkpoint later config.model.backbone.pretrained = None model = MODELS.build(config.model) @@ -46,39 +48,145 @@ def init_recognizer(config: Union[str, mmengine.Config], def inference_recognizer(model: nn.Module, - video: str) -> List[Tuple[int, float]]: + video: Union[str, dict], + test_pipeline: Optional[Compose] = None + ) -> ActionDataSample: """Inference a video with the recognizer. Args: model (nn.Module): The loaded recognizer. - video (str): The video file path. + video (Union[str, dict]): The video file path or the results + dictionary (the input of pipeline). + test_pipeline (:obj:`Compose`, optional): The test pipeline. + If not specified, the test pipeline in the config will be + used. Defaults to None. Returns: - List[Tuple(int, float)]: Top-5 recognition result dict. + :obj:`ActionDataSample`: The inference results. Specifically, the + predicted scores are saved at ``result.pred_scores.item``. """ - cfg = model.cfg - # Build the data pipeline - val_pipeline_cfg = cfg.val_dataloader.dataset.pipeline - if 'Init' not in val_pipeline_cfg[0]['type']: - val_pipeline_cfg = [dict(type='OpenCVInit')] + val_pipeline_cfg + if test_pipeline is None: + cfg = model.cfg + test_pipeline_cfg = cfg.test_pipeline + test_pipeline = Compose(test_pipeline_cfg) + + input_flag = None + if isinstance(video, dict): + input_flag = 'dict' + elif isinstance(video, str): + input_flag = 'video' else: - val_pipeline_cfg[0] = dict(type='OpenCVInit') - for i in range(len(val_pipeline_cfg)): - if 'Decode' in val_pipeline_cfg[i]['type']: - val_pipeline_cfg[i] = dict(type='OpenCVDecode') - val_pipeline = Compose(val_pipeline_cfg) - - # Prepare & process inputs - data = dict(filename=video, label=-1, start_index=0, modality='RGB') - data = val_pipeline(data) + raise RuntimeError(f'The type of argument `video` is not supported: ' + f'{type(video)}') + + if input_flag == 'dict': + data = video + if input_flag == 'video': + data = dict(filename=video, label=-1, start_index=0, modality='RGB') + + data = test_pipeline(data) data = pseudo_collate([data]) # Forward the model with torch.no_grad(): - pred_scores = model.val_step(data)[0].pred_scores.item.tolist() - score_tuples = tuple(zip(range(len(pred_scores)), pred_scores)) - score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True) - top5_label = score_sorted[:5] + result = model.test_step(data)[0] + + return result + + +def detection_inference(det_config: Union[str, Path, mmengine.Config], + det_checkpoint: str, + frame_paths: List[str], + det_score_thr: float = 0.9, + det_cat_id: int = 0, + device: Union[str, torch.device] = 'cuda:0') -> tuple: + """Detect human boxes given frame paths. + + Args: + det_config (Union[str, :obj:`Path`, :obj:`mmengine.Config`]): Config + file path, :obj:`Path` or the config object. + det_checkpoint: Checkpoint path/url. + frame_paths (List[str]): The paths of frames to do detection inference. + det_score_thr (float): The threshold of human detection score. + Defaults to 0.9. + det_cat_id (int): The category id for human detection. Defaults to 0. + device (Union[str, torch.device]): The desired device of returned + tensor. Defaults to ``'cuda:0'``. - return top5_label + Returns: + List[np.ndarray]: List of detected human boxes. + List[:obj:`DetDataSample`]: List of data samples, generally used + to visualize data. + """ + try: + from mmdet.apis import inference_detector, init_detector + from mmdet.structures import DetDataSample + except (ImportError, ModuleNotFoundError): + raise ImportError('Failed to import `inference_detector` and ' + '`init_detector` from `mmdet.apis`. These apis are ' + 'required in this inference api! ') + + model = init_detector(det_config, det_checkpoint, device) + + results = [] + data_samples = [] + print('Performing Human Detection for each frame') + for frame_path in track_iter_progress(frame_paths): + det_data_sample: DetDataSample = inference_detector(model, frame_path) + pred_instance = det_data_sample.pred_instances.cpu().numpy() + bboxes = pred_instance.bboxes + # We only keep human detection bboxs with score larger + # than `det_score_thr` and category id equal to `det_cat_id`. + bboxes = bboxes[np.logical_and(pred_instance.labels == det_cat_id, + pred_instance.scores > det_score_thr)] + results.append(bboxes) + data_samples.append(det_data_sample) + + return results, data_samples + + +def pose_inference(pose_config: Union[str, Path, mmengine.Config], + pose_checkpoint: str, + frame_paths: List[str], + det_results: List[np.ndarray], + device: Union[str, torch.device] = 'cuda:0') -> tuple: + """Perform Top-Down pose estimation. + + Args: + pose_config (Union[str, :obj:`Path`, :obj:`mmengine.Config`]): Config + file path, :obj:`Path` or the config object. + pose_checkpoint: Checkpoint path/url. + frame_paths (List[str]): The paths of frames to do pose inference. + det_results (List[np.ndarray]): List of detected human boxes. + device (Union[str, torch.device]): The desired device of returned + tensor. Defaults to ``'cuda:0'``. + + Returns: + List[List[Dict[str, np.ndarray]]]: List of pose estimation results. + List[:obj:`PoseDataSample`]: List of data samples, generally used + to visualize data. + """ + try: + from mmpose.apis import inference_topdown, init_model + from mmpose.structures import PoseDataSample, merge_data_samples + except (ImportError, ModuleNotFoundError): + raise ImportError('Failed to import `inference_topdown` and ' + '`init_model` from `mmpose.apis`. These apis ' + 'are required in this inference api! ') + + model = init_model(pose_config, pose_checkpoint, device) + + results = [] + data_samples = [] + print('Performing Human Pose Estimation for each frame') + for f, d in track_iter_progress(list(zip(frame_paths, det_results))): + pose_data_samples: List[PoseDataSample] \ + = inference_topdown(model, f, d, bbox_format='xyxy') + pose_data_sample = merge_data_samples(pose_data_samples) + pose_data_sample.dataset_meta = model.dataset_meta + poses = pose_data_sample.pred_instances.to_dict() + results.append(poses) + data_samples.append(pose_data_sample) + + return results, data_samples diff --git a/mmaction/datasets/transforms/formatting.py b/mmaction/datasets/transforms/formatting.py index 3274eefa16..1608391d45 100644 --- a/mmaction/datasets/transforms/formatting.py +++ b/mmaction/datasets/transforms/formatting.py @@ -357,23 +357,33 @@ def __repr__(self): @TRANSFORMS.register_module() class FormatGCNInput(BaseTransform): - """Format final skeleton shape to the given input_format. + """Format final skeleton shape to the given ``input_format``. - Required keys are "keypoint" and "keypoint_score"(optional), - added or modified keys are "keypoint" and "input_shape". + Required Keys: + + - keypoint + - keypoint_score (optional) + + Modified Key: + + - keypoint + + Added Key: + + - input_shape Args: input_format (str): Define the final skeleton format. """ - def __init__(self, input_format, num_person=2): + def __init__(self, input_format: str, num_person: int = 2) -> None: self.input_format = input_format if self.input_format not in ['NCTVM']: raise ValueError( f'The input format {self.input_format} is invalid.') self.num_person = num_person - def transform(self, results): + def transform(self, results: dict) -> dict: """Performs the FormatShape formatting. Args: diff --git a/mmaction/datasets/transforms/loading.py b/mmaction/datasets/transforms/loading.py index 1518cbec89..427271d6da 100644 --- a/mmaction/datasets/transforms/loading.py +++ b/mmaction/datasets/transforms/loading.py @@ -868,12 +868,12 @@ def __repr__(self): class OpenCVInit(BaseTransform): """Using OpenCV to initialize the video_reader. - Required keys are ``filename``, added or modified keys are ``new_path``, - ``video_reader`` and ``total_frames``. + Required keys are ``'filename'``, added or modified keys are ` + `'new_path'``, ``'video_reader'`` and ``'total_frames'``. Args: io_backend (str): io backend where frames are store. - Defaults to ``disk``. + Defaults to ``'disk'``. """ def __init__(self, io_backend: str = 'disk', **kwargs) -> None: @@ -928,8 +928,9 @@ def __repr__(self): class OpenCVDecode(BaseTransform): """Using OpenCV to decode the video. - Required keys are ``video_reader``, ``filename`` and ``frame_inds``, added - or modified keys are ``imgs``, ``img_shape`` and ``original_shape``. + Required keys are ``'video_reader'``, ``'filename'`` and ``'frame_inds'``, + added or modified keys are ``'imgs'``, ``'img_shape'`` and + ``'original_shape'``. """ def transform(self, results: dict) -> dict: @@ -970,23 +971,37 @@ def transform(self, results: dict) -> dict: class RawFrameDecode(BaseTransform): """Load and decode frames with given indices. - Required keys are "frame_dir", "filename_tmpl" and "frame_inds", - added or modified keys are "imgs", "img_shape" and "original_shape". + Required Keys: + + - frame_dir + - filename_tmpl + - frame_inds + - modality + - offset (optional) + + Added Keys: + + - img + - img_shape + - original_shape Args: - io_backend (str): IO backend where frames are stored. Default: 'disk'. + io_backend (str): IO backend where frames are stored. + Defaults to ``'disk'``. decoding_backend (str): Backend used for image decoding. - Default: 'cv2'. - kwargs (dict, optional): Arguments for FileClient. + Defaults to ``'cv2'``. """ - def __init__(self, io_backend='disk', decoding_backend='cv2', **kwargs): + def __init__(self, + io_backend: str = 'disk', + decoding_backend: str = 'cv2', + **kwargs) -> None: self.io_backend = io_backend self.decoding_backend = decoding_backend self.kwargs = kwargs self.file_client = None - def transform(self, results): + def transform(self, results: dict) -> dict: """Perform the ``RawFrameDecode`` to pick frames given indices. Args: diff --git a/mmaction/utils/__init__.py b/mmaction/utils/__init__.py index 5c7b0abd29..02ac88b015 100644 --- a/mmaction/utils/__init__.py +++ b/mmaction/utils/__init__.py @@ -1,16 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .collect_env import collect_env from .gradcam_utils import GradCAM -from .misc import get_random_string, get_shm_dir, get_thread_id +from .misc import frame_extract, get_random_string, get_shm_dir, get_thread_id from .setup_env import register_all_modules -from .typing import (ConfigType, ForwardResults, InstanceList, LabelList, - MultiConfig, OptConfigType, OptInstanceList, OptLabelList, - OptMultiConfig, OptSampleList, SampleList, SamplingResult) +from .typing import * # noqa: F401,F403 __all__ = [ 'collect_env', 'get_random_string', 'get_thread_id', 'get_shm_dir', - 'GradCAM', 'register_all_modules', 'ConfigType', 'OptConfigType', - 'MultiConfig', 'OptMultiConfig', 'InstanceList', 'OptInstanceList', - 'SampleList', 'OptSampleList', 'ForwardResults', 'LabelList', - 'OptLabelList', 'SamplingResult' + 'frame_extract', 'GradCAM', 'register_all_modules' ] diff --git a/mmaction/utils/misc.py b/mmaction/utils/misc.py index e9b5f37feb..374d62dcf3 100644 --- a/mmaction/utils/misc.py +++ b/mmaction/utils/misc.py @@ -1,8 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import ctypes +import os +import os.path as osp import random import string +import cv2 +import mmcv +import numpy as np + def get_random_string(length: int = 15): """Get random string with letters and digits. @@ -25,3 +31,39 @@ def get_thread_id(): def get_shm_dir(): """Get shm dir for temporary usage.""" return '/dev/shm' + + +def frame_extract(video_path: str, short_side: int): + """Extract frames given video_path. + + Args: + video_path (str): The video path. + short_side (int): The short-side of the image. + """ + # Load the video, extract frames into ./tmp/video_name + target_dir = osp.join('./tmp', osp.basename(osp.splitext(video_path)[0])) + os.makedirs(target_dir, exist_ok=True) + # Should be able to handle videos up to several hours + frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg') + vid = cv2.VideoCapture(video_path) + frames = [] + frame_paths = [] + flag, frame = vid.read() + cnt = 0 + new_h, new_w = None, None + while flag: + if new_h is None: + h, w, _ = frame.shape + new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf)) + + frame = mmcv.imresize(frame, (new_w, new_h)) + + frames.append(frame) + frame_path = frame_tmpl.format(cnt + 1) + frame_paths.append(frame_path) + + cv2.imwrite(frame_path, frame) + cnt += 1 + flag, frame = vid.read() + + return frame_paths, frames diff --git a/requirements/tests.txt b/requirements/tests.txt index 43f7b16d22..32d3e95cc6 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -2,6 +2,7 @@ coverage flake8 interrogate isort==4.3.21 +parameterized pytest pytest-runner xdoctest >= 0.10.0 diff --git a/tests/apis/test_inference.py b/tests/apis/test_inference.py index 33f0038b22..883132aec5 100644 --- a/tests/apis/test_inference.py +++ b/tests/apis/test_inference.py @@ -1,54 +1,69 @@ # Copyright (c) OpenMMLab. All rights reserved. -import pytest +import os.path as osp +from pathlib import Path +from unittest import TestCase + import torch -import torch.nn as nn +from parameterized import parameterized from mmaction.apis import inference_recognizer, init_recognizer +from mmaction.structures import ActionDataSample from mmaction.utils import register_all_modules -video_config_file = 'configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py' # noqa: E501 -video_path = 'demo/demo.mp4' -register_all_modules() +class TestInference(TestCase): + + def setUp(self): + register_all_modules() + @parameterized.expand([(('configs/recognition/tsn/' + 'tsn_imagenet-pretrained-r50_8xb32-' + '1x1x3-100e_kinetics400-rgb.py'), ('cpu', 'cuda')) + ]) + def test_init_recognizer(self, config, devices): + project_dir = osp.abspath(osp.dirname(osp.dirname(__file__))) + project_dir = osp.join(project_dir, '..') + config_file = osp.join(project_dir, config) -def test_init_recognizer(): - with pytest.raises(TypeError): - # config must be a filename or Config object - init_recognizer(dict(config_file=None)) + for device in devices: + if device == 'cuda' and not torch.cuda.is_available(): + # Skip the test if cuda is required but unavailable + continue - if torch.cuda.is_available(): - device = 'cuda:0' - else: - device = 'cpu' + # test `init_recognizer` with str path + _ = init_recognizer(config_file, device=device) - model = init_recognizer(video_config_file, None, device) + # test `init_recognizer` with :obj:`Path` + _ = init_recognizer(Path(config_file), device=device) - isinstance(model, nn.Module) - if torch.cuda.is_available(): - assert next(model.parameters()).is_cuda is True - else: - assert next(model.parameters()).is_cuda is False - assert model.cfg.model.backbone.pretrained is None + # test `init_recognizer` with undesirable type + with self.assertRaisesRegex( + TypeError, 'config must be a filename or Config object'): + config_list = [config_file] + _ = init_recognizer(config_list) + @parameterized.expand([(('configs/recognition/tsn/' + 'tsn_imagenet-pretrained-r50_8xb32-' + '1x1x3-100e_kinetics400-rgb.py'), 'demo/demo.mp4', + ('cpu', 'cuda'))]) + def test_inference_recognizer(self, config, video_path, devices): + project_dir = osp.abspath(osp.dirname(osp.dirname(__file__))) + project_dir = osp.join(project_dir, '..') + config_file = osp.join(project_dir, config) + video_path = osp.join(project_dir, video_path) -def test_video_inference_recognizer(): - if torch.cuda.is_available(): - device = 'cuda:0' - else: - device = 'cpu' - model = init_recognizer(video_config_file, None, device) + for device in devices: + if device == 'cuda' and not torch.cuda.is_available(): + # Skip the test if cuda is required but unavailable + continue + model = init_recognizer(config_file, device=device) - with pytest.raises(FileNotFoundError): - # video path doesn't exist - inference_recognizer(model, 'missing.mp4') + for ops in model.cfg.test_pipeline: + if ops['type'] in ('TenCrop', 'ThreeCrop'): + # Use CenterCrop to reduce memory in order to pass CI + ops['type'] = 'CenterCrop' - for ops in model.cfg.test_pipeline: - if ops['type'] in ('TenCrop', 'ThreeCrop'): - # Use CenterCrop to reduce memory in order to pass CI - ops['type'] = 'CenterCrop' + result = inference_recognizer(model, video_path) - top5_label = inference_recognizer(model, video_path) - scores = [item[1] for item in top5_label] - assert len(top5_label) == 5 - assert scores == sorted(scores, reverse=True) + self.assertIsInstance(result, ActionDataSample) + self.assertTrue(result.pred_scores.item.shape, (400, )) diff --git a/tools/data/skeleton/label_map_ntu120.txt b/tools/data/skeleton/label_map_ntu60.txt similarity index 50% rename from tools/data/skeleton/label_map_ntu120.txt rename to tools/data/skeleton/label_map_ntu60.txt index 69826dfebf..d41ce05462 100644 --- a/tools/data/skeleton/label_map_ntu120.txt +++ b/tools/data/skeleton/label_map_ntu60.txt @@ -58,63 +58,3 @@ touch other person's pocket handshaking walking towards each other walking apart from each other -put on headphone -take off headphone -shoot at the basket -bounce ball -tennis bat swing -juggling table tennis balls -hush (quite) -flick hair -thumb up -thumb down -make ok sign -make victory sign -staple book -counting money -cutting nails -cutting paper (using scissors) -snapping fingers -open bottle -sniff (smell) -squat down -toss a coin -fold paper -ball up paper -play magic cube -apply cream on face -apply cream on hand back -put on bag -take off bag -put something into a bag -take something out of a bag -open a box -move heavy objects -shake fist -throw up cap/hat -hands up (both hands) -cross arms -arm circles -arm swings -running on the spot -butt kicks (kick backward) -cross toe touch -side kick -yawn -stretch oneself -blow nose -hit other person with something -wield knife towards other person -knock over other person (hit with body) -grab other person’s stuff -shoot at other person with a gun -step on foot -high-five -cheers and drink -carry something with other person -take a photo of other person -follow other person -whisper in other person’s ear -exchange things with other person -support somebody with hand -finger-guessing game (playing rock-paper-scissors)