Skip to content

Commit

Permalink
add default config and support more ops (PaddlePaddle#180)
Browse files Browse the repository at this point in the history
* add default config and support more ops

* remove debug code

* remove debug code

* fix details
  • Loading branch information
slf12 authored Mar 23, 2020
1 parent 52502c0 commit 188ec0d
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 63 deletions.
204 changes: 143 additions & 61 deletions paddleslim/quant/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,28 @@
import logging
import copy
import numpy as np
import math

import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core

#_logger = logging.basicConfig(level=logging.DEBUG)
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)

__all__ = ['quant_embedding']

default_config = {
_default_single_config = {
"quantize_type": "abs_max",
"quantize_bits": 8,
"dtype": "int8"
}
SUPPORT_OP_TYPES = ['lookup_table', 'fused_embedding_seq_pool', 'pyramid_hash']
SUPPORT_QUANTIZE_TYPES = ['abs_max']
SUPPORT_QUANTIZE_BITS = [8]
SUPPORT_DTYPE = ['int8']

support_quantize_types = ['abs_max']
support_quantize_bits = [8]
support_dtype = ['int8']
_default_config = {"quantize_op_types": SUPPORT_OP_TYPES, }


def _merge_config(old_config, new_config):
Expand All @@ -49,32 +53,47 @@ def _merge_config(old_config, new_config):
"""
old_config.update(new_config)
keys = old_config.keys()
assert 'params_name' in keys, "params_name must be set"

quantize_type = old_config['quantize_type']
assert isinstance(quantize_type, str), "quantize_type must be \
assert isinstance(old_config['quantize_op_types'], (str, list)), \
'quantize_op_types can only be str or list[str]'
if isinstance(old_config['quantize_op_types'], str):
old_config['quantize_op_types'] = [old_config['quantize_op_types']]
for op_type in old_config['quantize_op_types']:
assert op_type in SUPPORT_OP_TYPES, \
'{} is not supported, supported op types are {}'.format(
op_type, SUPPORT_OP_TYPES)
if op_type not in keys:
old_config[op_type] = _default_single_config
continue
else:
assert isinstance(old_config[op_type], dict), \
"op type {}'s config must be dict"
config_tmp = copy.deepcopy(_default_single_config)
config_tmp.update(old_config[op_type])
old_config[op_type] = config_tmp

quantize_type = old_config[op_type]['quantize_type']
assert isinstance(quantize_type, str), "quantize_type must be \
str"

assert quantize_type in support_quantize_types, " \
quantize_type {} is not supported, now supported quantize type \
are {}.".format(quantize_type, support_quantize_types)

quantize_bits = old_config['quantize_bits']
assert isinstance(quantize_bits, int), "quantize_bits must be int"
assert quantize_bits in support_quantize_bits, " quantize_bits {} \
is not supported, now supported quantize bits are \
{}. ".format(quantize_bits, support_quantize_bits)

dtype = old_config['dtype']
assert isinstance(dtype, str), "dtype must be str"
assert dtype in support_dtype, " dtype {} is not \
supported, now supported dtypes are {} \
".format(dtype, support_dtype)
if 'threshold' in keys:
assert isinstance(old_config['threshold'], (float, int)), "threshold \
must be number."

print("quant_embedding config {}".format(old_config))
assert quantize_type in SUPPORT_QUANTIZE_TYPES , "" \
"quantize_type {} is not supported, now supported quantize type" \
" are {}.".format(quantize_type, SUPPORT_QUANTIZE_TYPES)

quantize_bits = old_config[op_type]['quantize_bits']
assert isinstance(quantize_bits, int), "quantize_bits must be int"
assert quantize_bits in SUPPORT_QUANTIZE_BITS , " quantize_bits {}" \
" is not supported, now supported quantize bits are" \
" {}. ".format(quantize_bits, SUPPORT_QUANTIZE_BITS)

dtype = old_config[op_type]['dtype']
assert isinstance(dtype, str), "dtype must be str"
assert dtype in SUPPORT_DTYPE , " dtype {} is not "\
"supported, now supported dtypes are {} ".format(dtype, SUPPORT_DTYPE)
if 'threshold' in old_config[op_type].keys():
assert isinstance(old_config[op_type]['threshold'], (float, int)), \
"threshold must be number."

_logger.info("quant_embedding config {}".format(old_config))
return old_config


Expand All @@ -90,18 +109,6 @@ def _get_var_tensor(scope, var_name):
return np.array(scope.find_var(var_name).get_tensor())


def _clip_tensor(tensor_array, threshold):
"""
when 'threshold' is set, clip tensor by 'threshold' and '-threshold'
Args:
tensor_array(np.array): array to clip
config(dict): config dict
"""
tensor_array[tensor_array > threshold] = threshold
tensor_array[tensor_array < -threshold] = -threshold
return tensor_array


def _get_scale_var_name(var_name):
"""
get scale var name
Expand Down Expand Up @@ -139,7 +146,8 @@ def _clear_var(var_name, scope):
tensor._clear()


def _quant_embedding_abs_max(graph, scope, place, config):
def _quant_embedding_abs_max(graph, scope, place, config, var_name,
embedding_node):
"""
quantize embedding using abs_max
Expand Down Expand Up @@ -190,16 +198,20 @@ def _insert_dequant_abs_max_op(graph, scope, var_node, scale_node, config):
for node in output_ops:
graph.update_input_link(var_node, dequant_var_node, node)

all_var_nodes = graph.all_var_nodes()
var_name = config['params_name']
# find embedding var node by 'params_name'
embedding_node = graph._find_node_by_name(all_var_nodes, var_name)
embedding_tensor = _get_var_tensor(scope, var_name)
if 'threshold' in config.keys():
embedding_tensor = _clip_tensor(embedding_tensor, config['threshold'])
def _clip_array(array, config):
if 'threshold' in config.keys():
threshold = config['threshold']
else:
abs_array = np.max(np.abs(array))
if abs_array < 1.0:
return array
threshold = np.percentile(np.abs(array), 99.99)
return np.clip(array, -threshold, threshold)

embedding_tensor = _get_var_tensor(scope, var_name)
embedding_array = _clip_array(embedding_tensor, config)
# get scale and quanted tensor
scale, quanted_tensor = _quant_abs_max(embedding_tensor, config)
scale, quanted_tensor = _quant_abs_max(embedding_array, config)

#create params must to use create_persistable_node
scale_var = graph.create_persistable_node(
Expand All @@ -221,18 +233,70 @@ def _insert_dequant_abs_max_op(graph, scope, var_node, scale_node, config):

# insert dequantize_abs_max op
for op_node in embedding_node.outputs:
if op_node.name() == 'lookup_table':
graph.update_input_link(embedding_node, quant_tensor_var, op_node)
var_node = op_node.outputs[0]
_insert_dequant_abs_max_op(graph, scope, var_node, scale_var,
config)
graph.update_input_link(embedding_node, quant_tensor_var, op_node)
out_name = op_node.output('Out')[0]
var_node = graph._find_node_by_name(op_node.outputs, out_name)
_insert_dequant_abs_max_op(graph, scope, var_node, scale_var, config)

# free float embedding params memory
_clear_var(embedding_node.name(), scope)
graph.safe_remove_nodes(embedding_node)


def quant_embedding(program, place, config, scope=None):
def _remove_link(in_node, out_node):
in_node.remove_output(out_node)
out_node.remove_input(in_node)


def _split_embedding_seq_pool(graph, op):
inputs = op.inputs
outputs = op.outputs
op_desc = op.node.op()
combiner = op_desc.attr("combiner")
padding_idx = op_desc.attr("padding_idx")
is_sparse = op_desc.attr("is_sparse")
ids = graph._find_node_by_name(inputs, op.input('Ids')[0])
weight = graph._find_node_by_name(inputs, op.input('W')[0])
out = outputs[0]
lookup_out = graph.create_var_node(
name=ids.name() + '.look_up_table.out',
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=weight.dtype())
lookup_table_op = graph.create_op_node(
op_type='lookup_table',
attrs={'is_sparse': is_sparse,
'padding_idx': padding_idx},
inputs={'W': weight,
'Ids': ids},
outputs={'Out': lookup_out})
_remove_link(ids, op)
_remove_link(weight, op)
_remove_link(op, out)
graph.link_to(ids, lookup_table_op)
graph.link_to(weight, lookup_table_op)
graph.link_to(lookup_table_op, lookup_out)
max_index = graph.create_var_node(
name=ids.name() + '.seq_pool_op.max_index',
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=weight.dtype())

seq_pool_op = graph.create_op_node(
op_type='sequence_pool',
inputs={'X': lookup_out},
outputs={'Out': out,
'MaxIndex': max_index},
attrs={'pooltype': combiner.upper(),
'is_test': True})
if combiner == 'max':
max_index.stop_gradient = True
graph.link_to(lookup_out, seq_pool_op)
graph.link_to(seq_pool_op, out)
graph.link_to(seq_pool_op, max_index)


def quant_embedding(program, place, config=None, scope=None):
"""quantize lookup_table op parameters
Args:
Expand All @@ -241,7 +305,6 @@ def quant_embedding(program, place, config, scope=None):
place(fluid.CPUPlace or fluid.CUDAPlace): This parameter represents the executor run on which device.
config(dict): config to quantize. The keys are 'params_name', 'quantize_type', \
'quantize_bits', 'dtype', 'threshold'. \
``params_name`` is parameter name to quantize, must be set.
``quantize_type`` is quantize type, supported types are ['abs_max'], default is "abs_max".
``quantize_bits`` supported bits are [8] and default is 8.
``dtype`` is quantize dtype, supported dtype are ['int8'], default is 'int8'.
Expand All @@ -251,12 +314,31 @@ def quant_embedding(program, place, config, scope=None):
Returns:
None
"""
assert isinstance(config, dict), "config must be dict"
config = _merge_config(copy.deepcopy(default_config), config)
config = config or {}
config = _merge_config(copy.deepcopy(_default_config), config)
scope = fluid.global_scope() if scope is None else scope

graph = IrGraph(core.Graph(program.desc), for_test=True)
if config['quantize_type'] == 'abs_max':
_quant_embedding_abs_max(graph, scope, place, config)
quantize_params_map = {}
all_op = graph.all_op_nodes()
for op in all_op:
if op.inputs == [] and op.outputs == []:
continue
op_type = op.name()
if op_type in config['quantize_op_types']:
weight_name = op.input('W')[0]
if weight_name in quantize_params_map.values():
continue
embedding_node = graph._find_node_by_name(op.inputs,
op.input('W')[0])
for op_node in embedding_node.outputs:
if op_node.name() == 'fused_embedding_seq_pool':
_split_embedding_seq_pool(graph, op_node)
_quant_embedding_abs_max(graph, scope, place, \
config[op_type], weight_name, embedding_node)
quantize_params_map[weight_name] = _get_quant_var_name(weight_name)
for op in all_op:
if op.name() == 'fused_embedding_seq_pool':
graph.safe_remove_nodes(op)

return graph.to_program()
3 changes: 1 addition & 2 deletions tests/test_quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def test_quant_embedding(self):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())

config = {'params_name': 'emb', 'quantize_type': 'abs_max'}
quant_program = quant.quant_embedding(infer_program, place, config)
quant_program = quant.quant_embedding(infer_program, place)


if __name__ == '__main__':
Expand Down

0 comments on commit 188ec0d

Please sign in to comment.