Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch multigrid updates #1620

Merged
merged 4 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gluoncv/torch/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@

from .video_cls.dataset_classification import VideoClsDataset
from .video_cls.dataset_classification import build_dataloader, build_dataloader_test
from .video_cls.multigrid_helper import multiGridSampler, MultiGridBatchSampler
from .video_cls.multigrid_helper import multiGridHelper, MultiGridBatchSampler
from .coot.dataloader import create_datasets, create_loaders
23 changes: 13 additions & 10 deletions gluoncv/torch/data/video_cls/dataset_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import Dataset

from ..transforms.videotransforms import video_transforms, volume_transforms
from .multigrid_helper import multiGridSampler, MultiGridBatchSampler
from .multigrid_helper import multiGridHelper, MultiGridBatchSampler


__all__ = ['VideoClsDataset', 'build_dataloader', 'build_dataloader_test']
Expand Down Expand Up @@ -45,12 +45,12 @@ def __init__(self, anno_path, data_path, mode='train', clip_len=8,

if (mode == 'train'):
if self.use_multigrid:
self.MG_sampler = multiGridSampler()
self.mg_helper = multiGridHelper()
self.data_transform = []
for alpha in range(self.MG_sampler.mod_long):
for alpha in range(self.mg_helper.mod_long):
tmp = []
for beta in range(self.MG_sampler.mod_short):
info = self.MG_sampler.get_resize(alpha, beta)
for beta in range(self.mg_helper.mod_short):
info = self.mg_helper.get_resize(alpha, beta)
scale_s = info[1]
tmp.append(video_transforms.Compose([
video_transforms.Resize(int(self.short_side_size / scale_s),
Expand Down Expand Up @@ -108,7 +108,7 @@ def __getitem__(self, index):
if self.mode == 'train':
if self.use_multigrid is True:
index, alpha, beta = index
info = self.MG_sampler.get_resize(alpha, beta)
info = self.mg_helper.get_resize(alpha, beta)
scale_t = info[0]
data_transform_func = self.data_transform[alpha][beta]
else:
Expand Down Expand Up @@ -241,7 +241,8 @@ def build_dataloader(cfg):
train_dataset = VideoClsDataset(anno_path=cfg.CONFIG.DATA.TRAIN_ANNO_PATH,
data_path=cfg.CONFIG.DATA.TRAIN_DATA_PATH,
mode='train',
use_multigrid=cfg.CONFIG.DATA.MULTIGRID,
use_multigrid=cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE \
or cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE,
clip_len=cfg.CONFIG.DATA.CLIP_LEN,
frame_sample_rate=cfg.CONFIG.DATA.FRAME_RATE,
num_segment=cfg.CONFIG.DATA.NUM_SEGMENT,
Expand All @@ -254,7 +255,7 @@ def build_dataloader(cfg):
val_dataset = VideoClsDataset(anno_path=cfg.CONFIG.DATA.VAL_ANNO_PATH,
data_path=cfg.CONFIG.DATA.VAL_DATA_PATH,
mode='validation',
use_multigrid=cfg.CONFIG.DATA.MULTIGRID,
use_multigrid=False,
clip_len=cfg.CONFIG.DATA.CLIP_LEN,
frame_sample_rate=cfg.CONFIG.DATA.FRAME_RATE,
num_segment=cfg.CONFIG.DATA.NUM_SEGMENT,
Expand All @@ -273,9 +274,11 @@ def build_dataloader(cfg):
val_sampler = None

mg_sampler = None
if cfg.CONFIG.DATA.MULTIGRID:
if cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE or cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE:
mg_sampler = MultiGridBatchSampler(train_sampler, batch_size=cfg.CONFIG.TRAIN.BATCH_SIZE,
drop_last=True)
drop_last=True,
use_long=cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE,
use_short=cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False,
num_workers=9, pin_memory=True,
batch_sampler=mg_sampler)
Expand Down
141 changes: 82 additions & 59 deletions gluoncv/torch/data/video_cls/multigrid_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
from torch._six import int_classes as _int_classes


__all__ = ['multiGridSampler', 'MultiGridBatchSampler']
__all__ = ['multiGridHelper', 'MultiGridBatchSampler']


sq2 = np.sqrt(2)
class multiGridSampler(object):
class multiGridHelper(object):
"""
A Multigrid Method for Efficiently Training Video Models
Chao-Yuan Wu, Ross Girshick, Kaiming He, Christoph Feichtenhofer, Philipp Krähenbühl
CVPR 2020, https://arxiv.org/abs/1912.00998
"""
def __init__(self):
# Scale: [T, H, W]
self.long_cycle = np.asarray([[1, 1, 1], [2, 1, 1], [2, sq2, sq2]])[::-1]
self.short_cycle = np.asarray([[1, 1, 1], [1, sq2, sq2], [1, 2, 2]])[::-1]
self.short_cycle_sp = np.asarray([[1, 1, 1], [1, sq2, sq2], [1, sq2, sq2]])[::-1]
Expand Down Expand Up @@ -54,7 +55,15 @@ class MultiGridBatchSampler(Sampler):
Chao-Yuan Wu, Ross Girshick, Kaiming He, Christoph Feichtenhofer, Philipp Krähenbühl
CVPR 2020, https://arxiv.org/abs/1912.00998
"""
def __init__(self, sampler, batch_size, drop_last):
def __init__(self, sampler, batch_size, drop_last, use_long=False, use_short=False):
'''
:param sampler: torch.utils.data.Sample
:param batch_size: int
:param drop_last: bool
:param use_long: bool
:param use_short: bool
Apply batch collecting function based on multiGridHelper definition
'''
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
Expand All @@ -66,85 +75,99 @@ def __init__(self, sampler, batch_size, drop_last):
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
if not isinstance(use_long, bool):
raise ValueError("use_long should be a boolean value, but got "
"use_long={}".format(use_long))
if not isinstance(use_short, bool):
raise ValueError("use_short should be a boolean value, but got "
"use_short={}".format(use_short))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last

self.MG_sampler = multiGridSampler()
self.alpha = self.MG_sampler.mod_long - 1
self.mg_helper = multiGridHelper()
# single grid setting
self.alpha = self.mg_helper.mod_long - 1
self.beta = self.mg_helper.mod_short - 1
self.short_cycle_label = False
if use_long:
self.activate_long_cycle()
if use_short:
self.activate_short_cycle()
self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta)

def activate_short_cycle(self):
self.short_cycle_label = True
self.beta = 0
self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta)
self.label = True

def deactivate(self):
self.label = False
self.alpha = self.MG_sampler.mod_long - 1

def activate(self):
self.label = True
def activate_long_cycle(self):
self.alpha = 0

def deactivate(self):
self.alpha = self.mg_helper.mod_long - 1
self.beta = self.mg_helper.mod_short - 1
self.short_cycle_label = False

def __iter__(self):
batch = []
if self.label:
if self.short_cycle_label:
self.beta = 0
else:
self.beta = self.MG_sampler.mod_short - 1
self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta)
self.beta = self.mg_helper.mod_short - 1
self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta)
for idx in self.sampler:
batch.append([idx, self.alpha, self.beta])
if len(batch) == self.batch_size*self.batch_scale:
yield batch
batch = []
if self.label:
self.beta = (self.beta + 1)%self.MG_sampler.mod_short
self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta)
if self.short_cycle_label:
self.beta = (self.beta + 1) % self.mg_helper.mod_short
self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta)

if len(batch) > 0 and not self.drop_last:
yield batch

def step_alpha(self):
self.alpha = (self.alpha + 1)%self.MG_sampler.mod_long

def compute_lr_milestone(self, lr_milestone):
"""
long cycle milestones
"""
self.len_long = self.MG_sampler.mod_long
self.n_epoch_long = 0
for x in range(self.len_long):
self.n_epoch_long += self.MG_sampler.get_scale_alpha(x)
lr_long_cycle = []
for i, _ in enumerate(lr_milestone):
if i == 0:
pre = 0
else:
pre = lr_milestone[i-1]
cycle_length = (lr_milestone[i] - pre) // self.n_epoch_long
bonus = (lr_milestone[i] - pre)%self.n_epoch_long // self.len_long
for j in range(self.len_long)[::-1]:
pre = pre + cycle_length*(2**j) + bonus
if j == 0:
pre = lr_milestone[i]
lr_long_cycle.append(pre)
lr_long_cycle.append(0)
lr_long_cycle = sorted(lr_long_cycle)
return lr_long_cycle
def step_long_cycle(self):
self.alpha = (self.alpha + 1) % self.mg_helper.mod_long

# def compute_lr_milestone(self, lr_milestone):
# """
# long cycle milestones, deprecated. Define long cycle in config files
# """
# self.len_long = self.mg_helper.mod_long
# self.n_epoch_long = 0
# for x in range(self.len_long):
# self.n_epoch_long += self.mg_helper.get_scale_alpha(x)
# lr_long_cycle = []
# for i, _ in enumerate(lr_milestone):
# if i == 0:
# pre = 0
# else:
# pre = lr_milestone[i-1]
# cycle_length = (lr_milestone[i] - pre) // self.n_epoch_long
# bonus = (lr_milestone[i] - pre)%self.n_epoch_long // self.len_long
# for j in range(self.len_long)[::-1]:
# pre = pre + cycle_length*(2**j) + bonus
# if j == 0:
# pre = lr_milestone[i]
# lr_long_cycle.append(pre)
# lr_long_cycle.append(0)
# lr_long_cycle = sorted(lr_long_cycle)
# return lr_long_cycle

def __len__(self):
self.len_short = self.MG_sampler.mod_short
self.n_epoch_short = 0
for x in range(self.len_short):
self.n_epoch_short += self.MG_sampler.get_scale_beta(x)
short_batch_size = self.batch_size * self.MG_sampler.get_scale_alpha(self.alpha)
num_short = len(self.sampler) // short_batch_size

total = num_short // self.n_epoch_short * self.len_short
remain = self.n_epoch_short
for x in range(self.len_short):
remain = remain - (2**x)
scale_per_short_cycle = 0
for x in range(self.mg_helper.mod_short):
scale_per_short_cycle += self.mg_helper.get_scale(self.alpha, x)
num_full_short_cycle = len(self.sampler) // (self.batch_size * scale_per_short_cycle)

total = num_full_short_cycle * self.mg_helper.mod_short
remain = len(self.sampler) % (self.batch_size * scale_per_short_cycle)
for x in range(self.mg_helper.mod_short):
remain = remain - self.mg_helper.get_scale(self.alpha, x)*self.batch_size
if remain >= 0 or (remain < 0 and self.drop_last is False):
total += 1
if remain <= 0:
break
total = total + int(num_short%self.n_epoch_short >= remain)

assert remain <= 0
return total
7 changes: 5 additions & 2 deletions gluoncv/torch/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@
# Resume training from a specific epoch. Set to -1 means train from beginning.
_C.CONFIG.TRAIN.RESUME_EPOCH: -1

# Whether to use multigrid training to speed up.
_C.CONFIG.TRAIN.MULTIGRID = CN(new_allowed=True)
_C.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE = False
_C.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE = False
_C.CONFIG.TRAIN.MULTIGRID.LONG_CYCLE_EPOCH = [10, 20, 30]

_C.CONFIG.VAL = CN(new_allowed=True)
# Evaluate model on test data every eval period epochs.
Expand All @@ -91,8 +96,6 @@
_C.CONFIG.DATA.VAL_DATA_PATH = ''
# The number of classes to predict for the model.
_C.CONFIG.DATA.NUM_CLASSES = 400
# Whether to use multigrid training to speed up.
_C.CONFIG.DATA.MULTIGRID = False
# The number of frames of the input clip.
_C.CONFIG.DATA.CLIP_LEN = 16
# The video sampling rate of the input clip.
Expand Down
5 changes: 5 additions & 0 deletions scripts/action-recognition/train_ddp_shortonly_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gluoncv.torch.engine.config import get_cfg_defaults
from gluoncv.torch.engine.launch import spawn_workers
from gluoncv.torch.utils.utils import build_log_dir
from gluoncv.torch.utils.lr_policy import GradualWarmupScheduler


def main_worker(cfg):
Expand Down Expand Up @@ -67,6 +68,10 @@ def main_worker(cfg):
else:
scheduler.step()

if cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE:
if epoch in cfg.CONFIG.TRAIN.MULTIGRID.LONG_CYCLE_EPOCH:
mg_sampler.step_long_cycle()

if epoch % cfg.CONFIG.VAL.FREQ == 0 or epoch == cfg.CONFIG.TRAIN.EPOCH_NUM - 1:
validation_classification(model, val_loader, epoch, criterion, cfg, writer)

Expand Down