From 883486eca88ae2eee0aaeacf7c1d23592c739981 Mon Sep 17 00:00:00 2001 From: "Cheng, Penghui" Date: Tue, 1 Aug 2023 19:04:14 +0800 Subject: [PATCH] Update code Signed-off-by: Cheng, Penghui --- neural_compressor/adaptor/pytorch.py | 340 +++--------------- .../test_adaptor_pytorch_1.x.py | 17 +- 2 files changed, 64 insertions(+), 293 deletions(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 1650bb4c97c..f3d487eca32 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -1248,15 +1248,28 @@ def get_fused_list(self, model): for op_name, child in model.named_modules(): if type(child) in self.fused_op_list: in_fused_loop = False - fp32_int8_ops = [op_name,] + is_fused_module = False type_name = str(child).split("(")[0] prefix_index = op_name.rfind(".") + fp32_int8_ops = [] for fp32_op_name, module in self.pre_optimized_model.model.named_modules(): fp32_type_name = str(module).split("(")[0] prefix_fp32_index = fp32_op_name.rfind(".") - if op_name == fp32_op_name: + if not is_fused_module: + is_fused_module = self.is_fused_module(module) + if is_fused_module: + in_fused_loop = True + continue + if is_fused_module and in_fused_loop: + if op_name == fp32_op_name[: fp32_op_name.rfind(".")]: + fp32_int8_ops.append(fp32_op_name) + continue + else: + is_fused_module =False + in_fused_loop = False + elif op_name == fp32_op_name and not in_fused_loop: in_fused_loop = True - continue + fp32_int8_ops.append(fp32_op_name) elif in_fused_loop and \ op_name[: prefix_index if prefix_index > -1 else 0] == \ fp32_op_name[: prefix_fp32_index if prefix_fp32_index > -1 else 0]: @@ -1273,7 +1286,8 @@ def get_fused_list(self, model): elif in_fused_loop: in_fused_loop = False break - fused_dict.update({op_name: fp32_int8_ops}) + if len(fp32_int8_ops) > 1: + fused_dict.update({op_name: fp32_int8_ops}) return fused_dict def inspect_tensor(self, @@ -1300,8 +1314,8 @@ def inspect_tensor(self, {'activation': self.fused_dict[key][-1], 'weight': self.fused_dict[key][0]} if not is_quantized: op_list_.append(self.fused_dict[key][-1]) - elif self.fused_dict[key][0] not in op_list_: - op_list_.append(self.fused_dict[key][0]) + elif key not in op_list_: + op_list_.append(key) break assert min(iteration_list) > 0, \ @@ -1342,37 +1356,39 @@ def inspect_tensor(self, if bool(self.fused_dict): if is_quantized: for a in fp32_int8_map: - if op_name == fp32_int8_map[a]['weight']: + if op_name == a: + tensor_name = fp32_int8_map[a]['weight'] if type(value) is list: - summary[a] = {} + summary[tensor_name] = {} for index in range(len(value)): - summary[a].update({ - a + ".output" + str(index): + summary[tensor_name].update({ + tensor_name + ".output" + str(index): dequantize(value[index]).numpy() if value[index].is_quantized else value[index].numpy() }) else: - summary[a] = { - a + ".output0": + summary[tensor_name] = { + tensor_name + ".output0": dequantize(value).numpy() if value.is_quantized else value.numpy() } else: for a in fp32_int8_map: # pragma: no cover if op_name == fp32_int8_map[a]['activation']: + tensor_name = fp32_int8_map[a]['weight'] if type(value) is list: - summary[a] = {} + summary[tensor_name] = {} for index in range(len(value)): - summary[a].update({ - a + ".output" + str(index): + summary[tensor_name].update({ + tensor_name + ".output" + str(index): dequantize(value[index]).numpy() if value[index].is_quantized else value[index].numpy() }) else: - summary[a] = { - a + ".output0": + summary[tensor_name] = { + tensor_name + ".output0": dequantize(value).numpy() if value.is_quantized else value.numpy() } @@ -1409,16 +1425,17 @@ def inspect_tensor(self, if bool(self.fused_dict): if is_quantized: for a in fp32_int8_map: - if op == fp32_int8_map[a]['weight']: - if a in ret['weight']: - ret['weight'][a].update({ + if op == a: + tensor_name = fp32_int8_map[a]['weight'] + if tensor_name in ret['weight']: + ret['weight'][tensor_name].update({ key: dequantize(state_dict[key]).numpy() if state_dict[key].is_quantized else state_dict[key].detach().numpy() }) else: - ret['weight'][a] = \ + ret['weight'][tensor_name] = \ {key: dequantize(state_dict[key]).numpy() if state_dict[key].is_quantized else state_dict[key].detach().numpy()} @@ -1610,22 +1627,11 @@ def _prepare(model, inplace=True, op_list=[], white_list=None): _add_observer_(model, op_list=op_list) return model - # create properties - if self.version.release < Version("1.7.0").release: # pragma: no cover - white_list = self.white_list | \ - (set(torch.quantization.default_mappings.DEFAULT_MODULE_MAPPING.values()) | - set(torch.quantization.default_mappings.DEFAULT_QAT_MODULE_MAPPING.values()) | - set(torch.quantization.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING.values())) - elif self.version.release < Version("1.8.0").release: # pragma: no cover - white_list = torch.quantization.get_compare_output_module_list() - else: - white_list = torch.quantization.get_default_compare_output_module_list() - model = model if model.is_quantized else copy.deepcopy(model) model._model.qconfig = torch.quantization.QConfig( weight=torch.quantization.default_debug_observer, activation=_RecordingObserver.with_args(iteration_list=iteration_list)) - _prepare(model._model, op_list=op_list, white_list=white_list) + _prepare(model._model, op_list=op_list) return model @@ -2292,211 +2298,6 @@ def _get_scale_zeropoint(self, model, tune_cfg): if hasattr(modules[key[0]], 'zero_point'): value['activation']['zero_point'] = int(modules[key[0]].zero_point) - def _pre_eval_hook(self, model, op_list=None, iteration_list=None): - """The function is used to do some preprocession before evaluation phase. - Here, it used to add hook for dump output tensor for quantizable ops. - - Args: - model (object): input model - - Returns: - model (object): model with hook - """ - from abc import ABCMeta - - def _with_args(cls_or_self, **kwargs): - r"""Wrapper that allows creation of class factories. - - This can be useful when there is a need to create classes with the same - constructor arguments, but different instances. - - Example:: - - >>> Foo.with_args = classmethod(_with_args) - >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) - >>> foo_instance1 = foo_builder() - >>> foo_instance2 = foo_builder() - >>> id(foo_instance1) == id(foo_instance2) - False - """ - class _PartialWrapper(object): - def __init__(self, p): - self.p = p - - def __call__(self, *args, **keywords): - return self.p(*args, **keywords) - - def __repr__(self): - return self.p.__repr__() - - with_args = _with_args - - r = _PartialWrapper(partial(cls_or_self, **kwargs)) - return r - - ABC = ABCMeta(str("ABC"), (object, ), {}) # compatible with Python 2 *and* 3: - - class _RecordingObserver(ABC, torch.nn.Module): - """The module is mainly for debug and records the tensor values during runtime. - - Args: - iteration_list (list, optional): indexs of iteration which to dump tensor. - """ - def __init__(self, iteration_list=None, **kwargs): - super(_RecordingObserver, self).__init__(**kwargs) - self.output_tensors_dict = OrderedDict() - self.current_iter = 1 - self.iteration_list = iteration_list - - def forward(self, x): - if (self.iteration_list is None and self.current_iter == 1) or \ - (self.iteration_list is not None and - self.current_iter in self.iteration_list): - if type(x) is tuple or type(x) is list: - self.output_tensors_dict[self.current_iter] = \ - [i.to("cpu") if i.device != 'cpu' else i.clone() for i in x] - else: - self.output_tensors_dict[self.current_iter] = \ - x.to("cpu") if x.device != "cpu" else x.clone() - self.current_iter += 1 - return x - - @torch.jit.export - def get_tensor_value(self): - return self.output_tensors_dict - - with_args = classmethod(_with_args) - - def _observer_forward_hook(module, input, output): - """Forward hook that calls observer on the output - - Args: - module (object): input module - input (object): module input - output (object): module output - - Returns: - module output tensor (object) - """ - return module.activation_post_process(output) - - def _add_observer_(module, op_list=None, prefix=""): - """Add observer for the leaf child of the module. - - This function insert observer module to all leaf child module that - has a valid qconfig attribute. - - Args: - module (object): input module with qconfig attributes for all the leaf modules that - we want to dump tensor - op_list (list, optional): list of ops which to be dumped in module - prefix (string): name of module - - Returns: - None, module is modified inplace with added observer modules and forward_hooks - """ - for name, child in module.named_children(): - op_name = name if prefix == "" else prefix + "." + name - if isinstance(child, torch.nn.quantized.FloatFunctional) and \ - (op_list is None or op_name in op_list): - if hasattr(child, 'qconfig') and child.qconfig is not None and ( - op_list is None or op_name in op_list): - child.activation_post_process = \ - child.qconfig.activation() - elif hasattr(child, 'qconfig') and child.qconfig is not None and \ - (op_list is None or op_name in op_list): - # observer and hook will be gone after we swap the module - child.add_module('activation_post_process', child.qconfig.activation()) - child.register_forward_hook(_observer_forward_hook) - else: - _add_observer_(child, op_list, op_name) - - def _propagate_qconfig_helper(module, - qconfig_dict, - white_list=None, - qconfig_parent=None, - prefix='', - fused=False): - """This is a helper function for `propagate_qconfig_` - - Args: - module (object): input module - qconfig_dict (dictionary): dictionary that maps from name of submodule to - quantization configuration - white_list (list, optional): list of quantizable modules - qconfig_parent (object, optional): config of parent module, we will fallback to - this config when there is no specified config - for current module - prefix (string, optional): corresponding prefix of the current module, - used as key in qconfig_dict - fused (bool, optional): Indicates whether the module is fused or not - - Return: - None, module is modified inplace with qconfig attached - """ - if white_list is None: - white_list = \ - torch.quantization.default_mappings.DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST \ - if self.version.release < Version("1.7.0").release else \ - torch.quantization.quantization_mappings.get_qconfig_propagation_list() - - if type(module) in white_list and type(module) != torch.nn.Sequential: - module.qconfig = qconfig_parent - else: - module.qconfig = None - if hasattr(module, '_modules'): - for name, child in module.named_children(): - module_prefix = prefix + '.' + name if prefix else name - _propagate_qconfig_helper(child, qconfig_dict, white_list, qconfig_parent, - module_prefix) - - def _prepare(model, inplace=True, op_list=[], white_list=None): - """The model will be attached with observer or fake quant modules, and qconfig - will be propagated. - - Args: - model (object): input model to be modified in-place - inplace (bool, optional): carry out model transformations in-place, - the original module is mutated - op_list (list, optional): list of ops which to be dumped in module - white_list (list, optional): list of quantizable modules - - Returns: - model (object): model with qconfig - """ - if not inplace: - model = copy.deepcopy(model) - _propagate_qconfig_helper(model, - qconfig_dict={}, - white_list=white_list, - qconfig_parent=model.qconfig) - # sanity check common API misusage - if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): # pragma: no cover - logger.warn("None of the submodule got qconfig applied. Make sure you " - "passed correct configuration through `qconfig_dict` or " - "by assigning the `.qconfig` attribute directly on submodules") - _add_observer_(model, op_list=op_list) - return model - - # create properties - if self.version.release < Version("1.7.0").release: # pragma: no cover - white_list = self.white_list | \ - (set(torch.quantization.default_mappings.DEFAULT_MODULE_MAPPING.values()) | - set(torch.quantization.default_mappings.DEFAULT_QAT_MODULE_MAPPING.values()) | - set(torch.quantization.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING.values())) - elif self.version.release < Version("1.8.0").release: # pragma: no cover - white_list = torch.quantization.get_compare_output_module_list() - else: - white_list = torch.quantization.get_default_compare_output_module_list() - - model = model if model.is_quantized else copy.deepcopy(model) - model._model.qconfig = torch.quantization.QConfig( - weight=torch.quantization.default_debug_observer, - activation=_RecordingObserver.with_args(iteration_list=iteration_list)) - _prepare(model._model, op_list=op_list, white_list=white_list) - - return model - def is_fused_child(self, op_name): """This is a helper function for `_post_eval_hook` @@ -2507,43 +2308,11 @@ def is_fused_child(self, op_name): (bool): if this op is fused """ - op = op_name[:op_name.rfind('.')] - if op in self.fused_dict and op_name[op_name.rfind('.') + 1:].isdigit(): - return True - else: - return False - - def is_fused_op(self, op_name): - """This is a helper function for `_post_eval_hook` - - Args: - op_name (string): op name - - Returns: - (bool): if this op is fused - - """ - op = op_name[:op_name.rfind('.')] - if op in self.fused_dict: - return True - else: - return False - - def is_last_fused_child(self, op_name): - """This is a helper function for `_post_eval_hook` - - Args: - op_name (string): op name - - Returns: - (bool): if this op is last fused op + for key in self.fused_dict: + if op_name in self.fused_dict[key]: + return True + return False - """ - op = op_name[:op_name.rfind('.')] - if op_name in self.fused_dict[op][-1]: - return True - else: - return False def _post_eval_hook(self, model, **args): """The function is used to do some post process after complete evaluation. @@ -2595,20 +2364,17 @@ def _post_eval_hook(self, model, **args): for key in observer_dict: if isinstance(observer_dict[key], torch.nn.modules.linear.Identity): continue - op_name = key.strip(".activation_post_process") + op_name = key.replace(".activation_post_process", "") summary[op_name + ".output"] = observer_dict[key].get_tensor_value() for iter in summary[op_name + ".output"]: # Only collect last fused child output op = op_name - if self.is_fused_child(op_name) == True and \ - self.is_last_fused_child(op_name) == True: - op = op_name[:op_name.rfind('.')] + if op_name in self.fused_dict: + op = self.fused_dict[op_name][0] else: - if self.is_fused_child(op_name) == True and \ - self.is_last_fused_child(op_name) == False: - continue - else: - op = op_name + for key in self.fused_dict: + if op_name in self.fused_dict[key]: + op = op_name if summary[op_name + ".output"][iter].is_quantized: writer.add_histogram(op + "/Output/int8", @@ -2620,7 +2386,6 @@ def _post_eval_hook(self, model, **args): for key in state_dict: if not isinstance(state_dict[key], torch.Tensor): continue - op = key[:key.rfind('.')] if self.is_fused_child(op) is True: # fused child tensorboard tag will be merge @@ -2657,7 +2422,12 @@ def set_tensor(self, model, tensor_dict): weight_bias = key[end + 1:] for op in self.fused_dict: if op_name in self.fused_dict[op]: - state_op_name = op + if model.is_quantized: + state_op_name = op + else: + state_op_name = self.fused_dict[op][0] + # elif op_name in self.fused_dict[op]: + # state_op_name = op if state_op_name is None: state_op_name = op_name for state_dict_key in state_dict.keys(): diff --git a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py index 65dc0a58d0f..266d0fe828f 100644 --- a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py +++ b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py @@ -709,16 +709,17 @@ def test_tensor_dump_and_set(self): quantizer.strategy.adaptor.inspect_tensor( model, dataloader, op_list=['conv1.0', 'layer1.0.conv1.0'], iteration_list=[1, 2], inspect_type='all', save_to_disk=True) - load_array = lambda *a, **k: np.load(*a, allow_pickle=True, **k) - a = load_array('saved/dump_tensor/activation_iter1.npz') - w = load_array('saved/dump_tensor/weight.npz') + with open('saved/inspect_result.pkl', 'rb') as fp: + tensor_dict = pickle.load(fp) + a = tensor_dict["activation"][0] + w = tensor_dict["weight"] if PT_VERSION >= Version("1.8.0").release: - self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] == - a['conv1.0'].item()['conv1.0.output0'].shape[1]) + self.assertTrue(w['conv1.0']['conv1.0.weight'].shape[0] == + a['conv1.0']['conv1.0.output0'].shape[1]) else: - self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] == - a['conv1.0'].item()['conv1.1.output0'].shape[1]) - data = np.random.random(w['conv1.0'].item()['conv1.0.weight'].shape).astype(np.float32) + self.assertTrue(w['conv1.0']['conv1.0.weight'].shape[0] == + a['conv1.0']['conv1.1.output0'].shape[1]) + data = np.random.random(w['conv1.0']['conv1.0.weight'].shape).astype(np.float32) quantizer.strategy.adaptor.set_tensor(q_model, {'conv1.0.weight': data}) changed_tensor = q_model.get_weight('conv1.weight') scales = changed_tensor.q_per_channel_scales()