diff --git a/gluoncv/auto/data/dataset.py b/gluoncv/auto/data/dataset.py index ee69275cda..5c95862219 100644 --- a/gluoncv/auto/data/dataset.py +++ b/gluoncv/auto/data/dataset.py @@ -56,7 +56,7 @@ class ImageClassificationDataset(pd.DataFrame): """ # preserved properties that will be copied to a new instance - _metadata = ['classes', 'to_mxnet', 'show_images', 'random_split', 'IMG_COL', 'LABEL_COL'] + _metadata = ['classes', 'IMG_COL', 'LABEL_COL'] def __init__(self, data, classes=None, image_column='image', label_column='label', **kwargs): root = kwargs.pop('root', None) @@ -131,27 +131,29 @@ def show_images(self, indices=None, nsample=16, ncol=4, shuffle=True, resize=224 fontsize : int, optional The fontsize for the title """ + df = self.reset_index(drop=True) if indices is None: if not shuffle: indices = range(nsample) else: - indices = list(range(len(self))) + indices = list(range(len(df))) np.random.shuffle(indices) indices = indices[:min(nsample, len(indices))] - images = [cv2.cvtColor(cv2.resize(cv2.imread(self.at[idx, self.IMG_COL]), (resize, resize), \ - interpolation=cv2.INTER_AREA), cv2.COLOR_BGR2RGB) for idx in indices if idx < len(self)] + images = [cv2.cvtColor(cv2.resize(cv2.imread(df.at[idx, df.IMG_COL]), (resize, resize), \ + interpolation=cv2.INTER_AREA), cv2.COLOR_BGR2RGB) for idx in indices if idx < len(df)] titles = None - if self.LABEL_COL in self.columns: - if self.classes: - titles = [self.classes[int(self.at[idx, self.LABEL_COL])] + ': ' + str(self.at[idx, self.LABEL_COL]) \ - for idx in indices if idx < len(self)] + if df.LABEL_COL in df.columns: + if df.classes: + titles = [df.classes[int(df.at[idx, df.LABEL_COL])] + ': ' + str(df.at[idx, df.LABEL_COL]) \ + for idx in indices if idx < len(df)] else: - titles = [str(self.at[idx, self.LABEL_COL]) for idx in indices if idx < len(self)] + titles = [str(df.at[idx, df.LABEL_COL]) for idx in indices if idx < len(df)] _show_images(images, cols=ncol, titles=titles, fontsize=fontsize) def to_mxnet(self): """Return a mxnet based iterator that returns ndarray and labels""" df = self.rename(columns={self.IMG_COL: "image", self.LABEL_COL: "label"}, errors='ignore') + df = df.reset_index(drop=True) return _MXImageClassificationDataset(df) @classmethod @@ -398,8 +400,7 @@ class ObjectDetectionDataset(pd.DataFrame): """ # preserved properties that will be copied to a new instance - _metadata = ['dataset_type', 'classes', 'pack', 'unpack', 'is_packed', - 'to_mxnet', 'color_map', 'show_images', 'random_split'] + _metadata = ['dataset_type', 'classes', 'color_map'] def __init__(self, data, dataset_type=None, classes=None, **kwargs): # dataset_type will be used to determine metrics, if None then auto resolve at runtime @@ -606,7 +607,8 @@ def is_packed(self): def to_mxnet(self): """Return a mxnet based iterator that returns ndarray and labels""" - return _MXObjectDetectionDataset(self) + df = self.reset_index(drop=True) + return _MXObjectDetectionDataset(df) def random_split(self, test_size=0.1, val_size=0, random_state=None): r"""Randomly split the dataset into train/val/test sets. @@ -661,6 +663,7 @@ def show_images(self, indices=None, nsample=16, ncol=4, shuffle=True, resize=512 The fontsize for title """ df = self.pack() + df = df.reset_index(drop=True) if indices is None: if not shuffle: indices = range(nsample) diff --git a/gluoncv/auto/estimators/center_net/center_net.py b/gluoncv/auto/estimators/center_net/center_net.py index 1797711b62..c0ef9165cd 100644 --- a/gluoncv/auto/estimators/center_net/center_net.py +++ b/gluoncv/auto/estimators/center_net/center_net.py @@ -320,7 +320,9 @@ def _init_network(self, **kwargs): self._cfg.center_net.transfer) net = get_model(self._cfg.center_net.transfer, pretrained=(not load_only)) if load_only: - net.initialize() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + net.initialize() net(mx.nd.zeros((1, 3, self._cfg.center_net.data_shape[0], self._cfg.center_net.data_shape[1]))) net.reset_class(self.classes, reuse_weights=[cname for cname in self.classes if cname in net.classes]) else: diff --git a/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py b/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py index bf920926e9..2a1b671a15 100644 --- a/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py +++ b/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py @@ -406,7 +406,9 @@ def _init_network(self, **kwargs): **kwargs) self.net.sampler._max_num_gt = self._cfg.faster_rcnn.max_num_gt if load_only: - self.net.initialize() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.net.initialize() self.net.set_nms(nms_thresh=0) self.net(mx.nd.zeros((1, 3, 600, 800))) self.net.set_nms(nms_thresh=self._cfg.faster_rcnn.nms_thresh) diff --git a/gluoncv/auto/estimators/ssd/ssd.py b/gluoncv/auto/estimators/ssd/ssd.py index 92c13d7fea..efff5c2fa5 100644 --- a/gluoncv/auto/estimators/ssd/ssd.py +++ b/gluoncv/auto/estimators/ssd/ssd.py @@ -361,7 +361,9 @@ def _init_network(self, **kwargs): norm_kwargs={'num_devices': len(self.ctx)}) self.async_net = get_model(self._cfg.ssd.transfer, pretrained=(not load_only)) # used by cpu worker if load_only: - self.net.initialize() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.net.initialize() self.net.set_nms(nms_thresh=0) self.net(mx.nd.zeros((1, 3, self._cfg.ssd.data_shape, self._cfg.ssd.data_shape))) self.net.set_nms(nms_thresh=self._cfg.ssd.nms_thresh, nms_topk=self._cfg.ssd.nms_topk) @@ -371,7 +373,9 @@ def _init_network(self, **kwargs): self.net = get_model(self._cfg.ssd.transfer, pretrained=(not load_only), norm_layer=gluon.nn.BatchNorm) self.async_net = get_model(self._cfg.ssd.transfer, pretrained=(not load_only), norm_layer=gluon.nn.BatchNorm) if load_only: - self.net.initialize() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.net.initialize() self.net.set_nms(nms_thresh=0) self.net(mx.nd.zeros((1, 3, self._cfg.ssd.data_shape, self._cfg.ssd.data_shape))) self.net.set_nms(nms_thresh=self._cfg.ssd.nms_thresh, nms_topk=self._cfg.ssd.nms_topk) diff --git a/gluoncv/auto/estimators/yolo/yolo.py b/gluoncv/auto/estimators/yolo/yolo.py index 163b92556b..5671818de4 100644 --- a/gluoncv/auto/estimators/yolo/yolo.py +++ b/gluoncv/auto/estimators/yolo/yolo.py @@ -391,7 +391,9 @@ def _init_network(self, **kwargs): self.net = get_model(self._cfg.yolo3.transfer, pretrained=(not load_only)) self.async_net = self.net if load_only: - self.net.initialize() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.net.initialize() self.net.set_nms(nms_thresh=0) self.net(mx.nd.zeros((1, 3, self._cfg.yolo3.data_shape, self._cfg.yolo3.data_shape))) self.net.set_nms(nms_thresh=self._cfg.yolo3.nms_thresh, nms_topk=self._cfg.yolo3.nms_topk) diff --git a/tests/auto/test_auto_estimators.py b/tests/auto/test_auto_estimators.py index a87421bd2b..b07d2437b4 100644 --- a/tests/auto/test_auto_estimators.py +++ b/tests/auto/test_auto_estimators.py @@ -18,6 +18,7 @@ # under the License. """Test auto estimators""" from PIL import Image +import numpy as np from gluoncv.auto.tasks import ImageClassification, ImagePrediction from gluoncv.auto.tasks import ObjectDetection from autogluon.core.scheduler.resource import get_cpu_count, get_gpu_count @@ -43,7 +44,7 @@ def test_image_regression_estimator(): est.predict_feature(IMAGE_REGRESS_TEST.iloc[0]['image']) # test save/load _save_load_test(est, 'img_regression.pkl') - + def test_image_classification_estimator(): from gluoncv.auto.estimators import ImageClassificationEstimator est = ImageClassificationEstimator({'train': {'epochs': 1, 'batch_size': 8}, 'gpus': list(range(get_gpu_count()))}) @@ -104,7 +105,7 @@ def test_ssd_estimator(): # test save/load est2 = _save_load_test(est, 'ssd.pkl') evaluate_result2 = est2.evaluate(OBJECT_DETECTION_VAL) - assert evaluate_result == evaluate_result2, f'{evaluate_result} != \n {evaluate_result2}' + np.testing.assert_array_equal(evaluate_result, evaluate_result2, err_msg=f'{evaluate_result} != \n {evaluate_result2}') def test_yolo3_estimator(): from gluoncv.auto.estimators import YOLOv3Estimator @@ -119,7 +120,7 @@ def test_yolo3_estimator(): # test save/load est2 = _save_load_test(est, 'yolo3.pkl') evaluate_result2 = est2.evaluate(OBJECT_DETECTION_VAL) - assert evaluate_result == evaluate_result2, f'{evaluate_result} != \n {evaluate_result2}' + np.testing.assert_array_equal(evaluate_result, evaluate_result2, err_msg=f'{evaluate_result} != \n {evaluate_result2}') def test_frcnn_estimator(): from gluoncv.auto.estimators import FasterRCNNEstimator