-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Dygraph post trainging quantization (#33445)
* dygraph post training quantization * refine the ptq config * refine ptq quantizer
- Loading branch information
1 parent
1b0c5ef
commit 2b6fc10
Showing
9 changed files
with
952 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
112 changes: 112 additions & 0 deletions
112
python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
import copy | ||
import numpy as np | ||
|
||
import paddle | ||
from paddle.fluid.log_helper import get_logger | ||
|
||
from . import utils | ||
from . import ptq_hooks | ||
from . import ptq_config | ||
from .ptq_registry import PTQRegistry | ||
|
||
__all__ = ['ImperativePTQ'] | ||
|
||
_logger = get_logger( | ||
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') | ||
|
||
|
||
class ImperativePTQ(object): | ||
""" | ||
Applying static post_training quantization to the dgraph model. | ||
""" | ||
|
||
def __init__(self, quant_config=ptq_config.default_ptq_config): | ||
""" | ||
Constructor. | ||
Args: | ||
algo(str): The algorithm in post_training quantizaion to be used. | ||
activation_bits(int): quantization bit number for activations. | ||
weight_bits(int): quantization bit number for weights. | ||
""" | ||
super(ImperativePTQ, self).__init__() | ||
|
||
assert isinstance(quant_config, ptq_config.PTQConfig) | ||
|
||
self._quant_config = quant_config | ||
|
||
def quantize(self, model, inplace=False): | ||
""" | ||
Add hook to the leaf layer to calculate the threshold of inputs and outputs. | ||
Args: | ||
model(paddle.nn.Layer): The model to be quantized. | ||
Returns: | ||
None | ||
""" | ||
assert isinstance(model, paddle.nn.Layer), \ | ||
"The model must be the instance of paddle.nn.Layer." | ||
|
||
if not inplace: | ||
model = copy.deepcopy(model) | ||
|
||
for name, layer in model.named_sublayers(): | ||
if PTQRegistry.is_supported_layer(layer) \ | ||
and utils.is_leaf_layer(layer): | ||
quant_config = copy.deepcopy(self._quant_config) | ||
layer._quant_config = quant_config | ||
|
||
hook = ptq_hooks.quant_forward_post_hook | ||
hook_handle = layer.register_forward_post_hook(hook) | ||
quant_config.hook_handle = hook_handle | ||
layer._forward_post_hooks.move_to_end( | ||
hook_handle._hook_id, last=False) | ||
|
||
return model | ||
|
||
def convert(self, model): | ||
""" | ||
Process the scales and remove the hooks. | ||
Args: | ||
model(paddle.nn.Layer): The model to be quantized. | ||
Returns: | ||
None | ||
""" | ||
assert isinstance(model, paddle.nn.Layer), \ | ||
"The input model must be the instance of paddle.nn.Layer." | ||
|
||
for name, sub_layer in model.named_sublayers(): | ||
if PTQRegistry.is_supported_layer(sub_layer) \ | ||
and utils.is_leaf_layer(sub_layer): | ||
|
||
assert hasattr(sub_layer, "_quant_config") | ||
quant_config = sub_layer._quant_config | ||
quant_config.hook_handle.remove() | ||
|
||
quant_config.in_act_quantizer.cal_thresholds() | ||
quant_config.out_act_quantizer.cal_thresholds() | ||
|
||
# get weight thresholds | ||
if isinstance(sub_layer, tuple(utils.fake_quant_input_layers)): | ||
weights = (sub_layer.weight, ) | ||
quant_config.wt_quantizer.sample_data(sub_layer, weights) | ||
|
||
# TODO (jc): | ||
# save input activation threshold and quant bits | ||
|
||
return model |
44 changes: 44 additions & 0 deletions
44
python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import six | ||
import abc | ||
import copy | ||
|
||
import paddle | ||
|
||
from .ptq_quantizer import * | ||
|
||
__all__ = ['PTQConfig', 'default_ptq_config'] | ||
|
||
|
||
class PTQConfig(object): | ||
""" | ||
The PTQ config shows how to quantize the inputs and outputs. | ||
""" | ||
|
||
def __init__(self, activation_quantizer, weight_quantizer): | ||
super(PTQConfig, self).__init__() | ||
|
||
assert isinstance(activation_quantizer, BaseQuantizer) | ||
assert isinstance(weight_quantizer, BaseQuantizer) | ||
|
||
self.in_act_quantizer = copy.deepcopy(activation_quantizer) | ||
self.out_act_quantizer = copy.deepcopy(activation_quantizer) | ||
self.wt_quantizer = copy.deepcopy(weight_quantizer) | ||
|
||
self.hook_handle = None | ||
|
||
|
||
default_ptq_config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer()) |
28 changes: 28 additions & 0 deletions
28
python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import math | ||
import numpy as np | ||
from . import ptq_config | ||
|
||
|
||
def quant_forward_post_hook(layer, inputs, outputs): | ||
""" | ||
The forward_post_hook for PTQ. | ||
""" | ||
assert hasattr(layer, '_quant_config'), \ | ||
"The layer should have _quant_config attr" | ||
layer._quant_config.in_act_quantizer.sample_data(layer, inputs) | ||
layer._quant_config.out_act_quantizer.sample_data(layer, (outputs, )) |
Oops, something went wrong.