Skip to content

Commit

Permalink
fixed the key comparison in the Bayesian strategy (#1484)
Browse files Browse the repository at this point in the history
  • Loading branch information
yiliu30 authored Nov 17, 2022
1 parent 43943d8 commit 1e9c12b
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions neural_compressor/strategy/bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sklearn.gaussian_process import GaussianProcessRegressor

from collections import OrderedDict
from copy import deepcopy

from ..utils import logger
from .strategy import strategy_registry, TuneStrategy
Expand Down Expand Up @@ -104,7 +105,10 @@ def params_to_tune_configs(self, params):
op_tuning_cfg[op_name_type] = configs[0]
else:
op_tuning_cfg[op_name_type] = configs[min(len(configs) - 1, int(params[op_name_type[0]]))]
calib_sampling_size = calib_sampling_size_lst[min(len(configs) - 1, int(params['calib_sampling_size']))]
if len(calib_sampling_size_lst) > 1:
calib_sampling_size = calib_sampling_size_lst[min(len(configs) - 1, int(params['calib_sampling_size']))]
else:
calib_sampling_size = calib_sampling_size_lst[0]
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
return op_tuning_cfg

Expand All @@ -115,7 +119,6 @@ def next_tune_cfg(self):
"""
params = None
pbounds = {}
from copy import deepcopy
tuning_space = self.tuning_space
calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options
op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg()
Expand All @@ -126,7 +129,8 @@ def next_tune_cfg(self):
for op_name_type, configs in self.op_configs.items():
if len(configs) > 1:
pbounds[op_name_type[0]] = (0, len(configs))
pbounds['calib_sampling_size'] = (0, len(calib_sampling_size_lst))
if len(calib_sampling_size_lst) > 1:
pbounds['calib_sampling_size'] = (0, len(calib_sampling_size_lst))
if len(pbounds) == 0:
yield self.params_to_tune_configs(params)
return
Expand Down Expand Up @@ -225,10 +229,11 @@ def __init__(self, pbounds, random_seed=9527):
"""
self.random_seed = random_seed
# Get the name of the parameters
self._keys = sorted(pbounds)
names = list(pbounds.keys())
self._keys = deepcopy(names)
# Create an array with parameters bounds
self._bounds = np.array(
[item[1] for item in sorted(pbounds.items(), key=lambda x: x[0])],
[pbounds[name] for name in names],
dtype=np.float
)

Expand Down Expand Up @@ -275,7 +280,7 @@ def params_to_array(self, params):
assert set(params) == set(self.keys)
except AssertionError:
raise ValueError(
"Parameters' keys ({}) do ".format(sorted(params)) +
"Parameters' keys ({}) do ".format(list(params.keys())) +
"not match the expected set of keys ({}).".format(self.keys)
)
return np.asarray([params[key] for key in self.keys])
Expand Down

0 comments on commit 1e9c12b

Please sign in to comment.