diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index b1eda3d71dc..1650bb4c97c 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -847,6 +847,12 @@ def __init__(self, framework_specific_info): self.fp32_results = [] self.fp32_preds_as_label = False + if self.version.release >= Version("2.0").release: + static_quant_mapping = tq.quantization_mappings.get_default_static_quant_module_mappings() + self.fused_op_list = \ + [static_quant_mapping[key] for key in static_quant_mapping if "intrinsic." in str(key)] + self.fused_dict = {} + def calib_func(self, model, dataloader, tmp_iterations, conf=None): try: for idx, (input, label) in enumerate(dataloader): @@ -1229,6 +1235,400 @@ def _combine_capability(self, bf16_ops, q_capability): q_capability['optypewise'][bf16_op[1]] = [bf16_config, fp32_config] return q_capability + def get_fused_list(self, model): + """This is a helper function to get fused op list. + + Args: + model (object): input model + + Returns: + dict of op list + """ + fused_dict = {} + for op_name, child in model.named_modules(): + if type(child) in self.fused_op_list: + in_fused_loop = False + fp32_int8_ops = [op_name,] + type_name = str(child).split("(")[0] + prefix_index = op_name.rfind(".") + for fp32_op_name, module in self.pre_optimized_model.model.named_modules(): + fp32_type_name = str(module).split("(")[0] + prefix_fp32_index = fp32_op_name.rfind(".") + if op_name == fp32_op_name: + in_fused_loop = True + continue + elif in_fused_loop and \ + op_name[: prefix_index if prefix_index > -1 else 0] == \ + fp32_op_name[: prefix_fp32_index if prefix_fp32_index > -1 else 0]: + if "BatchNorm" in str(type(module)): + fp32_int8_ops.append(fp32_op_name) + continue + elif fp32_type_name in type_name.split(".")[-1][-len(fp32_type_name) - 2:]: + fp32_int8_ops.append(fp32_op_name) + in_fused_loop = False + break + else: + in_fused_loop = False + break + elif in_fused_loop: + in_fused_loop = False + break + fused_dict.update({op_name: fp32_int8_ops}) + return fused_dict + + def inspect_tensor(self, + model, + dataloader, + op_list=None, + iteration_list=None, + inspect_type='activation', + save_to_disk=False, + save_path=None, + quantization_cfg=None): + assert self.version.release >= Version("2.0").release, "Inspect_tensor only support torch 1.8 or above!" + from neural_compressor.utils.utility import dump_data_to_local + from torch import dequantize + is_quantized = model.is_quantized + op_list_ = [] + fp32_int8_map = {} + for op_name in op_list: + op_list_.append(op_name) + for key in self.fused_dict: + if op_name in self.fused_dict[key]: + op_list_.pop() + fp32_int8_map[op_name] = \ + {'activation': self.fused_dict[key][-1], 'weight': self.fused_dict[key][0]} + if not is_quantized: + op_list_.append(self.fused_dict[key][-1]) + elif self.fused_dict[key][0] not in op_list_: + op_list_.append(self.fused_dict[key][0]) + break + + assert min(iteration_list) > 0, \ + "Iteration number should great zero, 1 means first iteration." + iterations = max(iteration_list) if iteration_list is not None else -1 + new_model = self._pre_eval_hook(model, op_list=op_list_, iteration_list=iteration_list) + self.evaluate(new_model, dataloader, iteration=iterations) + observer_dict = {} + ret = {} + if inspect_type == 'activation' or inspect_type == 'all': + from torch.quantization.quantize import _get_observer_dict as get_observer_dict + ret['activation'] = [] + get_observer_dict(new_model.model, observer_dict) + if iteration_list is None: + iteration_list = [1] + for i in iteration_list: + summary = OrderedDict() + for key in observer_dict: + if isinstance(observer_dict[key], torch.nn.modules.linear.Identity): + continue + op_name = key.replace(".activation_post_process", "") + value = observer_dict[key].get_tensor_value()[i] + if op_name in op_list: + if type(value) is list: + summary[op_name] = {} + for index in range(len(value)): + summary[op_name].update({ + op_name + ".output" + str(index): + dequantize(value[index]).numpy() + if value[index].is_quantized else value[index].numpy() + }) + else: + summary[op_name] = { + op_name + ".output0": + dequantize(value).numpy() if value.is_quantized else value.numpy() + } + else: + if bool(self.fused_dict): + if is_quantized: + for a in fp32_int8_map: + if op_name == fp32_int8_map[a]['weight']: + if type(value) is list: + summary[a] = {} + for index in range(len(value)): + summary[a].update({ + a + ".output" + str(index): + dequantize(value[index]).numpy() + if value[index].is_quantized else + value[index].numpy() + }) + else: + summary[a] = { + a + ".output0": + dequantize(value).numpy() + if value.is_quantized else value.numpy() + } + else: + for a in fp32_int8_map: # pragma: no cover + if op_name == fp32_int8_map[a]['activation']: + if type(value) is list: + summary[a] = {} + for index in range(len(value)): + summary[a].update({ + a + ".output" + str(index): + dequantize(value[index]).numpy() + if value[index].is_quantized else + value[index].numpy() + }) + else: + summary[a] = { + a + ".output0": + dequantize(value).numpy() + if value.is_quantized else value.numpy() + } + + ret['activation'].append(summary) + + if inspect_type == 'weight' or inspect_type == 'all': + ret['weight'] = {} + state_dict = new_model._model.state_dict() + + for key in state_dict: + if not isinstance(state_dict[key], torch.Tensor): + continue + if 'weight' not in key and 'bias' not in key: + continue + + op = key[:key.rfind('.')] + op = op.replace('._packed_params', '') + + if op in op_list: + if op in ret['weight']: + ret['weight'][op].update({ + key: + dequantize(state_dict[key]).numpy() + if state_dict[key].is_quantized else state_dict[key].detach().numpy() + }) + else: + ret['weight'][op] = { + key: + dequantize(state_dict[key]).numpy() + if state_dict[key].is_quantized else state_dict[key].detach().numpy() + } + else: + if bool(self.fused_dict): + if is_quantized: + for a in fp32_int8_map: + if op == fp32_int8_map[a]['weight']: + if a in ret['weight']: + ret['weight'][a].update({ + key: + dequantize(state_dict[key]).numpy() + if state_dict[key].is_quantized else + state_dict[key].detach().numpy() + }) + else: + ret['weight'][a] = \ + {key: dequantize(state_dict[key]).numpy() + if state_dict[key].is_quantized else + state_dict[key].detach().numpy()} + break + else: + ret['weight'] = None + + if save_to_disk: + if not save_path: + save_path = self.workspace_path + dump_data_to_local(ret, save_path, 'inspect_result.pkl') + + return ret + + def _pre_eval_hook(self, model, op_list=None, iteration_list=None): + """The function is used to do some preprocession before evaluation phase. + Here, it used to add hook for dump output tensor for quantizable ops. + + Args: + model (object): input model + + Returns: + model (object): model with hook + """ + from abc import ABCMeta + + def _with_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances. + + Example:: + + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) + >>> foo_instance1 = foo_builder() + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1) == id(foo_instance2) + False + """ + class _PartialWrapper(object): + def __init__(self, p): + self.p = p + + def __call__(self, *args, **keywords): + return self.p(*args, **keywords) + + def __repr__(self): + return self.p.__repr__() + + with_args = _with_args + + r = _PartialWrapper(partial(cls_or_self, **kwargs)) + return r + + ABC = ABCMeta(str("ABC"), (object, ), {}) # compatible with Python 2 *and* 3: + + class _RecordingObserver(ABC, torch.nn.Module): + """The module is mainly for debug and records the tensor values during runtime. + + Args: + iteration_list (list, optional): indexs of iteration which to dump tensor. + """ + def __init__(self, iteration_list=None, **kwargs): + super(_RecordingObserver, self).__init__(**kwargs) + self.output_tensors_dict = OrderedDict() + self.current_iter = 1 + self.iteration_list = iteration_list + + def forward(self, x): + if (self.iteration_list is None and self.current_iter == 1) or \ + (self.iteration_list is not None and + self.current_iter in self.iteration_list): + if type(x) is tuple or type(x) is list: + self.output_tensors_dict[self.current_iter] = \ + [i.to("cpu") if i.device != 'cpu' else i.clone() for i in x] + else: + self.output_tensors_dict[self.current_iter] = \ + x.to("cpu") if x.device != "cpu" else x.clone() + self.current_iter += 1 + return x + + @torch.jit.export + def get_tensor_value(self): + return self.output_tensors_dict + + with_args = classmethod(_with_args) + + def _observer_forward_hook(module, input, output): + """Forward hook that calls observer on the output + + Args: + module (object): input module + input (object): module input + output (object): module output + + Returns: + module output tensor (object) + """ + return module.activation_post_process(output) + + def _add_observer_(module, op_list=None, prefix=""): + """Add observer for the leaf child of the module. + + This function insert observer module to all leaf child module that + has a valid qconfig attribute. + + Args: + module (object): input module with qconfig attributes for all the leaf modules that + we want to dump tensor + op_list (list, optional): list of ops which to be dumped in module + prefix (string): name of module + + Returns: + None, module is modified inplace with added observer modules and forward_hooks + """ + for name, child in module.named_children(): + op_name = name if prefix == "" else prefix + "." + name + if isinstance(child, torch.nn.quantized.FloatFunctional) and \ + (op_list is None or op_name in op_list): + if hasattr(child, 'qconfig') and child.qconfig is not None and ( + op_list is None or op_name in op_list): + child.activation_post_process = \ + child.qconfig.activation() + elif hasattr(child, 'qconfig') and child.qconfig is not None and \ + (op_list is None or op_name in op_list): + # observer and hook will be gone after we swap the module + child.add_module('activation_post_process', child.qconfig.activation()) + child.register_forward_hook(_observer_forward_hook) + else: + _add_observer_(child, op_list, op_name) + + def _propagate_qconfig_helper(module, + qconfig_dict, + white_list=None, + qconfig_parent=None, + prefix='', + fused=False): + """This is a helper function for `propagate_qconfig_` + + Args: + module (object): input module + qconfig_dict (dictionary): dictionary that maps from name of submodule to + quantization configuration + white_list (list, optional): list of quantizable modules + qconfig_parent (object, optional): config of parent module, we will fallback to + this config when there is no specified config + for current module + prefix (string, optional): corresponding prefix of the current module, + used as key in qconfig_dict + fused (bool, optional): Indicates whether the module is fused or not + + Return: + None, module is modified inplace with qconfig attached + """ + module.qconfig = qconfig_parent + if hasattr(module, '_modules'): + for name, child in module.named_children(): + module_prefix = prefix + '.' + name if prefix else name + _propagate_qconfig_helper(child, qconfig_dict, white_list, qconfig_parent, + module_prefix) + + def _prepare(model, inplace=True, op_list=[], white_list=None): + """The model will be attached with observer or fake quant modules, and qconfig + will be propagated. + + Args: + model (object): input model to be modified in-place + inplace (bool, optional): carry out model transformations in-place, + the original module is mutated + op_list (list, optional): list of ops which to be dumped in module + white_list (list, optional): list of quantizable modules + + Returns: + model (object): model with qconfig + """ + if not inplace: + model = copy.deepcopy(model) + _propagate_qconfig_helper(model, + qconfig_dict={}, + white_list=white_list, + qconfig_parent=model.qconfig) + # sanity check common API misusage + if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): # pragma: no cover + logger.warn("None of the submodule got qconfig applied. Make sure you " + "passed correct configuration through `qconfig_dict` or " + "by assigning the `.qconfig` attribute directly on submodules") + _add_observer_(model, op_list=op_list) + return model + + # create properties + if self.version.release < Version("1.7.0").release: # pragma: no cover + white_list = self.white_list | \ + (set(torch.quantization.default_mappings.DEFAULT_MODULE_MAPPING.values()) | + set(torch.quantization.default_mappings.DEFAULT_QAT_MODULE_MAPPING.values()) | + set(torch.quantization.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING.values())) + elif self.version.release < Version("1.8.0").release: # pragma: no cover + white_list = torch.quantization.get_compare_output_module_list() + else: + white_list = torch.quantization.get_default_compare_output_module_list() + + model = model if model.is_quantized else copy.deepcopy(model) + model._model.qconfig = torch.quantization.QConfig( + weight=torch.quantization.default_debug_observer, + activation=_RecordingObserver.with_args(iteration_list=iteration_list)) + _prepare(model._model, op_list=op_list, white_list=white_list) + + return model + def is_fused_module(self, module): """This is a helper function for `_propagate_qconfig_helper` to detecte if this module is fused. @@ -1495,7 +1895,6 @@ def __init__(self, framework_specific_info): # for tensorboard self.dump_times = 0 - self.fused_dict = {} self.optype_statistics = None @@ -1604,6 +2003,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg) + self.fused_dict = self.get_fused_list(q_model.model) q_model.q_config = copy.deepcopy(self.tune_cfg) if self.approach != 'post_training_dynamic_quant': self._get_scale_zeropoint(q_model._model, q_model.q_config) @@ -1852,7 +2252,6 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): Returns: None """ - module_dict = dict(model.named_modules()) for op_name, child in model.named_modules(): if self.is_fused_module(child): @@ -1860,10 +2259,6 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): module_prefix = op_name + '.' + name if module_prefix in module_dict: module_dict.pop(module_prefix) # remove sub-modules of fused modules - if op_name in self.fused_dict: - self.fused_dict[op_name] = [self.fused_dict[op_name], module_prefix] - else: - self.fused_dict[op_name] = module_prefix for op_name, child in module_dict.items(): # there is accuracy issue in quantized LayerNorm op in pytorch <1.8.1, # so remove it here @@ -2252,171 +2647,6 @@ def _post_eval_hook(self, model, **args): def save(self, model, path=None): pass - def inspect_tensor(self, - model, - dataloader, - op_list=None, - iteration_list=None, - inspect_type='activation', - save_to_disk=False): - if self.version.release >= Version("1.8.0").release: - from torch.fx import GraphModule - if type(model._model) == GraphModule: # pragma: no cover - assert False, "Inspect_tensor didn't support fx graph model now!" - from torch import dequantize - import numpy as np - is_quantized = model.is_quantized - op_list_ = [] - fp32_int8_map = {} - for op_name in op_list: - op_list_.append(op_name) - for key in self.fused_dict: - if op_name in self.fused_dict[key]: - fp32_int8_map[op_name] = \ - {'activation': self.fused_dict[key][-1], 'weight': key} - if is_quantized: - op_list_.append(key) - op_list_.remove(op_name) - else: - op_list_.append(self.fused_dict[key][-1]) - - new_model = model if is_quantized else copy.deepcopy(model) - - assert min(iteration_list) > 0, \ - "Iteration number should great zero, 1 means first iteration." - iterations = max(iteration_list) if iteration_list is not None else -1 - new_model = self._pre_eval_hook(new_model, op_list=op_list_, iteration_list=iteration_list) - self.evaluate(new_model, dataloader, iteration=iterations) - observer_dict = {} - ret = {} - if inspect_type == 'activation' or inspect_type == 'all': - if self.version.release >= Version("2.0.0").release: - from torch.quantization.quantize import _get_observer_dict as get_observer_dict - else: - from torch.quantization import get_observer_dict - ret['activation'] = [] - get_observer_dict(new_model._model, observer_dict) - if iteration_list is None: - iteration_list = [1] - for i in iteration_list: - summary = OrderedDict() - for key in observer_dict: - if isinstance(observer_dict[key], torch.nn.modules.linear.Identity): - continue - op_name = key.replace(".activation_post_process", "") - value = observer_dict[key].get_tensor_value()[i] - if op_name in op_list: - if type(value) is list: - summary[op_name] = {} - for index in range(len(value)): - summary[op_name].update({ - op_name + ".output" + str(index): - dequantize(value[index]).numpy() - if value[index].is_quantized else value[index].numpy() - }) - else: - summary[op_name] = { - op_name + ".output0": - dequantize(value).numpy() if value.is_quantized else value.numpy() - } - else: - if bool(self.fused_dict): - if is_quantized: - for a in fp32_int8_map: - if op_name == fp32_int8_map[a]['weight']: - if type(value) is list: - summary[a] = {} - for index in range(len(value)): - summary[a].update({ - op_name + ".output" + str(index): - dequantize(value[index]).numpy() - if value[index].is_quantized else - value[index].numpy() - }) - else: - summary[a] = { - op_name + ".output0": - dequantize(value).numpy() - if value.is_quantized else value.numpy() - } - else: - for a in fp32_int8_map: # pragma: no cover - if op_name == fp32_int8_map[a]['activation']: - if type(value) is list: - summary[a] = {} - for index in range(len(value)): - summary[a].update({ - op_name + ".output" + str(index): - dequantize(value[index]).numpy() - if value[index].is_quantized else - value[index].numpy() - }) - else: - summary[a] = { - op_name + ".output0": - dequantize(value).numpy() - if value.is_quantized else value.numpy() - } - - if save_to_disk: - dump_dir = os.path.join(self.workspace_path, 'dump_tensor') - os.makedirs(dump_dir, exist_ok=True) - np.savez(os.path.join(dump_dir, 'activation_iter{}.npz'.format(i)), **summary) - - ret['activation'].append(summary) - - if inspect_type == 'weight' or inspect_type == 'all': - ret['weight'] = {} - state_dict = new_model._model.state_dict() - - for key in state_dict: - if not isinstance(state_dict[key], torch.Tensor): - continue - if 'weight' not in key and 'bias' not in key: - continue - - op = key[:key.rfind('.')] - op = op.replace('._packed_params', '') - - if op in op_list: - if op in ret['weight']: - ret['weight'][op].update({ - key: - dequantize(state_dict[key]).numpy() - if state_dict[key].is_quantized else state_dict[key].detach().numpy() - }) - else: - ret['weight'][op] = { - key: - dequantize(state_dict[key]).numpy() - if state_dict[key].is_quantized else state_dict[key].detach().numpy() - } - else: - if bool(self.fused_dict): - if is_quantized: - for a in fp32_int8_map: - if op == fp32_int8_map[a]['weight']: - if a in ret['weight']: - ret['weight'][a].update({ - key: - dequantize(state_dict[key]).numpy() - if state_dict[key].is_quantized else - state_dict[key].detach().numpy() - }) - else: - ret['weight'][a] = \ - {key: dequantize(state_dict[key]).numpy() - if state_dict[key].is_quantized else - state_dict[key].detach().numpy()} - break - - if save_to_disk: - np.savez(os.path.join(dump_dir, 'weight.npz'), **ret['weight']) - else: - ret['weight'] = None - - return ret - def set_tensor(self, model, tensor_dict): state_dict = model._model.state_dict() tensor_name = None @@ -3446,6 +3676,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg) + self.fused_dict = self.get_fused_list(q_model.model) + q_model.is_quantized = True q_model.q_config = copy.deepcopy(self.tune_cfg) if self.approach != 'post_training_dynamic_quant': self._get_scale_zeropoint(q_model._model, q_model.q_config) @@ -4583,7 +4815,6 @@ def _dump_model_op_stats(self, model, tune_cfg): field_names=field_names).print_stat() self.optype_statistics = field_names, output_data - def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): """This is a helper function for `query_fw_capability`, and it will get all quantizable ops from model. diff --git a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py index 71e411d44cf..65dc0a58d0f 100644 --- a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py +++ b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py @@ -2,6 +2,7 @@ import neural_compressor.adaptor.pytorch as nc_torch import numpy as np import os +import pickle import shutil import torch import torch.nn as nn @@ -1114,5 +1115,36 @@ def test_symbolic_trace(self): traced_model_qat = symbolic_trace(model_origin, is_qat=True) self.assertTrue(isinstance(traced_model_qat.sub, torch.fx.graph_module.GraphModule)) + def test_tensor_dump(self): + model = resnet18() + model = MODELS['pytorch'](model) + quantizer = Quantization('fx_ptq_yaml.yaml') + dataset = quantizer.dataset('dummy', (100, 3, 224, 224), label=True) + dataloader = common.DataLoader(dataset) + dataloader = common._generate_common_dataloader(dataloader, 'pytorch') + quantizer.eval_dataloader = dataloader + quantizer.calib_dataloader = dataloader + quantizer.model = model.model + q_model = quantizer.fit() + quantizer.strategy.adaptor.inspect_tensor( + model, dataloader, op_list=['conv1', 'layer1.0.conv1'], + iteration_list=[1, 2], inspect_type='all', save_to_disk=True) + with open('saved/inspect_result.pkl', 'rb') as fp: + tensor_dict = pickle.load(fp) + a = tensor_dict["activation"][0] + w = tensor_dict["weight"] + self.assertTrue(w['conv1']['conv1.weight'].shape[0] == + a['conv1']['conv1.output0'].shape[1]) + quantizer.strategy.adaptor.inspect_tensor( + q_model, dataloader, op_list=['conv1', 'layer1.0.conv1.0'], + iteration_list=[1, 2], inspect_type='all', save_to_disk=True) + with open('saved/inspect_result.pkl', 'rb') as fp: + tensor_dict = pickle.load(fp) + a = tensor_dict["activation"][0] + w = tensor_dict["weight"] + self.assertTrue(w['conv1']['conv1.weight'].shape[0] == + a['conv1']['conv1.output0'].shape[1]) + + if __name__ == "__main__": unittest.main()