Skip to content

Commit

Permalink
[Feature] Sparsedemo (open-mmlab#468)
Browse files Browse the repository at this point in the history
* resolve comments

* update changelog

* + sparse demo

* fix bug

* add doc for the new arg

* update changelog

* reorg code in long demo

* add a comment

* resolve comments

Co-authored-by: Jintao Lin <[email protected]>
  • Loading branch information
kennymckormick and dreamerlin authored Dec 24, 2020
1 parent 3c508b0 commit e9a70b9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 27 deletions.
1 change: 1 addition & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ Optional arguments:
- `INPUT_STEP`: Input step for sampling frames, which can help to get more spare input. If not specified , it will be set to 1.
- `DEVICE_TYPE`: Type of device to run the demo. Allowed values are cuda device like `cuda:0` or `cpu`. If not specified, it will be set to `cuda:0`.
- `THRESHOLD`: Threshold of prediction score for action recognition. Only label with score higher than the threshold will be shown. If not specified, it will be set to 0.01.
- `STRIDE`: By default, the demo generates a prediction for each single frame, which might cost lots of time. To speed up, you can set the argument `STRIDE` and then the demo will generate a prediction every `STRIDE x sample_length` frames (`sample_length` indicates the size of temporal window from which you sample frames, which equals to `clip_len x frame_interval`). For example, if the sample_length is 64 frames and you set `STRIDE` to 0.5, predictions will be generated every 32 frames. If set as 0, predictions will be generated for each frame. The desired value of `STRIDE` is (0, 1], while it also works for `STRIDE > 1` (the generated predictions will be too sparse). Default: 0.
Examples:
Expand Down
70 changes: 43 additions & 27 deletions demo/long_video_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def parse_args():
description='MMAction2 predict different labels in a long video demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file/url')
parser.add_argument('video', help='video file/url')
parser.add_argument('video_path', help='video file/url')
parser.add_argument('label', help='label file')
parser.add_argument('out_file', help='output filename')
parser.add_argument(
Expand All @@ -45,12 +45,24 @@ def parse_args():
type=float,
default=0.01,
help='recognition score threshold')
parser.add_argument(
'--stride',
type=float,
default=0,
help=('the prediction stride equals to stride * sample_length '
'(sample_length indicates the size of temporal window from '
'which you sample frames, which equals to '
'clip_len x frame_interval), if set as 0, the '
'prediction stride is 1'))
args = parser.parse_args()
return args


def show_results():
cap = cv2.VideoCapture(video_path)
def show_results(model, data, label, args):
frame_queue = deque(maxlen=args.sample_length)
result_queue = deque(maxlen=1)

cap = cv2.VideoCapture(args.video_path)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
Expand All @@ -62,7 +74,7 @@ def show_results():
frame_size = (frame_width, frame_height)

ind = 0
video_writer = cv2.VideoWriter(out_file, fourcc, fps, frame_size)
video_writer = cv2.VideoWriter(args.out_file, fourcc, fps, frame_size)
prog_bar = mmcv.ProgressBar(num_frames)
backup_frames = []

Expand All @@ -74,19 +86,19 @@ def show_results():
# drop it when encounting None
continue
backup_frames.append(np.array(frame)[:, :, ::-1])
if ind == sample_length:
if ind == args.sample_length:
# provide a quick show at the beginning
frame_queue.extend(backup_frames)
backup_frames = []
elif ((len(backup_frames) == input_step and ind > sample_length)
or ind == num_frames):
elif ((len(backup_frames) == args.input_step
and ind > args.sample_length) or ind == num_frames):
# pick a frame from the backup
# when the backup is full or reach the last frame
chosen_frame = random.choice(backup_frames)
backup_frames = []
frame_queue.append(chosen_frame)

ret, scores = inference()
ret, scores = inference(model, data, args, frame_queue)

if ret:
num_selected_labels = min(len(label), 5)
Expand All @@ -101,7 +113,7 @@ def show_results():
results = result_queue.popleft()
for i, result in enumerate(results):
selected_label, score = result
if score < threshold:
if score < args.threshold:
break
location = (0, 40 + i * 20)
text = selected_label + ': ' + str(round(score, 2))
Expand All @@ -120,38 +132,40 @@ def show_results():
cv2.destroyAllWindows()


def inference():
if len(frame_queue) != sample_length:
def inference(model, data, args, frame_queue):
if len(frame_queue) != args.sample_length:
# Do no inference when there is no enough frames
return False, None

cur_windows = list(np.array(frame_queue))
img = frame_queue.popleft()
if data['img_shape'] is None:
data['img_shape'] = img.shape[:2]
data['img_shape'] = frame_queue[0].shape[:2]

cur_data = data.copy()
cur_data['imgs'] = cur_windows
cur_data = test_pipeline(cur_data)
cur_data = args.test_pipeline(cur_data)
cur_data = collate([cur_data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
cur_data = scatter(cur_data, [device])[0]
cur_data = scatter(cur_data, [args.device])[0]
with torch.no_grad():
scores = model(return_loss=False, **cur_data)[0]

if args.stride > 0:
pred_stride = int(args.sample_length * args.stride)
for i in range(pred_stride):
frame_queue.popleft()

# for case ``args.stride=0``
# deque will automatically popleft one element

return True, scores


def main():
global frame_queue, threshold, sample_length, data, test_pipeline, model, \
out_file, video_path, device, input_step, label, result_queue

args = parse_args()
input_step = args.input_step
threshold = args.threshold
video_path = args.video
out_file = args.out_file

device = torch.device(args.device)
model = init_recognizer(args.config, args.checkpoint, device=device)
args.device = torch.device(args.device)
model = init_recognizer(args.config, args.checkpoint, device=args.device)
data = dict(img_shape=None, modality='RGB', label=-1)
with open(args.label, 'r') as f:
label = [line.strip() for line in f]
Expand All @@ -171,10 +185,12 @@ def main():
# remove step to decode frames
pipeline_.remove(step)
test_pipeline = Compose(pipeline_)

assert sample_length > 0
frame_queue = deque(maxlen=sample_length)
result_queue = deque(maxlen=1)
show_results()
args.sample_length = sample_length
args.test_pipeline = test_pipeline

show_results(model, data, label, args)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

**Improvements**

- Add arg `stride` to long_video_demo.py, to make inference faster ([#468](https://github.com/open-mmlab/mmaction2/pull/468))
- Support training and testing for Spatio-Temporal Action Detection ([#351](https://github.com/open-mmlab/mmaction2/pull/351))
- Fix CI due to pip upgrade ([#454](https://github.com/open-mmlab/mmaction2/pull/454))
- Add markdown lint in pre-commit hook ([#255](https://github.com/open-mmlab/mmaction2/pull/225))
Expand Down

0 comments on commit e9a70b9

Please sign in to comment.