Skip to content

Commit

Permalink
Torch classification (#1683)
Browse files Browse the repository at this point in the history
* fit, init_network, init_trainer. resume_fit in progress

* todo

* dataset

* fix config

* train and validate

* remove split

* fix

* train and val

* save andd load

* resume

* cpu fix

* save and load

* ctx

* fix

* predict

* waarning

* predict feature

* parallel

* parallel and save load

* disable ema

* test

* rmse metric and removed aug_splits

* fix

* fix epoch

* lint fix

* docstring

* fix lint

* fix

* dependency

* conflict

* fix ci

* fix save load

* fix

* custom net, disable ocustom optimizer, fix test OOM

* fix lint

* fix

* fix

* fix
  • Loading branch information
yinweisu authored Jul 22, 2021
1 parent 33c0081 commit afbf792
Show file tree
Hide file tree
Showing 18 changed files with 1,320 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ for f in $EFS/.mxnet/datasets/*; do
fi
done

python3 -m pip install sphinx==3.5.4 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark scipy mxtheme autogluon.core
python3 -m pip install sphinx==3.5.4 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark scipy mxtheme autogluon.core timm

export MXNET_CUDNN_AUTOTUNE_DEFAULT=0
cd docs
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/gpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ export MPLBACKEND=Agg
export KMP_DUPLICATE_LIB_OK=TRUE

if [[ $TESTS_PATH == *"auto"* ]]; then
echo "Installing autogluon.core for auto module"
echo "Installing autogluon.core and timm for auto module"
pip3 install autogluon.core==0.2.0
pip3 install timm==0.4.12
fi

nosetests --with-timer --timer-ok 5 --timer-warning 20 -x --with-coverage --cover-package $COVER_PACKAGE -v $TESTS_PATH
48 changes: 48 additions & 0 deletions gluoncv/auto/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
except ImportError:
MXDataset = object
mx = None
try:
import torch
TorchDataset = torch.utils.data.Dataset
except ImportError:
TorchDataset = object
torch = None

logger = logging.getLogger()

Expand Down Expand Up @@ -156,6 +162,12 @@ def to_mxnet(self):
df = df.reset_index(drop=True)
return _MXImageClassificationDataset(df)

def to_torch(self):
"""Return a pytorch based iterator that returns ndarray and labels"""
df = self.rename(columns={self.IMG_COL: "image", self.LABEL_COL: "label"}, errors='ignore')
df = df.reset_index(drop=True)
return _TorchImageClassificationDataset(df)

@classmethod
def from_csv(cls, csv_file, root=None, image_column='image', label_column='label', no_class=False):
r"""Create from csv file.
Expand Down Expand Up @@ -385,6 +397,42 @@ def __getitem__(self, idx):
label = self._dataset['label'][idx]
return img, label

class _TorchImageClassificationDataset(TorchDataset):
"""Internal wrapper read entries in pd.DataFrame as images/labels.
Parameters
----------
dataset : ImageClassificationDataset
DataFrame as ImageClassificationDataset.
"""
def __init__(self, dataset):
if torch is None:
raise RuntimeError('Unable to import pytorch which is required.')
assert isinstance(dataset, ImageClassificationDataset)
assert 'image' in dataset.columns
self._has_label = 'label' in dataset.columns
self._dataset = dataset
self.classes = self._dataset.classes
self._imread = Image.open
self.transform = None

def __len__(self):
return self._dataset.shape[0]

def __getitem__(self, idx):
im_path = self._dataset['image'][idx]
img = self._imread(im_path).convert('RGB')
label = None
# # pylint: disable=not-callable
if self.transform is not None:
img = self.transform(img)
if self._has_label:
label = self._dataset['label'][idx]
else:
label = torch.tensor(-1, dtype=torch.long)
return img, label


class ObjectDetectionDataset(pd.DataFrame):
"""ObjectDetection dataset as DataFrame.
Expand Down
2 changes: 2 additions & 0 deletions gluoncv/auto/estimators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Estimator implementations"""
# FIXME: for quick test purpose only
from .image_classification import ImageClassificationEstimator
from .ssd import SSDEstimator
from .yolo import YOLOv3Estimator
from .faster_rcnn import FasterRCNNEstimator
# from .mask_rcnn import MaskRCNNEstimator
from .center_net import CenterNetEstimator
from .torch_image_classification import TorchImageClassificationEstimator
31 changes: 31 additions & 0 deletions gluoncv/auto/estimators/base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,22 @@ def _validate_gpus(self, gpu_ids):
pass
return valid_gpus

#FIXME: better design than a duplicate function?
def _torch_validate_gpus(self, gpu_ids):
"""validate if requested gpus are actually available"""
valid_gpus = []
try:
import torch
for gid in gpu_ids:
try:
_ = torch.zeros(1, device=f'cuda:{gid}')
valid_gpus.append(str(gid))
except:
pass
except ImportError:
pass
return valid_gpus

def reset_ctx(self, ctx=None):
"""Reset model context.
Expand Down Expand Up @@ -289,6 +305,21 @@ def reset_ctx(self, ctx=None):
done = True
except ImportError:
pass
try:
import torch
if isinstance(self.net, (torch.nn.Module, torch.nn.DataParallel)):
for c in ctx_list:
assert isinstance(c, torch.device)
if hasattr(self.net, 'reset_ctx'):
self.net.reset_ctx(ctx_list)
else:
if isinstance(self.net, torch.nn.DataParallel):
self.net = torch.nn.DataParallel(self.net.module, device_ids=[ctx.index for ctx in ctx_list])
self.net.to(self.ctx[0])
self.ctx = ctx_list
done = True
except ImportError:
pass
if not done:
raise RuntimeError("Unable to reset_ctx, no `mxnet` and `pytorch`.")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Torch image classification estimator"""
from .torch_image_classification import TorchImageClassificationEstimator
113 changes: 113 additions & 0 deletions gluoncv/auto/estimators/torch_image_classification/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Default configs for torch image classification"""
# pylint: disable=bad-whitespace,missing-class-docstring
from typing import Union, Tuple
from autocfg import dataclass, field

@dataclass
class ModelCfg:
model: str = 'resnet101'
pretrained: bool = False
global_pool_type: Union[str, None] = None # Global pool type, one of (fast, avg, max, avgmax). Model default if None

@dataclass
class DatasetCfg:
img_size: Union[int, None] = None # Image patch size (default: None => model default)
input_size: Union[Tuple[int, int, int], None] = None # Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty
crop_pct: Union[float, None] = None # Input image center crop percent (for validation only)
mean: Union[Tuple, None] = None # Override mean pixel value of dataset
std : Union[Tuple, None] = None # Override std deviation of of dataset
interpolation: str = '' # Image resize interpolation type (overrides model)
validation_batch_size_multiplier: int = 1 # ratio of validation batch size to training batch size (default: 1)

@dataclass
class OptimizerCfg:
opt: str = 'sgd'
opt_eps: Union[float, None] = None # Optimizer Epsilon (default: None, use opt default)
opt_betas: Union[Tuple, None] = None # Optimizer Betas (default: None, use opt default)
momentum: float = 0.9
weight_decay: float = 0.0001
clip_grad: Union[float, None] = None # Clip gradient norm (default: None, no clipping)
clip_mode: str = 'norm' # Gradient clipping mode. One of ("norm", "value", "agc")

@dataclass
class TrainCfg:
batch_size: int = 32
sched: str = 'step' # LR scheduler
lr: float = 0.01
lr_noise: Union[Tuple, None] = None # learning rate noise on/off epoch percentages
lr_noise_pct: float = 0.67 # learning rate noise limit percent
lr_noise_std: float = 1.0 # learning rate noise std-dev
lr_cycle_mul: float = 1.0 # learning rate cycle len multiplier
lr_cycle_limit: int = 1 # learning rate cycle limit
warmup_lr: float = 0.0001
min_lr: float = 1e-5
epochs: int = 200
start_epoch: int = 0 # manual epoch number (useful on restarts)
decay_epochs: int = 30 # epoch interval to decay LR
warmup_epochs: int = 3 # epochs to warmup LR, if scheduler supports
cooldown_epochs: int = 10 # epochs to cooldown LR at min_lr, after cyclic schedule ends
patience_epochs: int = 10 # patience epochs for Plateau LR scheduler
decay_rate: float = 0.1
bn_momentum: Union[float, None] = None # BatchNorm momentum override
bn_eps: Union[float, None] = None # BatchNorm epsilon override
sync_bn: bool = False # Enable NVIDIA Apex or Torch synchronized BatchNorm
early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled
early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics
# the baseline value for metric, training won't stop if not reaching baseline
early_stop_baseline : Union[float, int] = 0.0
early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly

@dataclass
class AugmentationCfg:
no_aug: bool = False # Disable all training augmentation, override other train aug args
scale: Tuple[float, float] = (0.08, 1.0) # Random resize scale
ratio: Tuple[float, float] = (3./4., 4./3.) # Random resize aspect ratio (default: 0.75 1.33
hflip: float = 0.5 # Horizontal flip training aug probability
vflip: float = 0.0 # Vertical flip training aug probability
color_jitter: float = 0.4
auto_augment: Union[str, None] = None # Use AutoAugment policy. "v0" or "original
mixup: float = 0.0 # mixup alpha, mixup enabled if > 0
cutmix: float = 0.0 # cutmix alpha, cutmix enabled if > 0
cutmix_minmax: Union[Tuple, None] = None # cutmix min/max ratio, overrides alpha and enables cutmix if set
mixup_prob: float = 1.0 # Probability of performing mixup or cutmix when either/both is enabled
mixup_switch_prob: float = 0.5 # Probability of switching to cutmix when both mixup and cutmix enabled
mixup_mode: str = 'batch' # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
mixup_off_epoch: int = 0 # Turn off mixup after this epoch, disabled if 0
smoothing: float = 0.1 # Label smoothin
train_interpolation: str = 'random' # Training interpolation (random, bilinear, bicubic)
drop: float = 0.0 # Dropout rate
drop_path: Union[float, None] = None # Drop path rate
drop_block: Union[float, None] = None # Drop block rate

@dataclass
class ModelEMACfg:
model_ema: bool = True # Enable tracking moving average of model weights
model_ema_force_cpu: bool = False # Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation
model_ema_decay: float = 0.9998 # decay factor for model weights moving average

@dataclass
class MiscCfg:
seed: int = 42
log_interval: int = 50 # how many batches to wait before logging training status
num_workers: int = 4 # how many training processes to use
save_images: bool = False # save images of input bathes every log interval for debugging
amp: bool = False # use NVIDIA Apex AMP or Native AMP for mixed precision training
apex_amp: bool = False # Use NVIDIA Apex AMP mixed precision
native_amp: bool = False # Use Native Torch AMP mixed precision
pin_mem: bool = False # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU
prefetcher: bool = False # use fast prefetcher
eval_metric: str = 'top1' # 'Best metric (default: "top1")
tta: int = 0 # Test/inference time augmentation (oversampling) factor. 0=None
use_multi_epochs_loader: bool = False # use the multi-epochs-loader to save time at the beginning of every epoch
torchscript: bool = False # keep false, convert model torchscript for inference

@dataclass
class TorchImageClassificationCfg:
model : ModelCfg = field(default_factory=ModelCfg)
dataset: DatasetCfg = field(default_factory=DatasetCfg)
optimizer: OptimizerCfg = field(default_factory=OptimizerCfg)
train: TrainCfg = field(default_factory=TrainCfg)
augmentation: AugmentationCfg = field(default_factory=AugmentationCfg)
model_ema: ModelEMACfg = field(default_factory=ModelEMACfg)
misc: MiscCfg = field(default_factory=MiscCfg)
gpus : Union[Tuple, list] = (0, ) # gpu individual ids, not necessarily consecutive
Loading

0 comments on commit afbf792

Please sign in to comment.