Skip to content

Commit

Permalink
Merge c5d6ea1 into a15b934
Browse files Browse the repository at this point in the history
  • Loading branch information
Dai-Wenxun authored Sep 15, 2022
2 parents a15b934 + c5d6ea1 commit 881ce57
Show file tree
Hide file tree
Showing 16 changed files with 792 additions and 163 deletions.
9 changes: 7 additions & 2 deletions demo/demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from operator import itemgetter
from typing import Optional, Tuple

import cv2
Expand Down Expand Up @@ -141,12 +142,16 @@ def main():

# Build the recognizer from a config file and checkpoint file/url
model = init_recognizer(cfg, args.checkpoint, device=args.device)
result = inference_recognizer(model, args.video)

results = inference_recognizer(model, args.video)
pred_scores = result.pred_scores.item.tolist()
score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))
score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
top5_label = score_sorted[:5]

labels = open(args.label).readlines()
labels = [x.strip() for x in labels]
results = [(labels[k[0]], k[1]) for k in results]
results = [(labels[k[0]], k[1]) for k in top5_label]

print('The top-5 labels with corresponding scores are:')
for result in results:
Expand Down
Binary file added demo/demo_skeleton.avi
Binary file not shown.
192 changes: 192 additions & 0 deletions demo/demo_skeleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import shutil

import cv2
import mmcv
import mmengine
import numpy as np
import torch
from mmengine import DictAction
from mmengine.utils import track_iter_progress

from mmaction.apis import (detection_inference, inference_recognizer,
init_recognizer, pose_inference)
from mmaction.registry import VISUALIZERS
from mmaction.utils import frame_extract, register_all_modules

try:
import moviepy.editor as mpy
except ImportError:
raise ImportError('Please install moviepy to enable output file')

FONTFACE = cv2.FONT_HERSHEY_DUPLEX
FONTSCALE = 0.75
FONTCOLOR = (255, 255, 255) # BGR, white
THICKNESS = 1
LINETYPE = 1


def parse_args():
parser = argparse.ArgumentParser(description='MMAction2 demo')
parser.add_argument('video', help='video file/url')
parser.add_argument('out_filename', help='output filename')
parser.add_argument(
'--config',
default=('configs/skeleton/posec3d/'
'slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py'),
help='skeleton model config file path')
parser.add_argument(
'--checkpoint',
default=('https://download.openmmlab.com/mmaction/skeleton/posec3d/'
'slowonly_r50_u48_240e_ntu60_xsub_keypoint/'
'slowonly_r50_u48_240e_ntu60_xsub_keypoint-f3adabf1.pth'),
help='skeleton model checkpoint file/url')
parser.add_argument(
'--det-config',
default='demo/skeleton_demo_cfg/faster-rcnn_r50_fpn_2x_coco_infer.py',
help='human detection config file path (from mmdet)')
parser.add_argument(
'--det-checkpoint',
default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/'
'faster_rcnn_r50_fpn_2x_coco/'
'faster_rcnn_r50_fpn_2x_coco_'
'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),
help='human detection checkpoint file/url')
parser.add_argument(
'--det-score-thr',
type=float,
default=0.9,
help='the threshold of human detection score')
parser.add_argument(
'--det-cat-id',
type=int,
default=0,
help='the category id for human detection')
parser.add_argument(
'--pose-config',
default='demo/skeleton_demo_cfg/'
'td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py',
help='human pose estimation config file path (from mmpose)')
parser.add_argument(
'--pose-checkpoint',
default=('https://download.openmmlab.com/mmpose/top_down/hrnet/'
'hrnet_w32_coco_256x192-c78dce93_20200708.pth'),
help='human pose estimation checkpoint file/url')
parser.add_argument(
'--label-map',
default='tools/data/skeleton/label_map_ntu60.txt',
help='label map file')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
parser.add_argument(
'--short-side',
type=int,
default=480,
help='specify the short-side length of the image')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
default={},
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. For example, '
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
args = parser.parse_args()
return args


def visualize(args, frames, data_samples, action_label):
pose_config = mmengine.Config.fromfile(args.pose_config)
visualizer = VISUALIZERS.build(pose_config.visualizer)
visualizer.set_dataset_meta(data_samples[0].dataset_meta)

vis_frames = []
print('Drawing skeleton for each frame')
for d, f in track_iter_progress(list(zip(data_samples, frames))):
f = mmcv.imconvert(f, 'bgr', 'rgb')
visualizer.add_datasample(
'result',
f,
data_sample=d,
draw_gt=False,
draw_heatmap=False,
draw_bbox=False,
show=False,
wait_time=0,
out_file=None,
kpt_score_thr=0.3)
vis_frame = visualizer.get_image()
cv2.putText(vis_frame, action_label, (10, 30), FONTFACE, FONTSCALE,
FONTCOLOR, THICKNESS, LINETYPE)
vis_frames.append(vis_frame)

vid = mpy.ImageSequenceClip(vis_frames, fps=24)
vid.write_videofile(args.out_filename, remove_temp=True)


def main():
args = parse_args()
frame_paths, frames = frame_extract(args.video, args.short_side)

num_frame = len(frame_paths)
h, w, _ = frames[0].shape

# Get Human detection results.
det_results, _ = detection_inference(args.det_config, args.det_checkpoint,
frame_paths, args.det_score_thr,
args.det_cat_id, args.device)
torch.cuda.empty_cache()

# Get Pose estimation results.
pose_results, pose_data_samples = pose_inference(args.pose_config,
args.pose_checkpoint,
frame_paths, det_results,
args.device)
torch.cuda.empty_cache()

fake_anno = dict(
frame_dir='',
label=-1,
img_shape=(h, w),
original_shape=(h, w),
start_index=0,
modality='Pose',
total_frames=num_frame)
num_person = max([len(x['keypoints']) for x in pose_results])

num_keypoint = 17
keypoint = np.zeros((num_frame, num_person, num_keypoint, 2),
dtype=np.float16)
keypoint_score = np.zeros((num_frame, num_person, num_keypoint),
dtype=np.float16)
for i, poses in enumerate(pose_results):
keypoint[i] = poses['keypoints']
keypoint_score[i] = poses['keypoint_scores']

fake_anno['keypoint'] = keypoint.transpose((1, 0, 2, 3))
fake_anno['keypoint_score'] = keypoint_score.transpose((1, 0, 2))

register_all_modules()
config = mmengine.Config.fromfile(args.config)
config.merge_from_dict(args.cfg_options)
if 'data_preprocessor' in config.model:
config.model.data_preprocessor['mean'] = (w // 2, h // 2, .5)
config.model.data_preprocessor['std'] = (w, h, 1.)

model = init_recognizer(config, args.checkpoint, args.device)
result = inference_recognizer(model, fake_anno)

max_pred_index = result.pred_scores.item.argmax().item()
label_map = [x.strip() for x in open(args.label_map).readlines()]
action_label = label_map[max_pred_index]

visualize(args, frames, pose_data_samples, action_label)

tmp_frame_dir = osp.dirname(frame_paths[0])
shutil.rmtree(tmp_frame_dir)


if __name__ == '__main__':
main()
Binary file added demo/demo_skeleton_out.mp4
Binary file not shown.
139 changes: 139 additions & 0 deletions demo/skeleton_demo_cfg/faster-rcnn_r50_fpn_2x_coco_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) OpenMMLab. All rights reserved.
# model settings
model = dict(
type='FasterRCNN',
_scope_='mmdet',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)))

# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
file_client_args = dict(backend='disk')
test_pipeline = [
dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmdet.Resize', scale=(1333, 800), keep_ratio=True),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
test_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_val2017.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=test_pipeline))
Loading

0 comments on commit 881ce57

Please sign in to comment.