From afbf792ce3907dbe854effdfefe9cd10562c5bbe Mon Sep 17 00:00:00 2001 From: Weisu Yin Date: Thu, 22 Jul 2021 10:19:59 -0700 Subject: [PATCH] Torch classification (#1683) * fit, init_network, init_trainer. resume_fit in progress * todo * dataset * fix config * train and validate * remove split * fix * train and val * save andd load * resume * cpu fix * save and load * ctx * fix * predict * waarning * predict feature * parallel * parallel and save load * disable ema * test * rmse metric and removed aug_splits * fix * fix epoch * lint fix * docstring * fix lint * fix * dependency * conflict * fix ci * fix save load * fix * custom net, disable ocustom optimizer, fix test OOM * fix lint * fix * fix * fix --- .github/workflows/build_docs.sh | 2 +- .github/workflows/gpu_test.sh | 3 +- gluoncv/auto/data/dataset.py | 48 ++ gluoncv/auto/estimators/__init__.py | 2 + gluoncv/auto/estimators/base_estimator.py | 31 + .../torch_image_classification/__init__.py | 2 + .../torch_image_classification/default.py | 113 +++ .../torch_image_classification.py | 748 ++++++++++++++++++ .../utils/__init__.py | 5 + .../utils/constants.py | 3 + .../utils/metrics.py | 5 + .../torch_image_classification/utils/model.py | 44 ++ .../utils/optimizer.py | 11 + .../utils/scheduler.py | 81 ++ .../torch_image_classification/utils/utils.py | 70 ++ gluoncv/auto/estimators/utils.py | 17 +- tests/auto/test_torch_auto_estimators.py | 135 ++++ tests/py3_auto.yml | 4 + 18 files changed, 1320 insertions(+), 4 deletions(-) create mode 100644 gluoncv/auto/estimators/torch_image_classification/__init__.py create mode 100644 gluoncv/auto/estimators/torch_image_classification/default.py create mode 100644 gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py create mode 100644 gluoncv/auto/estimators/torch_image_classification/utils/__init__.py create mode 100644 gluoncv/auto/estimators/torch_image_classification/utils/constants.py create mode 100644 gluoncv/auto/estimators/torch_image_classification/utils/metrics.py create mode 100644 gluoncv/auto/estimators/torch_image_classification/utils/model.py create mode 100644 gluoncv/auto/estimators/torch_image_classification/utils/optimizer.py create mode 100644 gluoncv/auto/estimators/torch_image_classification/utils/scheduler.py create mode 100644 gluoncv/auto/estimators/torch_image_classification/utils/utils.py create mode 100644 tests/auto/test_torch_auto_estimators.py diff --git a/.github/workflows/build_docs.sh b/.github/workflows/build_docs.sh index a76007805b..d1b7f98bc0 100644 --- a/.github/workflows/build_docs.sh +++ b/.github/workflows/build_docs.sh @@ -15,7 +15,7 @@ for f in $EFS/.mxnet/datasets/*; do fi done -python3 -m pip install sphinx==3.5.4 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark scipy mxtheme autogluon.core +python3 -m pip install sphinx==3.5.4 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark scipy mxtheme autogluon.core timm export MXNET_CUDNN_AUTOTUNE_DEFAULT=0 cd docs diff --git a/.github/workflows/gpu_test.sh b/.github/workflows/gpu_test.sh index f4ddaf9255..d1ddaa6181 100644 --- a/.github/workflows/gpu_test.sh +++ b/.github/workflows/gpu_test.sh @@ -15,8 +15,9 @@ export MPLBACKEND=Agg export KMP_DUPLICATE_LIB_OK=TRUE if [[ $TESTS_PATH == *"auto"* ]]; then - echo "Installing autogluon.core for auto module" + echo "Installing autogluon.core and timm for auto module" pip3 install autogluon.core==0.2.0 + pip3 install timm==0.4.12 fi nosetests --with-timer --timer-ok 5 --timer-warning 20 -x --with-coverage --cover-package $COVER_PACKAGE -v $TESTS_PATH diff --git a/gluoncv/auto/data/dataset.py b/gluoncv/auto/data/dataset.py index a168423d78..1f98f5b2a6 100644 --- a/gluoncv/auto/data/dataset.py +++ b/gluoncv/auto/data/dataset.py @@ -22,6 +22,12 @@ except ImportError: MXDataset = object mx = None +try: + import torch + TorchDataset = torch.utils.data.Dataset +except ImportError: + TorchDataset = object + torch = None logger = logging.getLogger() @@ -156,6 +162,12 @@ def to_mxnet(self): df = df.reset_index(drop=True) return _MXImageClassificationDataset(df) + def to_torch(self): + """Return a pytorch based iterator that returns ndarray and labels""" + df = self.rename(columns={self.IMG_COL: "image", self.LABEL_COL: "label"}, errors='ignore') + df = df.reset_index(drop=True) + return _TorchImageClassificationDataset(df) + @classmethod def from_csv(cls, csv_file, root=None, image_column='image', label_column='label', no_class=False): r"""Create from csv file. @@ -385,6 +397,42 @@ def __getitem__(self, idx): label = self._dataset['label'][idx] return img, label +class _TorchImageClassificationDataset(TorchDataset): + """Internal wrapper read entries in pd.DataFrame as images/labels. + + Parameters + ---------- + dataset : ImageClassificationDataset + DataFrame as ImageClassificationDataset. + + """ + def __init__(self, dataset): + if torch is None: + raise RuntimeError('Unable to import pytorch which is required.') + assert isinstance(dataset, ImageClassificationDataset) + assert 'image' in dataset.columns + self._has_label = 'label' in dataset.columns + self._dataset = dataset + self.classes = self._dataset.classes + self._imread = Image.open + self.transform = None + + def __len__(self): + return self._dataset.shape[0] + + def __getitem__(self, idx): + im_path = self._dataset['image'][idx] + img = self._imread(im_path).convert('RGB') + label = None + # # pylint: disable=not-callable + if self.transform is not None: + img = self.transform(img) + if self._has_label: + label = self._dataset['label'][idx] + else: + label = torch.tensor(-1, dtype=torch.long) + return img, label + class ObjectDetectionDataset(pd.DataFrame): """ObjectDetection dataset as DataFrame. diff --git a/gluoncv/auto/estimators/__init__.py b/gluoncv/auto/estimators/__init__.py index 995db3597e..f4ed81de3b 100644 --- a/gluoncv/auto/estimators/__init__.py +++ b/gluoncv/auto/estimators/__init__.py @@ -1,7 +1,9 @@ """Estimator implementations""" +# FIXME: for quick test purpose only from .image_classification import ImageClassificationEstimator from .ssd import SSDEstimator from .yolo import YOLOv3Estimator from .faster_rcnn import FasterRCNNEstimator # from .mask_rcnn import MaskRCNNEstimator from .center_net import CenterNetEstimator +from .torch_image_classification import TorchImageClassificationEstimator diff --git a/gluoncv/auto/estimators/base_estimator.py b/gluoncv/auto/estimators/base_estimator.py index d328bbc7d4..38668b8f7c 100644 --- a/gluoncv/auto/estimators/base_estimator.py +++ b/gluoncv/auto/estimators/base_estimator.py @@ -259,6 +259,22 @@ def _validate_gpus(self, gpu_ids): pass return valid_gpus + #FIXME: better design than a duplicate function? + def _torch_validate_gpus(self, gpu_ids): + """validate if requested gpus are actually available""" + valid_gpus = [] + try: + import torch + for gid in gpu_ids: + try: + _ = torch.zeros(1, device=f'cuda:{gid}') + valid_gpus.append(str(gid)) + except: + pass + except ImportError: + pass + return valid_gpus + def reset_ctx(self, ctx=None): """Reset model context. @@ -289,6 +305,21 @@ def reset_ctx(self, ctx=None): done = True except ImportError: pass + try: + import torch + if isinstance(self.net, (torch.nn.Module, torch.nn.DataParallel)): + for c in ctx_list: + assert isinstance(c, torch.device) + if hasattr(self.net, 'reset_ctx'): + self.net.reset_ctx(ctx_list) + else: + if isinstance(self.net, torch.nn.DataParallel): + self.net = torch.nn.DataParallel(self.net.module, device_ids=[ctx.index for ctx in ctx_list]) + self.net.to(self.ctx[0]) + self.ctx = ctx_list + done = True + except ImportError: + pass if not done: raise RuntimeError("Unable to reset_ctx, no `mxnet` and `pytorch`.") diff --git a/gluoncv/auto/estimators/torch_image_classification/__init__.py b/gluoncv/auto/estimators/torch_image_classification/__init__.py new file mode 100644 index 0000000000..c4e4efb57a --- /dev/null +++ b/gluoncv/auto/estimators/torch_image_classification/__init__.py @@ -0,0 +1,2 @@ +"""Torch image classification estimator""" +from .torch_image_classification import TorchImageClassificationEstimator diff --git a/gluoncv/auto/estimators/torch_image_classification/default.py b/gluoncv/auto/estimators/torch_image_classification/default.py new file mode 100644 index 0000000000..08577ff426 --- /dev/null +++ b/gluoncv/auto/estimators/torch_image_classification/default.py @@ -0,0 +1,113 @@ +"""Default configs for torch image classification""" +# pylint: disable=bad-whitespace,missing-class-docstring +from typing import Union, Tuple +from autocfg import dataclass, field + +@dataclass +class ModelCfg: + model: str = 'resnet101' + pretrained: bool = False + global_pool_type: Union[str, None] = None # Global pool type, one of (fast, avg, max, avgmax). Model default if None + +@dataclass +class DatasetCfg: + img_size: Union[int, None] = None # Image patch size (default: None => model default) + input_size: Union[Tuple[int, int, int], None] = None # Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty + crop_pct: Union[float, None] = None # Input image center crop percent (for validation only) + mean: Union[Tuple, None] = None # Override mean pixel value of dataset + std : Union[Tuple, None] = None # Override std deviation of of dataset + interpolation: str = '' # Image resize interpolation type (overrides model) + validation_batch_size_multiplier: int = 1 # ratio of validation batch size to training batch size (default: 1) + +@dataclass +class OptimizerCfg: + opt: str = 'sgd' + opt_eps: Union[float, None] = None # Optimizer Epsilon (default: None, use opt default) + opt_betas: Union[Tuple, None] = None # Optimizer Betas (default: None, use opt default) + momentum: float = 0.9 + weight_decay: float = 0.0001 + clip_grad: Union[float, None] = None # Clip gradient norm (default: None, no clipping) + clip_mode: str = 'norm' # Gradient clipping mode. One of ("norm", "value", "agc") + +@dataclass +class TrainCfg: + batch_size: int = 32 + sched: str = 'step' # LR scheduler + lr: float = 0.01 + lr_noise: Union[Tuple, None] = None # learning rate noise on/off epoch percentages + lr_noise_pct: float = 0.67 # learning rate noise limit percent + lr_noise_std: float = 1.0 # learning rate noise std-dev + lr_cycle_mul: float = 1.0 # learning rate cycle len multiplier + lr_cycle_limit: int = 1 # learning rate cycle limit + warmup_lr: float = 0.0001 + min_lr: float = 1e-5 + epochs: int = 200 + start_epoch: int = 0 # manual epoch number (useful on restarts) + decay_epochs: int = 30 # epoch interval to decay LR + warmup_epochs: int = 3 # epochs to warmup LR, if scheduler supports + cooldown_epochs: int = 10 # epochs to cooldown LR at min_lr, after cyclic schedule ends + patience_epochs: int = 10 # patience epochs for Plateau LR scheduler + decay_rate: float = 0.1 + bn_momentum: Union[float, None] = None # BatchNorm momentum override + bn_eps: Union[float, None] = None # BatchNorm epsilon override + sync_bn: bool = False # Enable NVIDIA Apex or Torch synchronized BatchNorm + early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled + early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics + # the baseline value for metric, training won't stop if not reaching baseline + early_stop_baseline : Union[float, int] = 0.0 + early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly + +@dataclass +class AugmentationCfg: + no_aug: bool = False # Disable all training augmentation, override other train aug args + scale: Tuple[float, float] = (0.08, 1.0) # Random resize scale + ratio: Tuple[float, float] = (3./4., 4./3.) # Random resize aspect ratio (default: 0.75 1.33 + hflip: float = 0.5 # Horizontal flip training aug probability + vflip: float = 0.0 # Vertical flip training aug probability + color_jitter: float = 0.4 + auto_augment: Union[str, None] = None # Use AutoAugment policy. "v0" or "original + mixup: float = 0.0 # mixup alpha, mixup enabled if > 0 + cutmix: float = 0.0 # cutmix alpha, cutmix enabled if > 0 + cutmix_minmax: Union[Tuple, None] = None # cutmix min/max ratio, overrides alpha and enables cutmix if set + mixup_prob: float = 1.0 # Probability of performing mixup or cutmix when either/both is enabled + mixup_switch_prob: float = 0.5 # Probability of switching to cutmix when both mixup and cutmix enabled + mixup_mode: str = 'batch' # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" + mixup_off_epoch: int = 0 # Turn off mixup after this epoch, disabled if 0 + smoothing: float = 0.1 # Label smoothin + train_interpolation: str = 'random' # Training interpolation (random, bilinear, bicubic) + drop: float = 0.0 # Dropout rate + drop_path: Union[float, None] = None # Drop path rate + drop_block: Union[float, None] = None # Drop block rate + +@dataclass +class ModelEMACfg: + model_ema: bool = True # Enable tracking moving average of model weights + model_ema_force_cpu: bool = False # Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation + model_ema_decay: float = 0.9998 # decay factor for model weights moving average + +@dataclass +class MiscCfg: + seed: int = 42 + log_interval: int = 50 # how many batches to wait before logging training status + num_workers: int = 4 # how many training processes to use + save_images: bool = False # save images of input bathes every log interval for debugging + amp: bool = False # use NVIDIA Apex AMP or Native AMP for mixed precision training + apex_amp: bool = False # Use NVIDIA Apex AMP mixed precision + native_amp: bool = False # Use Native Torch AMP mixed precision + pin_mem: bool = False # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU + prefetcher: bool = False # use fast prefetcher + eval_metric: str = 'top1' # 'Best metric (default: "top1") + tta: int = 0 # Test/inference time augmentation (oversampling) factor. 0=None + use_multi_epochs_loader: bool = False # use the multi-epochs-loader to save time at the beginning of every epoch + torchscript: bool = False # keep false, convert model torchscript for inference + +@dataclass +class TorchImageClassificationCfg: + model : ModelCfg = field(default_factory=ModelCfg) + dataset: DatasetCfg = field(default_factory=DatasetCfg) + optimizer: OptimizerCfg = field(default_factory=OptimizerCfg) + train: TrainCfg = field(default_factory=TrainCfg) + augmentation: AugmentationCfg = field(default_factory=AugmentationCfg) + model_ema: ModelEMACfg = field(default_factory=ModelEMACfg) + misc: MiscCfg = field(default_factory=MiscCfg) + gpus : Union[Tuple, list] = (0, ) # gpu individual ids, not necessarily consecutive diff --git a/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py b/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py new file mode 100644 index 0000000000..f9fe82b313 --- /dev/null +++ b/gluoncv/auto/estimators/torch_image_classification/torch_image_classification.py @@ -0,0 +1,748 @@ +"""Torch Classification Estimator""" +# pylint: disable=unused-variable,bad-whitespace,missing-function-docstring,logging-format-interpolation,arguments-differ,logging-not-lazy, not-callable +import math +import os +import logging +import time +import warnings +import pickle +from contextlib import suppress +from PIL import Image + +import pandas as pd +import numpy as np +import torch +import torch.nn as nn +import torchvision.utils +from torch.optim.optimizer import Optimizer + +from timm.data import create_loader, Mixup, FastCollateMixup, AugMixDataset +from timm.models import create_model, safe_model_name, convert_splitbn_model, model_parameters +from timm.utils import random_seed, dispatch_clip_grad, accuracy, unwrap_model, get_state_dict +from timm.optim import create_optimizer_v2 +from timm.utils import ApexScaler, NativeScaler, ModelEmaV2, AverageMeter +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy + +from .default import TorchImageClassificationCfg +from .utils import resolve_data_config, update_cfg, optimizer_kwargs, \ + create_scheduler, rmse +from ..utils import EarlyStopperOnPlateau +from ..conf import _BEST_CHECKPOINT_FILE +from ..base_estimator import BaseEstimator, set_default +from ....utils.filesystem import try_import +problem_type_constants = try_import(package='autogluon.core.constants', + fromlist=['MULTICLASS', 'BINARY', 'REGRESSION'], + message='Failed to import problem type constants from autogluon.core.') +MULTICLASS = problem_type_constants.MULTICLASS +BINARY = problem_type_constants.BINARY +REGRESSION = problem_type_constants.REGRESSION + +warnings.filterwarnings('ignore', message='.*Argument interpolation should be of type InterpolationMode instead of int.*') + +try: + from apex import amp + from apex.parallel import convert_syncbn_model + has_apex = True +except ImportError: + has_apex = False + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + + +@set_default(TorchImageClassificationCfg()) +class TorchImageClassificationEstimator(BaseEstimator): + """Torch Estimator implementation for Image Classification. + + Parameters + ---------- + config : dict + Config in nested dict. + logger : logging.Logger + Optional logger for this estimator, can be `None` when default setting is used. + reporter : callable + The reporter for metric checkpointing. + net : torch.nn.Module + The custom network. If defined, the model name in config will be ignored so your + custom network will be used for training rather than pulling it from model zoo. + """ + def __init__(self, config, logger=None, reporter=None, net=None, optimizer=None, problem_type=None): + super().__init__(config, logger=logger, reporter=reporter, name=None) + if problem_type is None: + problem_type = MULTICLASS + self._problem_type = problem_type + self._feature_net = None + self._custom_net = False + + self._model_cfg = self._cfg.model + self._dataset_cfg = self._cfg.dataset + self._optimizer_cfg = self._cfg.optimizer + self._train_cfg = self._cfg.train + self._augmentation_cfg = self._cfg.augmentation + self._model_ema_cfg = self._cfg.model_ema + self._misc_cfg = self._cfg.misc + + # resolve AMP arguments based on PyTorch / Apex availability + self.use_amp = None + if self._misc_cfg.amp: + # `amp` chooses native amp before apex (APEX ver not actively maintained) + if self._misc_cfg.native_amp and has_native_amp: + self.use_amp = 'native' + elif self._misc_cfg.apex_amp and has_apex: + self.use_amp = 'apex' + elif self._misc_cfg.apex_amp or self._misc_cfg.native_amp: + self._logger.warning(f'Neither APEX or native Torch AMP is available, using float32. \ + Install NVIDA apex or upgrade to PyTorch 1.6') + # FIXME: will provided model conflict with config provided? + if net is not None: + assert isinstance(net, nn.Module), f"given custom network {type(net)}, `torch.nn` expected" + try: + net.to('cpu') + self._custom_net = True + except ValueError: + pass + self.net = net + if optimizer is not None: + self._logger.warning('Custom optimizer object not supported. Will follow the config instead.') + self._optimizer = None + + def _fit(self, train_data, val_data, time_limit=math.inf): + tic = time.time() + self._cp_name = '' + self._best_acc = 0.0 + self.epochs = self._train_cfg.epochs + self.epoch = 0 + self.start_epoch = self._train_cfg.start_epoch + self._time_elapsed = 0 + if max(self.start_epoch, self.epoch) >= self.epochs: + return {'time', self._time_elapsed} + self._init_trainer() + self._init_model_ema() + self._time_elapsed += time.time() - tic + return self._resume_fit(train_data, val_data, time_limit=time_limit) + + def _resume_fit(self, train_data, val_data, time_limit=math.inf): + tic = time.time() + # TODO: regression not implemented + if self._problem_type != REGRESSION and (not self.classes or not self.num_class): + raise ValueError('This is a classification problem and we are not able to determine classes of dataset') + + if max(self.start_epoch, self.epoch) >= self.epochs: + return {'time': self._time_elapsed} + + # wrap DP if possible + if self.found_gpu: + self.net = torch.nn.DataParallel(self.net, device_ids=[int(i) for i in self.valid_gpus]) + self.net = self.net.to(self.ctx[0]) + + # prepare dataset + train_dataset = train_data.to_torch() + val_dataset = val_data.to_torch() + + # setup mixup / cutmix + self._collate_fn = None + self._mixup_fn = None + self.mixup_active = self._augmentation_cfg.mixup > 0 or self._augmentation_cfg.cutmix > 0. or self._augmentation_cfg.cutmix_minmax is not None + if self.mixup_active: + mixup_args = dict( + mixup_alpha=self._augmentation_cfg.mixup, cutmix_alpha=self._augmentation_cfg.cutmix, + cutmix_minmax=self._augmentation_cfg.cutmix_minmax, prob=self._augmentation_cfg.mixup_prob, + switch_prob=self._augmentation_cfg.mixup_switch_prob, mode=self._augmentation_cfg.mixup_mode, + label_smoothing=self._augmentation_cfg.smoothing, num_classes=self.num_class) + if self._misc_cfg.prefetcher: + self._collate_fn = FastCollateMixup(**mixup_args) + else: + self._mixup_fn = Mixup(**mixup_args) + + # create data loaders w/ augmentation pipeiine + train_interpolation = self._augmentation_cfg.train_interpolation + if self._augmentation_cfg.no_aug or not train_interpolation: + train_interpolation = self._dataset_cfg.interpolation + train_loader = create_loader( + train_dataset, + input_size=self._dataset_cfg.input_size, + batch_size=self._train_cfg.batch_size, + is_training=True, + use_prefetcher=self._misc_cfg.prefetcher, + no_aug=self._augmentation_cfg.no_aug, + scale=self._augmentation_cfg.scale, + ratio=self._augmentation_cfg.ratio, + hflip=self._augmentation_cfg.hflip, + vflip=self._augmentation_cfg.vflip, + color_jitter=self._augmentation_cfg.color_jitter, + auto_augment=self._augmentation_cfg.auto_augment, + interpolation=train_interpolation, + mean=self._dataset_cfg.mean, + std=self._dataset_cfg.std, + num_workers=self._misc_cfg.num_workers, + distributed=False, + collate_fn=self._collate_fn, + pin_memory=self._misc_cfg.pin_mem, + use_multi_epochs_loader=self._misc_cfg.use_multi_epochs_loader + ) + + val_loader = create_loader( + val_dataset, + input_size=self._dataset_cfg.input_size, + batch_size=self._dataset_cfg.validation_batch_size_multiplier * self._train_cfg.batch_size, + is_training=False, + use_prefetcher=self._misc_cfg.prefetcher, + interpolation=self._dataset_cfg.interpolation, + mean=self._dataset_cfg.mean, + std=self._dataset_cfg.std, + num_workers=self._misc_cfg.num_workers, + distributed=False, + crop_pct=self._dataset_cfg.crop_pct, + pin_memory=self._misc_cfg.pin_mem, + ) + + self._time_elapsed += time.time() - tic + return self._train_loop(train_loader, val_loader, time_limit=time_limit) + + def _train_loop(self, train_loader, val_loader, time_limit=math.inf): + start_tic = time.time() + # setup loss function + if self.mixup_active: + # smoothing is handled with mixup target transform + train_loss_fn = SoftTargetCrossEntropy() + elif self._augmentation_cfg.smoothing: + train_loss_fn = LabelSmoothingCrossEntropy(smoothing=self._augmentation_cfg.smoothing) + else: + train_loss_fn = nn.CrossEntropyLoss() + validate_loss_fn = nn.CrossEntropyLoss() + train_loss_fn = train_loss_fn.to(self.ctx[0]) + validate_loss_fn = validate_loss_fn.to(self.ctx[0]) + eval_metric = self._misc_cfg.eval_metric + early_stopper = EarlyStopperOnPlateau( + patience=self._train_cfg.early_stop_patience, + min_delta=self._train_cfg.early_stop_min_delta, + baseline_value=self._train_cfg.early_stop_baseline, + max_value=self._train_cfg.early_stop_max_value) + + self._logger.info('Start training from [Epoch %d]', max(self._train_cfg.start_epoch, self.epoch)) + + self._time_elapsed += time.time() - start_tic + for self.epoch in range(max(self.start_epoch, self.epoch), self.epochs): + epoch = self.epoch + if self._best_acc >= 1.0: + self._logger.info('[Epoch {}] Early stopping as acc is reaching 1.0'.format(epoch)) + break + should_stop, stop_message = early_stopper.get_early_stop_advice() + if should_stop: + self._logger.info('[Epoch {}] '.format(epoch) + stop_message) + break + train_metrics = self.train_one_epoch( + epoch, self.net, train_loader, self._optimizer, train_loss_fn, + lr_scheduler=self._lr_scheduler, output_dir=self._logdir, + amp_autocast=self._amp_autocast, loss_scaler=self._loss_scaler, model_ema=self._model_ema, mixup_fn=self._mixup_fn, time_limit=time_limit) + # reaching time limit, exit early + if train_metrics['time_limit']: + self._logger.warning(f'`time_limit={time_limit}` reached, exit early...') + return {'train_acc': train_metrics['train_acc'], 'valid_acc': self._best_acc, + 'time': self._time_elapsed, 'checkpoint': self._cp_name} + post_tic = time.time() + + eval_metrics = self.validate(self.net, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast) + + if self._model_ema is not None and not self._model_ema_cfg.model_ema_force_cpu: + ema_eval_metrics = self.validate( + self._model_ema.module, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast) + eval_metrics = ema_eval_metrics + + val_acc = eval_metrics['top1'] + if self._reporter: + self._reporter(epoch=epoch, acc_reward=val_acc) + early_stopper.update(val_acc) + + if val_acc > self._best_acc: + self._cp_name = os.path.join(self._logdir, _BEST_CHECKPOINT_FILE) + self._logger.info('[Epoch %d] Current best top-1: %f vs previous %f, saved to %s', + self.epoch, val_acc, self._best_acc, self._cp_name) + self.save(self._cp_name) + self._best_acc = val_acc + + if self._lr_scheduler is not None: + # step LR for next epoch + self._lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) + + self._time_elapsed += time.time() - post_tic + + if 'accuracy' in train_metrics: + return {'train_acc': train_metrics['accuracy'], 'valid_acc': self._best_acc, + 'time': self._time_elapsed, 'checkpoint': self._cp_name} + # rmse + else: + return {'train_score': train_metrics['rmse'], 'valid_acc': self._best_acc, + 'time': self._time_elapsed, 'checkpoint': self._cp_name} + + def train_one_epoch( + self, epoch, net, loader, optimizer, loss_fn, + lr_scheduler=None, output_dir=None, amp_autocast=suppress, + loss_scaler=None, model_ema=None, mixup_fn=None, time_limit=math.inf): + start_tic = time.time() + if self._augmentation_cfg.mixup_off_epoch and epoch >= self._augmentation_cfg.mixup_off_epoch: + if self._misc_cfg.prefetcher and loader.mixup_enabled: + loader.mixup_enabled = False + elif mixup_fn is not None: + mixup_fn.mixup_enabled = False + + second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order + losses_m = AverageMeter() + train_metric_score_m = AverageMeter() + + net.train() + + num_updates = epoch * len(loader) + self._time_elapsed += time.time() - start_tic + tic = time.time() + last_tic = time.time() + train_metric_name = 'accuracy' + batch_idx = 0 + for batch_idx, (input, target) in enumerate(loader): + b_tic = time.time() + if self._time_elapsed > time_limit: + return {'train_acc': train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': True} + if not self._misc_cfg.prefetcher: + # prefetcher would move data to cuda by default + input, target = input.to(self.ctx[0]), target.to(self.ctx[0]) + if mixup_fn is not None: + input, target = mixup_fn(input, target) + + with amp_autocast(): + output = net(input) + loss = loss_fn(output, target) + if output.shape == target.shape: + train_metric_name = 'rmse' + train_metric_score = rmse(output, target) + else: + train_metric_score = accuracy(output, target)[0] / 100 + + losses_m.update(loss.item(), input.size(0)) + train_metric_score_m.update(train_metric_score.item(), output.size(0)) + + optimizer.zero_grad() + if loss_scaler is not None: + loss_scaler( + loss, optimizer, + clip_grad=self._optimizer_cfg.clip_grad, clip_mode=self._optimizer_cfg.clip_mode, + parameters=model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode), + create_graph=second_order) + else: + loss.backward(create_graph=second_order) + if self._optimizer_cfg.clip_grad is not None: + dispatch_clip_grad( + model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode), + value=self._optimizer_cfg.clip_grad, mode=self._optimizer_cfg.clip_mode) + optimizer.step() + + if model_ema is not None: + model_ema.update(net) + + if self.found_gpu: + torch.cuda.synchronize() + + num_updates += 1 + if (batch_idx+1) % self._misc_cfg.log_interval == 0: + lrl = [param_group['lr'] for param_group in optimizer.param_groups] + lr = sum(lrl) / len(lrl) + self._logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f', + epoch, batch_idx, + self._train_cfg.batch_size*self._misc_cfg.log_interval/(time.time()-last_tic), + train_metric_name, train_metric_score_m.avg, lr) + last_tic = time.time() + + if self._misc_cfg.save_images and output_dir: + torchvision.utils.save_image( + input, + os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), + padding=0, + normalize=True) + + if lr_scheduler is not None: + lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) + + self._time_elapsed += time.time() - b_tic + + throughput = int(self._train_cfg.batch_size * batch_idx / (time.time() - tic)) + self._logger.info('[Epoch %d] training: %s=%f', epoch, train_metric_name, train_metric_score_m.avg) + self._logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f', epoch, throughput, time.time()-tic) + + end_time = time.time() + if hasattr(optimizer, 'sync_lookahead'): + optimizer.sync_lookahead() + + self._time_elapsed += time.time() - end_time + + return {train_metric_name: train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': False} + + def validate(self, net, loader, loss_fn, amp_autocast=suppress): + losses_m = AverageMeter() + top1_m = AverageMeter() + top5_m = AverageMeter() + + net.eval() + + with torch.no_grad(): + for batch_idx, (input, target) in enumerate(loader): + if not self._misc_cfg.prefetcher: + input = input.to(self.ctx[0]) + target = target.to(self.ctx[0]) + + with amp_autocast(): + output = net(input) + if isinstance(output, (tuple, list)): + output = output[0] + + # augmentation reduction + reduce_factor = self._misc_cfg.tta + if reduce_factor > 1: + output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) + target = target[0:target.size(0):reduce_factor] + + loss = loss_fn(output, target) + acc1, acc5 = accuracy(output, target, topk=(1, min(5, self.num_class))) + acc1 /= 100 + acc5 /= 100 + + reduced_loss = loss.data + + if self.found_gpu: + torch.cuda.synchronize() + + losses_m.update(reduced_loss.item(), input.size(0)) + top1_m.update(acc1.item(), output.size(0)) + top5_m.update(acc5.item(), output.size(0)) + + self._logger.info('[Epoch %d] validation: top1=%f top5=%f', self.epoch, top1_m.avg, top5_m.avg) + # TODO: update early stoper + + return {'loss': losses_m.avg, 'top1': top1_m.avg, 'top5': top5_m.avg} + + def _init_network(self, **kwargs): + if self._problem_type == REGRESSION: + raise NotImplementedError + assert len(self.classes) == self.num_class + + # Disable syncBatchNorm as it's only supported on DDP + if self._train_cfg.sync_bn: + self._logger.info( + 'Disable Sync batch norm as it is not supported for now.') + update_cfg(self._cfg, {'train': {'sync_bn': False}}) + + # ctx + self.found_gpu = False + valid_gpus = [] + if self._cfg.gpus: + valid_gpus = self._torch_validate_gpus(self._cfg.gpus) + self.found_gpu = True + if not valid_gpus: + self.found_gpu = False + self._logger.warning( + 'No gpu detected, fallback to cpu. You can ignore this warning if this is intended.') + elif len(valid_gpus) != len(self._cfg.gpus): + self._logger.warning( + f'Loaded on gpu({valid_gpus}), different from gpu({self._cfg.gpus}).') + self.ctx = [torch.device(f'cuda:{gid}') for gid in valid_gpus] if self.found_gpu else [torch.device('cpu')] + self.valid_gpus = valid_gpus + + if not self.found_gpu and self.use_amp: + self.use_amp = None + self._logger.warning('Training on cpu. AMP disabled.') + update_cfg(self._cfg, {'misc': {'amp': False, 'apex_amp': False, 'native_amp': False}}) + + if not self.found_gpu and self._misc_cfg.prefetcher: + self._logger.warning( + 'Training on cpu. Prefetcher disabled.') + update_cfg(self._cfg, {'misc': {'prefetcher': False}}) + self._logger.warning( + 'Training on cpu. SyncBatchNorm disabled.') + update_cfg(self._cfg, {'train': {'sync_bn': False}}) + + if not self.net: + self.net = create_model( + self._model_cfg.model, + pretrained=self._model_cfg.pretrained, + num_classes=self.num_class, + global_pool=self._model_cfg.global_pool_type, + drop_rate=self._augmentation_cfg.drop, + drop_path_rate=self._augmentation_cfg.drop_path, + drop_block_rate=self._augmentation_cfg.drop_block, + bn_momentum=self._train_cfg.bn_momentum, + bn_eps=self._train_cfg.bn_eps, + scriptable=self._misc_cfg.torchscript + ) + + self._logger.info(f'Model {safe_model_name(self._model_cfg.model)} created, param count: \ + {sum([m.numel() for m in self.net.parameters()])}') + else: + self._logger.info(f'Use user provided model. Neglect model in config.') + + resolve_data_config(self._cfg, model=self.net) + + self.net = self.net.to(self.ctx[0]) + + # setup synchronized BatchNorm + if self._train_cfg.sync_bn: + if has_apex and self.use_amp != 'native': + # Apex SyncBN preferred unless native amp is activated + self.net = convert_syncbn_model(self.net) + else: + self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) + self._logger.info( + 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' + 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') + + if self._misc_cfg.torchscript: + assert not self.use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' + assert not self._train_cfg.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' + self.net = torch.jit.script(self.net) + + def _init_trainer(self): + if self._optimizer is None: + self._optimizer = create_optimizer_v2(self.net, **optimizer_kwargs(cfg=self._cfg)) + self._init_loss_scaler() + self._lr_scheduler, self.epochs = create_scheduler(self._cfg, self._optimizer) + self._lr_scheduler.step(self.start_epoch, self.epoch) + + def _init_loss_scaler(self): + # setup automatic mixed-precision (AMP) loss scaling and op casting + self._amp_autocast = suppress # do nothing + self._loss_scaler = None + if self.use_amp == 'apex': + self.net, self._optimizer = amp.initialize(self.net, self._optimizer, opt_level='O1') + self._loss_scaler = ApexScaler() + self._logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') + elif self.use_amp == 'native': + self._amp_autocast = torch.cuda.amp.autocast + self._loss_scaler = NativeScaler() + self._logger.info('Using native Torch AMP. Training in mixed precision.') + else: + self._logger.info('AMP not enabled. Training in float32.') + + def _init_model_ema(self): + # Disable for now + if self._model_ema_cfg.model_ema: + self._logger.info('Disable EMA as it is not supported for now.') + update_cfg(self._cfg, {'model_ema': {'model_ema': False}}) + # setup exponential moving average of model weights, SWA could be used here too + self._model_ema = None + if self._model_ema_cfg.model_ema: + # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper + self._model_ema = ModelEmaV2( + self.net, decay=self._model_ema_cfg.model_ema_decay, device='cpu' if self._model_ema_cfg.model_ema_force_cpu else None) + + + def evaluate(self, val_data): + return self._evaluate(val_data) + + def _evaluate(self, val_data): + validate_loss_fn = nn.CrossEntropyLoss() + validate_loss_fn = validate_loss_fn.to(self.ctx[0]) + return self.validate(self.net, val_data, validate_loss_fn, amp_autocast=self._amp_autocast) + + def _predict(self, x, **kwargs): + if isinstance(x, pd.DataFrame): + assert 'image' in x.columns, "Expect column `image` for input images" + df = self._predict(tuple(x['image'])) + return df.reset_index(drop=True) + elif isinstance(x, (list, tuple)): + loader = create_loader( + ImageListDataset(x), + input_size=self._dataset_cfg.input_size, + batch_size=self._train_cfg.batch_size, + use_prefetcher=self._misc_cfg.prefetcher, + interpolation=self._dataset_cfg.interpolation, + mean=self._dataset_cfg.mean, + std=self._dataset_cfg.std, + num_workers=self._misc_cfg.num_workers, + crop_pct=self._dataset_cfg.crop_pct + ) + + self.net.eval() + + topk = min(5, self.num_class) + results = [] + idx = 0 + with torch.no_grad(): + for input, _ in loader: + input = input.to(self.ctx[0]) + labels = self.net(input) + for l in labels: + probs = nn.functional.softmax(l, dim=0).cpu().numpy().flatten() + topk_inds = l.topk(topk)[1].cpu().numpy().flatten() + results.extend([{'class': self.classes[topk_inds[k]], + 'score': probs[topk_inds[k]], + 'id': topk_inds[k], + 'image': x[idx]} + for k in range(topk)]) + idx += 1 + return pd.DataFrame(results) + elif not isinstance(x, torch.tensor): + raise ValueError('Input is not supported: {}'.format(type(x))) + with torch.no_grad(): + input = x.to(self.ctx[0]) + label = self.net(input) + topk = min(5, self.num_class) + probs = nn.functional.softmax(label, dim=0).cpu().numpy().flatten() + topk_inds = label.topk(topk)[1].cpu().numpy().flatten() + df = pd.DataFrame([{'class': self.classes[topk_inds[k]], + 'score': probs[topk_inds[k]], + 'id': topk_inds[k]} + for k in range(topk)]) + return df + + + def _predict_feature(self, x, **kwargs): + if isinstance(x, pd.DataFrame): + assert 'image' in x.columns, "Expect column `image` for input images" + df = self._predict_feature(tuple(x['image'])) + df = df.set_index(x.index) + df['image'] = x['image'] + return df + elif isinstance(x, (list, tuple)): + assert isinstance(x[0], str), "expect image paths in list/tuple input" + loader = create_loader( + ImageListDataset(x), + input_size=self._dataset_cfg.input_size, + batch_size=self._train_cfg.batch_size, + use_prefetcher=self._misc_cfg.prefetcher, + interpolation=self._dataset_cfg.interpolation, + mean=self._dataset_cfg.mean, + std=self._dataset_cfg.std, + num_workers=self._misc_cfg.num_workers, + crop_pct=self._dataset_cfg.crop_pct + ) + + self.net.eval() + + results = [] + with torch.no_grad(): + for input, _ in loader: + input = input.to(self.ctx[0]) + try: + features = self.net.forward_features(input) + except AttributeError: + features = self.net.module.forward_features(input) + for f in features: + f = f.cpu().numpy().flatten() + results.append({'image_feature': f}) + df = pd.DataFrame(results) + df['image'] = x + return df + elif not isinstance(x, torch.tensor): + raise ValueError('Input is not supported: {}'.format(type(x))) + with torch.no_grad(): + input = x.to(self.ctx[0]) + feature = self.net.forward_features(input) + result = [{'image_feature': feature}] + df = pd.DataFrame(result) + return df + + def _reconstruct_state_dict(self, state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + return new_state_dict + + # pylint: disable=redefined-outer-name, reimported + def __getstate__(self): + d = self.__dict__.copy() + try: + import torch + net = d.pop('net', None) + model_ema = d.pop('_model_ema', None) + optimizer = d.pop('_optimizer', None) + loss_scaler = d.pop('_loss_scaler', None) + save_state = {} + if net is not None: + if not self._custom_net: + if isinstance(net, torch.nn.DataParallel): + save_state['state_dict'] = get_state_dict(net.module, unwrap_model) + else: + save_state['state_dict'] = get_state_dict(net, unwrap_model) + else: + net_pickle = pickle.dumps(net) + save_state['net_pickle'] = net_pickle + if optimizer is not None: + save_state['optimizer'] = optimizer.state_dict() + if loss_scaler is not None: + save_state[loss_scaler.state_dict_key] = loss_scaler.state_dict() + if model_ema is not None: + save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model) + except ImportError: + pass + d['save_state'] = save_state + d['_logger'] = None + d['_reporter'] = None + return d + + def __setstate__(self, state): + save_state = state.pop('save_state', None) + self.__dict__.update(state) + # logger + self._logger = logging.getLogger(state.get('_name', self.__class__.__name__)) + self._logger.setLevel(logging.ERROR) + try: + fh = logging.FileHandler(self._log_file) + self._logger.addHandler(fh) + #pylint: disable=bare-except + except: + pass + if not save_state: + self.net = None + self._optimizer = None + self._logger.setLevel(logging.INFO) + return + try: + import torch + self.net = None + self._optimizer = None + if self._custom_net: + if save_state.get('net_pickle', None): + self.net = pickle.loads(save_state['net_pickle']) + else: + if save_state.get('state_dict', None): + self._init_network() + net_state_dict = self._reconstruct_state_dict(save_state['state_dict']) + if isinstance(self.net, torch.nn.DataParallel): + self.net.module.load_state_dict(net_state_dict) + else: + self.net.load_state_dict(net_state_dict) + if save_state.get('optimizer', None): + self._init_trainer() + self._optimizer.load_state_dict(save_state['optimizer']) + if hasattr(self, '_loss_scaler') and self._loss_scaler and self._loss_scaler.state_dict_key in save_state: + loss_scaler_dict = save_state[self._loss_scaler.state_dict_key] + self._loss_scaler.load_state_dict(loss_scaler_dict) + if save_state.get('state_dict_ema', None): + self._init_model_ema() + model_ema_dict = save_state.get('state_dict_ema') + model_ema_dict = self._reconstruct_state_dict(model_ema_dict) + if isinstance(self.net, torch.nn.DataParallel): + self._model_ema.module.module.load_state_dict(model_ema_dict) + else: + self._model_ema.module.load_state_dict(model_ema_dict) + except ImportError: + pass + self._logger.setLevel(logging.INFO) + +class ImageListDataset(torch.utils.data.Dataset): + """An internal image list dataset for batch predict""" + def __init__(self, imlist): + self._imlist = imlist + self.transform = None + + def __getitem__(self, idx): + img = Image.open(self._imlist[idx]).convert('RGB') + label = None + if self.transform is not None: + img = self.transform(img) + return img, torch.tensor(-1, dtype=torch.long) + + def __len__(self): + return len(self._imlist) diff --git a/gluoncv/auto/estimators/torch_image_classification/utils/__init__.py b/gluoncv/auto/estimators/torch_image_classification/utils/__init__.py new file mode 100644 index 0000000000..3a2b7a3d31 --- /dev/null +++ b/gluoncv/auto/estimators/torch_image_classification/utils/__init__.py @@ -0,0 +1,5 @@ +from .model import resume_checkpoint +from .utils import resolve_data_config, update_cfg +from .optimizer import optimizer_kwargs +from .scheduler import create_scheduler +from .metrics import rmse diff --git a/gluoncv/auto/estimators/torch_image_classification/utils/constants.py b/gluoncv/auto/estimators/torch_image_classification/utils/constants.py new file mode 100644 index 0000000000..d3f3c4186e --- /dev/null +++ b/gluoncv/auto/estimators/torch_image_classification/utils/constants.py @@ -0,0 +1,3 @@ +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +DEFAULT_CROP_PCT = 0.875 diff --git a/gluoncv/auto/estimators/torch_image_classification/utils/metrics.py b/gluoncv/auto/estimators/torch_image_classification/utils/metrics.py new file mode 100644 index 0000000000..43089a5463 --- /dev/null +++ b/gluoncv/auto/estimators/torch_image_classification/utils/metrics.py @@ -0,0 +1,5 @@ +import torch +from torch.nn.functional import softmax + +def rmse(outputs, target): + return torch.sqrt(torch.mean((softmax(outputs, dim=0)-target)**2)) diff --git a/gluoncv/auto/estimators/torch_image_classification/utils/model.py b/gluoncv/auto/estimators/torch_image_classification/utils/model.py new file mode 100644 index 0000000000..210a720edf --- /dev/null +++ b/gluoncv/auto/estimators/torch_image_classification/utils/model.py @@ -0,0 +1,44 @@ +from collections import OrderedDict +import os + +import torch + +# FIXME: shouldn't need this anymore +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, logger=None, log_info=False): + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + logger.info('Restoring model state from checkpoint...') + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + + if log_info: + logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + model.load_state_dict(checkpoint) + if log_info: + logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch + else: + logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() diff --git a/gluoncv/auto/estimators/torch_image_classification/utils/optimizer.py b/gluoncv/auto/estimators/torch_image_classification/utils/optimizer.py new file mode 100644 index 0000000000..029fb409dc --- /dev/null +++ b/gluoncv/auto/estimators/torch_image_classification/utils/optimizer.py @@ -0,0 +1,11 @@ +def optimizer_kwargs(cfg): + kwargs = dict( + optimizer_name=cfg.optimizer.opt, + learning_rate=cfg.train.lr, + weight_decay=cfg.optimizer.weight_decay, + momentum=cfg.optimizer.momentum) + if cfg.optimizer.opt_eps is not None: + kwargs['eps'] = cfg.optimizer.opt_eps + if cfg.optimizer.opt_betas is not None: + kwargs['betas'] = cfg.optimizer.opt_betas + return kwargs diff --git a/gluoncv/auto/estimators/torch_image_classification/utils/scheduler.py b/gluoncv/auto/estimators/torch_image_classification/utils/scheduler.py new file mode 100644 index 0000000000..17786a4ed6 --- /dev/null +++ b/gluoncv/auto/estimators/torch_image_classification/utils/scheduler.py @@ -0,0 +1,81 @@ +from timm.scheduler import CosineLRScheduler, PlateauLRScheduler,\ + StepLRScheduler, TanhLRScheduler + +def create_scheduler(cfg, optimizer): + num_epochs = cfg.train.epochs + + if cfg.train.lr_noise is not None: + lr_noise = cfg.train.lr_noise + if isinstance(lr_noise, (list, tuple)): + noise_range = [n * num_epochs for n in lr_noise] + if len(noise_range) == 1: + noise_range = noise_range[0] + else: + noise_range = lr_noise * num_epochs + else: + noise_range = None + + lr_scheduler = None + if cfg.train.sched == 'cosine': + lr_scheduler = CosineLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=getattr(cfg.train, 'lr_cycle_mul', 1.), + lr_min=cfg.train.min_lr, + decay_rate=cfg.train.decay_rate, + warmup_lr_init=cfg.train.warmup_lr, + warmup_t=cfg.train.warmup_epochs, + cycle_limit=getattr(cfg.train, 'lr_cycle_limit', 1), + t_in_epochs=True, + noise_range_t=noise_range, + noise_pct=getattr(cfg.train, 'lr_noise_pct', 0.67), + noise_std=getattr(cfg.train, 'lr_noise_std', 1.), + noise_seed=getattr(cfg.misc, 'seed', 42), + ) + num_epochs = lr_scheduler.get_cycle_length() + cfg.train.cooldown_epochs + elif cfg.train.sched == 'tanh': + lr_scheduler = TanhLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=getattr(cfg.train, 'lr_cycle_mul', 1.), + lr_min=cfg.train.min_lr, + warmup_lr_init=cfg.train.warmup_lr, + warmup_t=cfg.train.warmup_epochs, + cycle_limit=getattr(cfg.train, 'lr_cycle_limit', 1), + t_in_epochs=True, + noise_range_t=noise_range, + noise_pct=getattr(cfg.train, 'lr_noise_pct', 0.67), + noise_std=getattr(cfg.train, 'lr_noise_std', 1.), + noise_seed=getattr(cfg.misc, 'seed', 42), + ) + num_epochs = lr_scheduler.get_cycle_length() + cfg.train.cooldown_epochs + elif cfg.train.sched == 'step': + lr_scheduler = StepLRScheduler( + optimizer, + decay_t=cfg.train.decay_epochs, + decay_rate=cfg.train.decay_rate, + warmup_lr_init=cfg.train.warmup_lr, + warmup_t=cfg.train.warmup_epochs, + noise_range_t=noise_range, + noise_pct=getattr(cfg.train, 'lr_noise_pct', 0.67), + noise_std=getattr(cfg.train, 'lr_noise_std', 1.), + noise_seed=getattr(cfg.misc, 'seed', 42), + ) + elif cfg.train.sched == 'plateau': + mode = 'min' if 'loss' in getattr(cfg.misc, 'eval_metric', '') else 'max' + lr_scheduler = PlateauLRScheduler( + optimizer, + decay_rate=cfg.train.decay_rate, + patience_t=cfg.train.patience_epochs, + lr_min=cfg.train.min_lr, + mode=mode, + warmup_lr_init=cfg.train.warmup_lr, + warmup_t=cfg.train.warmup_epochs, + cooldown_t=0, + noise_range_t=noise_range, + noise_pct=getattr(cfg.train, 'lr_noise_pct', 0.67), + noise_std=getattr(cfg.train, 'lr_noise_std', 1.), + noise_seed=getattr(cfg.misc, 'seed', 42), + ) + + return lr_scheduler, num_epochs diff --git a/gluoncv/auto/estimators/torch_image_classification/utils/utils.py b/gluoncv/auto/estimators/torch_image_classification/utils/utils.py new file mode 100644 index 0000000000..21a6e8e37e --- /dev/null +++ b/gluoncv/auto/estimators/torch_image_classification/utils/utils.py @@ -0,0 +1,70 @@ +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT + +def update_cfg(cfg, udict): + cfg.unfreeze() + cfg.update(udict) + cfg.freeze() + +# pylint: disable=dangerous-default-value +def resolve_data_config(cfg, default_cfg={}, model=None, use_test_size=False): + default_cfg = default_cfg + if not default_cfg and model is not None and hasattr(model, 'default_cfg'): + default_cfg = model.default_cfg + + # Resolve input/image size + in_chans = 3 + input_size = (in_chans, 224, 224) + if cfg.dataset.input_size is not None: + assert isinstance(cfg.dataset.input_size, (tuple, list)) + assert len(cfg.dataset.input_size) == 3 + input_size = tuple(cfg.dataset.input_size) + in_chans = input_size[0] # input_size overrides in_chans + elif cfg.dataset.img_size is not None: + assert isinstance(cfg.dataset.img_size, int) + input_size = (in_chans, cfg.dataset.img_size, cfg.dataset.img_size) + else: + if use_test_size and 'test_input_size' in default_cfg: + input_size = default_cfg['test_input_size'] + elif 'input_size' in default_cfg: + input_size = default_cfg['input_size'] + update_cfg(cfg, {'dataset': {'input_size': input_size}}) + + # resolve interpolation method + interpolation = 'bicubic' + if cfg.dataset.interpolation is not None: + interpolation = cfg.dataset.interpolation + elif 'interpolation' in default_cfg: + interpolation = default_cfg['interpolation'] + update_cfg(cfg, {'dataset': {'interpolation': interpolation}}) + + # resolve dataset + model mean for normalization + mean = IMAGENET_DEFAULT_MEAN + if cfg.dataset.mean is not None: + mean = tuple(cfg.dataset.mean) + if len(mean) == 1: + mean = tuple(list(mean) * in_chans) + else: + assert len(mean) == in_chans + elif 'mean' in default_cfg: + mean = default_cfg['mean'] + update_cfg(cfg, {'dataset': {'mean': mean}}) + + # resolve dataset + model std deviation for normalization + std = IMAGENET_DEFAULT_STD + if cfg.dataset.std is not None: + std = tuple(cfg.dataset.std) + if len(std) == 1: + std = tuple(list(std) * in_chans) + else: + assert len(std) == in_chans + elif 'std' in default_cfg: + std = default_cfg['std'] + update_cfg(cfg, {'dataset': {'std': std}}) + + # resolve default crop percentage + crop_pct = DEFAULT_CROP_PCT + if cfg.dataset.crop_pct is not None: + crop_pct = cfg.dataset.crop_pct + elif 'crop_pct' in default_cfg: + crop_pct = default_cfg['crop_pct'] + update_cfg(cfg, {'dataset': {'crop_pct': crop_pct}}) diff --git a/gluoncv/auto/estimators/utils.py b/gluoncv/auto/estimators/utils.py index 13634af21b..108c04c651 100644 --- a/gluoncv/auto/estimators/utils.py +++ b/gluoncv/auto/estimators/utils.py @@ -124,6 +124,19 @@ def _suggest_load_context(model, mode, orig_ctx): if not all(isinstance(i, int) for i in mode): raise ValueError('Requires integer gpu id, given {}'.format(mode)) return [mx.gpu(i) for i in mode if i in range(mx.context.num_gpus())] - if torch is not None and isinstance(model, torch.Module): - pass + if torch is not None and isinstance(model, (torch.nn.Module, torch.nn.DataParallel)): + if mode == 'auto': + if orig_ctx[0] == torch.device('cpu'): + mode = 'cpu' + else: + mode = 'gpu' + if mode == 'cpu': + return [torch.device('cpu')] + if mode == 'gpu': + return [torch.device(f'cuda:{gid}') for gid in range(torch.cuda.device_count())] + if isinstance(mode, (list, tuple)): + if not all(isinstance(i, int) for i in mode): + raise ValueError('Requires integer gpu id, given {}'.format(mode)) + return [torch.device(f'cuda:{gid}') for gid in mode if gid in range(torch.cuda.device_count())] + return None diff --git a/tests/auto/test_torch_auto_estimators.py b/tests/auto/test_torch_auto_estimators.py new file mode 100644 index 0000000000..0967d10823 --- /dev/null +++ b/tests/auto/test_torch_auto_estimators.py @@ -0,0 +1,135 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test auto estimators""" +from nose.tools import nottest +from gluoncv.auto.estimators import TorchImageClassificationEstimator +from gluoncv.auto.data.dataset import ImageClassificationDataset +import autogluon.core as ag +from autogluon.core.scheduler.resource import get_cpu_count, get_gpu_count +import mxnet as mx + +IMAGE_CLASS_DATASET, _, IMAGE_CLASS_TEST = ImageClassificationDataset.from_folders( + 'https://autogluon.s3.amazonaws.com/datasets/shopee-iet.zip') + +def test_image_classification_estimator(): + est = TorchImageClassificationEstimator({'model': {'model': 'resnet18'}, 'train': {'epochs': 1}, 'gpus': list(range(get_gpu_count()))}) + res = est.fit(IMAGE_CLASS_DATASET) + est.predict(IMAGE_CLASS_TEST) + est.predict_feature(IMAGE_CLASS_TEST) + _save_load_test(est, 'test.pkl') + +def test_image_classification_estimator_cpu(): + est = TorchImageClassificationEstimator({'model': {'model': 'resnet18'}, 'train': {'epochs': 1}, 'gpus': ()}) + res = est.fit(IMAGE_CLASS_DATASET) + est.predict(IMAGE_CLASS_TEST) + est.predict_feature(IMAGE_CLASS_TEST) + _save_load_test(est, 'test.pkl') + +@nottest +def test_config_combination(): + for _ in range(100): + test_config = build_config().rand + est = TorchImageClassificationEstimator(test_config) + res = est.fit(IMAGE_CLASS_DATASET) + est.predict(IMAGE_CLASS_TEST) + est.predict_feature(IMAGE_CLASS_TEST) + _save_load_test(est, 'test.pkl') + +def _save_load_test(est, filename): + est._cfg.unfreeze() + est._cfg.gpus = list(range(16)) # invalid cfg, check if load can restore succesfully + est.save(filename) + est2 = est.__class__.load(filename) + return est2 + +@ag.func( + pretrained=ag.space.Categorical(True, False), + global_pool_type=ag.space.Categorical('fast', 'avg', 'max', 'avgmax'), + sync_bn=ag.space.Categorical(True, False), + no_aug=ag.space.Categorical(True, False), + mixup=ag.space.Categorical(0.0, 0.5), + cutmix=ag.space.Categorical(0.0, 0.5), + model_ema=ag.space.Categorical(True, False), + model_ema_force_cpu=ag.space.Categorical(True, False), + save_images=ag.space.Categorical(True, False), + pin_mem=ag.space.Categorical(True, False), + use_multi_epochs_loader=ag.space.Categorical(True, False), + amp=ag.space.Categorical(True, False), + apex_amp=ag.space.Categorical(True, False), + native_amp=ag.space.Categorical(True, False), + prefetcher=ag.space.Categorical(True, False), + interpolation=ag.space.Categorical('random', 'bilinear', 'bicubic'), + batch_size=ag.space.Categorical(1,2,4,8,16,32), + hflip=ag.space.Categorical(0.0, 0.5, 1.0), + vflip=ag.space.Categorical(0.0, 0.5, 1.0), + train_interpolation=ag.space.Categorical('random', 'bilinear', 'bicubic'), + num_workers=ag.space.Categorical(1,2,4,8), + tta=ag.space.Categorical(0,1) +) +def build_config(pretrained, global_pool_type, sync_bn, no_aug, mixup, cutmix, model_ema, + model_ema_force_cpu, save_images, pin_mem, use_multi_epochs_loader, + amp, apex_amp, native_amp, prefetcher, interpolation, batch_size, hflip, vflip, + train_interpolation, num_workers, tta): + config = { + 'model': { + 'model': 'resnet50', + 'pretrained': pretrained, + 'global_pool_type': global_pool_type, + }, + 'dataset': { + 'interpolation': interpolation + }, + 'train': { + 'batch_size': batch_size, + 'sync_bn': sync_bn, + }, + 'augmentation': { + 'no_aug': no_aug, + 'mixup': mixup, + 'cutmix': cutmix, + 'hflip': hflip, + 'vflip': vflip, + 'train_interpolation': train_interpolation, + }, + 'model_ema': { + 'model_ema': model_ema, + 'model_ema_force_cpu': model_ema_force_cpu, + }, + 'misc': { + 'num_workers': num_workers, + 'save_images': save_images, + 'pin_mem': pin_mem, + 'tta': tta, + 'use_multi_epochs_loader': use_multi_epochs_loader, + 'amp': amp, + 'apex_amp': apex_amp, + 'native_amp': native_amp, + 'prefetcher': prefetcher, + } + } + if config['augmentation']['mixup'] or config['augmentation']['cutmix']: + config['train']['batch_size'] = 2 + config['train']['epochs'] = 1 + config['gpus'] = list(range(get_gpu_count())) + config['misc']['apex_amp'] = False # apex amp cause mem leak: https://github.com/NVIDIA/apex/issues/439 + return config + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/py3_auto.yml b/tests/py3_auto.yml index a08a88672e..b9ea3f06db 100644 --- a/tests/py3_auto.yml +++ b/tests/py3_auto.yml @@ -2,6 +2,7 @@ name: gluon_cv_py3_mxnet channels: - conda-forge - defaults + - pytorch dependencies: - python=3.6 - perl @@ -16,6 +17,8 @@ dependencies: - tqdm - pillow - pandas==1.3 + - pytorch==1.6.0 + - torchvision==0.7.0 - pip: - https://repo.mxnet.io/dist/python/cu100mkl/mxnet_cu100mkl-1.6.0b20191010-py2.py3-none-manylinux1_x86_64.whl - coverage-badge @@ -26,3 +29,4 @@ dependencies: - portalocker - autocfg>=0.0.6 - autogluon.core==0.2.0 + - timm==0.4.12