diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 05e26d8b05d..5e9e72a8882 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -198,7 +198,7 @@ def local_config(self): def local_config(self, config): self._local_config = config - def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig: + def set_local(self, operator_name: Union[str, Callable], config: BaseConfig) -> BaseConfig: if operator_name in self.local_config: logger.warning("The configuration for %s has already been set, update it.", operator_name) self.local_config[operator_name] = config @@ -392,14 +392,16 @@ def _get_op_name_op_type_config(self): op_name_config_dict = dict() for name, config in self.local_config.items(): if self._is_op_type(name): - op_type_config_dict[name] = config + # Convert the Callable to String. + new_name = self._op_type_to_str(name) + op_type_config_dict[new_name] = config else: op_name_config_dict[name] = config return op_type_config_dict, op_name_config_dict def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None - ) -> OrderedDict[Union[str, Callable], OrderedDict[str, BaseConfig]]: + ) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]: config_mapping = OrderedDict() if config_list is None: config_list = [self] @@ -416,6 +418,14 @@ def to_config_mapping( config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern] return config_mapping + @staticmethod + def _op_type_to_str(op_type: Callable) -> str: + # * Ort and TF may override this method. + op_type_name = getattr(op_type, "__name__", "") + if op_type_name == "": + logger.warning("The op_type %s has no attribute __name__.", op_type) + return op_type_name + @staticmethod def _is_op_type(name: str) -> bool: # * Ort and TF may override this method. diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index 9fd2540b803..0f17cc59f92 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -101,13 +101,13 @@ def set_module(model, op_name, new_module): setattr(second_last_module, name_list[-1], new_module) -def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, Callable]]: +def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, str]]: module_dict = dict(model.named_modules()) filter_result = [] filter_result_set = set() for op_name, module in module_dict.items(): if isinstance(module, tuple(white_module_list)): - pair = (op_name, type(module)) + pair = (op_name, type(module).__name__) if pair not in filter_result_set: filter_result_set.add(pair) filter_result.append(pair) diff --git a/test/3x/common/test_common.py b/test/3x/common/test_common.py index 68c25abd532..d1df7d98b1d 100644 --- a/test/3x/common/test_common.py +++ b/test/3x/common/test_common.py @@ -75,6 +75,29 @@ def __repr__(self) -> str: return "FakeModel" +class FakeOpType: + def __init__(self) -> None: + self.name = "fake_module" + + def __call__(self, x) -> Any: + return x + + def __repr__(self) -> str: + return "FakeModule" + + +class OP_TYPE1(FakeOpType): + pass + + +class OP_TYPE2(FakeOpType): + pass + + +def build_simple_fake_model(): + return FakeModel() + + @register_config(framework_name=FAKE_FRAMEWORK_NAME, algo_name=FAKE_CONFIG_NAME, priority=PRIORITY_FAKE_ALGO) class FakeAlgoConfig(BaseConfig): """Config class for fake algo.""" @@ -257,6 +280,32 @@ def test_mixed_two_algos(self): self.assertIn(OP1_NAME, [op_info[0] for op_info in config_mapping]) self.assertIn(OP2_NAME, [op_info[0] for op_info in config_mapping]) + def test_set_local_op_name(self): + quant_config = FakeAlgoConfig(weight_bits=4) + # set `OP1_NAME` + fc1_config = FakeAlgoConfig(weight_bits=6) + quant_config.set_local("OP1_NAME", fc1_config) + model_info = FAKE_MODEL_INFO + logger.info(quant_config) + configs_mapping = quant_config.to_config_mapping(model_info=model_info) + logger.info(configs_mapping) + self.assertTrue(configs_mapping[("OP1_NAME", "OP_TYPE1")].weight_bits == 6) + self.assertTrue(configs_mapping[("OP2_NAME", "OP_TYPE1")].weight_bits == 4) + self.assertTrue(configs_mapping[("OP3_NAME", "OP_TYPE2")].weight_bits == 4) + + def test_set_local_op_type(self): + quant_config = FakeAlgoConfig(weight_bits=4) + # set all `OP_TYPE1` + fc1_config = FakeAlgoConfig(weight_bits=6) + quant_config.set_local(OP_TYPE1, fc1_config) + model_info = FAKE_MODEL_INFO + logger.info(quant_config) + configs_mapping = quant_config.to_config_mapping(model_info=model_info) + logger.info(configs_mapping) + self.assertTrue(configs_mapping[("OP1_NAME", "OP_TYPE1")].weight_bits == 6) + self.assertTrue(configs_mapping[("OP2_NAME", "OP_TYPE1")].weight_bits == 6) + self.assertTrue(configs_mapping[("OP3_NAME", "OP_TYPE2")].weight_bits == 4) + class TestConfigSet(unittest.TestCase): def setUp(self): diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index f296dab88a9..c5bdc5261cf 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -147,8 +147,8 @@ def test_config_white_lst2(self): logger.info(quant_config) configs_mapping = quant_config.to_config_mapping(model_info=model_info) logger.info(configs_mapping) - self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 6) - self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 4) + self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6) + self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 4) def test_config_from_dict(self): quant_config = { @@ -253,16 +253,31 @@ def test_config_mapping(self): logger.info(quant_config) configs_mapping = quant_config.to_config_mapping(model_info=model_info) logger.info(configs_mapping) - self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 6) - self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 4) + self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6) + self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 4) # test regular matching fc_config = RTNConfig(bits=5, dtype="int8") quant_config.set_local("fc", fc_config) configs_mapping = quant_config.to_config_mapping(model_info=model_info) logger.info(configs_mapping) - self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 5) - self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 5) - self.assertTrue(configs_mapping[("fc3", torch.nn.Linear)].bits == 5) + self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 5) + self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 5) + self.assertTrue(configs_mapping[("fc3", "Linear")].bits == 5) + + def test_set_local_op_type(self): + quant_config = RTNConfig(bits=4, dtype="nf4") + # set all `Linear` + fc1_config = RTNConfig(bits=6, dtype="int8") + quant_config.set_local(torch.nn.Linear, fc1_config) + # get model and quantize + fp32_model = build_simple_torch_model() + model_info = get_model_info(fp32_model, white_module_list=[torch.nn.Linear]) + logger.info(quant_config) + configs_mapping = quant_config.to_config_mapping(model_info=model_info) + logger.info(configs_mapping) + self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6) + self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 6) + self.assertTrue(configs_mapping[("fc3", "Linear")].bits == 6) def test_gptq_config(self): gptq_config1 = GPTQConfig(bits=8, act_order=True)