Skip to content

Commit

Permalink
Merge branch 'skeleton_demo' of https://github.com/Dai-Wenxun/mmaction2
Browse files Browse the repository at this point in the history
… into skeleton_demo
  • Loading branch information
Dai-Wenxun committed Sep 15, 2022
2 parents fbe0db9 + 819f209 commit c5d6ea1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
10 changes: 4 additions & 6 deletions mmaction/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import torch
import torch.nn as nn
from torchvision.transforms import Compose
from mmengine.dataset import Compose, pseudo_collate
from mmengine.runner import load_checkpoint
from mmengine.utils import track_iter_progress
Expand Down Expand Up @@ -48,11 +47,10 @@ def init_recognizer(config: Union[str, Path, mmengine.Config],
return model


def inference_recognizer(
model: nn.Module,
video: Union[str, dict],
test_pipeline: Optional[Compose] = None
) -> ActionDataSample:
def inference_recognizer(model: nn.Module,
video: Union[str, dict],
test_pipeline: Optional[Compose] = None
) -> ActionDataSample:
"""Inference a video with the recognizer.
Args:
Expand Down
14 changes: 7 additions & 7 deletions tests/apis/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import os.path as osp
import torch
import torch.nn as nn
from pathlib import Path
from unittest import TestCase

import torch
from parameterized import parameterized

from mmaction.apis import inference_recognizer, init_recognizer
from mmaction.utils import register_all_modules
from mmaction.structures import ActionDataSample
from mmaction.utils import register_all_modules


class TestInference(TestCase):
Expand All @@ -19,7 +18,8 @@ def setUp(self):

@parameterized.expand([(('configs/recognition/tsn/'
'tsn_imagenet-pretrained-r50_8xb32-'
'1x1x3-100e_kinetics400-rgb.py'), ('cpu', 'cuda'))])
'1x1x3-100e_kinetics400-rgb.py'), ('cpu', 'cuda'))
])
def test_init_recognizer(self, config, devices):
project_dir = osp.abspath(osp.dirname(osp.dirname(__file__)))
project_dir = osp.join(project_dir, '..')
Expand All @@ -44,8 +44,8 @@ def test_init_recognizer(self, config, devices):

@parameterized.expand([(('configs/recognition/tsn/'
'tsn_imagenet-pretrained-r50_8xb32-'
'1x1x3-100e_kinetics400-rgb.py'),
'demo/demo.mp4', ('cpu', 'cuda'))])
'1x1x3-100e_kinetics400-rgb.py'), 'demo/demo.mp4',
('cpu', 'cuda'))])
def test_inference_recognizer(self, config, video_path, devices):
project_dir = osp.abspath(osp.dirname(osp.dirname(__file__)))
project_dir = osp.join(project_dir, '..')
Expand Down

0 comments on commit c5d6ea1

Please sign in to comment.