Skip to content

Commit

Permalink
Enhance the set_local for operator type (#1745)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yiliu30 and pre-commit-ci[bot] authored Apr 25, 2024
1 parent fdb5097 commit a58638c
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 12 deletions.
16 changes: 13 additions & 3 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def local_config(self):
def local_config(self, config):
self._local_config = config

def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig:
def set_local(self, operator_name: Union[str, Callable], config: BaseConfig) -> BaseConfig:
if operator_name in self.local_config:
logger.warning("The configuration for %s has already been set, update it.", operator_name)
self.local_config[operator_name] = config
Expand Down Expand Up @@ -392,14 +392,16 @@ def _get_op_name_op_type_config(self):
op_name_config_dict = dict()
for name, config in self.local_config.items():
if self._is_op_type(name):
op_type_config_dict[name] = config
# Convert the Callable to String.
new_name = self._op_type_to_str(name)
op_type_config_dict[new_name] = config
else:
op_name_config_dict[name] = config
return op_type_config_dict, op_name_config_dict

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDict[Union[str, Callable], OrderedDict[str, BaseConfig]]:
) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]:
config_mapping = OrderedDict()
if config_list is None:
config_list = [self]
Expand All @@ -416,6 +418,14 @@ def to_config_mapping(
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
return config_mapping

@staticmethod
def _op_type_to_str(op_type: Callable) -> str:
# * Ort and TF may override this method.
op_type_name = getattr(op_type, "__name__", "")
if op_type_name == "":
logger.warning("The op_type %s has no attribute __name__.", op_type)
return op_type_name

@staticmethod
def _is_op_type(name: str) -> bool:
# * Ort and TF may override this method.
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ def set_module(model, op_name, new_module):
setattr(second_last_module, name_list[-1], new_module)


def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, Callable]]:
def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, str]]:
module_dict = dict(model.named_modules())
filter_result = []
filter_result_set = set()
for op_name, module in module_dict.items():
if isinstance(module, tuple(white_module_list)):
pair = (op_name, type(module))
pair = (op_name, type(module).__name__)
if pair not in filter_result_set:
filter_result_set.add(pair)
filter_result.append(pair)
Expand Down
49 changes: 49 additions & 0 deletions test/3x/common/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,29 @@ def __repr__(self) -> str:
return "FakeModel"


class FakeOpType:
def __init__(self) -> None:
self.name = "fake_module"

def __call__(self, x) -> Any:
return x

def __repr__(self) -> str:
return "FakeModule"


class OP_TYPE1(FakeOpType):
pass


class OP_TYPE2(FakeOpType):
pass


def build_simple_fake_model():
return FakeModel()


@register_config(framework_name=FAKE_FRAMEWORK_NAME, algo_name=FAKE_CONFIG_NAME, priority=PRIORITY_FAKE_ALGO)
class FakeAlgoConfig(BaseConfig):
"""Config class for fake algo."""
Expand Down Expand Up @@ -257,6 +280,32 @@ def test_mixed_two_algos(self):
self.assertIn(OP1_NAME, [op_info[0] for op_info in config_mapping])
self.assertIn(OP2_NAME, [op_info[0] for op_info in config_mapping])

def test_set_local_op_name(self):
quant_config = FakeAlgoConfig(weight_bits=4)
# set `OP1_NAME`
fc1_config = FakeAlgoConfig(weight_bits=6)
quant_config.set_local("OP1_NAME", fc1_config)
model_info = FAKE_MODEL_INFO
logger.info(quant_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("OP1_NAME", "OP_TYPE1")].weight_bits == 6)
self.assertTrue(configs_mapping[("OP2_NAME", "OP_TYPE1")].weight_bits == 4)
self.assertTrue(configs_mapping[("OP3_NAME", "OP_TYPE2")].weight_bits == 4)

def test_set_local_op_type(self):
quant_config = FakeAlgoConfig(weight_bits=4)
# set all `OP_TYPE1`
fc1_config = FakeAlgoConfig(weight_bits=6)
quant_config.set_local(OP_TYPE1, fc1_config)
model_info = FAKE_MODEL_INFO
logger.info(quant_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("OP1_NAME", "OP_TYPE1")].weight_bits == 6)
self.assertTrue(configs_mapping[("OP2_NAME", "OP_TYPE1")].weight_bits == 6)
self.assertTrue(configs_mapping[("OP3_NAME", "OP_TYPE2")].weight_bits == 4)


class TestConfigSet(unittest.TestCase):
def setUp(self):
Expand Down
29 changes: 22 additions & 7 deletions test/3x/torch/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def test_config_white_lst2(self):
logger.info(quant_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 6)
self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 4)
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6)
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 4)

def test_config_from_dict(self):
quant_config = {
Expand Down Expand Up @@ -253,16 +253,31 @@ def test_config_mapping(self):
logger.info(quant_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 6)
self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 4)
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6)
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 4)
# test regular matching
fc_config = RTNConfig(bits=5, dtype="int8")
quant_config.set_local("fc", fc_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 5)
self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 5)
self.assertTrue(configs_mapping[("fc3", torch.nn.Linear)].bits == 5)
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 5)
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 5)
self.assertTrue(configs_mapping[("fc3", "Linear")].bits == 5)

def test_set_local_op_type(self):
quant_config = RTNConfig(bits=4, dtype="nf4")
# set all `Linear`
fc1_config = RTNConfig(bits=6, dtype="int8")
quant_config.set_local(torch.nn.Linear, fc1_config)
# get model and quantize
fp32_model = build_simple_torch_model()
model_info = get_model_info(fp32_model, white_module_list=[torch.nn.Linear])
logger.info(quant_config)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
logger.info(configs_mapping)
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6)
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 6)
self.assertTrue(configs_mapping[("fc3", "Linear")].bits == 6)

def test_gptq_config(self):
gptq_config1 = GPTQConfig(bits=8, act_order=True)
Expand Down

0 comments on commit a58638c

Please sign in to comment.