From 005a19cb0ce6d3161d39aa62f25ff3d7df40614e Mon Sep 17 00:00:00 2001 From: Chunhui Liu Date: Fri, 19 Feb 2021 18:12:05 -0800 Subject: [PATCH 1/3] pytorch multigrid updates (1) modify config logic for multigrid usage, now support (a) open long cycle and short cycle independently. (b) define long cycle by users (2) rename multiGridSampler to multiGridHelper (3) re-implement code for computing batch len in multi grid scenario. --- gluoncv/torch/data/__init__.py | 2 +- .../data/video_cls/dataset_classification.py | 23 +-- .../torch/data/video_cls/multigrid_helper.py | 142 +++++++++++------- gluoncv/torch/engine/config.py | 7 +- .../train_ddp_shortonly_pytorch.py | 5 + 5 files changed, 108 insertions(+), 71 deletions(-) diff --git a/gluoncv/torch/data/__init__.py b/gluoncv/torch/data/__init__.py index 33c8cad4ee..a4abe3d546 100644 --- a/gluoncv/torch/data/__init__.py +++ b/gluoncv/torch/data/__init__.py @@ -4,5 +4,5 @@ from .video_cls.dataset_classification import VideoClsDataset from .video_cls.dataset_classification import build_dataloader, build_dataloader_test -from .video_cls.multigrid_helper import multiGridSampler, MultiGridBatchSampler +from .video_cls.multigrid_helper import multiGridHelper, MultiGridBatchSampler from .coot.dataloader import create_datasets, create_loaders diff --git a/gluoncv/torch/data/video_cls/dataset_classification.py b/gluoncv/torch/data/video_cls/dataset_classification.py index b2f8473230..379dff1f95 100644 --- a/gluoncv/torch/data/video_cls/dataset_classification.py +++ b/gluoncv/torch/data/video_cls/dataset_classification.py @@ -8,7 +8,7 @@ from torch.utils.data import Dataset from ..transforms.videotransforms import video_transforms, volume_transforms -from .multigrid_helper import multiGridSampler, MultiGridBatchSampler +from .multigrid_helper import multiGridHelper, MultiGridBatchSampler __all__ = ['VideoClsDataset', 'build_dataloader', 'build_dataloader_test'] @@ -45,12 +45,12 @@ def __init__(self, anno_path, data_path, mode='train', clip_len=8, if (mode == 'train'): if self.use_multigrid: - self.MG_sampler = multiGridSampler() + self.mg_helper = multiGridHelper() self.data_transform = [] - for alpha in range(self.MG_sampler.mod_long): + for alpha in range(self.mg_helper.mod_long): tmp = [] - for beta in range(self.MG_sampler.mod_short): - info = self.MG_sampler.get_resize(alpha, beta) + for beta in range(self.mg_helper.mod_short): + info = self.mg_helper.get_resize(alpha, beta) scale_s = info[1] tmp.append(video_transforms.Compose([ video_transforms.Resize(int(self.short_side_size / scale_s), @@ -108,7 +108,7 @@ def __getitem__(self, index): if self.mode == 'train': if self.use_multigrid is True: index, alpha, beta = index - info = self.MG_sampler.get_resize(alpha, beta) + info = self.mg_helper.get_resize(alpha, beta) scale_t = info[0] data_transform_func = self.data_transform[alpha][beta] else: @@ -241,7 +241,8 @@ def build_dataloader(cfg): train_dataset = VideoClsDataset(anno_path=cfg.CONFIG.DATA.TRAIN_ANNO_PATH, data_path=cfg.CONFIG.DATA.TRAIN_DATA_PATH, mode='train', - use_multigrid=cfg.CONFIG.DATA.MULTIGRID, + use_multigrid=cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE \ + or cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE , clip_len=cfg.CONFIG.DATA.CLIP_LEN, frame_sample_rate=cfg.CONFIG.DATA.FRAME_RATE, num_segment=cfg.CONFIG.DATA.NUM_SEGMENT, @@ -254,7 +255,7 @@ def build_dataloader(cfg): val_dataset = VideoClsDataset(anno_path=cfg.CONFIG.DATA.VAL_ANNO_PATH, data_path=cfg.CONFIG.DATA.VAL_DATA_PATH, mode='validation', - use_multigrid=cfg.CONFIG.DATA.MULTIGRID, + use_multigrid=False, clip_len=cfg.CONFIG.DATA.CLIP_LEN, frame_sample_rate=cfg.CONFIG.DATA.FRAME_RATE, num_segment=cfg.CONFIG.DATA.NUM_SEGMENT, @@ -273,9 +274,11 @@ def build_dataloader(cfg): val_sampler = None mg_sampler = None - if cfg.CONFIG.DATA.MULTIGRID: + if cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE or cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE: mg_sampler = MultiGridBatchSampler(train_sampler, batch_size=cfg.CONFIG.TRAIN.BATCH_SIZE, - drop_last=True) + drop_last=True, + use_long=cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE, + use_short=cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=9, pin_memory=True, batch_sampler=mg_sampler) diff --git a/gluoncv/torch/data/video_cls/multigrid_helper.py b/gluoncv/torch/data/video_cls/multigrid_helper.py index af6275c38b..77c0134eeb 100644 --- a/gluoncv/torch/data/video_cls/multigrid_helper.py +++ b/gluoncv/torch/data/video_cls/multigrid_helper.py @@ -6,17 +6,18 @@ from torch._six import int_classes as _int_classes -__all__ = ['multiGridSampler', 'MultiGridBatchSampler'] +__all__ = ['multiGridHelper', 'MultiGridBatchSampler'] sq2 = np.sqrt(2) -class multiGridSampler(object): +class multiGridHelper(object): """ A Multigrid Method for Efficiently Training Video Models Chao-Yuan Wu, Ross Girshick, Kaiming He, Christoph Feichtenhofer, Philipp Krähenbühl CVPR 2020, https://arxiv.org/abs/1912.00998 """ def __init__(self): + # Scale: [T, H, W] self.long_cycle = np.asarray([[1, 1, 1], [2, 1, 1], [2, sq2, sq2]])[::-1] self.short_cycle = np.asarray([[1, 1, 1], [1, sq2, sq2], [1, 2, 2]])[::-1] self.short_cycle_sp = np.asarray([[1, 1, 1], [1, sq2, sq2], [1, sq2, sq2]])[::-1] @@ -54,7 +55,15 @@ class MultiGridBatchSampler(Sampler): Chao-Yuan Wu, Ross Girshick, Kaiming He, Christoph Feichtenhofer, Philipp Krähenbühl CVPR 2020, https://arxiv.org/abs/1912.00998 """ - def __init__(self, sampler, batch_size, drop_last): + def __init__(self, sampler, batch_size, drop_last, use_long=False, use_short=False): + ''' + :param sampler: torch.utils.data.Sample + :param batch_size: int + :param drop_last: bool + :param use_long: bool + :param use_short: bool + Apply batch collecting function based on multiGridHelper definition + ''' if not isinstance(sampler, Sampler): raise ValueError("sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}" @@ -66,85 +75,102 @@ def __init__(self, sampler, batch_size, drop_last): if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) + if not isinstance(use_long, bool): + raise ValueError("use_long should be a boolean value, but got " + "use_long={}".format(use_long)) + if not isinstance(use_short, bool): + raise ValueError("use_short should be a boolean value, but got " + "use_short={}".format(use_short)) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last - self.MG_sampler = multiGridSampler() - self.alpha = self.MG_sampler.mod_long - 1 + self.mg_helper = multiGridHelper() + # single grid setting + self.alpha = self.mg_helper.mod_long - 1 + self.beta = self.mg_helper.mod_short - 1 + self.short_cycle_label = False + if use_long: + self.activate_long_cycle() + if use_short: + self.activate_short_cycle() + self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta) + + def activate_short_cycle(self): + self.short_cycle_label = True self.beta = 0 - self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta) - self.label = True - def deactivate(self): - self.label = False - self.alpha = self.MG_sampler.mod_long - 1 - - def activate(self): - self.label = True + def activate_long_cycle(self): self.alpha = 0 + def deactivate(self): + self.alpha = self.mg_helper.mod_long - 1 + self.beta = self.mg_helper.mod_short - 1 + self.short_cycle_label = False + def __iter__(self): batch = [] - if self.label: + if self.short_cycle_label: self.beta = 0 else: - self.beta = self.MG_sampler.mod_short - 1 - self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta) + self.beta = self.mg_helper.mod_short - 1 + self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta) for idx in self.sampler: batch.append([idx, self.alpha, self.beta]) if len(batch) == self.batch_size*self.batch_scale: yield batch batch = [] - if self.label: - self.beta = (self.beta + 1)%self.MG_sampler.mod_short - self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta) + if self.short_cycle_label: + self.beta = (self.beta + 1) % self.mg_helper.mod_short + self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta) if len(batch) > 0 and not self.drop_last: yield batch - def step_alpha(self): - self.alpha = (self.alpha + 1)%self.MG_sampler.mod_long - - def compute_lr_milestone(self, lr_milestone): - """ - long cycle milestones - """ - self.len_long = self.MG_sampler.mod_long - self.n_epoch_long = 0 - for x in range(self.len_long): - self.n_epoch_long += self.MG_sampler.get_scale_alpha(x) - lr_long_cycle = [] - for i, _ in enumerate(lr_milestone): - if i == 0: - pre = 0 - else: - pre = lr_milestone[i-1] - cycle_length = (lr_milestone[i] - pre) // self.n_epoch_long - bonus = (lr_milestone[i] - pre)%self.n_epoch_long // self.len_long - for j in range(self.len_long)[::-1]: - pre = pre + cycle_length*(2**j) + bonus - if j == 0: - pre = lr_milestone[i] - lr_long_cycle.append(pre) - lr_long_cycle.append(0) - lr_long_cycle = sorted(lr_long_cycle) - return lr_long_cycle + def step_long_cycle(self): + self.alpha = (self.alpha + 1) % self.mg_helper.mod_long + + # def compute_lr_milestone(self, lr_milestone): + # """ + # long cycle milestones, deprecated. Define long cycle in config files + # """ + # self.len_long = self.mg_helper.mod_long + # self.n_epoch_long = 0 + # for x in range(self.len_long): + # self.n_epoch_long += self.mg_helper.get_scale_alpha(x) + # lr_long_cycle = [] + # for i, _ in enumerate(lr_milestone): + # if i == 0: + # pre = 0 + # else: + # pre = lr_milestone[i-1] + # cycle_length = (lr_milestone[i] - pre) // self.n_epoch_long + # bonus = (lr_milestone[i] - pre)%self.n_epoch_long // self.len_long + # for j in range(self.len_long)[::-1]: + # pre = pre + cycle_length*(2**j) + bonus + # if j == 0: + # pre = lr_milestone[i] + # lr_long_cycle.append(pre) + # lr_long_cycle.append(0) + # lr_long_cycle = sorted(lr_long_cycle) + # return lr_long_cycle def __len__(self): - self.len_short = self.MG_sampler.mod_short - self.n_epoch_short = 0 - for x in range(self.len_short): - self.n_epoch_short += self.MG_sampler.get_scale_beta(x) - short_batch_size = self.batch_size * self.MG_sampler.get_scale_alpha(self.alpha) - num_short = len(self.sampler) // short_batch_size - - total = num_short // self.n_epoch_short * self.len_short - remain = self.n_epoch_short - for x in range(self.len_short): - remain = remain - (2**x) + scale_per_short_cycle = 0 + for x in range(self.mg_helper.mod_short): + scale_per_short_cycle += self.mg_helper.get_scale(self.alpha, x) + num_full_short_cycle = len(self.sampler) // (self.batch_size * scale_per_short_cycle) + + total = num_full_short_cycle * self.mg_helper.mod_short + remain = len(self.sampler) % (self.batch_size * scale_per_short_cycle) + for x in range(self.mg_helper.mod_short): + remain = remain - self.mg_helper.get_scale(self.alpha, x)*self.batch_size if remain <= 0: + if remain == 0 or self.drop_last is False: + total += 1 break - total = total + int(num_short%self.n_epoch_short >= remain) + else: + total += 1 + assert remain <= 0 return total diff --git a/gluoncv/torch/engine/config.py b/gluoncv/torch/engine/config.py index ec838cc220..7daf997de0 100644 --- a/gluoncv/torch/engine/config.py +++ b/gluoncv/torch/engine/config.py @@ -68,6 +68,11 @@ # Resume training from a specific epoch. Set to -1 means train from beginning. _C.CONFIG.TRAIN.RESUME_EPOCH: -1 +# Whether to use multigrid training to speed up. +_C.CONFIG.TRAIN.MULTIGRID = CN(new_allowed=True) +_C.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE = False +_C.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE = False +_C.CONFIG.TRAIN.MULTIGRID.LONG_CYCLE_EPOCH = [10, 20, 30] _C.CONFIG.VAL = CN(new_allowed=True) # Evaluate model on test data every eval period epochs. @@ -91,8 +96,6 @@ _C.CONFIG.DATA.VAL_DATA_PATH = '' # The number of classes to predict for the model. _C.CONFIG.DATA.NUM_CLASSES = 400 -# Whether to use multigrid training to speed up. -_C.CONFIG.DATA.MULTIGRID = False # The number of frames of the input clip. _C.CONFIG.DATA.CLIP_LEN = 16 # The video sampling rate of the input clip. diff --git a/scripts/action-recognition/train_ddp_shortonly_pytorch.py b/scripts/action-recognition/train_ddp_shortonly_pytorch.py index c6774a280a..b49fb3342e 100644 --- a/scripts/action-recognition/train_ddp_shortonly_pytorch.py +++ b/scripts/action-recognition/train_ddp_shortonly_pytorch.py @@ -14,6 +14,7 @@ from gluoncv.torch.engine.config import get_cfg_defaults from gluoncv.torch.engine.launch import spawn_workers from gluoncv.torch.utils.utils import build_log_dir +from gluoncv.torch.utils.lr_policy import GradualWarmupScheduler def main_worker(cfg): @@ -67,6 +68,10 @@ def main_worker(cfg): else: scheduler.step() + if cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE: + if epoch in cfg.CONFIG.TRAIN.MULTIGRID.LONG_CYCLE_EPOCH: + mg_sampler.step_long_cycle() + if epoch % cfg.CONFIG.VAL.FREQ == 0 or epoch == cfg.CONFIG.TRAIN.EPOCH_NUM - 1: validation_classification(model, val_loader, epoch, criterion, cfg, writer) From ffd16f4e42fc1430ca9948de009fd51d0f78924f Mon Sep 17 00:00:00 2001 From: Chunhui Liu Date: Mon, 22 Feb 2021 16:27:44 -0800 Subject: [PATCH 2/3] tiny bug fix --- gluoncv/torch/data/video_cls/dataset_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluoncv/torch/data/video_cls/dataset_classification.py b/gluoncv/torch/data/video_cls/dataset_classification.py index 379dff1f95..97351b72c3 100644 --- a/gluoncv/torch/data/video_cls/dataset_classification.py +++ b/gluoncv/torch/data/video_cls/dataset_classification.py @@ -242,7 +242,7 @@ def build_dataloader(cfg): data_path=cfg.CONFIG.DATA.TRAIN_DATA_PATH, mode='train', use_multigrid=cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE \ - or cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE , + or cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE, clip_len=cfg.CONFIG.DATA.CLIP_LEN, frame_sample_rate=cfg.CONFIG.DATA.FRAME_RATE, num_segment=cfg.CONFIG.DATA.NUM_SEGMENT, From f787dc94a4862374cafc623f1edc86a165dd2490 Mon Sep 17 00:00:00 2001 From: Chunhui Liu Date: Mon, 22 Feb 2021 16:35:01 -0800 Subject: [PATCH 3/3] tiny bug fix --- gluoncv/torch/data/video_cls/multigrid_helper.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/gluoncv/torch/data/video_cls/multigrid_helper.py b/gluoncv/torch/data/video_cls/multigrid_helper.py index 77c0134eeb..05fc2533bf 100644 --- a/gluoncv/torch/data/video_cls/multigrid_helper.py +++ b/gluoncv/torch/data/video_cls/multigrid_helper.py @@ -165,12 +165,9 @@ def __len__(self): remain = len(self.sampler) % (self.batch_size * scale_per_short_cycle) for x in range(self.mg_helper.mod_short): remain = remain - self.mg_helper.get_scale(self.alpha, x)*self.batch_size + if remain >= 0 or (remain < 0 and self.drop_last is False): + total += 1 if remain <= 0: - if remain == 0 or self.drop_last is False: - total += 1 break - else: - total += 1 assert remain <= 0 - return total