Skip to content

Commit

Permalink
Add max_reward to auto task schedulers (#1598)
Browse files Browse the repository at this point in the history
* max_reward

* Trigger Build
  • Loading branch information
zhreshold authored Jan 27, 2021
1 parent 00f0a91 commit 2672f9a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
9 changes: 3 additions & 6 deletions gluoncv/auto/tasks/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,7 @@ def __init__(self, config=None, estimator=None, logger=None):

# scheduler options
self.search_strategy = config.get('search_strategy', 'random')
self.search_options = config.get('search_options', None)
if self.search_options:
self.search_options.update({'debug_log': True})
else:
self.search_options = {'debug_log': True}
self.search_options = config.get('search_options', {})
self.scheduler_options = {
'resource': {'num_cpus': nthreads_per_trial, 'num_gpus': ngpus_per_trial},
'checkpoint': config.get('checkpoint', 'checkpoint/exp1.ag'),
Expand All @@ -197,7 +193,8 @@ def __init__(self, config=None, estimator=None, logger=None):
'reward_attr': 'acc_reward',
'dist_ip_addrs': config.get('dist_ip_addrs', None),
'searcher': self.search_strategy,
'search_options': self.search_options}
'search_options': self.search_options,
'max_reward': config.get('max_reward', 0.95)}
if self.search_strategy == 'hyperband':
self.scheduler_options.update({
'searcher': 'random',
Expand Down
9 changes: 3 additions & 6 deletions gluoncv/auto/tasks/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,7 @@ def __init__(self, config=None, logger=None):

# scheduler options
self.search_strategy = config.get('search_strategy', 'random')
self.search_options = config.get('search_options', None)
if self.search_options:
self.search_options.update({'debug_log': True})
else:
self.search_options = {'debug_log': True}
self.search_options = config.get('search_options', {})
self.scheduler_options = {
'resource': {'num_cpus': nthreads_per_trial, 'num_gpus': ngpus_per_trial},
'checkpoint': config.get('checkpoint', 'checkpoint/exp1.ag'),
Expand All @@ -215,7 +211,8 @@ def __init__(self, config=None, logger=None):
'reward_attr': 'map_reward',
'dist_ip_addrs': config.get('dist_ip_addrs', None),
'searcher': self.search_strategy,
'search_options': self.search_options}
'search_options': self.search_options,
'max_reward': config.get('max_reward', 0.9)}
if self.search_strategy == 'hyperband':
self.scheduler_options.update({
'searcher': 'random',
Expand Down

0 comments on commit 2672f9a

Please sign in to comment.