Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto] Autogluon compatibility update #1669

Merged
merged 15 commits into from
Jun 6, 2021
6 changes: 3 additions & 3 deletions .github/workflows/build_docs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ EFS=/mnt/efs
mkdir -p ~/.mxnet/datasets
for f in $EFS/.mxnet/datasets/*; do
if [ -d "$f" ]; then
# Will not run if no directories are available
# Will not run if no directories are available
ln -s $f ~/.mxnet/datasets/$(basename "$f")
fi
done

python3 -m pip install sphinx==3.5.4 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark scipy mxtheme
python3 -m pip install sphinx==3.5.4 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark scipy mxtheme autogluon.core

export MXNET_CUDNN_AUTOTUNE_DEFAULT=0
cd docs
cd docs
make html
COMMAND_EXIT_CODE=$?

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: Build Test
on:
on:
workflow_run:
workflows: ["Unit Test"]
types:
Expand Down
12 changes: 9 additions & 3 deletions .github/workflows/gpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ for f in $EFS/.mxnet/models/*.params; do
ln -s $f ~/.mxnet/models/$(basename "$f")
done

export MXNET_CUDNN_AUTOTUNE_DEFAULT=0
export MPLBACKEND=Agg
export KMP_DUPLICATE_LIB_OK=TRUE
export MXNET_CUDNN_AUTOTUNE_DEFAULT=0
export MPLBACKEND=Agg
export KMP_DUPLICATE_LIB_OK=TRUE

if [[ $TESTS_PATH == *"auto"* ]]; then
echo "Installing autogluon.core for auto module"
pip3 install autogluon.core==0.2.0
fi

nosetests --with-timer --timer-ok 5 --timer-warning 20 -x --with-coverage --cover-package $COVER_PACKAGE -v $TESTS_PATH
2 changes: 2 additions & 0 deletions docs/install/install-include.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ Select your preferences and run the install command.

- Requires `pip >= 9. <https://pip.pypa.io/en/stable/installing/>`_.

- Note that you can install the extra optional requirements all together by replacing "pip install gluoncv" with "pip install gluoncv[full]".

.. container:: nightly

- Nightly build provides latest features for enthusiasts.
Expand Down
44 changes: 29 additions & 15 deletions gluoncv/auto/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,23 @@ class ImageClassificationDataset(pd.DataFrame):
The input data.
classes : list of str, optional
The class synsets for this dataset, if `None`, it will infer from the data.
image_column : str, default is 'image'
The name of the column for image paths.
label_column : str, default is 'label'
The name for the label column, leave it as is if no label column is available. Note that
in such case you won't be able to train with this dataset, but can still visualize the images.

"""
# preserved properties that will be copied to a new instance
_metadata = ['classes', 'to_mxnet', 'show_images', 'random_split']
_metadata = ['classes', 'to_mxnet', 'show_images', 'random_split', 'IMG_COL', 'LABEL_COL']

def __init__(self, data, classes=None, **kwargs):
def __init__(self, data, classes=None, image_column='image', label_column='label', **kwargs):
root = kwargs.pop('root', None)
if isinstance(data, str) and data.endswith('csv'):
data = self.from_csv(data, root=root)
data = self.from_csv(data, root=root, image_column=image_column, label_column=label_column)
self.classes = classes
self.IMG_COL = image_column
self.LABEL_COL = label_column
super().__init__(data, **kwargs)

@property
Expand Down Expand Up @@ -127,20 +134,21 @@ def show_images(self, indices=None, nsample=16, ncol=4, shuffle=True, resize=224
indices = list(range(len(self)))
np.random.shuffle(indices)
indices = indices[:min(nsample, len(indices))]
images = [cv2.cvtColor(cv2.resize(cv2.imread(self.at[idx, 'image']), (resize, resize), \
images = [cv2.cvtColor(cv2.resize(cv2.imread(self.at[idx, self.IMG_COL]), (resize, resize), \
interpolation=cv2.INTER_AREA), cv2.COLOR_BGR2RGB) for idx in indices if idx < len(self)]
titles = None
if 'label' in self.columns:
titles = [self.classes[int(self.at[idx, 'label'])] + ': ' + str(self.at[idx, 'label']) \
if self.LABEL_COL in self.columns:
titles = [self.classes[int(self.at[idx, self.LABEL_COL])] + ': ' + str(self.at[idx, self.LABEL_COL]) \
for idx in indices if idx < len(self)]
_show_images(images, cols=ncol, titles=titles, fontsize=fontsize)

def to_mxnet(self):
"""Return a mxnet based iterator that returns ndarray and labels"""
return _MXImageClassificationDataset(self)
df = self.rename(columns={self.IMG_COL: "image", self.LABEL_COL: "label"}, errors='ignore')
return _MXImageClassificationDataset(df)

@classmethod
def from_csv(cls, csv_file, root=None):
def from_csv(cls, csv_file, root=None, image_column='image', label_column='label'):
r"""Create from csv file.

Parameters
Expand All @@ -149,19 +157,23 @@ def from_csv(cls, csv_file, root=None):
The path for csv file.
root : str
The relative root for image paths stored in csv file.

image_column : str, default is 'image'
The name of the column for image paths.
label_column : str, default is 'label'
The name for the label column, leave it as is if no label column is available. Note that
in such case you won't be able to train with this dataset, but can still visualize the images.
"""
if is_url(csv_file):
csv_file = url_data(csv_file, disp_depth=0)
df = pd.read_csv(csv_file)
assert 'image' in df.columns, "`image` column is required, used for accessing the original images"
if not 'label' in df.columns:
assert image_column in df.columns, f"`{image_column}` column is required, used for accessing the original images"
if not label_column in df.columns:
logger.info('label not in columns, no access to labels of images')
classes = None
else:
classes = df['label'].unique()
df = _absolute_pathify(df, root=root, column='image')
return cls(df, classes=classes)
classes = df[label_column].unique().tolist()
df = _absolute_pathify(df, root=root, column=image_column)
return cls(df, classes=classes, image_column=image_column, label_column=label_column)

@classmethod
def from_folder(cls, root, exts=('.jpg', '.jpeg', '.png')):
Expand Down Expand Up @@ -211,7 +223,7 @@ def from_folder(cls, root, exts=('.jpg', '.jpeg', '.png')):

@classmethod
def from_folders(cls, root, train='train', val='val', test='test', exts=('.jpg', '.jpeg', '.png')):
"""Method for loading splited datasets under root.
"""Method for loading (already) splited datasets under root.
like::
root/train/car/0001.jpg
root/train/car/xxxa.jpg
Expand All @@ -222,6 +234,8 @@ def from_folders(cls, root, train='train', val='val', test='test', exts=('.jpg',
will be loaded into three splits, with 3/1/2 images, respectively.
You can specify the sub-folder names of `train`/`val`/`test` individually. If one particular sub-folder is not
found, the corresponding returned dataset will be `None`.
Note: if your existing dataset isn't split into such format, please use `from_folder` function and apply
random splitting using `random_split` function afterwards.

Example:
>>> train_data, val_data, test_data = ImageClassificationDataset.from_folders('./data', val='validation')
Expand Down
4 changes: 2 additions & 2 deletions gluoncv/auto/estimators/base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _resume_fit(self, train_data, val_data, time_limit=math.inf):
def _evaluate(self, val_data):
raise NotImplementedError

def _init_network(self):
def _init_network(self, **kwargs):
raise NotImplementedError

def _init_trainer(self):
Expand Down Expand Up @@ -367,7 +367,7 @@ def __setstate__(self, state):
try:
import mxnet as _
net_params = state['net']
self._init_network()
self._init_network(load_only=True)
with temporary_filename() as tfile:
with open(tfile, 'wb') as fo:
fo.write(net_params)
Expand Down
10 changes: 7 additions & 3 deletions gluoncv/auto/estimators/center_net/center_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def _evaluate(self, val_data):
eval_metric.update(det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults)
return eval_metric.get()

def _init_network(self):
def _init_network(self, **kwargs):
load_only = kwargs.get('load_only', False)
if not self.num_class:
raise ValueError('Unable to create network when `num_class` is unknown. \
It should be inferred from dataset or resumed from saved states.')
Expand All @@ -317,7 +318,10 @@ def _init_network(self):
assert isinstance(self._cfg.center_net.transfer, str)
self._logger.info('Using transfer learning from %s, ignoring some of the network configs',
self._cfg.center_net.transfer)
net = get_model(self._cfg.center_net.transfer, pretrained=True)
net = get_model(self._cfg.center_net.transfer, pretrained=(not load_only))
if load_only:
net.initialize()
net(mx.nd.zeros((1, 3, self._cfg.center_net.data_shape[0], self._cfg.center_net.data_shape[1])))
net.reset_class(self.classes, reuse_weights=[cname for cname in self.classes if cname in net.classes])
else:
net_name = '_'.join(('center_net', self._cfg.center_net.base_network, self.dataset))
Expand All @@ -327,7 +331,7 @@ def _init_network(self):
('wh', {'num_output': self._cfg.center_net.heads.wh_outputs}),
('reg', {'num_output': self._cfg.center_net.heads.reg_outputs})])
base_network = get_base_network(self._cfg.center_net.base_network,
pretrained=self._cfg.train.pretrained_base)
pretrained=self._cfg.train.pretrained_base and not load_only)
net = get_center_net(self._cfg.center_net.base_network,
self.dataset,
base_network=base_network,
Expand Down
12 changes: 9 additions & 3 deletions gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,8 @@ def _predict_merge(x, ctx_id=0):
valid_df = df[df['predict_score'] > 0].reset_index(drop=True)
return valid_df

def _init_network(self):
def _init_network(self, **kwargs):
load_only = kwargs.get('load_only', False)
if not self.num_class:
raise ValueError('Unable to create network when `num_class` is unknown. \
It should be inferred from dataset or resumed from saved states.')
Expand Down Expand Up @@ -400,10 +401,15 @@ def _init_network(self):
self._cfg.faster_rcnn.use_fpn = 'fpn' in self._cfg.faster_rcnn.transfer
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.net = get_model(self._cfg.faster_rcnn.transfer, pretrained=True,
self.net = get_model(self._cfg.faster_rcnn.transfer, pretrained=(not load_only),
per_device_batch_size=self.batch_size // self.num_gpus,
**kwargs)
self.net.sampler._max_num_gt = self._cfg.faster_rcnn.max_num_gt
if load_only:
self.net.initialize()
self.net.set_nms(nms_thresh=0)
self.net(mx.nd.zeros((1, 3, 600, 800)))
self.net.set_nms(nms_thresh=self._cfg.faster_rcnn.nms_thresh)
self.net.reset_class(self.classes,
reuse_weights=[cname for cname in self.classes if cname in self.net.classes])
else:
Expand All @@ -427,7 +433,7 @@ def _init_network(self):
warnings.simplefilter("always")
self.net = get_model('custom_faster_rcnn_fpn', classes=self.classes, transfer=None,
dataset=self._cfg.dataset,
pretrained_base=self._cfg.train.pretrained_base,
pretrained_base=self._cfg.train.pretrained_base and not load_only,
base_network_name=self._cfg.faster_rcnn.base_network,
norm_layer=norm_layer, norm_kwargs=norm_kwargs,
sym_norm_layer=sym_norm_layer, sym_norm_kwargs=sym_norm_kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def _train_loop(self, train_data, val_data, time_limit=math.inf):
return {'train_acc': train_metric_score, 'valid_acc': self._best_acc,
'time': self._time_elapsed, 'checkpoint': cp_name}

def _init_network(self):
def _init_network(self, **kwargs):
load_only = kwargs.get('load_only', False)
if not self.num_class:
raise ValueError('Unable to create network when `num_class` is unknown. \
It should be inferred from dataset or resumed from saved states.')
Expand Down Expand Up @@ -287,7 +288,8 @@ def _init_network(self):
if input_size != self.input_size:
self._logger.info(f'Change input size to {self.input_size}, given model type: {model_name}')

if self._cfg.img_cls.use_pretrained:
use_pretrained = not load_only and self._cfg.img_cls.use_pretrained
if use_pretrained:
kwargs = {'ctx': self.ctx, 'pretrained': True, 'classes': 1000 if 'cifar' not in model_name else 10}
else:
kwargs = {'ctx': self.ctx, 'pretrained': False, 'classes': self.num_class}
Expand All @@ -303,7 +305,7 @@ def _init_network(self):

if model_name:
self.net = get_model(model_name, **kwargs)
if model_name and self._cfg.img_cls.use_pretrained:
if model_name and use_pretrained:
# reset last fully connected layer
fc_layer_found = False
for fc_name in ('output', 'fc'):
Expand Down Expand Up @@ -337,7 +339,7 @@ def _init_network(self):
self.net.cast(self._cfg.train.dtype)

# teacher model for distillation training
if self._cfg.train.teacher is not None and self._cfg.train.hard_weight < 1.0 and self.num_class == 1000:
if not load_only and self._cfg.train.teacher is not None and self._cfg.train.hard_weight < 1.0 and self.num_class == 1000:
teacher_name = self._cfg.train.teacher
self.teacher = get_model(teacher_name, pretrained=True, classes=self.num_class, ctx=self.ctx)
self.teacher.cast(self._cfg.train.dtype)
Expand Down
25 changes: 18 additions & 7 deletions gluoncv/auto/estimators/ssd/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def _predict_merge(x, ctx_id=0):
valid_df = df[df['predict_score'] > 0].reset_index(drop=True)
return valid_df

def _init_network(self):
def _init_network(self, **kwargs):
load_only = kwargs.get('load_only', False)
if not self.num_class:
raise ValueError('Unable to create network when `num_class` is unknown. \
It should be inferred from dataset or resumed from saved states.')
Expand Down Expand Up @@ -355,15 +356,25 @@ def _init_network(self):
if self._cfg.ssd.syncbn and len(self.ctx) > 1:
with warnings.catch_warnings(record=True) as _:
warnings.simplefilter("always")
self.net = get_model(self._cfg.ssd.transfer, pretrained=True,
self.net = get_model(self._cfg.ssd.transfer, pretrained=(not load_only),
norm_layer=gluon.contrib.nn.SyncBatchNorm,
norm_kwargs={'num_devices': len(self.ctx)})
self.async_net = get_model(self._cfg.ssd.transfer, pretrained=True) # used by cpu worker
self.async_net = get_model(self._cfg.ssd.transfer, pretrained=(not load_only)) # used by cpu worker
if load_only:
self.net.initialize()
self.net.set_nms(nms_thresh=0)
self.net(mx.nd.zeros((1, 3, self._cfg.ssd.data_shape, self._cfg.ssd.data_shape)))
self.net.set_nms(nms_thresh=0.3)
self.net.reset_class(self.classes,
reuse_weights=[cname for cname in self.classes if cname in self.net.classes])
else:
self.net = get_model(self._cfg.ssd.transfer, pretrained=True, norm_layer=gluon.nn.BatchNorm)
self.async_net = get_model(self._cfg.ssd.transfer, pretrained=True, norm_layer=gluon.nn.BatchNorm)
self.net = get_model(self._cfg.ssd.transfer, pretrained=(not load_only), norm_layer=gluon.nn.BatchNorm)
self.async_net = get_model(self._cfg.ssd.transfer, pretrained=(not load_only), norm_layer=gluon.nn.BatchNorm)
if load_only:
self.net.initialize()
self.net.set_nms(nms_thresh=0)
self.net(mx.nd.zeros((1, 3, self._cfg.ssd.data_shape, self._cfg.ssd.data_shape)))
self.net.set_nms(nms_thresh=0.3)
self.net.reset_class(self.classes,
reuse_weights=[cname for cname in self.classes if cname in self.net.classes])
# elif self._cfg.ssd.custom_model:
Expand All @@ -379,7 +390,7 @@ def _init_network(self):
steps=self._cfg.ssd.steps,
classes=self.classes,
dataset='auto',
pretrained_base=True,
pretrained_base=(not load_only),
norm_layer=gluon.contrib.nn.SyncBatchNorm,
norm_kwargs={'num_devices': len(self.ctx)})
self.async_net = custom_ssd(base_network_name=self._cfg.ssd.base_network,
Expand All @@ -402,7 +413,7 @@ def _init_network(self):
steps=self._cfg.ssd.steps,
classes=self.classes,
dataset=self._cfg.dataset,
pretrained_base=True,
pretrained_base=(not load_only),
norm_layer=gluon.nn.BatchNorm)
self.async_net = self.net

Expand Down
Loading