From f4f31f2396a7e6b168ece6527187da2be8911442 Mon Sep 17 00:00:00 2001 From: XGZhang <46363693+XGZhang11@users.noreply.github.com> Date: Mon, 2 Aug 2021 08:35:52 +0000 Subject: [PATCH 1/6] support quantization of conv2d_transpose --- .../slim/quantization/imperative/qat.py | 62 ++++++++----- .../slim/quantization/imperative/utils.py | 18 +++- python/paddle/nn/quant/quant_layers.py | 87 +++++++++++++++++++ 3 files changed, 140 insertions(+), 27 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index b8c0e47e9bbc26..32a3ebfe047030 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -42,17 +42,18 @@ class ImperativeQuantAware(object): Applying quantization aware training (QAT) to the dgraph model. """ - def __init__(self, - quantizable_layer_type=['Conv2D', 'Linear'], - weight_quantize_type='abs_max', - activation_quantize_type='moving_average_abs_max', - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - weight_preprocess_layer=None, - act_preprocess_layer=None, - weight_quantize_layer=None, - act_quantize_layer=None): + def __init__( + self, + quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'], + weight_quantize_type='abs_max', + activation_quantize_type='moving_average_abs_max', + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_preprocess_layer=None, + act_preprocess_layer=None, + weight_quantize_layer=None, + act_quantize_layer=None): """ The constructor for ImperativeQuantAware. @@ -232,17 +233,18 @@ class ImperativeQuantizeInputs(object): logic both for activation inputs and weight inputs. """ - def __init__(self, - quantizable_layer_type=['Conv2D', 'Linear'], - weight_quantize_type='abs_max', - activation_quantize_type='moving_average_abs_max', - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - weight_preprocess_layer=None, - act_preprocess_layer=None, - weight_quantize_layer=None, - act_quantize_layer=None): + def __init__( + self, + quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'], + weight_quantize_type='abs_max', + activation_quantize_type='moving_average_abs_max', + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_preprocess_layer=None, + act_preprocess_layer=None, + weight_quantize_layer=None, + act_quantize_layer=None): """ The constructor for ImperativeQuantizeInputs. @@ -303,6 +305,18 @@ def __init__(self, } def apply(self, model): + """ + Quantize the weights and activations to calculate for specific + layers in the dygraph model. + + Args: + model(fluid.dygraph.Layer): The target model which would + calculate the input quantization scale. + + Returns: + None + """ + assert isinstance(model, dygraph.Layer), \ "The model must be the instance of dygraph.Layer." @@ -544,7 +558,9 @@ def _is_skip_quant_op(self, block, in_op): 1. the type of input op should be conv2d, depthwise_conv2d or matmul 2. the previous ops of the input op are not fake_quantize_dequantize ops """ - target_op_types = ["conv2d", "depthwise_conv2d", "matmul"] + target_op_types = [ + "conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose" + ] if in_op.type not in target_op_types: return False diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index a9d52c5a87ad36..5a98ac80549f18 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -24,6 +24,7 @@ from ..quantization_pass import _get_input_name_index layer_name_map = { + 'Conv2DTranspose': paddle.nn.Conv2DTranspose, 'Conv2D': paddle.nn.Conv2D, 'Linear': paddle.nn.Linear, 'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D, @@ -47,7 +48,9 @@ # Apply fake quant for the inputs of these layers # TODO (jc): support paddle.nn.Conv2DTranspose -fake_quant_input_layers = [paddle.nn.Conv2D, paddle.nn.Linear] +fake_quant_input_layers = [ + paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.Conv2DTranspose +] # Apply fake quant for the output of these layers # TODO(jc): fix the problem of adding duplicate fake_quant ops @@ -65,7 +68,8 @@ ] fake_quant_wrap_layers = [ - quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear + quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear, + quant_layers.QuantizedConv2DTranspose ] # The weight format of these layers is Cin * Cout * H * W @@ -84,9 +88,9 @@ def load_variable_data(scope, var_name): - ''' + """ Load variable value from scope - ''' + """ var_node = scope.find_var(var_name) assert var_node is not None, \ "Can not find " + var_name + " in the scope." @@ -120,6 +124,12 @@ def find_parent_layer_and_sub_name(model, name): the sub_name of the layer. For example, if name is 'block_1/convbn_1/conv_1', the parent layer is 'block_1/convbn_1' and the sub_name is `conv_1`. + Args: + model(fluid.dygraph.Layer): the model to be quantized. + name(string): the name of a layer + + Returns: + parent_layer, subname """ assert isinstance(model, paddle.nn.Layer), \ "The model must be the instance of paddle.nn.Layer." diff --git a/python/paddle/nn/quant/quant_layers.py b/python/paddle/nn/quant/quant_layers.py index 5573683ebd0458..cc98f73174b4b3 100644 --- a/python/paddle/nn/quant/quant_layers.py +++ b/python/paddle/nn/quant/quant_layers.py @@ -31,6 +31,7 @@ 'FakeQuantMovingAverageAbsMax', 'FakeQuantChannelWiseAbsMax', 'QuantizedConv2D', + 'QuantizedConv2DTranspose', 'QuantizedLinear', 'MovingAverageAbsMaxScale', 'MAOutputScaleLayer', @@ -481,6 +482,92 @@ def forward(self, input): data_format=self._data_format) +class QuantizedConv2DTranspose(layers.Layer): + """ + The computational logic of QuantizedConv2DTranspose is the same with Conv2DTranspose. + The only difference is that its inputs are all fake quantized. + """ + + def __init__(self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None): + super(QuantizedConv2DTranspose, self).__init__() + # For Conv2DTranspose + self._groups = getattr(layer, '_groups') + self._stride = getattr(layer, '_stride') + self._padding = getattr(layer, '_padding') + self._output_padding = getattr(layer, 'output_padding') + self._dilation = getattr(layer, '_dilation') + self._data_format = getattr(layer, '_data_format') + self.weight = getattr(layer, 'weight') + self.bias = getattr(layer, 'bias') + # For FakeQuant + self._conv2d_transpose_quant_axis = 1 + if weight_quant_layer is not None: + self._fake_quant_weight = weight_quant_layer() + else: + self._fake_quant_weight = _get_fake_quant_type( + weight_quantize_type, + name=self.weight.name, + moving_rate=moving_rate, + quant_bits=weight_bits, + dtype=self._dtype, + quant_on_weight=True, + channel_num=self.weight.shape[ + self._conv2d_transpose_quant_axis], + quant_axis=self._conv2d_transpose_quant_axis) + if act_quant_layer is not None: + self._fake_quant_input = act_quant_layer() + else: + self._fake_quant_input = _get_fake_quant_type( + activation_quantize_type, + name=layer.full_name(), + moving_rate=moving_rate, + quant_bits=activation_bits, + dtype=self._dtype, + quant_on_weight=False) + + self._act_preprocess = act_pre_layer( + ) if act_pre_layer is not None else None + self._weight_preprocess = weight_pre_layer( + ) if weight_pre_layer is not None else None + + def forward(self, input, output_size=None): + if self._act_preprocess is not None: + input = self._act_preprocess(input) + quant_input = self._fake_quant_input(input) + + weight = self.weight + if self._weight_preprocess is not None: + weight = self._weight_preprocess(self.weight) + quant_weight = self._fake_quant_weight(weight) + + if output_size is None: + output_padding = self._output_padding + else: + output_padding = 0 + + return F.conv2d_transpose( + quant_input, + quant_weight, + bias=self.bias, + padding=self._padding, + output_padding=output_padding, + stride=self._stride, + dilation=self._dilation, + groups=self._groups, + output_size=output_size, + data_format=self._data_format) + + class QuantizedLinear(layers.Layer): """ The computational logic of QuantizedLinear is the same with Linear. From ac21a6041534a459c3dce4a0a75f375cb50e6e17 Mon Sep 17 00:00:00 2001 From: XGZhang <46363693+XGZhang11@users.noreply.github.com> Date: Thu, 5 Aug 2021 11:17:43 +0000 Subject: [PATCH 2/6] fix quantization bugs --- .../contrib/slim/quantization/post_training_quantization.py | 2 ++ .../fluid/contrib/slim/quantization/quantization_pass.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 5996e752c8c22d..5272d9f59903d7 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -578,6 +578,8 @@ def _sample_mse(self): var_tensor = _load_variable_data(self._scope, var_name) var_tensor = var_tensor.flatten() abs_max_value = float(np.max(np.abs(var_tensor))) + if abs_max_value == 0.0: + abs_max_value = 1e-8 s = 0.3 if var_name not in self._best_mse_loss: self._best_mse_loss[var_name] = float('inf') diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index b3b12a477e2a0a..857486b3fc46cc 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1312,6 +1312,8 @@ def _insert_post_dequant_op(self, graph, op_node): assert self._is_float( scale_v), 'The scale of parameter %s is not a float.' % ( original_var_name) + if scale_v == 0.0: + scale_v = 1e-8 max_range *= param_range / scale_v else: max_range *= act_range @@ -1413,6 +1415,8 @@ def _clip(x, scale): x[:, i] = _clip(x[:, i], s) x[:, i] = np.round(x[:, i] / s * bnt) else: + if scale == 0.0: + scale = 1e-8 x = _clip(x, scale) x = np.round(x / scale * bnt) return x From b8dd69dceea451747a465c05930723d38880de8a Mon Sep 17 00:00:00 2001 From: XGZhang <46363693+XGZhang11@users.noreply.github.com> Date: Thu, 5 Aug 2021 11:40:33 +0000 Subject: [PATCH 3/6] fix a quantization bug --- python/paddle/nn/quant/quant_layers.py | 87 -------------------------- 1 file changed, 87 deletions(-) diff --git a/python/paddle/nn/quant/quant_layers.py b/python/paddle/nn/quant/quant_layers.py index cc98f73174b4b3..5573683ebd0458 100644 --- a/python/paddle/nn/quant/quant_layers.py +++ b/python/paddle/nn/quant/quant_layers.py @@ -31,7 +31,6 @@ 'FakeQuantMovingAverageAbsMax', 'FakeQuantChannelWiseAbsMax', 'QuantizedConv2D', - 'QuantizedConv2DTranspose', 'QuantizedLinear', 'MovingAverageAbsMaxScale', 'MAOutputScaleLayer', @@ -482,92 +481,6 @@ def forward(self, input): data_format=self._data_format) -class QuantizedConv2DTranspose(layers.Layer): - """ - The computational logic of QuantizedConv2DTranspose is the same with Conv2DTranspose. - The only difference is that its inputs are all fake quantized. - """ - - def __init__(self, - layer, - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - weight_quantize_type='abs_max', - activation_quantize_type='abs_max', - weight_pre_layer=None, - act_pre_layer=None, - weight_quant_layer=None, - act_quant_layer=None): - super(QuantizedConv2DTranspose, self).__init__() - # For Conv2DTranspose - self._groups = getattr(layer, '_groups') - self._stride = getattr(layer, '_stride') - self._padding = getattr(layer, '_padding') - self._output_padding = getattr(layer, 'output_padding') - self._dilation = getattr(layer, '_dilation') - self._data_format = getattr(layer, '_data_format') - self.weight = getattr(layer, 'weight') - self.bias = getattr(layer, 'bias') - # For FakeQuant - self._conv2d_transpose_quant_axis = 1 - if weight_quant_layer is not None: - self._fake_quant_weight = weight_quant_layer() - else: - self._fake_quant_weight = _get_fake_quant_type( - weight_quantize_type, - name=self.weight.name, - moving_rate=moving_rate, - quant_bits=weight_bits, - dtype=self._dtype, - quant_on_weight=True, - channel_num=self.weight.shape[ - self._conv2d_transpose_quant_axis], - quant_axis=self._conv2d_transpose_quant_axis) - if act_quant_layer is not None: - self._fake_quant_input = act_quant_layer() - else: - self._fake_quant_input = _get_fake_quant_type( - activation_quantize_type, - name=layer.full_name(), - moving_rate=moving_rate, - quant_bits=activation_bits, - dtype=self._dtype, - quant_on_weight=False) - - self._act_preprocess = act_pre_layer( - ) if act_pre_layer is not None else None - self._weight_preprocess = weight_pre_layer( - ) if weight_pre_layer is not None else None - - def forward(self, input, output_size=None): - if self._act_preprocess is not None: - input = self._act_preprocess(input) - quant_input = self._fake_quant_input(input) - - weight = self.weight - if self._weight_preprocess is not None: - weight = self._weight_preprocess(self.weight) - quant_weight = self._fake_quant_weight(weight) - - if output_size is None: - output_padding = self._output_padding - else: - output_padding = 0 - - return F.conv2d_transpose( - quant_input, - quant_weight, - bias=self.bias, - padding=self._padding, - output_padding=output_padding, - stride=self._stride, - dilation=self._dilation, - groups=self._groups, - output_size=output_size, - data_format=self._data_format) - - class QuantizedLinear(layers.Layer): """ The computational logic of QuantizedLinear is the same with Linear. From 41754480ed362f67b7df2865d8f50953a46ae83a Mon Sep 17 00:00:00 2001 From: XGZhang <46363693+XGZhang11@users.noreply.github.com> Date: Thu, 5 Aug 2021 11:50:30 +0000 Subject: [PATCH 4/6] fix quantization bugs --- .../slim/quantization/imperative/qat.py | 62 +++++++------------ .../slim/quantization/imperative/utils.py | 18 ++---- 2 files changed, 27 insertions(+), 53 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 32a3ebfe047030..b8c0e47e9bbc26 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -42,18 +42,17 @@ class ImperativeQuantAware(object): Applying quantization aware training (QAT) to the dgraph model. """ - def __init__( - self, - quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'], - weight_quantize_type='abs_max', - activation_quantize_type='moving_average_abs_max', - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - weight_preprocess_layer=None, - act_preprocess_layer=None, - weight_quantize_layer=None, - act_quantize_layer=None): + def __init__(self, + quantizable_layer_type=['Conv2D', 'Linear'], + weight_quantize_type='abs_max', + activation_quantize_type='moving_average_abs_max', + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_preprocess_layer=None, + act_preprocess_layer=None, + weight_quantize_layer=None, + act_quantize_layer=None): """ The constructor for ImperativeQuantAware. @@ -233,18 +232,17 @@ class ImperativeQuantizeInputs(object): logic both for activation inputs and weight inputs. """ - def __init__( - self, - quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'], - weight_quantize_type='abs_max', - activation_quantize_type='moving_average_abs_max', - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - weight_preprocess_layer=None, - act_preprocess_layer=None, - weight_quantize_layer=None, - act_quantize_layer=None): + def __init__(self, + quantizable_layer_type=['Conv2D', 'Linear'], + weight_quantize_type='abs_max', + activation_quantize_type='moving_average_abs_max', + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_preprocess_layer=None, + act_preprocess_layer=None, + weight_quantize_layer=None, + act_quantize_layer=None): """ The constructor for ImperativeQuantizeInputs. @@ -305,18 +303,6 @@ def __init__( } def apply(self, model): - """ - Quantize the weights and activations to calculate for specific - layers in the dygraph model. - - Args: - model(fluid.dygraph.Layer): The target model which would - calculate the input quantization scale. - - Returns: - None - """ - assert isinstance(model, dygraph.Layer), \ "The model must be the instance of dygraph.Layer." @@ -558,9 +544,7 @@ def _is_skip_quant_op(self, block, in_op): 1. the type of input op should be conv2d, depthwise_conv2d or matmul 2. the previous ops of the input op are not fake_quantize_dequantize ops """ - target_op_types = [ - "conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose" - ] + target_op_types = ["conv2d", "depthwise_conv2d", "matmul"] if in_op.type not in target_op_types: return False diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index 5a98ac80549f18..a9d52c5a87ad36 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -24,7 +24,6 @@ from ..quantization_pass import _get_input_name_index layer_name_map = { - 'Conv2DTranspose': paddle.nn.Conv2DTranspose, 'Conv2D': paddle.nn.Conv2D, 'Linear': paddle.nn.Linear, 'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D, @@ -48,9 +47,7 @@ # Apply fake quant for the inputs of these layers # TODO (jc): support paddle.nn.Conv2DTranspose -fake_quant_input_layers = [ - paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.Conv2DTranspose -] +fake_quant_input_layers = [paddle.nn.Conv2D, paddle.nn.Linear] # Apply fake quant for the output of these layers # TODO(jc): fix the problem of adding duplicate fake_quant ops @@ -68,8 +65,7 @@ ] fake_quant_wrap_layers = [ - quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear, - quant_layers.QuantizedConv2DTranspose + quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear ] # The weight format of these layers is Cin * Cout * H * W @@ -88,9 +84,9 @@ def load_variable_data(scope, var_name): - """ + ''' Load variable value from scope - """ + ''' var_node = scope.find_var(var_name) assert var_node is not None, \ "Can not find " + var_name + " in the scope." @@ -124,12 +120,6 @@ def find_parent_layer_and_sub_name(model, name): the sub_name of the layer. For example, if name is 'block_1/convbn_1/conv_1', the parent layer is 'block_1/convbn_1' and the sub_name is `conv_1`. - Args: - model(fluid.dygraph.Layer): the model to be quantized. - name(string): the name of a layer - - Returns: - parent_layer, subname """ assert isinstance(model, paddle.nn.Layer), \ "The model must be the instance of paddle.nn.Layer." From 12bca65fd71176703f259f0f03402261f55ce415 Mon Sep 17 00:00:00 2001 From: XGZhang <46363693+XGZhang11@users.noreply.github.com> Date: Sun, 8 Aug 2021 16:05:04 +0800 Subject: [PATCH 5/6] Update post_training_quantization.py --- .../contrib/slim/quantization/post_training_quantization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 5272d9f59903d7..06f3f5f3afa750 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -578,8 +578,7 @@ def _sample_mse(self): var_tensor = _load_variable_data(self._scope, var_name) var_tensor = var_tensor.flatten() abs_max_value = float(np.max(np.abs(var_tensor))) - if abs_max_value == 0.0: - abs_max_value = 1e-8 + abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value s = 0.3 if var_name not in self._best_mse_loss: self._best_mse_loss[var_name] = float('inf') From 2a486e6db68260c7058ee55fa25413388981368a Mon Sep 17 00:00:00 2001 From: XGZhang <46363693+XGZhang11@users.noreply.github.com> Date: Sun, 8 Aug 2021 16:07:18 +0800 Subject: [PATCH 6/6] Update quantization_pass.py --- .../fluid/contrib/slim/quantization/quantization_pass.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 857486b3fc46cc..9917730daa543f 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1312,8 +1312,7 @@ def _insert_post_dequant_op(self, graph, op_node): assert self._is_float( scale_v), 'The scale of parameter %s is not a float.' % ( original_var_name) - if scale_v == 0.0: - scale_v = 1e-8 + scale_v = 1e-8 if scale_v == 0.0 else scale_v max_range *= param_range / scale_v else: max_range *= act_range @@ -1415,8 +1414,7 @@ def _clip(x, scale): x[:, i] = _clip(x[:, i], s) x[:, i] = np.round(x[:, i] / s * bnt) else: - if scale == 0.0: - scale = 1e-8 + scale = 1e-8 if scale == 0.0 else scale x = _clip(x, scale) x = np.round(x / scale * bnt) return x