diff --git a/gluoncv/auto/estimators/center_net/center_net.py b/gluoncv/auto/estimators/center_net/center_net.py index 8ff23cb1d4..619e2039bc 100644 --- a/gluoncv/auto/estimators/center_net/center_net.py +++ b/gluoncv/auto/estimators/center_net/center_net.py @@ -130,13 +130,16 @@ def _train_loop(self, train_data, val_data, train_eval_data): self._logger.info('Start training from [Epoch %d]', max(self._cfg.train.start_epoch, self.epoch)) for self.epoch in range(max(self._cfg.train.start_epoch, self.epoch), self._cfg.train.epochs): + epoch = self.epoch + if self._best_map >= 1.0: + self._logger.info('[Epoch %d] Early stopping as mAP is reaching 1.0', epoch) + break wh_metric.reset() center_reg_metric.reset() heatmap_loss_metric.reset() tic = time.time() btic = time.time() self.net.hybridize() - epoch = self.epoch for i, batch in enumerate(train_data): split_data = [ @@ -225,9 +228,12 @@ def _evaluate(self, val_data): self._cfg.valid.batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep', num_workers=self._cfg.valid.num_workers) for batch in val_data: - data = gluon.utils.split_and_load(batch[0], ctx_list=self.ctx, batch_axis=0, + val_ctx = self.ctx + if batch[0].shape[0] < len(val_ctx): + val_ctx = val_ctx[:batch[0].shape[0]] + data = gluon.utils.split_and_load(batch[0], ctx_list=val_ctx, batch_axis=0, even_split=False) - label = gluon.utils.split_and_load(batch[1], ctx_list=self.ctx, batch_axis=0, + label = gluon.utils.split_and_load(batch[1], ctx_list=val_ctx, batch_axis=0, even_split=False) det_bboxes = [] det_ids = [] diff --git a/gluoncv/auto/estimators/faster_rcnn/default.py b/gluoncv/auto/estimators/faster_rcnn/default.py index 7dbb4eb09b..bd20aa90cc 100644 --- a/gluoncv/auto/estimators/faster_rcnn/default.py +++ b/gluoncv/auto/estimators/faster_rcnn/default.py @@ -70,9 +70,9 @@ class FasterRCNN: num_box_head_dense_filters : int = 1024 # Input image short side size. - image_short : int = 800 + image_short : int = 600 # Maximum size of input image long side. - image_max_size : int = 1333 + image_max_size : int = 1000 # Whether to enable custom model. # custom_model = True @@ -143,7 +143,7 @@ class TrainCfg: # Misc # ---- # log interval in terms of iterations - log_interval : int = 100 + log_interval : int = 10 # Random seed to be fixed. seed : int = 233 # Whether to enable verbose logging @@ -202,4 +202,4 @@ class FasterRCNNCfg: # dist_async are available. kv_store : str = 'nccl' # Whether to disable hybridize the model. Memory usage and speed will decrese. - disable_hybridization : bool = False + disable_hybridization : bool = True diff --git a/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py b/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py index 8173516618..1640bfada0 100644 --- a/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py +++ b/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py @@ -1,7 +1,8 @@ """Faster RCNN Estimator.""" -# pylint: disable=logging-not-lazy,abstract-method +# pylint: disable=logging-not-lazy,abstract-method,unused-variable import os import time +import warnings import pandas as pd import numpy as np @@ -54,6 +55,7 @@ class FasterRCNNEstimator(BaseEstimator): """ def __init__(self, config, logger=None, reporter=None): super(FasterRCNNEstimator, self).__init__(config, logger, reporter) + self.batch_size = self._cfg.train.batch_size def _fit(self, train_data, val_data): """Fit Faster R-CNN model.""" @@ -62,8 +64,10 @@ def _fit(self, train_data, val_data): self._time_elapsed = 0 if max(self._cfg.train.start_epoch, self.epoch) >= self._cfg.train.epochs: return {'time', self._time_elapsed} - self.net.collect_params().setattr('grad_req', 'null') - self.net.collect_train_params().setattr('grad_req', 'write') + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.net.collect_params().setattr('grad_req', 'null') + self.net.collect_train_params().setattr('grad_req', 'write') self._init_trainer() return self._resume_fit(train_data, val_data) @@ -78,11 +82,6 @@ def _resume_fit(self, train_data, val_data): val_dataset = val_data.to_mxnet() # dataloader - self.batch_size = self._cfg.train.batch_size // min(1, self.num_gpus) \ - if self._cfg.horovod else self._cfg.train.batch_size - if not self._cfg.horovod: - if self._cfg.train.batch_size == 1 and self.num_gpus > 1: - self.batch_size *= self.num_gpus train_loader, val_loader, train_eval_loader = _get_dataloader( self.net, train_dataset, val_dataset, FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform, self.batch_size, len(self.ctx), self._cfg) @@ -121,6 +120,9 @@ def _train_loop(self, train_data, val_data, train_eval_data): for self.epoch in range(max(self._cfg.train.start_epoch, self.epoch), self._cfg.train.epochs): epoch = self.epoch + if self._best_map >= 1.0: + self._logger.info('[Epoch %d] Early stopping as mAP is reaching 1.0', epoch) + break btic = time.time() rcnn_task = ForwardBackwardTask(self.net, self.trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss, mix_ratio=1.0, @@ -189,7 +191,7 @@ def _train_loop(self, train_data, val_data, train_eval_data): metrics + metrics2]) self._logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format( epoch, i, - self._cfg.train.log_interval * self._cfg.train.batch_size / ( + self._cfg.train.log_interval * self.batch_size / ( time.time() - btic), msg)) btic = time.time() @@ -250,7 +252,9 @@ def _evaluate(self, val_data): gt_difficults = [] for x, y, im_scale in zip(*batch): # get prediction results - ids, scores, bboxes = self.net(x) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ids, scores, bboxes = self.net(x) det_ids.append(ids) det_scores.append(scores) # clip to image size @@ -318,6 +322,7 @@ def _init_network(self): else: ctx = [mx.gpu(int(i)) for i in self._cfg.gpus] self.ctx = ctx if ctx else [mx.cpu()] + # network kwargs = {} module_list = [] @@ -330,15 +335,29 @@ def _init_network(self): self.num_gpus = hvd.size() if self._cfg.horovod else len(self.ctx) + # adjust batch size + self.batch_size = self._cfg.train.batch_size // min(1, self.num_gpus) \ + if self._cfg.horovod else self._cfg.train.batch_size + if not self._cfg.horovod: + if self._cfg.train.batch_size == 1 and self.num_gpus > 1: + self.batch_size *= self.num_gpus + elif self._cfg.train.batch_size < self.num_gpus: + self.batch_size = self.num_gpus + if self.batch_size % self.num_gpus != 0: + raise ValueError(f"batch_size {self._cfg.train.batch_size} must be divisible by # gpu {self.num_gpus}") + if self._cfg.faster_rcnn.transfer is not None: assert isinstance(self._cfg.faster_rcnn.transfer, str) self._logger.info( f'Using transfer learning from {self._cfg.faster_rcnn.transfer}, ' + 'the other network parameters are ignored.') self._cfg.faster_rcnn.use_fpn = 'fpn' in self._cfg.faster_rcnn.transfer - self.net = get_model(self._cfg.faster_rcnn.transfer, pretrained=True, - per_device_batch_size=self._cfg.train.batch_size // self.num_gpus, - **kwargs) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.net = get_model(self._cfg.faster_rcnn.transfer, pretrained=True, + per_device_batch_size=self.batch_size // self.num_gpus, + **kwargs) + self.net.sampler._max_num_gt = self._cfg.faster_rcnn.max_num_gt self.net.reset_class(self.classes, reuse_weights=[cname for cname in self.classes if cname in self.net.classes]) else: @@ -358,43 +377,45 @@ def _init_network(self): norm_kwargs = None sym_norm_layer = None sym_norm_kwargs = None - self.net = get_model('custom_faster_rcnn_fpn', classes=self.classes, transfer=None, - dataset=self._cfg.dataset, - pretrained_base=self._cfg.train.pretrained_base, - base_network_name=self._cfg.faster_rcnn.base_network, - norm_layer=norm_layer, norm_kwargs=norm_kwargs, - sym_norm_layer=sym_norm_layer, sym_norm_kwargs=sym_norm_kwargs, - num_fpn_filters=self._cfg.faster_rcnn.num_fpn_filters, - num_box_head_conv=self._cfg.faster_rcnn.num_box_head_conv, - num_box_head_conv_filters= - self._cfg.faster_rcnn.num_box_head_conv_filters, - num_box_head_dense_filters= - self._cfg.faster_rcnn.num_box_head_dense_filters, - short=self._cfg.faster_rcnn.image_short, - max_size=self._cfg.faster_rcnn.image_max_size, - min_stage=2, max_stage=6, - nms_thresh=self._cfg.faster_rcnn.nms_thresh, - nms_topk=self._cfg.faster_rcnn.nms_topk, - roi_mode=self._cfg.faster_rcnn.roi_mode, - roi_size=self._cfg.faster_rcnn.roi_size, - strides=self._cfg.faster_rcnn.strides, - clip=self._cfg.faster_rcnn.clip, - rpn_channel=self._cfg.faster_rcnn.rpn_channel, - base_size=self._cfg.faster_rcnn.anchor_base_size, - scales=self._cfg.faster_rcnn.anchor_scales, - ratios=self._cfg.faster_rcnn.anchor_aspect_ratio, - alloc_size=self._cfg.faster_rcnn.anchor_alloc_size, - rpn_nms_thresh=self._cfg.faster_rcnn.rpn_nms_thresh, - rpn_train_pre_nms=self._cfg.train.rpn_train_pre_nms, - rpn_train_post_nms=self._cfg.train.rpn_train_post_nms, - rpn_test_pre_nms=self._cfg.valid.rpn_test_pre_nms, - rpn_test_post_nms=self._cfg.valid.rpn_test_post_nms, - rpn_min_size=self._cfg.train.rpn_min_size, - per_device_batch_size=self._cfg.train.batch_size // self.num_gpus, - num_sample=self._cfg.train.rcnn_num_samples, - pos_iou_thresh=self._cfg.train.rcnn_pos_iou_thresh, - pos_ratio=self._cfg.train.rcnn_pos_ratio, - max_num_gt=self._cfg.faster_rcnn.max_num_gt) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.net = get_model('custom_faster_rcnn_fpn', classes=self.classes, transfer=None, + dataset=self._cfg.dataset, + pretrained_base=self._cfg.train.pretrained_base, + base_network_name=self._cfg.faster_rcnn.base_network, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, + sym_norm_layer=sym_norm_layer, sym_norm_kwargs=sym_norm_kwargs, + num_fpn_filters=self._cfg.faster_rcnn.num_fpn_filters, + num_box_head_conv=self._cfg.faster_rcnn.num_box_head_conv, + num_box_head_conv_filters= + self._cfg.faster_rcnn.num_box_head_conv_filters, + num_box_head_dense_filters= + self._cfg.faster_rcnn.num_box_head_dense_filters, + short=self._cfg.faster_rcnn.image_short, + max_size=self._cfg.faster_rcnn.image_max_size, + min_stage=2, max_stage=6, + nms_thresh=self._cfg.faster_rcnn.nms_thresh, + nms_topk=self._cfg.faster_rcnn.nms_topk, + roi_mode=self._cfg.faster_rcnn.roi_mode, + roi_size=self._cfg.faster_rcnn.roi_size, + strides=self._cfg.faster_rcnn.strides, + clip=self._cfg.faster_rcnn.clip, + rpn_channel=self._cfg.faster_rcnn.rpn_channel, + base_size=self._cfg.faster_rcnn.anchor_base_size, + scales=self._cfg.faster_rcnn.anchor_scales, + ratios=self._cfg.faster_rcnn.anchor_aspect_ratio, + alloc_size=self._cfg.faster_rcnn.anchor_alloc_size, + rpn_nms_thresh=self._cfg.faster_rcnn.rpn_nms_thresh, + rpn_train_pre_nms=self._cfg.train.rpn_train_pre_nms, + rpn_train_post_nms=self._cfg.train.rpn_train_post_nms, + rpn_test_pre_nms=self._cfg.valid.rpn_test_pre_nms, + rpn_test_post_nms=self._cfg.valid.rpn_test_post_nms, + rpn_min_size=self._cfg.train.rpn_min_size, + per_device_batch_size=self.batch_size // self.num_gpus, + num_sample=self._cfg.train.rcnn_num_samples, + pos_iou_thresh=self._cfg.train.rcnn_pos_iou_thresh, + pos_ratio=self._cfg.train.rcnn_pos_ratio, + max_num_gt=self._cfg.faster_rcnn.max_num_gt) if self._cfg.resume.strip(): self.net.load_parameters(self._cfg.resume.strip()) diff --git a/gluoncv/auto/estimators/image_classification/image_classification.py b/gluoncv/auto/estimators/image_classification/image_classification.py index 8ef885631c..6af6b3b7f9 100644 --- a/gluoncv/auto/estimators/image_classification/image_classification.py +++ b/gluoncv/auto/estimators/image_classification/image_classification.py @@ -124,6 +124,10 @@ def _train_loop(self, train_data, val_data): self._logger.info('Start training from [Epoch %d]', max(self._cfg.train.start_epoch, self.epoch)) for self.epoch in range(max(self._cfg.train.start_epoch, self.epoch), self._cfg.train.epochs): epoch = self.epoch + if self._best_acc >= 1.0: + self._logger.info('[Epoch {}] Early stopping as acc is reaching 1.0'.format(epoch)) + break + mx.nd.waitall() tic = time.time() btic = time.time() if self._cfg.train.use_rec: diff --git a/gluoncv/auto/estimators/image_classification/utils.py b/gluoncv/auto/estimators/image_classification/utils.py index 9c303f75c7..958aa11aa6 100644 --- a/gluoncv/auto/estimators/image_classification/utils.py +++ b/gluoncv/auto/estimators/image_classification/utils.py @@ -71,6 +71,9 @@ def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num return train_data, val_data, rec_batch_fn def loader_batch_fn(batch, ctx): + if batch[0].shape[0] < len(ctx): + # if # sample is less than # ctx, reduce the # ctx + ctx = ctx[:batch[0].shape[0]] data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False) return data, label diff --git a/gluoncv/auto/estimators/ssd/ssd.py b/gluoncv/auto/estimators/ssd/ssd.py index be6488ea19..677da55a25 100644 --- a/gluoncv/auto/estimators/ssd/ssd.py +++ b/gluoncv/auto/estimators/ssd/ssd.py @@ -123,6 +123,9 @@ def _train_loop(self, train_data, val_data, train_eval_data): self.net.collect_params().reset_ctx(self.ctx) for self.epoch in range(max(self._cfg.train.start_epoch, self.epoch), self._cfg.train.epochs): epoch = self.epoch + if self._best_map >= 1.0: + self._logger.info('[Epoch {}] Early stopping as mAP is reaching 1.0'.format(epoch)) + break while lr_steps and epoch >= lr_steps[0]: new_lr = self.trainer.learning_rate * lr_decay lr_steps.pop(0) @@ -222,8 +225,11 @@ def _evaluate(self, val_data): self.net.collect_params().reset_ctx(self.ctx) self.net.hybridize(static_alloc=True, static_shape=True) for batch in val_data: - data = gluon.utils.split_and_load(batch[0], ctx_list=self.ctx, batch_axis=0, even_split=False) - label = gluon.utils.split_and_load(batch[1], ctx_list=self.ctx, batch_axis=0, even_split=False) + val_ctx = self.ctx + if batch[0].shape[0] < len(val_ctx): + val_ctx = val_ctx[:batch[0].shape[0]] + data = gluon.utils.split_and_load(batch[0], ctx_list=val_ctx, batch_axis=0, even_split=False) + label = gluon.utils.split_and_load(batch[1], ctx_list=val_ctx, batch_axis=0, even_split=False) det_bboxes = [] det_ids = [] det_scores = [] @@ -300,9 +306,12 @@ def _init_network(self): self._logger.info( f'Using transfer learning from {self._cfg.ssd.transfer}, the other network parameters are ignored.') if self._cfg.ssd.syncbn and len(self.ctx) > 1: - self.net = get_model(self._cfg.ssd.transfer, pretrained=True, norm_layer=gluon.contrib.nn.SyncBatchNorm, - norm_kwargs={'num_devices': len(self.ctx)}) - self.async_net = get_model(self._cfg.ssd.transfer, pretrained=True) # used by cpu worker + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.net = get_model(self._cfg.ssd.transfer, pretrained=True, + norm_layer=gluon.contrib.nn.SyncBatchNorm, + norm_kwargs={'num_devices': len(self.ctx)}) + self.async_net = get_model(self._cfg.ssd.transfer, pretrained=True) # used by cpu worker self.net.reset_class(self.classes, reuse_weights=[cname for cname in self.classes if cname in self.net.classes]) else: @@ -313,38 +322,42 @@ def _init_network(self): # elif self._cfg.ssd.custom_model: else: if self._cfg.ssd.syncbn and len(self.ctx) > 1: - self.net = custom_ssd(base_network_name=self._cfg.ssd.base_network, - base_size=self._cfg.ssd.data_shape, - filters=self._cfg.ssd.filters, - sizes=self._cfg.ssd.sizes, - ratios=self._cfg.ssd.ratios, - steps=self._cfg.ssd.steps, - classes=self.classes, - dataset='auto', - pretrained_base=True, - norm_layer=gluon.contrib.nn.SyncBatchNorm, - norm_kwargs={'num_devices': len(self.ctx)}) - self.async_net = custom_ssd(base_network_name=self._cfg.ssd.base_network, - base_size=self._cfg.ssd.data_shape, - filters=self._cfg.ssd.filters, - sizes=self._cfg.ssd.sizes, - ratios=self._cfg.ssd.ratios, - steps=self._cfg.ssd.steps, - classes=self.classes, - dataset='auto', - pretrained_base=False) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.net = custom_ssd(base_network_name=self._cfg.ssd.base_network, + base_size=self._cfg.ssd.data_shape, + filters=self._cfg.ssd.filters, + sizes=self._cfg.ssd.sizes, + ratios=self._cfg.ssd.ratios, + steps=self._cfg.ssd.steps, + classes=self.classes, + dataset='auto', + pretrained_base=True, + norm_layer=gluon.contrib.nn.SyncBatchNorm, + norm_kwargs={'num_devices': len(self.ctx)}) + self.async_net = custom_ssd(base_network_name=self._cfg.ssd.base_network, + base_size=self._cfg.ssd.data_shape, + filters=self._cfg.ssd.filters, + sizes=self._cfg.ssd.sizes, + ratios=self._cfg.ssd.ratios, + steps=self._cfg.ssd.steps, + classes=self.classes, + dataset='auto', + pretrained_base=False) else: - self.net = custom_ssd(base_network_name=self._cfg.ssd.base_network, - base_size=self._cfg.ssd.data_shape, - filters=self._cfg.ssd.filters, - sizes=self._cfg.ssd.sizes, - ratios=self._cfg.ssd.ratios, - steps=self._cfg.ssd.steps, - classes=self.classes, - dataset=self._cfg.dataset, - pretrained_base=True, - norm_layer=gluon.nn.BatchNorm) - self.async_net = self.net + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.net = custom_ssd(base_network_name=self._cfg.ssd.base_network, + base_size=self._cfg.ssd.data_shape, + filters=self._cfg.ssd.filters, + sizes=self._cfg.ssd.sizes, + ratios=self._cfg.ssd.ratios, + steps=self._cfg.ssd.steps, + classes=self.classes, + dataset=self._cfg.dataset, + pretrained_base=True, + norm_layer=gluon.nn.BatchNorm) + self.async_net = self.net with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") diff --git a/gluoncv/auto/estimators/yolo/yolo.py b/gluoncv/auto/estimators/yolo/yolo.py index caa29bfd38..a986511951 100644 --- a/gluoncv/auto/estimators/yolo/yolo.py +++ b/gluoncv/auto/estimators/yolo/yolo.py @@ -114,6 +114,9 @@ def _train_loop(self, train_data, val_data, train_eval_data): self._logger.info('Start training from [Epoch %d]', max(self._cfg.train.start_epoch, self.epoch)) for self.epoch in range(max(self._cfg.train.start_epoch, self.epoch), self._cfg.train.epochs): epoch = self.epoch + if self._best_map >= 1.0: + self._logger.info('[Epoch {}] Early stopping as mAP is reaching 1.0'.format(epoch)) + break tic = time.time() btic = time.time() if self._cfg.train.mixup: @@ -222,8 +225,11 @@ def _evaluate(self, val_data): mx.nd.waitall() self.net.hybridize() for batch in val_data: - data = gluon.utils.split_and_load(batch[0], ctx_list=self.ctx, batch_axis=0, even_split=False) - label = gluon.utils.split_and_load(batch[1], ctx_list=self.ctx, batch_axis=0, even_split=False) + val_ctx = self.ctx + if batch[0].shape[0] < len(val_ctx): + val_ctx = val_ctx[:batch[0].shape[0]] + data = gluon.utils.split_and_load(batch[0], ctx_list=val_ctx, batch_axis=0, even_split=False) + label = gluon.utils.split_and_load(batch[1], ctx_list=val_ctx, batch_axis=0, even_split=False) det_bboxes = [] det_ids = [] det_scores = [] diff --git a/gluoncv/auto/tasks/object_detection.py b/gluoncv/auto/tasks/object_detection.py index 329227d6d4..6bd2b9dcfc 100644 --- a/gluoncv/auto/tasks/object_detection.py +++ b/gluoncv/auto/tasks/object_detection.py @@ -81,6 +81,11 @@ def _train_object_detection(args, reporter): tic = time.time() try: estimator_cls = args.pop('estimator', None) + if estimator_cls == FasterRCNNEstimator: + # safe guard if too many GT in dataset + train_dataset = train_data.to_mxnet() + max_gt_count = max([y[1].shape[0] for y in train_dataset]) + 20 + args['faster_rcnn']['max_num_gt'] = max_gt_count estimator = estimator_cls(args, reporter=reporter) # training result = estimator.fit(train_data=train_data, val_data=val_data) diff --git a/gluoncv/auto/tasks/utils.py b/gluoncv/auto/tasks/utils.py index 45e5564a5d..02b38c3edb 100644 --- a/gluoncv/auto/tasks/utils.py +++ b/gluoncv/auto/tasks/utils.py @@ -326,7 +326,8 @@ def config_to_nested_v0(config): 'anchor_base_size', 'anchor_aspect_ratio', 'anchor_scales', 'anchor_alloc_size', 'rpn_channel', 'rpn_nms_thresh', 'max_num_gt', 'norm_layer', 'use_fpn', 'num_fpn_filters', 'num_box_head_conv', 'num_box_head_conv_filters', 'num_box_head_dense_filters', - 'image_short', 'image_max_size', 'custom_model', 'amp', 'static_alloc'], + 'image_short', 'image_max_size', 'custom_model', 'amp', 'static_alloc', + 'disable_hybridization'], 'train': ['pretrained_base', 'batch_size', 'start_epoch', 'epochs', 'lr', 'lr_decay', 'lr_decay_epoch', 'lr_mode', 'lr_warmup', 'lr_warmup_factor', 'momentum', 'wd', 'rpn_train_pre_nms', 'rpn_train_post_nms', 'rpn_smoothl1_rho', 'rpn_min_size',