Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto module update #1596

Merged
merged 13 commits into from
Jan 24, 2021
2 changes: 2 additions & 0 deletions gluoncv/auto/data/auto_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from ...utils.download import download
from ...utils.filesystem import unzip, untar, PathTree

__all__ = ['url_data']

def url_data(url, path=None, overwrite=False, overwrite_folder=False, sha1_hash=None, root=None, disp_depth=1):
"""Download an given URL

Expand Down
5 changes: 3 additions & 2 deletions gluoncv/auto/estimators/base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def __setstate__(self, state):
self.__dict__.update(state)
# logger
self._logger = logging.getLogger(state.get('_name', self.__class__.__name__))
self._logger.setLevel(logging.INFO)
self._logger.setLevel(logging.ERROR)
fh = logging.FileHandler(self._log_file)
self._logger.addHandler(fh)
try:
Expand All @@ -275,7 +275,7 @@ def __setstate__(self, state):
with temporary_filename() as tfile:
with open(tfile, 'wb') as fo:
fo.write(net_params)
self.net.load_parameters(tfile)
self.net.load_parameters(tfile, ignore_extra=True)
trainer_state = state['trainer']
self._init_trainer()
with temporary_filename() as tfile:
Expand All @@ -284,3 +284,4 @@ def __setstate__(self, state):
self.trainer.load_states(tfile)
except ImportError:
pass
self._logger.setLevel(logging.INFO)
6 changes: 6 additions & 0 deletions gluoncv/auto/estimators/center_net/center_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def _predict(self, x):
short_size = min(self._cfg.center_net.data_shape)
if isinstance(x, str):
x = load_test(x, short=short_size, max_size=1024)[0]
elif isinstance(x, np.ndarray):
return self._predict(mx.nd.array(x))
elif isinstance(x, mx.nd.NDArray):
if len(x.shape) != 3 or x.shape[-1] != 3:
raise ValueError('array input with shape (h, w, 3) is required for predict')
x = transform_test(x, short=short_size, max_size=1024)[0]
elif isinstance(x, pd.DataFrame):
assert 'image' in x.columns, "Expect column `image` for input images"
Expand All @@ -56,6 +60,8 @@ def _predict_merge(x):
y['image'] = x
return y
return pd.concat([_predict_merge(xx) for xx in x['image']]).reset_index(drop=True)
elif isinstance(x, (list, tuple)):
return pd.concat([self._predict(xx) for xx in x]).reset_index(drop=True)
else:
raise ValueError('Input is not supported: {}'.format(type(x)))
height, width = x.shape[2:4]
Expand Down
6 changes: 6 additions & 0 deletions gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ def _predict(self, x):
short_size = self.net.short[-1] if isinstance(self.net.short, (tuple, list)) else self.net.short
if isinstance(x, str):
x = load_test(x, short=short_size, max_size=1024)[0]
elif isinstance(x, np.ndarray):
if len(x.shape) != 3 or x.shape[-1] != 3:
raise ValueError('array input with shape (h, w, 3) is required for predict')
return self._predict(mx.nd.array(x))
elif isinstance(x, mx.nd.NDArray):
x = transform_test(x, short=short_size, max_size=1024)[0]
elif isinstance(x, pd.DataFrame):
Expand All @@ -285,6 +289,8 @@ def _predict_merge(x):
y['image'] = x
return y
return pd.concat([_predict_merge(xx) for xx in x['image']]).reset_index(drop=True)
elif isinstance(x, (list, tuple)):
return pd.concat([self._predict(xx) for xx in x]).reset_index(drop=True)
else:
raise ValueError('Input is not supported: {}'.format(type(x)))
height, width = x.shape[2:4]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,11 @@ def _predict(self, x):
resize = int(math.ceil(self.input_size / self._cfg.train.crop_ratio))
if isinstance(x, str):
x = transform_eval(mx.image.imread(x), resize_short=resize, crop_size=self.input_size)
elif isinstance(x, np.ndarray):
return self._predict(mx.nd.array(x))
elif isinstance(x, mx.nd.NDArray):
if len(x.shape) != 3 or x.shape[-1] != 3:
raise ValueError('array input with shape (h, w, 3) is required for predict')
x = transform_eval(x, resize_short=resize, crop_size=self.input_size)
elif isinstance(x, pd.DataFrame):
assert 'image' in x.columns, "Expect column `image` for input images"
Expand Down
6 changes: 6 additions & 0 deletions gluoncv/auto/estimators/ssd/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,11 @@ def _predict(self, x):
short_size = int(self._cfg.ssd.data_shape)
if isinstance(x, str):
x = load_test(x, short=short_size, max_size=1024)[0]
elif isinstance(x, np.ndarray):
return self._predict(mx.nd.array(x))
elif isinstance(x, mx.nd.NDArray):
if len(x.shape) != 3 or x.shape[-1] != 3:
raise ValueError('array input with shape (h, w, 3) is required for predict')
x = transform_test(x, short=short_size, max_size=1024)[0]
elif isinstance(x, pd.DataFrame):
assert 'image' in x.columns, "Expect column `image` for input images"
Expand All @@ -259,6 +263,8 @@ def _predict_merge(x):
y['image'] = x
return y
return pd.concat([_predict_merge(xx) for xx in x['image']]).reset_index(drop=True)
elif isinstance(x, (list, tuple)):
return pd.concat([self._predict(xx) for xx in x]).reset_index(drop=True)
else:
raise ValueError('Input is not supported: {}'.format(type(x)))
height, width = x.shape[2:4]
Expand Down
6 changes: 6 additions & 0 deletions gluoncv/auto/estimators/yolo/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ def _predict(self, x):
short_size = int(self._cfg.yolo3.data_shape)
if isinstance(x, str):
x = load_test(x, short=short_size, max_size=1024)[0]
elif isinstance(x, np.ndarray):
return self._predict(mx.nd.array(x))
elif isinstance(x, mx.nd.NDArray):
if len(x.shape) != 3 or x.shape[-1] != 3:
raise ValueError('array input with shape (h, w, 3) is required for predict')
x = transform_test(x, short=short_size, max_size=1024)[0]
elif isinstance(x, pd.DataFrame):
assert 'image' in x.columns, "Expect column `image` for input images"
Expand All @@ -260,6 +264,8 @@ def _predict_merge(x):
y['image'] = x
return y
return pd.concat([_predict_merge(xx) for xx in x['image']]).reset_index(drop=True)
elif isinstance(x, (list, tuple)):
return pd.concat([self._predict(xx) for xx in x]).reset_index(drop=True)
else:
raise ValueError('Input is not supported: {}'.format(type(x)))
height, width = x.shape[2:4]
Expand Down
35 changes: 17 additions & 18 deletions gluoncv/auto/tasks/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class LiteConfig:
lr : Union[ag.Space, float] = 1e-2
num_trials : int = 1
epochs : Union[ag.Space, int] = 5
batch_size : Union[ag.Space, int] = 3 # 2 ** 3 == 8
batch_size : Union[ag.Space, int] = 8
nthreads_per_trial : int = 32
ngpus_per_trial : int = 0
time_limits : int = 7 * 24 * 60 * 60 # 7 days
Expand All @@ -44,7 +44,7 @@ class DefaultConfig:
lr : Union[ag.Space, float] = ag.Categorical(1e-2, 5e-2)
num_trials : int = 3
epochs : Union[ag.Space, int] = 15
batch_size : Union[ag.Space, int] = 4 # 2 ** 4 = 16
batch_size : Union[ag.Space, int] = 16
nthreads_per_trial : int = 128
ngpus_per_trial : int = 8
time_limits : int = 7 * 24 * 60 * 60 # 7 days
Expand All @@ -61,6 +61,13 @@ def _train_image_classification(args, reporter):
# train, val data
train_data = args.pop('train_data')
val_data = args.pop('val_data')
# exponential batch size for Int() space batch sizes
try:
exp_batch_size = args.pop('exp_batch_size')
except AttributeError:
exp_batch_size = False
if exp_batch_size and 'batch_size' in args:
args['batch_size'] = 2 ** args['batch_size']
try:
task = args.pop('task')
dataset = args.pop('dataset')
Expand Down Expand Up @@ -141,6 +148,8 @@ def __init__(self, config=None, estimator=None, logger=None):
else:
if not config.get('dist_ip_addrs', None):
ngpus_per_trial = config.get('ngpus_per_trial', gpu_count)
if ngpus_per_trial > gpu_count:
ngpus_per_trial = gpu_count
if ngpus_per_trial < 1:
self._logger.info('No GPU detected/allowed, using most conservative search space.')
default_config = LiteConfig()
Expand All @@ -167,10 +176,6 @@ def __init__(self, config=None, estimator=None, logger=None):
# additional configs
config['num_workers'] = nthreads_per_trial
config['gpus'] = [int(i) for i in range(ngpus_per_trial)]
# if config['gpus']:
# config['batch_size'] = config.get('batch_size', 8) * len(config['gpus'])
# self._logger.info('Increase batch size to %d based on the number of gpus %d',
# config['batch_size'], len(config['gpus']))
config['seed'] = config.get('seed', np.random.randint(32,767))
self._config = config

Expand Down Expand Up @@ -242,17 +247,6 @@ def fit(self, train_data, val_data=None, train_size=0.9, random_state=None):
len(train), len(val))
train_data, val_data = train, val

# automatically suggest some hyperparameters based on the dataset statistics(experimental)
# estimator = self._config.get('estimator', None)
# if estimator is None:
# estimator = [ImageClassificationEstimator]
# elif isinstance(estimator, (tuple, list)):
# pass
# else:
# assert issubclass(estimator, BaseEstimator)
# estimator = [estimator]
# self._config['estimator'] = ag.Categorical(*estimator)

estimator = self._config.get('estimator', None)
if estimator is None:
estimator = [ImageClassificationEstimator]
Expand All @@ -266,7 +260,12 @@ def fit(self, train_data, val_data=None, train_size=0.9, random_state=None):
estimator[i] = ImageClassificationEstimator
else:
estimator.pop(e)
self._config['estimator'] = ag.Categorical(*estimator)
if not estimator:
raise ValueError('Unable to determine the estimator for fit function.')
if len(estimator) == 1:
self._config['estimator'] = estimator[0]
else:
self._config['estimator'] = ag.Categorical(*estimator)

# register args
config = self._config.copy()
Expand Down
43 changes: 9 additions & 34 deletions gluoncv/auto/tasks/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,39 +59,16 @@ def _train_object_detection(args, reporter):
----------
args: <class 'autogluon.utils.edict.EasyDict'>
"""
# pruning for batch size
# if args.get('batch_size', None):
# if args.estimator == FasterRCNNEstimator and args.batch_size not in [4, 8]:
# logging.info('Estimator and batch size are not matched, this trial is skipped.')
# return
# elif args.estimator != FasterRCNNEstimator and args.batch_size in [4, 8]:
# logging.info('Estimator and batch size are not matched, this trial is skipped.')
# return

# pruning for base network
# if args.get('base_network', None):
# if args.estimator == SSDEstimator and \
# args.base_network not in ['vgg16_atrous', 'resnet18_v1', 'resnet50_v1',
# 'resnet101_v2', 'resnet152_v2', 'resnet34_v1b']:
# logging.info('Estimator and base network are not matched, this trial is skipped.')
# return
# elif args.estimator == YOLOv3Estimator and \
# args.base_network not in ['darknet53']:
# logging.info('Estimator and base network are not matched, this trial is skipped.')
# return
# elif args.estimator == FasterRCNNEstimator and \
# args.base_network not in ['resnet50_v1b', 'resnet101_v1d',
# 'resnest50', 'resnest101', 'resnest269']:
# logging.info('Estimator and base network are not matched, this trial is skipped.')
# return
# elif args.estimator == CenterNetEstimator and \
# args.base_network not in ['resnet18_v1b', 'resnet50_v1b', 'resnet101_v1b', 'dla34']:
# logging.info('Estimator and base network are not matched, this trial is skipped.')
# return

# train, val data
train_data = args.pop('train_data')
val_data = args.pop('val_data')
# exponential batch size for Int() space batch sizes
try:
exp_batch_size = args.pop('exp_batch_size')
except AttributeError:
exp_batch_size = False
if exp_batch_size and 'batch_size' in args:
args['batch_size'] = 2 ** args['batch_size']
try:
task = args.pop('task')
dataset = args.pop('dataset')
Expand Down Expand Up @@ -164,6 +141,8 @@ def __init__(self, config=None, logger=None):
else:
if not config.get('dist_ip_addrs', None):
ngpus_per_trial = config.get('ngpus_per_trial', gpu_count)
if ngpus_per_trial > gpu_count:
ngpus_per_trial = gpu_count
if ngpus_per_trial < 1:
self._logger.info('No GPU detected/allowed, using most conservative search space.')
default_config = LiteConfig()
Expand Down Expand Up @@ -215,10 +194,6 @@ def __init__(self, config=None, logger=None):
# additional configs
config['num_workers'] = nthreads_per_trial
config['gpus'] = [int(i) for i in range(ngpus_per_trial)]
# if config['gpus']:
# config['batch_size'] = config.get('batch_size', 8) * len(config['gpus'])
# self._logger.info('Increase batch size to %d based on the number of gpus %d',
# config['batch_size'], len(config['gpus']))
config['seed'] = config.get('seed', np.random.randint(32,767))
self._config = config

Expand Down
58 changes: 20 additions & 38 deletions gluoncv/auto/tasks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ def auto_suggest(config, estimator, logger):
estimator[i] = FasterRCNNEstimator
elif e == 'center_net':
estimator[i] = CenterNetEstimator
config['estimator'] = ag.Categorical(*estimator)
if not estimator:
raise ValueError('Unable to determine the estimator for fit function.')
if len(estimator) == 1:
config['estimator'] = estimator[0]
else:
config['estimator'] = ag.Categorical(*estimator)

# get dataset statistics
# user needs to define a Dataset object "train_dataset" when using custom dataset
Expand Down Expand Up @@ -217,36 +222,6 @@ def get_recursively(search_dict, field):

def config_to_nested(config):
"""Convert config to nested version"""
# estimator = config.get('estimator', None)
# if estimator is None:
# transfer = config.get('transfer', None)
# assert transfer is not None, "estimator or transfer is required in search space"
# if transfer.startswith('ssd'):
# estimator = SSDEstimator
# elif transfer.startswith('faster_rcnn'):
# estimator = FasterRCNNEstimator
# elif transfer.startswith('yolo3'):
# estimator = YOLOv3Estimator
# elif transfer.startswith('center_net'):
# estimator = CenterNetEstimator
# else:
# estimator = ImageClassificationEstimator
# else:
# # str to instance
# if isinstance(estimator, str):
# if estimator == 'ssd':
# estimator = SSDEstimator
# elif estimator == 'faster_rcnn':
# estimator = FasterRCNNEstimator
# elif estimator == 'yolo3':
# estimator = YOLOv3Estimator
# elif estimator == 'center_net':
# estimator = CenterNetEstimator
# elif estimator == 'img_cls':
# estimator = ImageClassificationEstimator
# else:
# raise ValueError(f'Unknown estimator: {estimator}')

estimator = config.get('estimator', None)
transfer = config.get('transfer', None)
# choose hyperparameters based on pretrained model in transfer learning
Expand Down Expand Up @@ -294,21 +269,28 @@ def config_to_nested(config):
else:
assert issubclass(estimator, BaseEstimator)

# batch size is the power of 2
if config.get('batch_size', None):
config['batch_size'] = 2 ** config['batch_size']

cfg_map = estimator._default_cfg.asdict()

def _recursive_update(config, key, value):
def _recursive_update(config, key, value, auto_strs, auto_ints):
for k, v in config.items():
if k in auto_strs:
config[k] = 'auto'
if k in auto_ints:
config[k] = -1
if key == k:
config[key] = value
elif isinstance(v, dict):
_recursive_update(v, key, value)
_recursive_update(v, key, value, auto_strs, auto_ints)

if 'use_rec' in config:
auto_strs = ['data_dir']
auto_ints = []
else:
auto_strs = ['data_dir', 'rec_train', 'rec_train_idx', 'rec_val', 'rec_val_idx',
'dataset', 'dataset_root']
auto_ints = ['num_training_samples']
for k, v in config.items():
_recursive_update(cfg_map, k, v)
_recursive_update(cfg_map, k, v, auto_strs, auto_ints)
cfg_map['estimator'] = estimator
return cfg_map

Expand Down