diff --git a/gluoncv/auto/tasks/image_classification.py b/gluoncv/auto/tasks/image_classification.py index 82afd90413..1e84dd8834 100644 --- a/gluoncv/auto/tasks/image_classification.py +++ b/gluoncv/auto/tasks/image_classification.py @@ -181,11 +181,7 @@ def __init__(self, config=None, estimator=None, logger=None): # scheduler options self.search_strategy = config.get('search_strategy', 'random') - self.search_options = config.get('search_options', None) - if self.search_options: - self.search_options.update({'debug_log': True}) - else: - self.search_options = {'debug_log': True} + self.search_options = config.get('search_options', {}) self.scheduler_options = { 'resource': {'num_cpus': nthreads_per_trial, 'num_gpus': ngpus_per_trial}, 'checkpoint': config.get('checkpoint', 'checkpoint/exp1.ag'), @@ -197,7 +193,8 @@ def __init__(self, config=None, estimator=None, logger=None): 'reward_attr': 'acc_reward', 'dist_ip_addrs': config.get('dist_ip_addrs', None), 'searcher': self.search_strategy, - 'search_options': self.search_options} + 'search_options': self.search_options, + 'max_reward': config.get('max_reward', 0.95)} if self.search_strategy == 'hyperband': self.scheduler_options.update({ 'searcher': 'random', diff --git a/gluoncv/auto/tasks/object_detection.py b/gluoncv/auto/tasks/object_detection.py index 223fec06d1..329227d6d4 100644 --- a/gluoncv/auto/tasks/object_detection.py +++ b/gluoncv/auto/tasks/object_detection.py @@ -199,11 +199,7 @@ def __init__(self, config=None, logger=None): # scheduler options self.search_strategy = config.get('search_strategy', 'random') - self.search_options = config.get('search_options', None) - if self.search_options: - self.search_options.update({'debug_log': True}) - else: - self.search_options = {'debug_log': True} + self.search_options = config.get('search_options', {}) self.scheduler_options = { 'resource': {'num_cpus': nthreads_per_trial, 'num_gpus': ngpus_per_trial}, 'checkpoint': config.get('checkpoint', 'checkpoint/exp1.ag'), @@ -215,7 +211,8 @@ def __init__(self, config=None, logger=None): 'reward_attr': 'map_reward', 'dist_ip_addrs': config.get('dist_ip_addrs', None), 'searcher': self.search_strategy, - 'search_options': self.search_options} + 'search_options': self.search_options, + 'max_reward': config.get('max_reward', 0.9)} if self.search_strategy == 'hyperband': self.scheduler_options.update({ 'searcher': 'random',