diff --git a/gluoncv/torch/data/__init__.py b/gluoncv/torch/data/__init__.py index 33c8cad4ee..a4abe3d546 100644 --- a/gluoncv/torch/data/__init__.py +++ b/gluoncv/torch/data/__init__.py @@ -4,5 +4,5 @@ from .video_cls.dataset_classification import VideoClsDataset from .video_cls.dataset_classification import build_dataloader, build_dataloader_test -from .video_cls.multigrid_helper import multiGridSampler, MultiGridBatchSampler +from .video_cls.multigrid_helper import multiGridHelper, MultiGridBatchSampler from .coot.dataloader import create_datasets, create_loaders diff --git a/gluoncv/torch/data/video_cls/dataset_classification.py b/gluoncv/torch/data/video_cls/dataset_classification.py index b2f8473230..97351b72c3 100644 --- a/gluoncv/torch/data/video_cls/dataset_classification.py +++ b/gluoncv/torch/data/video_cls/dataset_classification.py @@ -8,7 +8,7 @@ from torch.utils.data import Dataset from ..transforms.videotransforms import video_transforms, volume_transforms -from .multigrid_helper import multiGridSampler, MultiGridBatchSampler +from .multigrid_helper import multiGridHelper, MultiGridBatchSampler __all__ = ['VideoClsDataset', 'build_dataloader', 'build_dataloader_test'] @@ -45,12 +45,12 @@ def __init__(self, anno_path, data_path, mode='train', clip_len=8, if (mode == 'train'): if self.use_multigrid: - self.MG_sampler = multiGridSampler() + self.mg_helper = multiGridHelper() self.data_transform = [] - for alpha in range(self.MG_sampler.mod_long): + for alpha in range(self.mg_helper.mod_long): tmp = [] - for beta in range(self.MG_sampler.mod_short): - info = self.MG_sampler.get_resize(alpha, beta) + for beta in range(self.mg_helper.mod_short): + info = self.mg_helper.get_resize(alpha, beta) scale_s = info[1] tmp.append(video_transforms.Compose([ video_transforms.Resize(int(self.short_side_size / scale_s), @@ -108,7 +108,7 @@ def __getitem__(self, index): if self.mode == 'train': if self.use_multigrid is True: index, alpha, beta = index - info = self.MG_sampler.get_resize(alpha, beta) + info = self.mg_helper.get_resize(alpha, beta) scale_t = info[0] data_transform_func = self.data_transform[alpha][beta] else: @@ -241,7 +241,8 @@ def build_dataloader(cfg): train_dataset = VideoClsDataset(anno_path=cfg.CONFIG.DATA.TRAIN_ANNO_PATH, data_path=cfg.CONFIG.DATA.TRAIN_DATA_PATH, mode='train', - use_multigrid=cfg.CONFIG.DATA.MULTIGRID, + use_multigrid=cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE \ + or cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE, clip_len=cfg.CONFIG.DATA.CLIP_LEN, frame_sample_rate=cfg.CONFIG.DATA.FRAME_RATE, num_segment=cfg.CONFIG.DATA.NUM_SEGMENT, @@ -254,7 +255,7 @@ def build_dataloader(cfg): val_dataset = VideoClsDataset(anno_path=cfg.CONFIG.DATA.VAL_ANNO_PATH, data_path=cfg.CONFIG.DATA.VAL_DATA_PATH, mode='validation', - use_multigrid=cfg.CONFIG.DATA.MULTIGRID, + use_multigrid=False, clip_len=cfg.CONFIG.DATA.CLIP_LEN, frame_sample_rate=cfg.CONFIG.DATA.FRAME_RATE, num_segment=cfg.CONFIG.DATA.NUM_SEGMENT, @@ -273,9 +274,11 @@ def build_dataloader(cfg): val_sampler = None mg_sampler = None - if cfg.CONFIG.DATA.MULTIGRID: + if cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE or cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE: mg_sampler = MultiGridBatchSampler(train_sampler, batch_size=cfg.CONFIG.TRAIN.BATCH_SIZE, - drop_last=True) + drop_last=True, + use_long=cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE, + use_short=cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=9, pin_memory=True, batch_sampler=mg_sampler) diff --git a/gluoncv/torch/data/video_cls/multigrid_helper.py b/gluoncv/torch/data/video_cls/multigrid_helper.py index af6275c38b..05fc2533bf 100644 --- a/gluoncv/torch/data/video_cls/multigrid_helper.py +++ b/gluoncv/torch/data/video_cls/multigrid_helper.py @@ -6,17 +6,18 @@ from torch._six import int_classes as _int_classes -__all__ = ['multiGridSampler', 'MultiGridBatchSampler'] +__all__ = ['multiGridHelper', 'MultiGridBatchSampler'] sq2 = np.sqrt(2) -class multiGridSampler(object): +class multiGridHelper(object): """ A Multigrid Method for Efficiently Training Video Models Chao-Yuan Wu, Ross Girshick, Kaiming He, Christoph Feichtenhofer, Philipp Krähenbühl CVPR 2020, https://arxiv.org/abs/1912.00998 """ def __init__(self): + # Scale: [T, H, W] self.long_cycle = np.asarray([[1, 1, 1], [2, 1, 1], [2, sq2, sq2]])[::-1] self.short_cycle = np.asarray([[1, 1, 1], [1, sq2, sq2], [1, 2, 2]])[::-1] self.short_cycle_sp = np.asarray([[1, 1, 1], [1, sq2, sq2], [1, sq2, sq2]])[::-1] @@ -54,7 +55,15 @@ class MultiGridBatchSampler(Sampler): Chao-Yuan Wu, Ross Girshick, Kaiming He, Christoph Feichtenhofer, Philipp Krähenbühl CVPR 2020, https://arxiv.org/abs/1912.00998 """ - def __init__(self, sampler, batch_size, drop_last): + def __init__(self, sampler, batch_size, drop_last, use_long=False, use_short=False): + ''' + :param sampler: torch.utils.data.Sample + :param batch_size: int + :param drop_last: bool + :param use_long: bool + :param use_short: bool + Apply batch collecting function based on multiGridHelper definition + ''' if not isinstance(sampler, Sampler): raise ValueError("sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}" @@ -66,85 +75,99 @@ def __init__(self, sampler, batch_size, drop_last): if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) + if not isinstance(use_long, bool): + raise ValueError("use_long should be a boolean value, but got " + "use_long={}".format(use_long)) + if not isinstance(use_short, bool): + raise ValueError("use_short should be a boolean value, but got " + "use_short={}".format(use_short)) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last - self.MG_sampler = multiGridSampler() - self.alpha = self.MG_sampler.mod_long - 1 + self.mg_helper = multiGridHelper() + # single grid setting + self.alpha = self.mg_helper.mod_long - 1 + self.beta = self.mg_helper.mod_short - 1 + self.short_cycle_label = False + if use_long: + self.activate_long_cycle() + if use_short: + self.activate_short_cycle() + self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta) + + def activate_short_cycle(self): + self.short_cycle_label = True self.beta = 0 - self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta) - self.label = True - def deactivate(self): - self.label = False - self.alpha = self.MG_sampler.mod_long - 1 - - def activate(self): - self.label = True + def activate_long_cycle(self): self.alpha = 0 + def deactivate(self): + self.alpha = self.mg_helper.mod_long - 1 + self.beta = self.mg_helper.mod_short - 1 + self.short_cycle_label = False + def __iter__(self): batch = [] - if self.label: + if self.short_cycle_label: self.beta = 0 else: - self.beta = self.MG_sampler.mod_short - 1 - self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta) + self.beta = self.mg_helper.mod_short - 1 + self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta) for idx in self.sampler: batch.append([idx, self.alpha, self.beta]) if len(batch) == self.batch_size*self.batch_scale: yield batch batch = [] - if self.label: - self.beta = (self.beta + 1)%self.MG_sampler.mod_short - self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta) + if self.short_cycle_label: + self.beta = (self.beta + 1) % self.mg_helper.mod_short + self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta) if len(batch) > 0 and not self.drop_last: yield batch - def step_alpha(self): - self.alpha = (self.alpha + 1)%self.MG_sampler.mod_long - - def compute_lr_milestone(self, lr_milestone): - """ - long cycle milestones - """ - self.len_long = self.MG_sampler.mod_long - self.n_epoch_long = 0 - for x in range(self.len_long): - self.n_epoch_long += self.MG_sampler.get_scale_alpha(x) - lr_long_cycle = [] - for i, _ in enumerate(lr_milestone): - if i == 0: - pre = 0 - else: - pre = lr_milestone[i-1] - cycle_length = (lr_milestone[i] - pre) // self.n_epoch_long - bonus = (lr_milestone[i] - pre)%self.n_epoch_long // self.len_long - for j in range(self.len_long)[::-1]: - pre = pre + cycle_length*(2**j) + bonus - if j == 0: - pre = lr_milestone[i] - lr_long_cycle.append(pre) - lr_long_cycle.append(0) - lr_long_cycle = sorted(lr_long_cycle) - return lr_long_cycle + def step_long_cycle(self): + self.alpha = (self.alpha + 1) % self.mg_helper.mod_long + + # def compute_lr_milestone(self, lr_milestone): + # """ + # long cycle milestones, deprecated. Define long cycle in config files + # """ + # self.len_long = self.mg_helper.mod_long + # self.n_epoch_long = 0 + # for x in range(self.len_long): + # self.n_epoch_long += self.mg_helper.get_scale_alpha(x) + # lr_long_cycle = [] + # for i, _ in enumerate(lr_milestone): + # if i == 0: + # pre = 0 + # else: + # pre = lr_milestone[i-1] + # cycle_length = (lr_milestone[i] - pre) // self.n_epoch_long + # bonus = (lr_milestone[i] - pre)%self.n_epoch_long // self.len_long + # for j in range(self.len_long)[::-1]: + # pre = pre + cycle_length*(2**j) + bonus + # if j == 0: + # pre = lr_milestone[i] + # lr_long_cycle.append(pre) + # lr_long_cycle.append(0) + # lr_long_cycle = sorted(lr_long_cycle) + # return lr_long_cycle def __len__(self): - self.len_short = self.MG_sampler.mod_short - self.n_epoch_short = 0 - for x in range(self.len_short): - self.n_epoch_short += self.MG_sampler.get_scale_beta(x) - short_batch_size = self.batch_size * self.MG_sampler.get_scale_alpha(self.alpha) - num_short = len(self.sampler) // short_batch_size - - total = num_short // self.n_epoch_short * self.len_short - remain = self.n_epoch_short - for x in range(self.len_short): - remain = remain - (2**x) + scale_per_short_cycle = 0 + for x in range(self.mg_helper.mod_short): + scale_per_short_cycle += self.mg_helper.get_scale(self.alpha, x) + num_full_short_cycle = len(self.sampler) // (self.batch_size * scale_per_short_cycle) + + total = num_full_short_cycle * self.mg_helper.mod_short + remain = len(self.sampler) % (self.batch_size * scale_per_short_cycle) + for x in range(self.mg_helper.mod_short): + remain = remain - self.mg_helper.get_scale(self.alpha, x)*self.batch_size + if remain >= 0 or (remain < 0 and self.drop_last is False): + total += 1 if remain <= 0: break - total = total + int(num_short%self.n_epoch_short >= remain) - + assert remain <= 0 return total diff --git a/gluoncv/torch/engine/config.py b/gluoncv/torch/engine/config.py index ec838cc220..7daf997de0 100644 --- a/gluoncv/torch/engine/config.py +++ b/gluoncv/torch/engine/config.py @@ -68,6 +68,11 @@ # Resume training from a specific epoch. Set to -1 means train from beginning. _C.CONFIG.TRAIN.RESUME_EPOCH: -1 +# Whether to use multigrid training to speed up. +_C.CONFIG.TRAIN.MULTIGRID = CN(new_allowed=True) +_C.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE = False +_C.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE = False +_C.CONFIG.TRAIN.MULTIGRID.LONG_CYCLE_EPOCH = [10, 20, 30] _C.CONFIG.VAL = CN(new_allowed=True) # Evaluate model on test data every eval period epochs. @@ -91,8 +96,6 @@ _C.CONFIG.DATA.VAL_DATA_PATH = '' # The number of classes to predict for the model. _C.CONFIG.DATA.NUM_CLASSES = 400 -# Whether to use multigrid training to speed up. -_C.CONFIG.DATA.MULTIGRID = False # The number of frames of the input clip. _C.CONFIG.DATA.CLIP_LEN = 16 # The video sampling rate of the input clip. diff --git a/scripts/action-recognition/train_ddp_shortonly_pytorch.py b/scripts/action-recognition/train_ddp_shortonly_pytorch.py index c6774a280a..b49fb3342e 100644 --- a/scripts/action-recognition/train_ddp_shortonly_pytorch.py +++ b/scripts/action-recognition/train_ddp_shortonly_pytorch.py @@ -14,6 +14,7 @@ from gluoncv.torch.engine.config import get_cfg_defaults from gluoncv.torch.engine.launch import spawn_workers from gluoncv.torch.utils.utils import build_log_dir +from gluoncv.torch.utils.lr_policy import GradualWarmupScheduler def main_worker(cfg): @@ -67,6 +68,10 @@ def main_worker(cfg): else: scheduler.step() + if cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE: + if epoch in cfg.CONFIG.TRAIN.MULTIGRID.LONG_CYCLE_EPOCH: + mg_sampler.step_long_cycle() + if epoch % cfg.CONFIG.VAL.FREQ == 0 or epoch == cfg.CONFIG.TRAIN.EPOCH_NUM - 1: validation_classification(model, val_loader, epoch, criterion, cfg, writer)