Skip to content

Commit

Permalink
add parallel search for custom model (#4725)
Browse files Browse the repository at this point in the history
  • Loading branch information
shane-huang authored May 30, 2022
1 parent a4e334e commit 6ae9938
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 29 deletions.
23 changes: 23 additions & 0 deletions python/nano/src/bigdl/nano/automl/hpo/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,29 @@ def _model_build(self, trial):
self._model_compile(model, trial)
return model

def _get_model_build_args(self):
return {'lazyobj': self._lazyobj}

@staticmethod
def _get_model_builder(model_build_args,
compile_args,
compile_kwargs,
backend):

lazyobj = model_build_args.get('lazyobj')

def model_builder(trial):
model = backend.instantiate(trial, lazyobj)
# self._model_compile(model, trial)
# instantiate optimizers if it is autoobj
optimizer = compile_kwargs.get('optimizer', None)
if optimizer and isinstance(optimizer, AutoObject):
optimizer = backend.instantiate(trial, optimizer)
compile_kwargs['optimizer'] = optimizer
model.compile(*compile_args, **compile_kwargs)
return model
return model_builder

return TFAutoMdl

return registered_class
Expand Down
4 changes: 3 additions & 1 deletion python/nano/src/bigdl/nano/automl/pytorch/hposearcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def _run_search_n_procs(self, n_procs=4):
if n_trials:
subp_n_trials = math.ceil(n_trials / n_procs)
new_searcher.run_kwargs['n_trials'] = subp_n_trials
run_parallel(args=new_searcher, n_procs=n_procs)
run_parallel(func=new_searcher._run_search,
kwargs={},
n_procs=n_procs)

def search(self,
model,
Expand Down
91 changes: 71 additions & 20 deletions python/nano/src/bigdl/nano/automl/tf/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,79 @@ def _fix_target_metric(self, target_metric, fit_kwargs):
invalidInputError(False, "invalid target metric")
return target_metric

def _create_objective(self, target_metric, create_kwargs, fit_kwargs):
isprune = True if create_kwargs.get('pruner', None) else False
self.objective = Objective(
model=self._model_build,
@staticmethod
def _create_objective(model_builder,
target_metric,
isprune,
fit_kwargs,
backend):
objective = Objective(
model=model_builder,
target_metric=target_metric,
pruning=isprune,
backend=self.backend,
backend=backend,
**fit_kwargs,
)

# def _run_search_n_procs(self, n_procs=4):
# new_searcher = copy.deepcopy(self)
# n_trials = new_searcher.run_kwargs.get('n_trials', None)
# if n_trials:
# subp_n_trials = math.ceil(n_trials / n_procs)
# new_searcher.run_kwargs['n_trials'] = subp_n_trials
# run_parallel(args=new_searcher, n_procs=n_procs)

def _run_search(self):
return objective

@staticmethod
def _run_search_subproc(study,
get_model_builder_func,
get_model_builder_func_args,
backend,
target_metric,
isprune,
fit_kwargs,
run_kwargs):
"""A stand-alone function for running parallel search."""
# # run optimize
model_builder = get_model_builder_func(**get_model_builder_func_args)

objective = HPOMixin._create_objective(model_builder,
target_metric,
isprune,
fit_kwargs,
backend)

study.optimize(objective, **run_kwargs)

def _run_search_n_procs(self, isprune, n_procs=4):

# subp_study = copy.deepcopy(self.study)
# subp_objective = copy.deepcopy(self.objective)
subp_run_kwargs = copy.deepcopy(self.run_kwargs)
n_trials = subp_run_kwargs.get('n_trials', None)
if n_trials:
subp_n_trials = math.ceil(n_trials / n_procs)
subp_run_kwargs['n_trials'] = subp_n_trials

subp_kwargs = {'study': self.study,
'get_model_builder_func': self._get_model_builder,
'get_model_builder_func_args': {
'model_build_args': self._get_model_build_args(),
'compile_args': self.compile_args,
'compile_kwargs': self.compile_kwargs,
'backend': self.backend},
'backend': self.backend,
'target_metric': self.target_metric,
'isprune': isprune,
'fit_kwargs': self.fit_kwargs,
'run_kwargs': subp_run_kwargs}

# set_loky_pickler('pickle')
# Parallel(n_jobs=2)(_run_search(**subp_kwargs) for _ in range(1))
# set_loky_pickler()
run_parallel(func=self._run_search_subproc,
kwargs=subp_kwargs,
n_procs=n_procs)

def _run_search(self, isprune):
if self.objective is None:
self.objective = self._create_objective(self._model_build,
self.target_metric,
isprune,
self.fit_kwargs,
self.backend)
self.study.optimize(self.objective, **self.run_kwargs)

def search(
Expand Down Expand Up @@ -170,13 +223,11 @@ def search(
if self.study is None:
self.study = _create_study(resume, self.create_kwargs, self.backend)

if self.objective is None:
self._create_objective(self.target_metric, self.create_kwargs, self.fit_kwargs)

isprune = True if self.create_kwargs.get('pruner', None) else False
if n_parallels and n_parallels > 1:
self._run_search_n_procs(n_parallels)
self._run_search_n_procs(isprune, n_procs=n_parallels)
else:
self._run_search()
self._run_search(isprune)

self.tune_end = False

Expand Down
8 changes: 5 additions & 3 deletions python/nano/src/bigdl/nano/automl/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
log = logging.getLogger(__name__)


def run_parallel(args, n_procs):
def run_parallel(func, kwargs, n_procs):
"""
Utility to Run a number of parallel processes.
Expand All @@ -39,8 +39,10 @@ def run_parallel(args, n_procs):
log.info("-" * 100)

with TemporaryDirectory() as temp_dir:
with open(os.path.join(temp_dir, "searcher.pkl"), 'wb') as f:
cloudpickle.dump(args, f)
with open(os.path.join(temp_dir, "search_kwargs.pkl"), 'wb') as f:
cloudpickle.dump(kwargs, f)
with open(os.path.join(temp_dir, "search_func.pkl"), 'wb') as f:
cloudpickle.dump(func, f)

processes = _run_subprocess(temp_dir, n_procs)

Expand Down
10 changes: 5 additions & 5 deletions python/nano/src/bigdl/nano/automl/utils/parallel_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
if __name__ == '__main__':
temp_dir = sys.argv[1]

with open(os.path.join(temp_dir, "searcher.pkl"), 'rb') as f:
args = cloudpickle.load(f)
with open(os.path.join(temp_dir, "search_kwargs.pkl"), 'rb') as f:
kwargs = cloudpickle.load(f)
with open(os.path.join(temp_dir, "search_func.pkl"), 'rb') as f:
func = cloudpickle.load(f)

searcher = args
# do we need to reset seed?
# reset_seed()

searcher._run_search()
func(**kwargs)

0 comments on commit 6ae9938

Please sign in to comment.