Skip to content

Commit

Permalink
Dygraph post trainging quantization (#33445)
Browse files Browse the repository at this point in the history
* dygraph post training quantization

* refine the ptq config

* refine ptq quantizer
  • Loading branch information
juncaipeng authored Jun 22, 2021
1 parent 1b0c5ef commit 2b6fc10
Show file tree
Hide file tree
Showing 9 changed files with 952 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@
from . import qat
from .qat import *

from . import ptq
from .ptq import *

from . import ptq_config
from .ptq_config import *

from . import ptq_quantizer
from .ptq_quantizer import *

from . import ptq_registry
from .ptq_registry import *

__all__ = []
__all__ += quant_nn.__all__
__all__ += qat.__all__
__all__ += ptq.__all__
__all__ += ptq_config.__all__
__all__ += ptq_quantizer.__all__
__all__ += ptq_registry.__all__
112 changes: 112 additions & 0 deletions python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import copy
import numpy as np

import paddle
from paddle.fluid.log_helper import get_logger

from . import utils
from . import ptq_hooks
from . import ptq_config
from .ptq_registry import PTQRegistry

__all__ = ['ImperativePTQ']

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')


class ImperativePTQ(object):
"""
Applying static post_training quantization to the dgraph model.
"""

def __init__(self, quant_config=ptq_config.default_ptq_config):
"""
Constructor.
Args:
algo(str): The algorithm in post_training quantizaion to be used.
activation_bits(int): quantization bit number for activations.
weight_bits(int): quantization bit number for weights.
"""
super(ImperativePTQ, self).__init__()

assert isinstance(quant_config, ptq_config.PTQConfig)

self._quant_config = quant_config

def quantize(self, model, inplace=False):
"""
Add hook to the leaf layer to calculate the threshold of inputs and outputs.
Args:
model(paddle.nn.Layer): The model to be quantized.
Returns:
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer."

if not inplace:
model = copy.deepcopy(model)

for name, layer in model.named_sublayers():
if PTQRegistry.is_supported_layer(layer) \
and utils.is_leaf_layer(layer):
quant_config = copy.deepcopy(self._quant_config)
layer._quant_config = quant_config

hook = ptq_hooks.quant_forward_post_hook
hook_handle = layer.register_forward_post_hook(hook)
quant_config.hook_handle = hook_handle
layer._forward_post_hooks.move_to_end(
hook_handle._hook_id, last=False)

return model

def convert(self, model):
"""
Process the scales and remove the hooks.
Args:
model(paddle.nn.Layer): The model to be quantized.
Returns:
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer."

for name, sub_layer in model.named_sublayers():
if PTQRegistry.is_supported_layer(sub_layer) \
and utils.is_leaf_layer(sub_layer):

assert hasattr(sub_layer, "_quant_config")
quant_config = sub_layer._quant_config
quant_config.hook_handle.remove()

quant_config.in_act_quantizer.cal_thresholds()
quant_config.out_act_quantizer.cal_thresholds()

# get weight thresholds
if isinstance(sub_layer, tuple(utils.fake_quant_input_layers)):
weights = (sub_layer.weight, )
quant_config.wt_quantizer.sample_data(sub_layer, weights)

# TODO (jc):
# save input activation threshold and quant bits

return model
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import six
import abc
import copy

import paddle

from .ptq_quantizer import *

__all__ = ['PTQConfig', 'default_ptq_config']


class PTQConfig(object):
"""
The PTQ config shows how to quantize the inputs and outputs.
"""

def __init__(self, activation_quantizer, weight_quantizer):
super(PTQConfig, self).__init__()

assert isinstance(activation_quantizer, BaseQuantizer)
assert isinstance(weight_quantizer, BaseQuantizer)

self.in_act_quantizer = copy.deepcopy(activation_quantizer)
self.out_act_quantizer = copy.deepcopy(activation_quantizer)
self.wt_quantizer = copy.deepcopy(weight_quantizer)

self.hook_handle = None


default_ptq_config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer())
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import math
import numpy as np
from . import ptq_config


def quant_forward_post_hook(layer, inputs, outputs):
"""
The forward_post_hook for PTQ.
"""
assert hasattr(layer, '_quant_config'), \
"The layer should have _quant_config attr"
layer._quant_config.in_act_quantizer.sample_data(layer, inputs)
layer._quant_config.out_act_quantizer.sample_data(layer, (outputs, ))
Loading

0 comments on commit 2b6fc10

Please sign in to comment.