Skip to content

Commit

Permalink
Support temporal pool in tsm (open-mmlab#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin authored Jul 20, 2020
1 parent 08bb3f8 commit d85364d
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNetTSM',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False,
temporal_pool=True,
shift_div=8),
cls_head=dict(
type='TSMHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.5,
init_std=0.001,
temporal_pool=True,
is_shift=True))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train'
data_root_val = 'data/kinetics400/rawframes_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1,
num_fixed_crops=13),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD',
constructor='TSMOptimizerConstructor',
paramwise_cfg=dict(fc_lr5=True),
lr=0.01, # this lr is used for 8 gpus
momentum=0.9,
weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[20, 40])
total_epochs = 50
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsm_temporal_pool_r50_1x1x8_100e_kinetics400_rgb/'
load_from = None
resume_from = None
workflow = [('train', 1)]
37 changes: 37 additions & 0 deletions mmaction/models/backbones/resnet_tsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,46 @@ def make_block_temporal(stage, num_segments):
else:
raise NotImplementedError

def make_temporal_pool(self):
"""Make temporal pooling between layer1 and layer2, using a 3D max
pooling layer."""

class TemporalPool(nn.Module):
"""Temporal pool module.
Wrap layer2 in ResNet50 with a 3D max pooling layer.
Args:
net (nn.Module): Module to make temporal pool.
num_segments (int): Number of frame segments.
"""

def __init__(self, net, num_segments):
super().__init__()
self.net = net
self.num_segments = num_segments
self.max_pool3d = nn.MaxPool3d(
kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))

def forward(self, x):
# [N, C, H, W]
n, c, h, w = x.size()
# [N // num_segments, C, num_segments, H, W]
x = x.view(n // self.num_segments, self.num_segments, c, h,
w).transpose(1, 2)
# [N // num_segmnets, C, num_segments // 2, H, W]
x = self.max_pool3d(x)
# [N // 2, C, H, W]
x = x.transpose(1, 2).contiguous().view(n // 2, c, h, w)
return self.net(x)

self.layer2 = TemporalPool(self.layer2, self.num_segments)

def init_weights(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
super().init_weights()
if self.is_shift:
self.make_temporal_shift()
if self.temporal_pool:
self.make_temporal_pool()
17 changes: 17 additions & 0 deletions tests/test_models/test_backbone.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -505,6 +507,12 @@ def test_resnet_tsm_backbone():
for layer_name in resnet_tsm_50_temporal_pool.res_layers:
layer = getattr(resnet_tsm_50_temporal_pool, layer_name)
blocks = list(layer.children())

if layer_name == 'layer2':
assert len(blocks) == 2
assert isinstance(blocks[1], nn.MaxPool3d)
blocks = copy.deepcopy(blocks[0])

for block in blocks:
assert isinstance(block.conv1.conv, TemporalShift)
if layer_name == 'layer1':
Expand All @@ -516,6 +524,15 @@ def test_resnet_tsm_backbone():
assert block.conv1.conv.shift_div == resnet_tsm_50_temporal_pool.shift_div # noqa: E501
assert isinstance(block.conv1.conv.net, nn.Conv2d)

input_shape = (8, 3, 64, 64)
imgs = _demo_inputs(input_shape)

feat = resnet_tsm_50(imgs)
assert feat.shape == torch.Size([8, 2048, 2, 2])

feat = resnet_tsm_50_temporal_pool(imgs)
assert feat.shape == torch.Size([4, 2048, 2, 2])


def test_slowfast_backbone():
"""Test SlowFast backbone."""
Expand Down

0 comments on commit d85364d

Please sign in to comment.